diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index dbc466f036ee..ec927173cd1b 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,7 +11,7 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, const std::optional& in_bias, - const int64_t CuCount); + const int64_t CuCount, const bool fast_skinny_gemm); void wvSplitKQ(const at::Tensor& in_a, const at::Tensor& in_b, const std::optional& in_bias, at::Tensor& out_c, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 15ebcc776ad7..8d390f605809 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -386,12 +386,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // YTILE represents how many column of weight matrix // are being worked on by each wave. //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) + for (int i = 0; i < YTILE; i++) { + for (int n = 0; n < N; n++) { if constexpr (!use_mfma) sum[n][i] = 0; else sum4[n][i] = {0, 0, 0, 0}; + } + } bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -412,7 +414,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { // Fetch the weight matrix from memory! #pragma unroll @@ -452,16 +453,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t n = 0; n < N; n++) { #pragma unroll for (int y = 0; y < YTILE; y++) { - if constexpr (!use_mfma) + if constexpr (!use_mfma) { #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) } - else + } else { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 4; b++) + for (uint32_t b = 0; b < A_CHUNK / 4; b++) { sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); + } + } } } } @@ -514,9 +517,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { #pragma unroll for (int y = 0; y < YTILE; y++) { - // float accm1 = 0; - // for (int i=0; i<64; i++) - // accm1 += __shfl(sum4[n][y][i%4], i); float accm = sum4[n][y][0]; asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm) @@ -684,12 +684,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // YTILE represents how many column of weight matrix // are being worked on by each wave. //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) + for (int i = 0; i < YTILE; i++) { + for (int n = 0; n < N; n++) { if constexpr (!use_mfma) sum[n][i] = 0; else sum4[n][i] = {0, 0, 0, 0}; + } + } bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -752,16 +754,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (int y = 0; y < YTILE; y++) { - if constexpr (!use_mfma) + if constexpr (!use_mfma) { #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) } - else + } else { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 4; b++) + for (uint32_t b = 0; b < A_CHUNK / 4; b++) { sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); + } + } } } } @@ -816,10 +820,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { #pragma unroll for (int y = 0; y < YTILE; y++) { - // float accm1 = 0; - // for (int i=0; i<64; i++) - // accm1 += __shfl(sum4[n][y][i%4], i); - float accm = sum4[n][y][0]; asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " : "=v"(accm) @@ -1021,12 +1021,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // YTILE represents how many column of weight matrix // are being worked on by each wave. //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) + for (int i = 0; i < YTILE; i++) { + for (int n = 0; n < N; n++) { if constexpr (!use_mfma) sum[n][i] = 0; else sum4[n][i] = {0, 0, 0, 0}; + } + } bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -1112,16 +1114,18 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (int y = 0; y < YTILE; y++) { - if constexpr (!use_mfma) + if constexpr (!use_mfma) { #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) } - else + } else { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 4; b++) + for (uint32_t b = 0; b < A_CHUNK / 4; b++) { sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); + } + } } } } @@ -1261,8 +1265,9 @@ int mindiv(int N, int div1, int div2) { rnds[i] = (N + nPrRnd - 1) / nPrRnd; nPrRnd -= div1; } - for (int i = 12; i >= 0; i--) + for (int i = 12; i >= 0; i--) { if (rnds[0] == rnds[i]) return (div2 - i); + } return 0; } @@ -1366,21 +1371,19 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, } // This version targets cases skinny where CUs are not filled -// Wave-SplitK is used with reduction done via atomics. +// Wave-SplitK is used with reduction done via deterministic store-then-reduce. #if defined(__gfx950__) #define WVSPLITKRC_1KPASS template __global__ void __launch_bounds__(WvPrGrp* THRDS) __attribute__((amdgpu_waves_per_eu(1, 1))) - wvSplitKrc_(const int actlN, const int K, const int M, const int Bx, - const int By, const scalar_t* __restrict__ B, - const scalar_t* __restrict__ A, - const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C, - const int CuCount) { - // Use upper half of glbl buffer for atomic reduce counting - int* cntr = (int*)(&glbl[M * N]); - + wvSplitKrc_deterministic_(const int actlN, const int K, const int Kap, + const int M, const int Bx, const int By, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + const scalar_t* __restrict__ BIAS, float* glbl, + scalar_t* C, const int CuCount) { constexpr int NTILE = 16; constexpr int APAD = 1; constexpr int ASTRD = 64; @@ -1412,7 +1415,10 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) unsigned int* myStg = (unsigned int*)(&stg[WVLDS * (threadIdx.y / GrpsShrB)]); __shared__ scalar_t s[max_lds_len - WvPrGrp * WVLDS / GrpsShrB]; - #ifndef WVSPLITKRC_1KPASS + #ifdef WVSPLITKRC_1KPASS + int constexpr kFit = 512 / CHUNKK; + int constexpr kfitsPerRdc = 1; + #else constexpr int TUC_ = (THRDS * UNRL * A_CHUNK); // find biggest k size that fits padded into LDS constexpr uint32_t kFit__ = (max_lds_len - WvPrGrp * WVLDS / GrpsShrB) / N; @@ -1422,7 +1428,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // find best k split to fill the CUs if (((K + kfitsPerRdc * kFit - 1) / (kfitsPerRdc * kFit)) * numCuWithFullK <= - CuCount) + CuCount) { while (true) { while (kFit > TUC_) { uint32_t kFit_ = kFit - TUC_; @@ -1439,9 +1445,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) else break; } - #else - int constexpr kFit = 512 / CHUNKK; - int constexpr kfitsPerRdc = 1; + } #endif bool doRdc = true; // Assuming (kfitsPerRdc * kFit < K) is always true @@ -1458,6 +1462,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc; const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc); + int* cntr = (int*)(&glbl[M * N * k_rnd]); + scalar8 sum4[N / NTILE / GrpsShrB][1] = {0}; bigType bigB_[YTILE / GrpsShrB / CHUNKK][UNRL]; const uint32_t bLoader = (threadIdx.y % GrpsShrB); @@ -1471,8 +1477,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #ifdef WVSPLITKRC_1KPASS // Early glbl init, B[] loading, if 1KPASS if constexpr (FAST_UNSAFE_RDC_INIT) { - if (m + (threadIdx.x % 16) < M) - if (doRdc) + if (m + (threadIdx.x % 16) < M) { + if (doRdc) { if (k_str == 0) { int mindx = m + (threadIdx.x % 16); int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + @@ -1490,6 +1496,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } + } + } } // Load first B[] chunk @@ -1499,22 +1507,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK; const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; #pragma unroll - for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) + for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) { bigB_[y / CHUNKK][k2].h8 = (loadnt( (scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB + bLoader + m, M - 1) * K]))); + } } { #else while (m < Mmod) { - #endif - - #ifndef WVSPLITKRC_1KPASS if constexpr (FAST_UNSAFE_RDC_INIT) { - if (m + (threadIdx.x % 16) < M) - if (doRdc) + if (m + (threadIdx.x % 16) < M) { + if (doRdc) { if (k_str == 0) { int mindx = m + (threadIdx.x % 16); int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + @@ -1532,6 +1538,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } + } + } } #endif @@ -1571,11 +1579,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff); for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) { __builtin_amdgcn_global_load_lds( - (int*)(&A[min__( - K * actlN - A_CHUNK, - kOffcp + K * (n / CHUNKK + - (N / CHUNKK) * (threadIdx.x / (64 / CHUNKK)) + - (threadIdx.y % sprdN)))]), + (int*)(&A[min__(Kap * actlN - A_CHUNK, + kOffcp + Kap * (n / CHUNKK + + (N / CHUNKK) * (threadIdx.x / + (64 / CHUNKK)) + + (threadIdx.y % sprdN)))]), (int*)(&s[(k + kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]), 16, 0, 0); @@ -1604,21 +1612,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #ifndef WVSPLITKRC_1KPASS // Fire load of next B[] chunk... if ((k1 + THRDS * A_CHUNK * UNRL < k_end) && - (k1 + THRDS * A_CHUNK * UNRL < K)) + (k1 + THRDS * A_CHUNK * UNRL < K)) { #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + THRDS * A_CHUNK * UNRL + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; #pragma unroll - for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) + for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) { bigB_[y / CHUNKK][k2].h8 = (loadnt( (scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB + bLoader + m, M - 1) * K]))); + } } + } #endif // B[] staging is cooperative across GrpsShrB, so sync here before reading @@ -1643,7 +1653,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str; #pragma unroll - for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE) + for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE) { #pragma unroll for (uint32_t n = 0; n < NTILE / CHUNKK; n++) { uint32_t idxa = @@ -1656,6 +1666,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k; bigA[nt / CHUNKK + n][k2] = *((const bigType*)(&(s[idxa]))); } + } } // Do the MFMAs @@ -1679,31 +1690,63 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } + union flt4 { + scalar8 s8; + float2 f2[2]; + float4 f4; + }; if (m + (threadIdx.x % 16) < M) { int my_cntr; int mindx = m + (threadIdx.x % 16); int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction scalar_t biases[N / NTILE / GrpsShrB][4] = {}; // Atomic add the output, read biases - for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) - for (uint32_t j = 0; j < 4; j++) { - // int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + - // (N / GrpsShrB) * (threadIdx.y % GrpsShrB); - // int adr = mindx + M * nindx; - int g_nindx = - j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; - int g_adr = g_mindx + M * g_nindx * 4; - atomicAdd(&glbl[g_adr], sum4[nt][0][j]); - } + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + int g_nindx = + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; + int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4; + flt4 flt4_ = {.s8 = sum4[nt][0]}; + __hip_atomic_store((float2*)&glbl[g_adr + M * N * (m0 / Mmod)], + flt4_.f2[0], __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + __hip_atomic_store((float2*)&glbl[g_adr + 2 + M * N * (m0 / Mmod)], + flt4_.f2[1], __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + + __atomic_signal_fence(__ATOMIC_SEQ_CST); + asm volatile("s_waitcnt vmcnt(0)" ::: "memory"); + __atomic_signal_fence(__ATOMIC_SEQ_CST); + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); int adr_ = mindx + M * nindx_ / 4; - // Update the complete counter my_cntr = atomicAdd(&cntr[adr_], 1); - float vals[N / NTILE / GrpsShrB][4] = {}; + + __syncthreads(); // make sure LDS is free for write out staging + + // Update the complete counter + flt4 vals[N / NTILE / GrpsShrB] = {}; // If we're the last k-shard, read back the value and convert... if (my_cntr + 1 == k_rnd) { - if (BIAS) + cntr[adr_] = 0; // clear for next round + + #pragma unroll + for (int ks = 0; ks < k_rnd; ks++) { + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + int g_nindx = + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; + int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4; + __builtin_amdgcn_global_load_lds( + (float4*)(&glbl[g_adr + M * N * ks]), + &(((float4*)s)[(threadIdx.y * THRDS) + ks * THRDS * 4 + + nt * THRDS * 4 * k_rnd]), + 16, 0, 0); + *(float4*)(&glbl[g_adr + M * N * ks]) = + {}; // clear out for next round + } + } + if (BIAS) { for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { for (uint32_t j = 0; j < 4; j++) { int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + @@ -1711,12 +1754,410 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx]; } } + } + asm volatile("s_waitcnt 0"); + for (int ks = 0; ks < k_rnd; ks++) { + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + float4 eval = ((float4*)s)[(threadIdx.x + threadIdx.y * THRDS) + + ks * THRDS * 4 + nt * THRDS * 4 * k_rnd]; + vals[nt].f4 += eval; + } + } + __builtin_amdgcn_sched_barrier(0); for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { for (uint32_t j = 0; j < 4; j++) { - int g_nindx = - j + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; - int g_adr = g_mindx + M * g_nindx * 4; - vals[nt][j] = glbl[g_adr]; + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + if (nindx < actlN) { + int adr = mindx + M * nindx; + if constexpr (std::is_same_v) { + vals[nt].s8[j] += __bfloat162float(biases[nt][j]); + C[adr] = __float2bfloat16(vals[nt].s8[j]); + } else { + vals[nt].s8[j] += __half2float(biases[nt][j]); + C[adr] = __float2half(vals[nt].s8[j]); + } + } + } + } + } + + #ifndef WVSPLITKRC_1KPASS + m0 += CuCount * WvPrGrp * YTILE / GrpsShrB; + m = (m0 + m1) % Mmod; + k_str = (m0 / Mmod) * kFit * kfitsPerRdc; + k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc; + if (k_str >= K) break; + kBase = 0; + #endif + } +} + +// Wave-SplitK with fast non-deterministic atomicAdd reduction. +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + __attribute__((amdgpu_waves_per_eu(1, 1))) + wvSplitKrc_fast_(const int actlN, const int K, const int Kap, const int M, + const int Bx, const int By, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + const scalar_t* __restrict__ BIAS, float* glbl, + scalar_t* C, const int CuCount) { + constexpr int NTILE = 16; + constexpr int APAD = 1; + constexpr int ASTRD = 64; + constexpr int BPAD = 1; + constexpr int WVLDS_ = THRDS * A_CHUNK / CHUNKK; + constexpr int WVLDS = ((WVLDS_ + A_CHUNK * BPAD)) * YTILE; + + constexpr int max_lds_len = LDS_SIZE / 2; + + using scalar16 = + __attribute__((__vector_size__((A_CHUNK * 2) * sizeof(float)))) float; + using scalar8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; + union bigType { + scalar_t h[A_CHUNK]; + float f[A_CHUNK / 2]; + unsigned int i[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + unsigned long l[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; + scalar8 h8; + }; + using big4 = __attribute__((__vector_size__(4 * sizeof(bigType)))) __bf16; + + __shared__ scalar_t stg[WvPrGrp * WVLDS / GrpsShrB]; + unsigned int* myStg = (unsigned int*)(&stg[WVLDS * (threadIdx.y / GrpsShrB)]); + __shared__ scalar_t s[max_lds_len - WvPrGrp * WVLDS / GrpsShrB]; + + #ifdef WVSPLITKRC_1KPASS + int constexpr kFit = 512 / CHUNKK; + int constexpr kfitsPerRdc = 1; + #else + constexpr int TUC_ = (THRDS * UNRL * A_CHUNK); + // find biggest k size that fits padded into LDS + constexpr uint32_t kFit__ = (max_lds_len - WvPrGrp * WVLDS / GrpsShrB) / N; + constexpr uint32_t kFit_ = (kFit__ * ASTRD) / (APAD + ASTRD); + uint32_t kFit = kFit_ - (kFit_ % TUC_); + uint32_t kfitsPerRdc = (K + kFit - 1) / kFit; + + // find best k split to fill the CUs + if (((K + kfitsPerRdc * kFit - 1) / (kfitsPerRdc * kFit)) * numCuWithFullK <= + CuCount) { + while (true) { + while (kFit > TUC_) { + uint32_t kFit_ = kFit - TUC_; + if (((K + (kfitsPerRdc * kFit_ - 1)) / (kfitsPerRdc * kFit_)) * + numCuWithFullK > + CuCount) + break; + kFit = kFit_; + } + if (((K + ((kfitsPerRdc - 1) * kFit - 1)) / ((kfitsPerRdc - 1) * kFit)) * + numCuWithFullK <= + CuCount) + kfitsPerRdc--; + else + break; + } + } + #endif + + bool doRdc = true; // Assuming (kfitsPerRdc * kFit < K) is always true + uint32_t numCuWithFullK = + ((M + (WvPrGrp * YTILE / GrpsShrB) - 1) / (WvPrGrp * YTILE / GrpsShrB)); + uint32_t Mmod = numCuWithFullK * (WvPrGrp * YTILE / GrpsShrB); + + // given above k-split, find this wave's position + uint32_t kFitPdd = kFit * CHUNKK + ((kFit * CHUNKK) / ASTRD) * APAD; + uint32_t m0 = (blockIdx.x * WvPrGrp / GrpsShrB) * YTILE; + uint32_t m1 = ((threadIdx.y % WvPrGrp) / GrpsShrB) * YTILE; + uint32_t m = (m0 + m1) % Mmod; + const uint32_t k_str = (m0 / Mmod) * kFit * kfitsPerRdc; + uint32_t k_end = (m0 / Mmod + 1) * kFit * kfitsPerRdc; + const uint32_t k_rnd = (K + kFit * kfitsPerRdc - 1) / (kFit * kfitsPerRdc); + + // Fast: counter lives right after a single M*N slab. + int* cntr = (int*)(&glbl[M * N]); + + scalar8 sum4[N / NTILE / GrpsShrB][1] = {0}; + bigType bigB_[YTILE / GrpsShrB / CHUNKK][UNRL]; + const uint32_t bLoader = (threadIdx.y % GrpsShrB); + uint32_t kBase = 0; + if (k_str >= K) return; + if (m >= Mmod) return; + + bool noreloada = false; + constexpr bool FAST_UNSAFE_RDC_INIT = false; + + #ifdef WVSPLITKRC_1KPASS + // Early glbl init, B[] loading, if 1KPASS + if constexpr (FAST_UNSAFE_RDC_INIT) { + if (m + (threadIdx.x % 16) < M) { + if (doRdc) { + if (k_str == 0) { + int mindx = m + (threadIdx.x % 16); + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr_ = mindx + M * nindx_ / 4; + __hip_atomic_store(&cntr[adr_], 0, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + __hip_atomic_store(&glbl[adr], 0, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + } + } + } + } + } + + // Load first B[] chunk + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k_str + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK; + const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; + #pragma unroll + for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) { + bigB_[y / CHUNKK][k2].h8 = (loadnt( + (scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB + + bLoader + m, + M - 1) * + K]))); + } + } + + { + #else + while (m < Mmod) { + if constexpr (FAST_UNSAFE_RDC_INIT) { + if (m + (threadIdx.x % 16) < M) { + if (doRdc) { + if (k_str == 0) { + int mindx = m + (threadIdx.x % 16); + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr_ = mindx + M * nindx_ / 4; + __hip_atomic_store(&cntr[adr_], 0, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr = mindx + M * nindx; + __hip_atomic_store(&glbl[adr], 0, __ATOMIC_RELAXED, + __HIP_MEMORY_SCOPE_AGENT); + } + } + } + } + } + } + #endif + + #ifndef WVSPLITKRC_1KPASS + for (uint32_t k1 = k_str; k1 < k_end; k1 += THRDS * A_CHUNK * UNRL) { + #else + const uint32_t k1 = k_str; + { + #endif + #ifndef WVSPLITKRC_1KPASS + const bool reloada = (!noreloada) && + ((k1 == k_str) || (k1 == k_str + kBase + kFit)) && + (k1 < k_end); + // load next chunk of A[] to LDS + if (reloada) { + if (k1 != k_str) kBase += kFit; + __syncthreads(); + #else + const bool reloada = (!noreloada) && + ((k1 == k_str) || (k1 == k_str + kBase + kFit)) && + (k1 < k_end); + if (reloada) { + #endif + constexpr int sprdN = 4; + const uint32_t thrd = threadIdx.x % (THRDS / CHUNKK); + + #ifndef WVSPLITKRC_1KPASS + #pragma unroll + for (int k = 0; k < kFit; + k += (THRDS * (WvPrGrp / sprdN) * A_CHUNK) / CHUNKK) { + #else + const unsigned int k = 0; + { + #endif + unsigned int kOff = k + (thrd * A_CHUNK); + unsigned int kOffcp = min__(K - A_CHUNK, k_str + kOff); + for (unsigned int n = 0; n < N; n += CHUNKK * sprdN) { + __builtin_amdgcn_global_load_lds( + (int*)(&A[min__(Kap * actlN - A_CHUNK, + kOffcp + Kap * (n / CHUNKK + + (N / CHUNKK) * (threadIdx.x / + (64 / CHUNKK)) + + (threadIdx.y % sprdN)))]), + (int*)(&s[(k + + kFitPdd * ((n / CHUNKK) + (threadIdx.y % sprdN)))]), + 16, 0, 0); + } + + // Stage loaded B[] to LDS for MFMA swizzling... + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + (threadIdx.x % (THRDS / CHUNKK)) * A_CHUNK; + const bool oob_k = (k_ >= K); + for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) { + uint32_t idx = + (threadIdx.x % (THRDS / CHUNKK)) * 4 + + ((y + threadIdx.x / (THRDS / CHUNKK)) * GrpsShrB + bLoader) * + ((THRDS / CHUNKK + BPAD) * 4); + // zero out if oob + *((scalar8*)&myStg[idx]) = + (oob_k) // TODO: ever necessary (y*GrpsShrB+bLoader+m>=M) ? + ? 0 + : bigB_[y / CHUNKK][k2].h8; + } + } + } + } + } + #ifndef WVSPLITKRC_1KPASS + // Fire load of next B[] chunk... + if ((k1 + THRDS * A_CHUNK * UNRL < k_end) && + (k1 + THRDS * A_CHUNK * UNRL < K)) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + THRDS * A_CHUNK * UNRL + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; + #pragma unroll + for (uint32_t y = 0; y < YTILE / GrpsShrB; y += CHUNKK) { + bigB_[y / CHUNKK][k2].h8 = (loadnt( + (scalar8*)(&B_[min__((y + threadIdx.x / (THRDS / CHUNKK)) * + GrpsShrB + + bLoader + m, + M - 1) * + K]))); + } + } + } + #endif + + // B[] staging is cooperative across GrpsShrB, so sync here before reading + // back. This wait is currently inserted by compiler, but not gauranteed. + asm volatile("s_waitcnt 0"); + __syncthreads(); + + // read back B[] swizzled for MFMA... + bigType bigB[YTILE / CHUNKK][UNRL]; + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + for (uint32_t y = 0; y < YTILE / CHUNKK; y++) { + unsigned int idx = + (threadIdx.x % YTILE) * ((THRDS / CHUNKK + BPAD) * 4) + + (threadIdx.x / YTILE) * 4 + y * 16; + bigB[y][k2].h8 = *((scalar8*)&myStg[idx]); + } + } + + // rReadback A[] swizzled for MFMA... + bigType bigA[N / GrpsShrB / CHUNKK][UNRL]; + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK - kBase - k_str; + #pragma unroll + for (uint32_t nt = 0; nt < N / GrpsShrB; nt += NTILE) { + #pragma unroll + for (uint32_t n = 0; n < NTILE / CHUNKK; n++) { + uint32_t idxa = + ((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) % (N / CHUNKK) + + (threadIdx.x % NTILE)) * + kFitPdd + + ((nt + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / + (N / CHUNKK)) * + A_CHUNK * (64 / CHUNKK) + + A_CHUNK * ((threadIdx.x / NTILE) + n * 4) + k; + bigA[nt / CHUNKK + n][k2] = *((const bigType*)(&(s[idxa]))); + } + } + } + + // Do the MFMAs + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + #pragma unroll + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + #pragma unroll + for (uint32_t j = 0; j < YTILE / CHUNKK; j++) { + if constexpr (std::is_same_v) { + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_f16( + bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8, + sum4[nt][0], 0, 0, 0); + } else { // bf16 + sum4[nt][0] = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + bigA[nt * (YTILE / CHUNKK) + j][k2].h8, bigB[j][k2].h8, + sum4[nt][0], 0, 0, 0); + } + } + } + } + } + + union flt4 { + scalar8 s8; + float2 f2[2]; + float4 f4; + }; + + if (m + (threadIdx.x % 16) < M) { + int my_cntr; + int mindx = m + (threadIdx.x % 16); + int g_mindx = m * 4 + (threadIdx.x % 64); // coalesced atomic reduction + scalar_t biases[N / NTILE / GrpsShrB][4] = {}; + // Atomic add the output, read biases + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + int g_nindx = + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; + int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4; + for (uint32_t j = 0; j < 4; j++) + atomicAdd((&glbl[g_adr + j]), sum4[nt][0][j]); + } + + __atomic_signal_fence(__ATOMIC_SEQ_CST); + asm volatile("s_waitcnt vmcnt(0)" ::: "memory"); + __atomic_signal_fence(__ATOMIC_SEQ_CST); + + int nindx_ = (0 + (threadIdx.x / 16) * 4) + 0 * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + int adr_ = mindx + M * nindx_ / 4; + my_cntr = atomicAdd(&cntr[adr_], 1); + + // Update the complete counter + flt4 vals[N / NTILE / GrpsShrB] = {}; + // If we're the last k-shard, read back the value and convert... + if (my_cntr + 1 == k_rnd) { + cntr[adr_] = 0; // clear for next round + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + int g_nindx = + (nt * NTILE + (N / GrpsShrB) * (threadIdx.y % GrpsShrB)) / 4; + int g_adr = g_mindx * 4 + 0 + M * g_nindx * 4; + vals[nt].f4 = *(float4*)(&glbl[g_adr]); + *(float4*)(&glbl[g_adr]) = {}; // clear out for next round + } + if (BIAS) { + for (uint32_t nt = 0; nt < N / NTILE / GrpsShrB; nt++) { + for (uint32_t j = 0; j < 4; j++) { + int nindx = (j + (threadIdx.x / 16) * 4) + nt * NTILE + + (N / GrpsShrB) * (threadIdx.y % GrpsShrB); + biases[nt][j] = BIAS[(mindx % Bx) + (nindx % By) * Bx]; + } } } __builtin_amdgcn_sched_barrier(0); @@ -1727,11 +2168,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (nindx < actlN) { int adr = mindx + M * nindx; if constexpr (std::is_same_v) { - vals[nt][j] += __bfloat162float(biases[nt][j]); - C[adr] = __float2bfloat16(vals[nt][j]); + vals[nt].s8[j] += __bfloat162float(biases[nt][j]); + C[adr] = __float2bfloat16(vals[nt].s8[j]); } else { - vals[nt][j] += __half2float(biases[nt][j]); - C[adr] = __float2half(vals[nt][j]); + vals[nt].s8[j] += __half2float(biases[nt][j]); + C[adr] = __float2half(vals[nt].s8[j]); } } } @@ -1748,23 +2189,36 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #endif } } + #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitKrc_(const int actlN, const int K, const int M, - const int Bx, const int By, const scalar_t* B, - const scalar_t* __restrict__ A, - const scalar_t* __restrict__ BIAS, float* glbl, - // int* cntr, - scalar_t* C, const int CuCount){UNREACHABLE_CODE} -#endif // defined(__HIP__GFX9__) TODO: Add NAVI support +__global__ void wvSplitKrc_deterministic_( + const int actlN, const int K, const int Kap, const int M, const int Bx, + const int By, const scalar_t* B, const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, float* glbl, scalar_t* C, + const int CuCount) { + UNREACHABLE_CODE +} +template +__global__ void wvSplitKrc_fast_(const int actlN, const int K, const int Kap, + const int M, const int Bx, const int By, + const scalar_t* B, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ BIAS, float* glbl, + scalar_t* C, + const int CuCount){UNREACHABLE_CODE} +#endif // defined(__gfx950__) torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, const std::optional& in_bias, - const int64_t CuCount) { - auto M_in = in_a.size(0); - auto N_in = in_b.size(0); - auto K_in = in_a.size(1); + const int64_t CuCount, const bool fast_skinny_gemm) { + auto M_in = in_b.size(0); + auto N_in = in_a.size(0); + auto K_in = in_b.size(1); + auto Kap_in = in_a.stride(0); + auto Bx_in = (in_bias.has_value() && in_bias->numel() > 0) ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) @@ -1781,13 +2235,9 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, auto out_c = torch::empty( {N_in, M_in}, - torch::TensorOptions().dtype(in_b.dtype()).device(in_b.device())); + torch::TensorOptions().dtype(in_a.dtype()).device(in_a.device())); auto N_p2 = 1U << (32 - __builtin_clz(N_in - 1)); - auto axl_glbl = torch::empty( - {N_p2 + N_p2 / 4, M_in + M_in / 4}, - torch::TensorOptions().dtype(torch::kFloat32).device(in_b.device())); - axl_glbl.zero_(); // disable for FAST_UNSAFE_RDC_INIT dim3 grid(CuCount); @@ -1795,61 +2245,103 @@ torch::Tensor wvSplitKrc(const at::Tensor& in_a, const at::Tensor& in_b, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // const int max_lds_len = get_lds_size() / 2; -#define WVSPLITKrc(_N, _GrpsShrB, _CHUNKK) \ + // With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile), + // and each working on a 512-shard of K, how many CUs would we need? + int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512); + + // How many of 4 waves in a group can work on same 16 Ms at same time? First + // try to maximize this. This reduces the Ms each group works on, i.e. + // increasing the number of CUs needed. + int GrpsShrB = min(N_p2 / 16, 4); + + // Given the above, how many CUs would we need? + int CuNeeded = rndup_cus * GrpsShrB; + + if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size"); + + // Can we increase SplitK by shrinking the K-shared to 256? + int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1; + + static torch::Tensor axl_glbl = + torch::zeros( + 8 * 1024 * 1024, + torch::TensorOptions().dtype(torch::kFloat32).device(in_a.device())) + .detach(); + auto glbl = axl_glbl.data_ptr(); + +#define WVSPLITKrc_DET(_N, _GrpsShrB, _CHUNKK) \ { \ dim3 block(64, 4); \ - wvSplitKrc_ \ - <<>>(N_in, K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, glbl, c, CuCount); \ + wvSplitKrc_deterministic_ \ + <<>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \ + af4, bf4, biasf4, glbl, c, CuCount); \ + } + +#define WVSPLITKrc_FAST(_N, _GrpsShrB, _CHUNKK) \ + { \ + dim3 block(64, 4); \ + wvSplitKrc_fast_ \ + <<>>(N_in, K_in, Kap_in, M_in, Bx_in, By_in, \ + af4, bf4, biasf4, glbl, c, CuCount); \ } - AT_DISPATCH_REDUCED_FLOATING_TYPES(in_b.scalar_type(), "wvSplitKrc", [&] { + AT_DISPATCH_REDUCED_FLOATING_TYPES(in_a.scalar_type(), "wvSplitKrc", [&] { using fptype = typename scalar::type; - fptype* af4 = reinterpret_cast(in_a.data_ptr()); + const fptype* af4 = reinterpret_cast(in_a.data_ptr()); const fptype* bf4 = reinterpret_cast(in_b.data_ptr()); const fptype* biasf4 = (in_bias.has_value() && in_bias->numel() > 0) ? reinterpret_cast(in_bias->data_ptr()) : nullptr; fptype* c = reinterpret_cast(out_c.data_ptr()); - auto glbl = axl_glbl.data_ptr(); - - // With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile), - // and each working on a 512-shard of K, how many CUs would we need? - int rndup_cus = ((M_in + 64 - 1) / 64) * ((K_in + 512 - 1) / 512); - - // How many of 4 waves in a group can work on same 16 Ms at same time? First - // try to maximize this. This reduces the Ms each group works on, i.e. - // increasing the number of CUs needed. - int GrpsShrB = min(N_p2 / 16, 4); - - // Given the above, how many CUs would we need? - int CuNeeded = rndup_cus * GrpsShrB; - - if (CuNeeded > CuCount) std::runtime_error("Invalid wvSplitKrc size"); - - // Can we increase SplitK by shrinking the K-shared to 256? - int chunkk = (CuNeeded * 2 <= CuCount) ? 2 : 1; - - switch (N_p2) { - case 16: - WVSPLITKrc(16, 1, 1) break; - case 32: - if (chunkk == 2) - WVSPLITKrc(32, 2, 2) else if (chunkk == 1) WVSPLITKrc(32, 2, 1) break; - case 64: - if (chunkk == 2) - WVSPLITKrc(64, 4, 2) else if (chunkk == 1) WVSPLITKrc(64, 4, 1) break; - case 128: - if (chunkk == 2) - WVSPLITKrc(128, 4, 2) else if (chunkk == 1) - WVSPLITKrc(128, 4, 1) break; - default: - throw std::runtime_error( - "Unsupported N value: " + std::to_string(M_in) + "," + - std::to_string(K_in) + "," + std::to_string(N_in)); + + if (fast_skinny_gemm) { + switch (N_p2) { + case 16: + WVSPLITKrc_FAST(16, 1, 1) break; + case 32: + if (chunkk == 2) + WVSPLITKrc_FAST(32, 2, 2) else if (chunkk == 1) + WVSPLITKrc_FAST(32, 2, 1) break; + case 64: + if (chunkk == 2) + WVSPLITKrc_FAST(64, 4, 2) else if (chunkk == 1) + WVSPLITKrc_FAST(64, 4, 1) break; + case 128: + if (chunkk == 2) + WVSPLITKrc_FAST(128, 4, 2) else if (chunkk == 1) + WVSPLITKrc_FAST(128, 4, 1) break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } + } else { + switch (N_p2) { + case 16: + WVSPLITKrc_DET(16, 1, 1) break; + case 32: + if (chunkk == 2) + WVSPLITKrc_DET(32, 2, 2) else if (chunkk == 1) + WVSPLITKrc_DET(32, 2, 1) break; + case 64: + if (chunkk == 2) + WVSPLITKrc_DET(64, 4, 2) else if (chunkk == 1) + WVSPLITKrc_DET(64, 4, 1) break; + case 128: + if (chunkk == 2) + WVSPLITKrc_DET(128, 4, 2) else if (chunkk == 1) + WVSPLITKrc_DET(128, 4, 1) break; + default: + throw std::runtime_error( + "Unsupported N value: " + std::to_string(M_in) + "," + + std::to_string(K_in) + "," + std::to_string(N_in)); + } } }); + +#undef WVSPLITKrc_DET +#undef WVSPLITKrc_FAST return out_c; } @@ -1919,7 +2411,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } - // Fetch activation matrix from either just LDS or from both LDS / memory + // Fetch activation matrix from either just LDS or from both LDS / memory #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; @@ -1930,7 +2422,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } - // Do the matrix multiplication in interleaved manner + // Do the matrix multiplication in interleaved manner #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t n = 0; n < N; n++) { @@ -1963,12 +2455,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 0) { scalar_t biases[N][YTILE] = {}; - if (BIAS) + if (BIAS) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx]; } } + } for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. @@ -2064,7 +2557,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } - // Fetch activation matrix from either just LDS or from both LDS / memory + // Fetch activation matrix from either just LDS or from both LDS / memory #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; @@ -2078,7 +2571,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } - // Do the matrix multiplication in interleaved manner + // Do the matrix multiplication in interleaved manner #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { for (uint32_t n = 0; n < N; n++) { @@ -2111,12 +2604,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.x == 0) { scalar_t biases[N][YTILE] = {}; - if (BIAS) + if (BIAS) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx]; } } + } for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { if (y + m >= M) break; // To avoid mem access fault. diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index b0b44964c243..c6a4163525c1 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -28,7 +28,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // Custom gemm op for skinny matrix-matrix multiplication rocm_ops.def( - "wvSplitKrc(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount) -> " + "wvSplitKrc(Tensor in_a, Tensor in_b, Tensor? in_bias, int CuCount, " + "bool fast_skinny_gemm=False) -> " "Tensor"); rocm_ops.impl("wvSplitKrc", torch::kCUDA, &wvSplitKrc); diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 2564f182912f..db0c8c4de022 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import math import pytest @@ -12,267 +13,1254 @@ from vllm.utils.platform_utils import get_cu_count DTYPES = [torch.bfloat16, torch.float16] -BIAS_MODES = [0, 1, 2] -# Specific (N, K, M) combinations for targeted testing -NKM_FACTORS_LLMM1 = [ - # Small, medium, large cases - (1, 8, 16), - (1, 32, 64), + +TOLERANCE_ULP = 1 # ULPs for fp32-accumulation kernels +LLMM1_TOLERANCE_ULP = 5 # 5x wider for dtype-precision accumulation + +# LLMM1: N must be 1, M must be divisible by rows_per_block (4, 8). +# K values span small (128) to large (6144), all multiples of 8. +# M values are all divisible by lcm(4,8)=8. +NKM_LLMM1 = [ (1, 128, 256), (1, 512, 1024), (1, 2048, 4096), - # Edge cases with specific K sizes - (1, 6144, 1024), - (1, 8192, 2048), - # Very large case - (1, 4096, 8192), + (1, 6144, 1024), # wide K, moderate M + (1, 4096, 8192), # large M ] -NKM_FACTORS_WVSPLITK = [ - # Different batch sizes with key dimensions - (1, 16, 16), +# wvSplitK: N in {1..4}, K % 8 == 0. +# Includes small shapes (64x64) through large (16384x8192), +# plus tiny-M edge cases (M=8) that stress YTILE/commitColumn logic. +NKM_WVSPLITK = [ (1, 64, 64), (2, 256, 256), (3, 1024, 1024), (4, 4096, 4096), - # Extended K values - (1, 9216, 512), - (2, 10240, 1024), - (4, 16384, 8192), - # Minimum M constraint validation (m >= 8) - (1, 64, 8), - (2, 128, 8), - (4, 256, 8), -] - -N_FACTORS_WVSPLITKRC = [ - 13, - 16, - 17, - 25, - 29, - 31, - 32, - 41, - 51, - 64, - 71, - 81, - 91, - 103, - 117, - 128, + (1, 9216, 512), # wide K, narrow M + (4, 16384, 8192), # max batch x large dimensions + (1, 64, 8), # tiny M, single batch + (4, 256, 8), # tiny M, max batch ] -K_FACTORS_WVSPLITKRC = [2880, 2880 + 8, 3072, 3072 + 8] -M_FACTORS_WVSPLITKRC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16] +# wvSplitKrc: N values that round to valid N_p2 in {16,32,64,128}. +# K values include non-power-of-2 (2880) and aligned (3072), +# plus +8 offsets to test non-aligned K. +# M values include aligned and +16 offsets for boundary testing. +N_RC = [13, 16, 32, 64, 103, 128] +K_RC = [2880, 2880 + 8, 3072, 3072 + 8] +M_RC = [128, 128 + 16, 256, 256 + 16, 640, 640 + 16] -NKM_FACTORS_WVSPLITK_FP8 = [ - # FP8-specific cases with K % 16 == 0 - (1, 16, 16), - (1, 32, 16 + 16), +# FP8: N in {1..4}, K % 16 == 0. +# Spans from small (64) to very large (65552) K. +# +16 offsets test non-aligned-but-legal shapes. +NKM_FP8 = [ (1, 64, 64), - (1, 64, 64 + 16), - (1, 64 + 16, 64), (1, 64 + 16, 64 + 16), - (4, 64, 64), (4, 64, 64 + 16), - (4, 64 + 16, 64), - (4, 64 + 16, 64 + 16), - (2, 512, 512), - (3, 512, 512), (3, 512, 512 + 16), - (4, 512, 512), - (3, 2048, 2048), - (3, 2048, 2048 + 16), - (4, 2048 + 16, 2048), (4, 2048 + 16, 2048 + 16), (4, 4096, 4096), - (4, 16400, 2048), (4, 16400, 2048 + 16), - # Extended FP8 dimensions not covered by WVSPLITK (1, 14336, 1024), (2, 24576, 2048), (4, 32768, 28672), - (4, 32768 * 2, 28672), - (4, 32768 * 2, 28672 + 16), - (4, 32768 * 2 + 16, 28672), - (4, 32768 * 2 + 16, 28672 + 16), + (4, 32768 * 2 + 16, 28672 + 16), # K=65552, max stress +] + +# Model-realistic shapes matching real LLM architectures. +# (1, 4096, 32000) = Llama-7B hidden -> vocab single-token decode +# (4, 4096, 32000) = same with 4-token batch +# (1, 8192, 28672) = Llama-70B FFN intermediate +# (1, 4096, 14336) = Llama-7B FFN gate projection +# (1, 4096, 151936) = Qwen-7B hidden -> vocab (large vocabulary) +NKM_MODEL_REALISTIC = [ + (1, 4096, 32000), + (4, 4096, 32000), + (1, 8192, 28672), + (1, 4096, 14336), + (1, 4096, 151936), ] -SEEDS = [0] +DISTRIBUTIONS = ["normal", "mixed_scale", "sparse_activations"] def pad_fp8(weight): + """Pad FP8 tensor to 256-byte alignment, then remove padding. + + This simulates the memory layout after cudaMalloc alignment, + testing that kernels correctly handle stride != size(1). + """ num_pad = 256 // weight.element_size() import torch.nn.functional as F return F.pad(weight, (0, num_pad), "constant", 0)[..., :-num_pad] -@pytest.mark.parametrize("xnorm", [False, True]) -@pytest.mark.parametrize("n", N_FACTORS_WVSPLITKRC) -@pytest.mark.parametrize("k", K_FACTORS_WVSPLITKRC) -@pytest.mark.parametrize("m", M_FACTORS_WVSPLITKRC) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("bias_mode", BIAS_MODES) -@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") -@pytest.mark.skipif(not on_gfx950(), reason="only meant for gfx950") -def test_rocm_wvsplitkrc_kernel(xnorm, n, k, m, dtype, seed, bias_mode): - torch.manual_seed(seed) - cu_count = get_cu_count() +def _fits_rc(n, k, m, cu_count): + """Check if wvSplitKrc shape fits within available CUs. - # Next ^2 of n + The kernel assigns ceil(M/64) x ceil(K/512) workgroups, multiplied by + GrpsShrB (number of waves sharing B-tile loads within a group). + If this exceeds CuCount, the kernel cannot launch. + """ N_p2 = 1 << (n - 1).bit_length() - # With 64 Ms per CU (each of 4 SIMDs working on a 16x16 tile), - # and each working on a 512-shard of K, how many CUs would we need? rndup_cus = ((m + 64 - 1) // 64) * ((k + 512 - 1) // 512) - # How many of 4 waves in a group can work on same 16 Ms at same time? - # This reduces the Ms each group works on, i.e. increasing the number of CUs needed. GrpsShrB = min(N_p2 // 16, 4) - # Given the above, how many CUs would we need? - CuNeeded = rndup_cus * GrpsShrB - # candidate for atomic reduce count splitk? - fits_wvsplitkrc = CuNeeded <= cu_count + return rndup_cus * GrpsShrB <= cu_count - if not fits_wvsplitkrc: - pytest.skip("Too large for wvSplitKrc") - xavier = ( - math.sqrt(2 / k) if xnorm else 1 - ) # normalize to avoid large output-bias deltas - A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier - B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier +def _mixed_scale_multiplier(dtype): + """Maximum safe dynamic-range multiplier for mixed_scale distribution. - BIAS = None - if bias_mode == 1: - BIAS = torch.rand(m, dtype=dtype, device="cuda") * 2 - 1 - elif bias_mode == 2: - BIAS = torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1 + With fp32 accumulation, a partial sum S and a small addend x are + correctly accumulated as long as |x| > |S| * epsilon_fp32, which is + essentially always true. But we also need the INPUTS themselves + to be representable in dtype, and the OUTPUTS to not overflow. - ref_out = torch.nn.functional.linear(A, B, BIAS) - out = ops.wvSplitKrc(B, A.view(-1, A.size(-1)), cu_count, BIAS) + We limit the hot/cold ratio to 0.25/epsilon_dtype so that: + 1. Cold-channel contributions retain >= 2 significant bits when + added to hot-channel partial sums in fp32 and then rounded + to dtype output. + 2. The dynamic range stays within dtype representable range. + + Results: bf16 -> ~32x, fp16 -> 50x (capped for small-K stability). + """ + eps = torch.finfo(dtype).eps + return min(50.0, 0.25 / eps) + + +def _make_inputs(n, k, m, dtype, distribution, seed=0): + """Generate test matrices with controlled distributions. + + Returns A (NxK) and B (MxK) in the specified dtype on CUDA. + """ + torch.manual_seed(seed) + xavier = math.sqrt(2 / k) + + if distribution == "normal": + A = torch.randn(n, k, dtype=dtype, device="cuda") * xavier + B = torch.randn(m, k, dtype=dtype, device="cuda") * xavier + + elif distribution == "mixed_scale": + A = torch.randn(n, k, dtype=dtype, device="cuda") * xavier + B = torch.randn(m, k, dtype=dtype, device="cuda") * xavier + hot = torch.randperm(k, device="cuda")[: max(1, k // 10)] + B[:, hot] *= _mixed_scale_multiplier(dtype) + + elif distribution == "sparse_activations": + A = torch.randn(n, k, dtype=dtype, device="cuda").clamp(min=0) * xavier + B = torch.randn(m, k, dtype=dtype, device="cuda") * xavier + + elif distribution == "uniform": + A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier - if xnorm: - torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8) else: - torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-2) + raise ValueError(f"Unknown distribution: {distribution}") + return A, B + + +def _make_inputs_fp8(n, k, m, distribution, seed=0): + """Generate FP8-quantized inputs. + + Starts from fp32 to avoid double-quantization artifacts (generating + in bf16 then quantizing to fp8 would compound rounding errors). + Returns (A_fp8, B_fp8, scale_a, scale_b). + """ + A, B = _make_inputs(n, k, m, torch.float32, distribution, seed) + A_fp8, sa = ref_dynamic_per_tensor_fp8_quant(A) + B_fp8, sb = ref_dynamic_per_tensor_fp8_quant(B) + return A_fp8, B_fp8, sa, sb + + +def _make_bias(n, m, dtype, bias_mode, seed=42): + """Generate bias tensor. + + bias_mode=0: None (no bias) + bias_mode=1: 1-D bias (M,) - typical layer bias + bias_mode=2: 2-D bias (N,M) - per-batch bias (rare but supported) + """ + torch.manual_seed(seed) + if bias_mode == 0: + return None + elif bias_mode == 1: + return torch.rand(m, dtype=dtype, device="cuda") * 2 - 1 + elif bias_mode == 2: + return torch.rand(n, m, dtype=dtype, device="cuda") * 2 - 1 + raise ValueError(f"Unknown bias_mode: {bias_mode}") + + +def _gemm_tol(dtype): + """Tolerance for fp32-accumulation kernels (wvSplitK, wvSplitKrc). + + Both kernel and reference produce: + result = round_to_dtype(fp32_accumulated_sum) + + We use 1*epsilon (1 ULP), the theoretical minimum for fp32 + accumulation where kernel and reference follow the same + reduction order. + + Resulting tolerances: + bf16: rtol = atol ~= 7.81e-3 (practical: <= 0.8% relative error) + fp16: rtol = atol ~= 9.77e-4 (practical: <= 0.1% relative error) + """ + eps = torch.finfo(dtype).eps + tol = TOLERANCE_ULP * eps + return tol, tol + + +def _llmm1_tol(dtype): + """Tolerance for LLMM1 (dtype-precision accumulation). + + LLMM1 accumulates 8 products per thread in native dtype via + __hmul2/__hfma2, introducing O(4 * epsilon_dtype) rounding per group. + This compounds with output rounding to produce errors significantly + larger than the fp32-accumulation baseline. + + The 5x multiplier (5 ULP total) covers worst-case dtype-precision + accumulation while remaining tight enough to catch real precision + regressions. + + Resulting tolerances: + bf16: rtol = atol ~= 3.91e-2 (practical: up to 3.9% relative) + fp16: rtol = atol ~= 4.88e-3 (practical: up to 0.5% relative) + """ + eps = torch.finfo(dtype).eps + tol = LLMM1_TOLERANCE_ULP * eps + return tol, tol + + +def _fp8_gemm_tol(dtype, k): + """Tolerance for FP8 GEMM (wvSplitKQ). + + Two error sources combine: + + 1. Quantization noise: each fp8 input has relative error <= epsilon_fp8/2. + Over K products, the noise in the dot product grows as O(sqrt(K)) + under a random-noise model (central limit theorem on K + independent quantization errors). + + Empirical coefficient: 0.005 * sqrt(K) * epsilon_fp8, capped at 0.05 + to prevent false passes on very large K. + + 2. Output rounding: same as fp32-accumulation kernels, 1 * epsilon_dtype. + + We take the maximum of both components for atol, with rtol scaled + down by 10x (since relative tolerance only needs to handle the + multiplicative component of quantization noise, not the additive floor). + """ + eps = torch.finfo(dtype).eps + fp8_eps = 0.0625 # epsilon for fp8-e4m3 + quant_component = min(math.sqrt(k) * fp8_eps * 0.005, 0.05) + output_component = TOLERANCE_ULP * eps + atol = max(quant_component, output_component) + rtol = max(quant_component / 10, TOLERANCE_ULP * eps) + return atol, rtol + + +def _assert_accurate(out, ref, dtype, label="", atol_override=None, rtol_override=None): + """Assert GEMM accuracy with principled, dtype-derived tolerances. + + Per-element check: pass iff |diff| <= max(atol, rtol * |ref|). + The dual criterion ensures: + - Near-zero outputs judged by absolute error (avoids division-by-zero + in relative checks) + - Large outputs judged by relative error (scales with magnitude) + + Three-pronged failure criteria (ALL must hold for a pass): + 1. >=99.999% of elements pass per-element tolerance + (allows only extremely rare rounding edge cases) + 2. No element exceeds 3x its tolerance + (catches catastrophic single-element bugs like wrong index) + 3. Mean absolute error < atol * 0.25 + (catches systematic bias early, before it reaches full tolerance) + """ + atol, rtol = _gemm_tol(dtype) + if atol_override is not None: + atol = atol_override + if rtol_override is not None: + rtol = rtol_override + + diff = (out.float() - ref.float()).abs() + ref_abs = ref.float().abs() + + per_elem_tol = torch.maximum( + torch.full_like(diff, atol), + rtol * ref_abs, + ) + + element_ok = diff <= per_elem_tol + pass_rate = element_ok.float().mean().item() + max_violation_ratio = (diff / per_elem_tol.clamp(min=1e-30)).max().item() + mean_abs = diff.mean().item() + + passed = ( + pass_rate >= 0.99999 and max_violation_ratio <= 3.0 and mean_abs < atol * 0.25 + ) + + if not passed: + w = diff.argmax().item() + nc = out.shape[-1] if out.dim() > 1 else out.shape[0] + reasons = [] + if pass_rate < 0.99999: + reasons.append(f"pass_rate={pass_rate:.7f}<0.99999") + if max_violation_ratio > 3.0: + reasons.append(f"max_violation={max_violation_ratio:.1f}x>3x") + if mean_abs >= atol * 0.25: + reasons.append(f"mean_abs={mean_abs:.6e}>={atol * 0.25:.6e}") + + raise AssertionError( + f"Accuracy FAILED ({label})\n" + f" tolerances: atol={atol:.6e}, rtol={rtol:.6e}\n" + f" failures: {'; '.join(reasons)}\n" + f" max_abs={diff.max().item():.6e}, " + f"mean_abs={mean_abs:.6e}, " + f"pass_rate={pass_rate:.7f}\n" + f" worst [{w // nc},{w % nc}]: " + f"got={out.flatten()[w].item():.6f} " + f"ref={ref.flatten()[w].item():.6f} " + f"tol={per_elem_tol.flatten()[w].item():.6f}\n" + f" shape={list(out.shape)}, dtype={dtype}" + ) + +def _assert_deterministic(fn, num_runs=10, label=""): + """Assert bitwise-identical output across repeated runs. -@pytest.mark.parametrize("n,k,m", NKM_FACTORS_LLMM1) + Non-determinism in GEMM kernels propagates through softmax and + corrupts logprobs, causing token-selection instability during + generation. The deterministic kernel path must be bitwise exact + because downstream consumers (e.g., speculative decoding + verification) rely on reproducible logits. + """ + results = [fn() for _ in range(num_runs)] + for i in range(1, num_runs): + if not torch.equal(results[0], results[i]): + d = (results[0].float() - results[i].float()).abs() + ndiff = (d > 0).sum().item() + raise AssertionError( + f"Non-determinism ({label}): run 0 vs {i}, " + f"{ndiff}/{d.numel()} elements differ, " + f"max_diff={d.max().item():.6e}" + ) + + +def _set_rows(idx_tensor, n): + return [set(idx_tensor[r].tolist()) for r in range(n)] + + +def _assert_logprobs(out_logits, ref_logits, label="", top1_lp_max_diff=0.01): + """Assert logprobs derived from kernel output match reference. + + Checks three properties that directly impact generation quality: + + 1. Top-1 token agreement (>= 99%): + Determines greedy decoding output. A top-1 mismatch means the + kernel would produce different text than the reference. + We allow 1% disagreement because near-tied logits (where two + tokens have nearly equal probability) are inherently sensitive + to rounding - a 0.001 logit difference can flip the argmax. + + 2. Top-5 set overlap (>= 95%): + Determines sampling diversity. Even when top-1 agrees, if the + top-5 set is different, sampling-based generation explores + different token spaces. + + 3. Top-1 logprob accuracy (<= 0.01 nats): + Determines confidence calibration. A 0.01 nat error in the + top-1 logprob translates to ~1% multiplicative probability + error, acceptable for most applications. + """ + n, m = ref_logits.shape + ref_lp = torch.nn.functional.log_softmax(ref_logits.float(), dim=-1) + out_lp = torch.nn.functional.log_softmax(out_logits.float(), dim=-1) + + ref_t1 = ref_lp.argmax(-1) + out_t1 = out_lp.argmax(-1) + t1_rate = (ref_t1 == out_t1).float().mean().item() + + k5 = min(5, m) + ref_t5 = _set_rows(ref_lp.topk(k5, -1).indices, n) + out_t5 = _set_rows(out_lp.topk(k5, -1).indices, n) + t5_rate = sum(len(a & b) / k5 for a, b in zip(ref_t5, out_t5)) / n + + lp_diff = ( + (ref_lp.gather(1, ref_t1.unsqueeze(1)) - out_lp.gather(1, ref_t1.unsqueeze(1))) + .abs() + .max() + .item() + ) + + failures = [] + if t1_rate < 0.99: + failures.append(f"top1_match={t1_rate:.4f}<0.99") + if t5_rate < 0.95: + failures.append(f"top5_overlap={t5_rate:.4f}<0.95") + if lp_diff > top1_lp_max_diff: + failures.append(f"top1_lp_diff={lp_diff:.6e}>{top1_lp_max_diff}") + if failures: + raise AssertionError(f"Logprobs FAILED ({label}): {'; '.join(failures)}") + + +def _assert_logprobs_deterministic(fn, num_runs=10, label=""): + """Assert logprobs are bitwise identical across runs. + + Even if raw logits have sub-ULP non-deterministic differences, + softmax amplifies them near decision boundaries - a 1e-7 logit + jitter between two near-tied tokens can flip the argmax and + change generated text. + """ + lps = [ + torch.nn.functional.log_softmax(fn().float(), dim=-1) for _ in range(num_runs) + ] + for i in range(1, num_runs): + if not torch.equal(lps[0], lps[i]): + d = (lps[0] - lps[i]).abs() + flips = (lps[0].argmax(-1) != lps[i].argmax(-1)).sum().item() + raise AssertionError( + f"Logprobs non-determinism ({label}): run 0 vs {i}, " + f"{(d > 0).sum().item()} diffs, " + f"max={d.max().item():.6e}, " + f"top1_flips={flips}/{lps[0].shape[0]}" + ) + + +# Shape constraints enforced: +# - N=1 (all tuples have N=1) +# - M divisible by rows_per_block (all M in NKM_LLMM1 are divisible by 8) +# - K multiple of 8 (all K in NKM_LLMM1 satisfy this) +@pytest.mark.parametrize("n,k,m", NKM_LLMM1) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("rows_per_block", [2, 4, 8, 16]) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") +@pytest.mark.parametrize("rows_per_block", [4, 8]) +@pytest.mark.parametrize("distribution", DISTRIBUTIONS) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") @torch.inference_mode() -def test_rocm_llmm1_kernel(n, k, m, dtype, rows_per_block, seed): - torch.manual_seed(seed) - # TODO: Zero-centering the inputs causes errors for LLMM1! - # Without that the numbers quickly saturate, and may - # be giving false matches. - A = torch.rand(n, k, dtype=dtype, device="cuda") - B = torch.rand(m, k, dtype=dtype, device="cuda") +def test_llmm1_accuracy(n, k, m, dtype, rows_per_block, distribution): + assert n == 1, f"LLMM1 requires N=1, got {n}" + assert m % rows_per_block == 0, ( + f"M={m} not divisible by rows_per_block={rows_per_block}" + ) - ref_out = torch.matmul(A, B.t()) + A, B = _make_inputs(n, k, m, dtype, distribution) out = ops.LLMM1(B, A, rows_per_block) + ref = torch.matmul(A, B.t()) - torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) + atol, rtol = _llmm1_tol(dtype) + _assert_accurate( + out, + ref, + dtype, + atol_override=atol, + rtol_override=rtol, + label=f"LLMM1 {n}x{k}x{m} {dtype} rpb={rows_per_block} {distribution}", + ) -@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +@pytest.mark.parametrize("n,k,m", [(1, 2048, 4096), (1, 4096, 8192)]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") -def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): - torch.manual_seed(seed) - cu_count = get_cu_count() - - A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 - B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@torch.inference_mode() +def test_llmm1_determinism(n, k, m, dtype): + A, B = _make_inputs(n, k, m, dtype, "normal") + _assert_deterministic( + lambda: ops.LLMM1(B, A, 4), label=f"LLMM1 {n}x{k}x{m} {dtype}" + ) - ref_out = torch.nn.functional.linear(A, B) - out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count) - torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) +@pytest.mark.parametrize("n,k,m", [(1, 2048, 4096), (1, 4096, 8192)]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@torch.inference_mode() +def test_llmm1_logprobs(n, k, m, dtype): + A, B = _make_inputs(n, k, m, dtype, "normal") + _assert_logprobs( + ops.LLMM1(B, A, 4), + torch.matmul(A, B.t()), + label=f"LLMM1 lp {n}x{k}x{m} {dtype}", + ) -@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +# Shape constraints enforced: +# - N in {1,2,3,4} (all tuples satisfy this) +# - K % 8 == 0 (all K values satisfy this) +@pytest.mark.parametrize("n,k,m", NKM_WVSPLITK) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") -def test_rocm_wvsplitk_bias1D_kernel(n, k, m, dtype, seed): - torch.manual_seed(seed) - cu_count = get_cu_count() +@pytest.mark.parametrize("distribution", DISTRIBUTIONS) +@pytest.mark.parametrize("bias_mode", [0, 1, 2]) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +def test_wvsplitk_accuracy(n, k, m, dtype, distribution, bias_mode): + assert 1 <= n <= 4, f"wvSplitK requires N in {{1..4}}, got {n}" + assert k % 8 == 0, f"wvSplitK requires K%8==0, got K={k}" - xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas - A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier - B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier - BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 + A, B = _make_inputs(n, k, m, dtype, distribution) + BIAS = _make_bias(n, m, dtype, bias_mode) + cu = get_cu_count() + _assert_accurate( + ops.wvSplitK(B, A.view(-1, A.size(-1)), cu, BIAS), + torch.nn.functional.linear(A, B, BIAS), + dtype, + label=f"wvSplitK {n}x{k}x{m} {dtype} {distribution} b={bias_mode}", + ) - ref_out = torch.nn.functional.linear(A, B, BIAS) - out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) - torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) +@pytest.mark.parametrize("n,k,m", NKM_WVSPLITK) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +def test_wvsplitk_determinism(n, k, m, dtype): + A, B = _make_inputs(n, k, m, dtype, "normal") + cu = get_cu_count() + _assert_deterministic( + lambda: ops.wvSplitK(B, A.view(-1, A.size(-1)), cu), + label=f"wvSplitK {n}x{k}x{m} {dtype}", + ) -@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK) +@pytest.mark.parametrize("n,k,m", NKM_MODEL_REALISTIC) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), reason="only test for rocm") -def test_rocm_wvsplitk_bias2D_kernel(n, k, m, dtype, seed): - torch.manual_seed(seed) - cu_count = get_cu_count() +@pytest.mark.parametrize("distribution", DISTRIBUTIONS) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +def test_wvsplitk_model_realistic(n, k, m, dtype, distribution): + """Test with shapes from real LLM architectures. + + These shapes exercise the 'big' kernel path (A >> LDS capacity) + and stress the persistent-WG iteration over large M dimensions. + The vocabulary-sized M (32000, 151936) ensures the commitColumn + fragmentation logic at the tail is exercised. + """ + assert 1 <= n <= 4 + assert k % 8 == 0 - xavier = math.sqrt(2 / k) # normalize to avoid large output-bias deltas - A = (torch.rand(n, k, dtype=dtype, device="cuda") - 0.5) * xavier - B = (torch.rand(m, k, dtype=dtype, device="cuda") - 0.5) * xavier - BIAS = torch.rand(n, m, dtype=dtype, device="cuda") - 0.5 + A, B = _make_inputs(n, k, m, dtype, distribution) + cu = get_cu_count() + _assert_accurate( + ops.wvSplitK(B, A.view(-1, A.size(-1)), cu), + torch.nn.functional.linear(A, B), + dtype, + label=f"wvSplitK realistic {n}x{k}x{m} {dtype} {distribution}", + ) - ref_out = torch.nn.functional.linear(A, B, BIAS) - out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu_count, BIAS) - torch.testing.assert_close(out, ref_out, atol=1e-8, rtol=1e-2) +# Shape constraints enforced: +# - gfx950 only (skipif) +# - N -> N_p2 in {16, 32, 64, 128} (all N_RC values round to valid N_p2) +# - K % 8 == 0 (all K_RC values satisfy this) +# - CU fit checked via _fits_rc at runtime +# Full dimension sweep - normal distribution, both kernel paths +@pytest.mark.parametrize("fast_skinny_gemm", [False, True]) +@pytest.mark.parametrize("n", N_RC) +@pytest.mark.parametrize("k", K_RC) +@pytest.mark.parametrize("m", M_RC) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@pytest.mark.skipif(not on_gfx950(), reason="gfx950 only") +def test_wvsplitkrc_accuracy(fast_skinny_gemm, n, k, m, dtype): + cu = get_cu_count() + if not _fits_rc(n, k, m, cu): + pytest.skip("Shape exceeds CU capacity for wvSplitKrc") + A, B = _make_inputs(n, k, m, dtype, "normal") + _assert_accurate( + ops.wvSplitKrc(A, B, cu, None, fast_skinny_gemm=fast_skinny_gemm), + torch.nn.functional.linear(A, B), + dtype, + label=f"wvSplitKrc fast={fast_skinny_gemm} {n}x{k}x{m} {dtype}", + ) +# Feature coverage: xnorm, padding, bias @pytest.mark.parametrize("xnorm", [False, True]) -@pytest.mark.parametrize("n,k,m", NKM_FACTORS_WVSPLITK_FP8) +@pytest.mark.parametrize("padded_a", [False, True]) +@pytest.mark.parametrize("bias_mode", [0, 1, 2]) +@pytest.mark.parametrize("n", [13, 64, 128]) +@pytest.mark.parametrize("k", [2880, 3072]) +@pytest.mark.parametrize("m", [128, 640]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@pytest.mark.skipif(not on_gfx950(), reason="gfx950 only") +def test_wvsplitkrc_features(xnorm, padded_a, bias_mode, n, k, m, dtype): + """Test feature interactions: xnorm scaling, stride padding, bias modes. + + xnorm=True uses xavier scaling sqrt(2/K), keeping values small enough + that a tighter absolute tolerance (1e-3) is appropriate - this + verifies that the kernel doesn't introduce artifacts at small magnitudes. + """ + cu = get_cu_count() + if not _fits_rc(n, k, m, cu): + pytest.skip("Shape exceeds CU capacity") + xavier = math.sqrt(2 / k) if xnorm else 1 + torch.manual_seed(0) + A = (torch.rand(n, k, dtype=dtype, device="cuda") * 2 - 1) * xavier + B = (torch.rand(m, k, dtype=dtype, device="cuda") * 2 - 1) * xavier + if padded_a: + A = pad_fp8(A) + BIAS = _make_bias(n, m, dtype, bias_mode) + ref = torch.nn.functional.linear(A, B, BIAS) + out = ops.wvSplitKrc(A, B, cu, BIAS) + if xnorm: + # Xavier-scaled values are O(sqrt(2/K)) ~= 0.02-0.04 for typical K. + # Output magnitudes are O(1/sqrt(K)), making 1e-3 a ~1-3 ULP check. + torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-8) + else: + _assert_accurate( + out, + ref, + dtype, + label=f"wvSplitKrc xnorm={xnorm} pad={padded_a} " + f"bias={bias_mode} {n}x{k}x{m} {dtype}", + ) + + +# Distribution stress on representative subset +@pytest.mark.parametrize("distribution", DISTRIBUTIONS) +@pytest.mark.parametrize("fast_skinny_gemm", [False, True]) +@pytest.mark.parametrize("n", [32, 128]) +@pytest.mark.parametrize("k", [2880, 3072]) +@pytest.mark.parametrize("m", [128, 640]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@pytest.mark.skipif(not on_gfx950(), reason="gfx950 only") +def test_wvsplitkrc_distributions(distribution, fast_skinny_gemm, n, k, m, dtype): + cu = get_cu_count() + if not _fits_rc(n, k, m, cu): + pytest.skip("Shape exceeds CU capacity") + A, B = _make_inputs(n, k, m, dtype, distribution) + _assert_accurate( + ops.wvSplitKrc(A, B, cu, None, fast_skinny_gemm=fast_skinny_gemm), + torch.nn.functional.linear(A, B), + dtype, + label=f"wvSplitKrc {distribution} fast={fast_skinny_gemm} {n}x{k}x{m} {dtype}", + ) + + +# Determinism - deterministic path only (fast_skinny_gemm=False) +# The fast path uses atomicAdd which is non-deterministic by design. +@pytest.mark.parametrize("n", N_RC) +@pytest.mark.parametrize("k", [2880, 3072]) +@pytest.mark.parametrize("m", [128, 256, 640]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@pytest.mark.skipif(not on_gfx950(), reason="gfx950 only") +def test_wvsplitkrc_determinism(n, k, m, dtype): + cu = get_cu_count() + if not _fits_rc(n, k, m, cu): + pytest.skip("Shape exceeds CU capacity") + A, B = _make_inputs(n, k, m, dtype, "normal") + _assert_deterministic( + lambda: ops.wvSplitKrc(A, B, cu, None, fast_skinny_gemm=False), + label=f"wvSplitKrc {n}x{k}x{m} {dtype}", + ) + + +# Shape constraints enforced: +# - MI3XX + fp8 support (skipif) +# - N in {1..4} (all NKM_FP8 tuples satisfy this) +# - K % 16 == 0 (all K values in NKM_FP8 satisfy this) +@pytest.mark.parametrize("n,k,m", NKM_FP8) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("padded_a", [False, True]) -@pytest.mark.parametrize("padded_b", [False, True]) @pytest.mark.parametrize("biased", [False, True]) @pytest.mark.skipif( not (current_platform.is_rocm() and current_platform.supports_fp8()), - reason="only test for rocm fp8", + reason="rocm fp8 only", ) -def test_rocm_wvsplitk_fp8_kernel( - xnorm, n, k, m, dtype, seed, padded_a, padded_b, biased -): - torch.manual_seed(seed) +def test_fp8_accuracy(n, k, m, dtype, padded_a, biased): + assert 1 <= n <= 4, f"wvSplitKQ requires N in {{1..4}}, got {n}" + assert k % 16 == 0, f"wvSplitKQ requires K%16==0, got K={k}" - xavier = math.sqrt(2 / k) if xnorm else 1 # normalize to avoid large deltas - A = (torch.rand(n, k, device="cuda") * 2 - 1) * xavier - B = (torch.rand(m, k, device="cuda") * 2 - 1) * xavier - - A, scale_a = ref_dynamic_per_tensor_fp8_quant(A) - B, scale_b = ref_dynamic_per_tensor_fp8_quant(B) - if padded_b: - B = pad_fp8(B) + A, B, sa, sb = _make_inputs_fp8(n, k, m, "normal") if padded_a: A = pad_fp8(A) + BIAS = None if not biased else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1) + ref = torch._scaled_mm(A, B.t(), out_dtype=dtype, scale_a=sa, scale_b=sb, bias=BIAS) + out = ops.wvSplitKQ(B, A, dtype, sa, sb, get_cu_count(), BIAS) + atol, rtol = _fp8_gemm_tol(dtype, k) + _assert_accurate( + out, + ref, + dtype, + atol_override=atol, + rtol_override=rtol, + label=f"FP8 {n}x{k}x{m} {dtype} pad={padded_a} bias={biased}", + ) - BIAS = None if (not biased) else (torch.rand(m, dtype=dtype, device="cuda") * 2 - 1) - ref_out = torch._scaled_mm( - A, B.t(), out_dtype=dtype, scale_a=scale_a, scale_b=scale_b, bias=BIAS +@pytest.mark.parametrize( + "n,k,m", [(3, 512, 512 + 16), (4, 4096, 4096), (2, 24576, 2048), (4, 32768, 28672)] +) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("distribution", ["normal", "sparse_activations"]) +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="rocm fp8 only", +) +def test_fp8_distributions(n, k, m, dtype, distribution): + """Test FP8 with non-normal distributions. + + mixed_scale is excluded because fp8 quantization with per-tensor + scaling cannot resolve the dynamic range - hot channels saturate + fp8 range while cold channels quantize to zero, making accuracy + comparison meaningless. sparse_activations is meaningful because + the zero structure is preserved through quantization. + """ + A, B, sa, sb = _make_inputs_fp8(n, k, m, distribution) + ref = torch._scaled_mm(A, B.t(), out_dtype=dtype, scale_a=sa, scale_b=sb) + out = ops.wvSplitKQ(B, A, dtype, sa, sb, get_cu_count()) + atol, rtol = _fp8_gemm_tol(dtype, k) + _assert_accurate( + out, + ref, + dtype, + atol_override=atol, + rtol_override=rtol, + label=f"FP8 dist {n}x{k}x{m} {dtype} {distribution}", ) - out = ops.wvSplitKQ(B, A, dtype, scale_a, scale_b, get_cu_count(), BIAS) - if xnorm: - torch.testing.assert_close(out, ref_out, atol=1e-3, rtol=1e-8) - elif k >= 32 * 1024: - # wider pytrch thresh for large-K & no xnorm - torch.testing.assert_close(out, ref_out, atol=0.07, rtol=5e-2) - else: - torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + +@pytest.mark.parametrize( + "n,k,m", [(4, 4096, 4096), (2, 24576, 2048), (4, 32768, 28672)] +) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="rocm fp8 only", +) +def test_fp8_determinism(n, k, m, dtype): + A, B, sa, sb = _make_inputs_fp8(n, k, m, "normal") + _assert_deterministic( + lambda: ops.wvSplitKQ(B, A, dtype, sa, sb, get_cu_count()), + label=f"FP8 {n}x{k}x{m} {dtype}", + ) + + +@pytest.mark.parametrize( + "n,k,m", [(4, 4096, 4096), (4, 16400, 2048 + 16), (1, 14336, 1024)] +) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="rocm fp8 only", +) +def test_fp8_padded_b(n, k, m, dtype): + """Test FP8 with padded B tensor (stride(0) != size(1)). + + The wvSplitKQ kernel reads B using Kbp = in_b.stride(0), which + differs from K when the tensor has alignment padding. This test + verifies the kernel uses the stride, not the logical size. + """ + A, B, sa, sb = _make_inputs_fp8(n, k, m, "normal") + B = pad_fp8(B) + ref = torch._scaled_mm(A, B.t(), out_dtype=dtype, scale_a=sa, scale_b=sb) + out = ops.wvSplitKQ(B, A, dtype, sa, sb, get_cu_count()) + atol, rtol = _fp8_gemm_tol(dtype, k) + _assert_accurate( + out, + ref, + dtype, + atol_override=atol, + rtol_override=rtol, + label=f"FP8 padB {n}x{k}x{m} {dtype}", + ) + + +# These tests verify that GEMM accuracy is sufficient for the downstream +# task that matters most: correct token selection in autoregressive +# generation. A GEMM kernel can pass element-wise tolerance checks +# while still producing wrong tokens if errors cluster near decision +# boundaries in the softmax output. +@pytest.mark.parametrize( + "n,k,m", [(1, 4096, 32000), (4, 4096, 32000), (1, 8192, 28672), (1, 4096, 151936)] +) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("distribution", DISTRIBUTIONS) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +def test_wvsplitk_logprobs(n, k, m, dtype, distribution): + A, B = _make_inputs(n, k, m, dtype, distribution) + cu = get_cu_count() + _assert_logprobs( + ops.wvSplitK(B, A.view(-1, A.size(-1)), cu), + torch.nn.functional.linear(A, B), + label=f"wvSplitK lp {n}x{k}x{m} {dtype} {distribution}", + ) + + +@pytest.mark.parametrize( + "n,k,m", [(1, 4096, 32000), (4, 4096, 32000), (1, 4096, 151936)] +) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +def test_wvsplitk_logprobs_determinism(n, k, m, dtype): + A, B = _make_inputs(n, k, m, dtype, "normal") + cu = get_cu_count() + _assert_logprobs_deterministic( + lambda: ops.wvSplitK(B, A.view(-1, A.size(-1)), cu), + label=f"wvSplitK lp det {n}x{k}x{m} {dtype}", + ) + + +@pytest.mark.parametrize("fast_skinny_gemm", [False, True]) +@pytest.mark.parametrize("n,k,m", [(16, 2880, 128), (64, 3072, 256), (128, 2880, 640)]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("distribution", DISTRIBUTIONS) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@pytest.mark.skipif(not on_gfx950(), reason="gfx950 only") +def test_wvsplitkrc_logprobs(fast_skinny_gemm, n, k, m, dtype, distribution): + cu = get_cu_count() + if not _fits_rc(n, k, m, cu): + pytest.skip("Shape exceeds CU capacity") + A, B = _make_inputs(n, k, m, dtype, distribution) + _assert_logprobs( + ops.wvSplitKrc(A, B, cu, None, fast_skinny_gemm=fast_skinny_gemm), + torch.nn.functional.linear(A, B), + label=f"wvSplitKrc lp fast={fast_skinny_gemm} {n}x{k}x{m} " + f"{dtype} {distribution}", + ) + + +@pytest.mark.parametrize("n,k,m", [(16, 2880, 128), (128, 2880, 640)]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@pytest.mark.skipif(not on_gfx950(), reason="gfx950 only") +def test_wvsplitkrc_logprobs_determinism(n, k, m, dtype): + cu = get_cu_count() + if not _fits_rc(n, k, m, cu): + pytest.skip("Shape exceeds CU capacity") + A, B = _make_inputs(n, k, m, dtype, "normal") + _assert_logprobs_deterministic( + lambda: ops.wvSplitKrc(A, B, cu, None, fast_skinny_gemm=False), + label=f"wvSplitKrc lp det {n}x{k}x{m} {dtype}", + ) + + +@pytest.mark.parametrize("n,k,m", [(4, 4096, 4096), (1, 14336, 1024), (2, 24576, 2048)]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="rocm fp8 only", +) +def test_fp8_logprobs(n, k, m, dtype): + """FP8 logprob check with relaxed top-1 logprob threshold. + + FP8 quantization noise is larger than dtype rounding noise, so + we allow 0.02 nats of logprob difference (vs 0.01 for bf16/fp16). + This corresponds to ~2% multiplicative probability error. + """ + A, B, sa, sb = _make_inputs_fp8(n, k, m, "normal") + ref = torch._scaled_mm(A, B.t(), out_dtype=dtype, scale_a=sa, scale_b=sb) + out = ops.wvSplitKQ(B, A, dtype, sa, sb, get_cu_count()) + _assert_logprobs( + out, ref, label=f"FP8 lp {n}x{k}x{m} {dtype}", top1_lp_max_diff=0.02 + ) + + +@pytest.mark.parametrize("n,k,m", [(4, 4096, 4096), (2, 24576, 2048)]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="rocm fp8 only", +) +def test_fp8_logprobs_determinism(n, k, m, dtype): + A, B, sa, sb = _make_inputs_fp8(n, k, m, "normal") + _assert_logprobs_deterministic( + lambda: ops.wvSplitKQ(B, A, dtype, sa, sb, get_cu_count()), + label=f"FP8 lp det {n}x{k}x{m} {dtype}", + ) + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@torch.inference_mode() +def test_wvsplitk_nan_propagation(dtype): + """NaN must propagate to the affected output row, not leak to others. + + Silent NaN absorption masks upstream bugs (e.g., uninitialized + memory, division by zero in prior layers). NaN leaking across + rows indicates an indexing bug in the kernel. + """ + n, k, m = 2, 256, 256 + torch.manual_seed(0) + A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 + B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 + A[0, 0] = float("nan") + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), get_cu_count()) + assert out[0].isnan().any(), "NaN lost in wvSplitK row 0" + assert not out[1].isnan().any(), "NaN leaked in wvSplitK row 1" + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@torch.inference_mode() +def test_llmm1_nan_propagation(dtype): + """NaN must propagate to the output. + + LLMM1 requires N=1, so we cannot test row isolation -- only that + NaN is not silently dropped by the half-precision FMA chain or + the warp shuffle reduction. + """ + n, k, m = 1, 256, 256 + torch.manual_seed(0) + A = torch.rand(n, k, dtype=dtype, device="cuda") + B = torch.rand(m, k, dtype=dtype, device="cuda") + A[0, 0] = float("nan") + out = ops.LLMM1(B, A, 4) + assert out[0].isnan().any(), "NaN lost in LLMM1" + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +@pytest.mark.skipif(not on_gfx950(), reason="gfx950 only") +def test_wvsplitkrc_nan_propagation(dtype): + """NaN row isolation for wvSplitKrc. + + The K-split-and-reduce architecture (store partial sums to global, + then sum) could mask NaN if the reduction uses non-NaN-propagating + operations. We verify both propagation and isolation. + """ + cu = get_cu_count() + n, k, m = 16, 2880, 128 + if not _fits_rc(n, k, m, cu): + pytest.skip("Shape exceeds CU capacity") + torch.manual_seed(0) + A = torch.rand(n, k, dtype=dtype, device="cuda") - 0.5 + B = torch.rand(m, k, dtype=dtype, device="cuda") - 0.5 + A[0, 0] = float("nan") + out = ops.wvSplitKrc(A, B, cu, None) + assert out[0].isnan().any(), "NaN lost in wvSplitKrc row 0" + assert not out[1].isnan().any(), "NaN leaked in wvSplitKrc row 1" + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +def test_zero_and_bias(dtype): + """Zero inputs must produce exactly zero; with bias, exactly the bias. + + This catches: + 1. Uninitialized accumulators (would produce non-zero from garbage) + 2. Incorrect bias indexing (would produce wrong values) + 3. Spurious accumulation artifacts (e.g., from LDS residuals) + + The rtol=0, atol=0 check is exact - zero x anything is exactly + zero in IEEE 754, with no tolerance needed. + """ + cu = get_cu_count() + n, k, m = 2, 256, 256 + A = torch.zeros(n, k, dtype=dtype, device="cuda") + B = torch.zeros(m, k, dtype=dtype, device="cuda") + out = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu) + assert torch.all(out == 0), f"0x0 produced non-zero: max={out.abs().max().item()}" + BIAS = torch.rand(m, dtype=dtype, device="cuda") - 0.5 + out_b = ops.wvSplitK(B, A.view(-1, A.size(-1)), cu, BIAS) + torch.testing.assert_close(out_b, BIAS.unsqueeze(0).expand(n, -1), atol=0, rtol=0) + + +@pytest.mark.parametrize("num_runs", [10]) +@pytest.mark.parametrize("enforce_eager", [True, False]) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +def test_e2e_logprob_reproducibility(num_runs, enforce_eager, vllm_runner): + """End-to-end logprob reproducibility across repeated inference runs. + + With enforce_eager=True (no CUDA graphs), results must be bitwise identical. + With enforce_eager=False (CUDA graphs enabled), we allow a small tolerance + (1e-6) to account for non-determinism introduced by graph capture/replay. + """ + # When CUDA graphs are enabled, allow tolerance up to 1e-6. + # When eager, require exact bitwise reproducibility. + atol = 1e-6 if not enforce_eager else 0.0 + + model = "TitanML/tiny-mixtral" + prompts = [ + "The capital of France is", + "In quantum computing, a qubit", + "def fibonacci(n):\n", + ] + max_tokens = 32 + top_logprobs = 5 + + all_runs = [] + for run_idx in range(num_runs): + with vllm_runner( + model, + dtype="half", + enforce_eager=enforce_eager, + max_num_seqs=1, + seed=0, + enable_prefix_caching=False, + ) as llm: + outputs = llm.generate_greedy_logprobs(prompts, max_tokens, top_logprobs) + all_runs.append(outputs) + + mode_label = "eager" if enforce_eager else "cuda-graph" + + # Per-run comparison tables and summary collection + summary_rows = [] + for run_idx in range(1, num_runs): + print(f"\n{'=' * 80}") + print(f" Run 0 vs Run {run_idx} (mode={mode_label}, atol={atol:.1e})") + print(f"{'=' * 80}") + + for prompt_idx, prompt in enumerate(prompts): + ref_token_ids, ref_text, ref_lps = all_runs[0][prompt_idx] + cur_token_ids, cur_text, cur_lps = all_runs[run_idx][prompt_idx] + + token_match = sum( + r == c for r, c in zip(ref_token_ids, cur_token_ids) + ) / len(ref_token_ids) + + # Compute logprob stats across all positions + lp_diffs = [] + top5_mismatches = 0 + for pos in range(len(ref_lps)): + ref_top = ref_lps[pos] + cur_top = cur_lps[pos] + if set(ref_top.keys()) != set(cur_top.keys()): + top5_mismatches += 1 + for token_id in set(ref_top.keys()) & set(cur_top.keys()): + d = abs(ref_top[token_id].logprob - cur_top[token_id].logprob) + lp_diffs.append(d) + + max_lp_diff = max(lp_diffs) if lp_diffs else 0.0 + mean_lp_diff = sum(lp_diffs) / len(lp_diffs) if lp_diffs else 0.0 + within_tol = all(d <= atol for d in lp_diffs) + + print(f'\n Prompt {prompt_idx}: "{prompt[:40]}..."') + print(f" {'─' * 60}") + print( + f" {'Token match:':<25} {token_match:>8.2%} " + f"({sum(r == c for r, c in zip(ref_token_ids, cur_token_ids))}" + f"/{len(ref_token_ids)})" + ) + print( + f" {'Top-5 set mismatches:':<25} {top5_mismatches:>8d} " + f"/ {len(ref_lps)} positions" + ) + print(f" {'Max logprob diff:':<25} {max_lp_diff:>12.2e}") + print(f" {'Mean logprob diff:':<25} {mean_lp_diff:>12.2e}") + print(f" {'Within tolerance:':<25} {'YES' if within_tol else 'NO':>8}") + + summary_rows.append( + { + "run": f"0 vs {run_idx}", + "prompt": prompt_idx, + "token_match": token_match, + "top5_mismatch": top5_mismatches, + "max_lp_diff": max_lp_diff, + "mean_lp_diff": mean_lp_diff, + "within_tol": within_tol, + } + ) + + # Assertions + assert ref_token_ids == cur_token_ids, ( + f"[{mode_label}] Token mismatch run 0 vs {run_idx}, " + f"prompt {prompt_idx}: " + f"ref={ref_token_ids[:10]}... " + f"cur={cur_token_ids[:10]}..." + ) + + for pos in range(len(ref_lps)): + ref_top = ref_lps[pos] + cur_top = cur_lps[pos] + ref_ids = set(ref_top.keys()) + cur_ids = set(cur_top.keys()) + + assert ref_ids == cur_ids, ( + f"[{mode_label}] Top-{top_logprobs} set mismatch at pos {pos}, " + f"run 0 vs {run_idx}, prompt {prompt_idx}: " + f"ref={ref_ids} cur={cur_ids}" + ) + + for token_id in ref_ids: + ref_val = ref_top[token_id].logprob + cur_val = cur_top[token_id].logprob + diff = abs(ref_val - cur_val) + assert diff <= atol, ( + f"[{mode_label}] Logprob mismatch at pos {pos}, " + f"token {token_id}, " + f"run 0 vs {run_idx}, prompt {prompt_idx}: " + f"ref={ref_val:.10f} cur={cur_val:.10f} " + f"diff={diff:.2e} > atol={atol:.1e}" + ) + + # Summary table + print(f"\n{'=' * 80}") + print( + f" REPRODUCIBILITY SUMMARY " + f"({num_runs} runs, {len(prompts)} prompts, " + f"mode={mode_label}, atol={atol:.1e})" + ) + print(f"{'=' * 80}") + print( + f" {'Comparison':<10} {'Prompt':<7} {'Tok Match':>10} " + f"{'Top5 Miss':>10} {'Max LP Diff':>12} " + f"{'Mean LP Diff':>13} {'In Tol':>8}" + ) + print(f" {'─' * 72}") + for row in summary_rows: + print( + f" {row['run']:<10} {row['prompt']:<7} " + f"{row['token_match']:>9.2%} " + f"{row['top5_mismatch']:>10d} " + f"{row['max_lp_diff']:>12.2e} " + f"{row['mean_lp_diff']:>13.2e} " + f"{'YES' if row['within_tol'] else 'NO':>8}" + ) + + all_within = all(r["within_tol"] for r in summary_rows) + all_token = all(r["token_match"] == 1.0 for r in summary_rows) + worst_lp = max(r["max_lp_diff"] for r in summary_rows) + print(f" {'─' * 72}") + print(f" All within tolerance: {'YES' if all_within else 'NO'}") + print(f" All tokens match: {'YES' if all_token else 'NO'}") + print(f" Worst logprob diff: {worst_lp:.2e}") + print(f"{'=' * 80}\n") + + +@pytest.mark.parametrize("num_runs", [10]) +@pytest.mark.skipif(not current_platform.is_rocm(), reason="rocm only") +def test_e2e_logprob_stability(num_runs, vllm_runner): + """Softer e2e check: logprobs within 0.001 nats where tokens agree. + + Same locked-down settings as the strict test. If the strict test + fails but this passes, the non-determinism source is small enough + to only affect near-tied logits. If both fail, something + fundamental is wrong beyond argmax sensitivity. + """ + model = "TitanML/tiny-mixtral" + prompts = [ + "The capital of France is", + "In quantum computing, a qubit", + "def fibonacci(n):\n", + ] + max_tokens = 32 + top_logprobs = 5 + + all_runs = [] + for run_idx in range(num_runs): + with vllm_runner( + model, + dtype="half", + enforce_eager=True, + max_num_seqs=1, + seed=0, + enable_prefix_caching=False, + ) as llm: + outputs = llm.generate_greedy_logprobs(prompts, max_tokens, top_logprobs) + all_runs.append(outputs) + + # Per-run comparison tables and summary collection + summary_rows = [] + + for run_idx in range(1, num_runs): + print(f"\n{'=' * 80}") + print(f" Run 0 vs Run {run_idx}") + print(f"{'=' * 80}") + + for prompt_idx, prompt in enumerate(prompts): + ref_token_ids, ref_text, ref_lps = all_runs[0][prompt_idx] + cur_token_ids, cur_text, cur_lps = all_runs[run_idx][prompt_idx] + + token_match = sum( + r == c for r, c in zip(ref_token_ids, cur_token_ids) + ) / len(ref_token_ids) + + # Compute logprob stats only at agreeing positions + lp_diffs_agree = [] + positions_checked = 0 + positions_skipped = 0 + for pos in range(min(len(ref_lps), len(cur_lps))): + if ( + pos < len(ref_token_ids) + and pos < len(cur_token_ids) + and ref_token_ids[pos] != cur_token_ids[pos] + ): + positions_skipped += 1 + continue + positions_checked += 1 + ref_top1_id = max(ref_lps[pos], key=lambda k: ref_lps[pos][k].logprob) + if ref_top1_id in cur_lps[pos]: + d = abs( + ref_lps[pos][ref_top1_id].logprob + - cur_lps[pos][ref_top1_id].logprob + ) + lp_diffs_agree.append(d) + + max_lp_diff = max(lp_diffs_agree) if lp_diffs_agree else 0.0 + mean_lp_diff = ( + sum(lp_diffs_agree) / len(lp_diffs_agree) if lp_diffs_agree else 0.0 + ) + + print(f'\n Prompt {prompt_idx}: "{prompt[:40]}..."') + print(f" {'─' * 60}") + print( + f" {'Token match:':<25} {token_match:>8.2%} " + f"({sum(r == c for r, c in zip(ref_token_ids, cur_token_ids))}" + f"/{len(ref_token_ids)})" + ) + print(f" {'Positions checked:':<25} {positions_checked:>8d}") + print(f" {'Positions skipped:':<25} {positions_skipped:>8d}") + print(f" {'Max top-1 LP diff:':<25} {max_lp_diff:>12.2e}") + print(f" {'Mean top-1 LP diff:':<25} {mean_lp_diff:>12.2e}") + + summary_rows.append( + { + "run": f"0 vs {run_idx}", + "prompt": prompt_idx, + "token_match": token_match, + "pos_checked": positions_checked, + "pos_skipped": positions_skipped, + "max_lp_diff": max_lp_diff, + "mean_lp_diff": mean_lp_diff, + } + ) + + # Assertions + assert token_match >= 0.70, ( + f"Token match rate {token_match:.2%} < 70%, " + f"run 0 vs {run_idx}, prompt {prompt_idx}" + ) + + for pos in range(min(len(ref_lps), len(cur_lps))): + if ( + pos < len(ref_token_ids) + and pos < len(cur_token_ids) + and ref_token_ids[pos] != cur_token_ids[pos] + ): + continue + ref_top1_id = max(ref_lps[pos], key=lambda k: ref_lps[pos][k].logprob) + if ref_top1_id in cur_lps[pos]: + diff = abs( + ref_lps[pos][ref_top1_id].logprob + - cur_lps[pos][ref_top1_id].logprob + ) + assert diff < 0.001, ( + f"Logprob drift {diff:.6f} >= 0.001 at pos {pos}, " + f"run 0 vs {run_idx}, prompt {prompt_idx}" + ) + + # Summary table + print(f"\n{'=' * 80}") + print(f" STABILITY SUMMARY ({num_runs} runs, {len(prompts)} prompts)") + print(f"{'=' * 80}") + print( + f" {'Comparison':<10} {'Prompt':<7} {'Tok Match':>10} " + f"{'Checked':>8} {'Skipped':>8} " + f"{'Max LP Diff':>12} {'Mean LP Diff':>13}" + ) + print(f" {'─' * 72}") + for row in summary_rows: + print( + f" {row['run']:<10} {row['prompt']:<7} " + f"{row['token_match']:>9.2%} " + f"{row['pos_checked']:>8d} " + f"{row['pos_skipped']:>8d} " + f"{row['max_lp_diff']:>12.2e} " + f"{row['mean_lp_diff']:>13.2e}" + ) + + worst_match = min(r["token_match"] for r in summary_rows) + worst_lp = max(r["max_lp_diff"] for r in summary_rows) + total_skipped = sum(r["pos_skipped"] for r in summary_rows) + total_checked = sum(r["pos_checked"] for r in summary_rows) + print(f" {'─' * 72}") + print(f" Worst token match: {worst_match:.2%}") + print(f" Worst logprob diff (agreeing positions): {worst_lp:.2e}") + print(f" Total positions checked/skipped: {total_checked}/{total_skipped}") + print(f"{'=' * 80}\n") diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e48ba6c997eb..7db1ce18e058 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2058,9 +2058,13 @@ def wvSplitK( def wvSplitKrc( - a: torch.Tensor, b: torch.Tensor, cu_count: int, bias: torch.Tensor = None + a: torch.Tensor, + b: torch.Tensor, + cu_count: int, + bias: torch.Tensor = None, + fast_skinny_gemm: bool = False, ) -> torch.Tensor: - return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count) + return torch.ops._rocm_C.wvSplitKrc(a, b, bias, cu_count, fast_skinny_gemm) def wvSplitKQ( diff --git a/vllm/config/model.py b/vllm/config/model.py index 5fb81ee424e1..27b5587dde8f 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -283,6 +283,12 @@ class ModelConfig: definitions""" io_processor_plugin: str | None = None """IOProcessor plugin name to load at model startup""" + fast_skinny_gemm: bool = False + """When enabled on ROCm, uses the non-deterministic atomicAdd reduction + path in skinny GEMM kernels (wvSplitKrc) for higher throughput. The + default (False) uses a deterministic store-then-reduce path that + guarantees bitwise reproducibility across runs with different batch + dimensions, at a small cost in LDS pressure and an extra sync.""" # Pooler config pooler_config: PoolerConfig | None = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index a962baba2a92..eb2f41d146a8 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -547,6 +547,7 @@ class EngineArgs: enable_flashinfer_autotune: bool = get_field( KernelConfig, "enable_flashinfer_autotune" ) + fast_skinny_gemm: bool = ModelConfig.fast_skinny_gemm worker_cls: str = ParallelConfig.worker_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls @@ -734,6 +735,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: model_group.add_argument( "--io-processor-plugin", **model_kwargs["io_processor_plugin"] ) + model_group.add_argument( + "--fast-skinny-gemm", **model_kwargs["fast_skinny_gemm"] + ) # Model loading arguments load_kwargs = get_kwargs(LoadConfig) @@ -1333,6 +1337,7 @@ def create_model_config(self) -> ModelConfig: logits_processors=self.logits_processors, video_pruning_rate=self.video_pruning_rate, io_processor_plugin=self.io_processor_plugin, + fast_skinny_gemm=self.fast_skinny_gemm, ) def validate_tensorizer_args(self): diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index a6a5ef106f3b..e6fce9aa1c50 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -24,6 +24,13 @@ "expert_gate", } +_fast_skinny_gemm: bool = False + + +def set_fast_skinny_gemm(enabled: bool): + global _fast_skinny_gemm + _fast_skinny_gemm = enabled + def is_layer_moe_router_gate(prefix: str) -> bool: if not prefix: @@ -150,10 +157,6 @@ def rocm_unquantized_gemm_impl( k = weight.shape[1] cu_count = get_cu_count() - if use_aiter_triton_gemm(n, m, k, x.dtype): - from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 - - return gemm_a16w16(x, weight, bias) # Next ^2 of n N_p2 = 1 << (n - 1).bit_length() @@ -167,7 +170,6 @@ def rocm_unquantized_gemm_impl( CuNeeded = rndup_cus * GrpsShrB # candidate for atomic reduce count splitk? fits_wvsplitkrc = CuNeeded <= cu_count - use_skinny_reduce_counting = ( envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx950() @@ -177,14 +179,20 @@ def rocm_unquantized_gemm_impl( and k % 8 == 0 and k > 512 and m % 16 == 0 + and m * n <= 128 * 1024 # max reduce buffer and fits_wvsplitkrc - and x.is_contiguous() + and weight.is_contiguous() ) ) if use_skinny_reduce_counting: - x_view = x.reshape(-1, x.size(-1)) - out = ops.wvSplitKrc(weight, x_view, cu_count, bias) - return out.reshape(*x.shape[:-1], weight.shape[0]) + return ops.wvSplitKrc( + x, weight, cu_count, bias, fast_skinny_gemm=_fast_skinny_gemm + ) + + if use_aiter_triton_gemm(n, m, k, x.dtype): + from aiter.ops.triton.gemm_a16w16 import gemm_a16w16 + + return gemm_a16w16(x, weight, bias) use_skinny = ( envs.VLLM_ROCM_USE_SKINNY_GEMM diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 99b799ea413d..cef3668ebeff 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -58,6 +58,7 @@ MRotaryEmbedding, XDRotaryEmbedding, ) +from vllm.model_executor.layers.utils import set_fast_skinny_gemm from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader from vllm.model_executor.model_loader.reload import ( finalize_layerwise_reload, @@ -89,6 +90,7 @@ PlaceholderRange, ) from vllm.multimodal.utils import group_mm_kwargs_by_modality +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingType from vllm.sequence import IntermediateTensors @@ -4180,6 +4182,11 @@ def load_model(self, eep_scale_up: bool = False) -> None: self.model_config.model, scope="global", ) + + # Set ROCm skinny GEMM mode before loading model + if current_platform.is_rocm(): + set_fast_skinny_gemm(self.model_config.fast_skinny_gemm) + global_expert_loads, old_global_expert_indices_per_model, rank_mapping = ( EplbState.get_eep_state(self.parallel_config) if eep_scale_up