diff --git a/csrc/apis/attention.hpp b/csrc/apis/attention.hpp index 12a6799d..702cbd30 100644 --- a/csrc/apis/attention.hpp +++ b/csrc/apis/attention.hpp @@ -53,11 +53,10 @@ static void fp8_gemm_nt_skip_head_mid(const std::pairget_arch_major(); @@ -66,7 +65,9 @@ static void fp8_gemm_nt_skip_head_mid(const std::pair& c, - const std::tuple& recipe, + std::optional> recipe, const std::string& compiled_dims) { // Shape must be `[B, M, K] @ [B, N, K].T` const auto& major_a = a.stride(-1) == 1 ? cute::UMMA::Major::K : cute::UMMA::Major::MN; @@ -163,15 +163,16 @@ static void fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, return; // Transform scaling factors - const auto& transformed_sfa = layout::transform_sf_into_required_layout(sfa, m, k, recipe, batch_size, true, false); - const auto& transformed_sfb = layout::transform_sf_into_required_layout(sfb, n, k, recipe, batch_size, false, false); + const auto& [transformed_sfa, transformed_sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + sfa, sfb, m, n, k, recipe, std::nullopt, std::nullopt, batch_size, batch_size, false); // Dispatch implementation - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 10) { sm100_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, compiled_dims); } else { - DG_HOST_UNREACHABLE("Unsupported architecture"); + const auto& major_sfb = get_major_type_ab(sfb); + sm90_fp8_bmm(a, transformed_sfa, b, transformed_sfb, c, d, batch_size, m, n, k, major_a, major_b, major_sfb, compiled_dims); } } @@ -182,6 +183,7 @@ static void fp8_einsum(const std::string& expr, const std::optional& c, const std::tuple& recipe) { // Some hardcoded Einstein sum kernels + const auto arch_major = device_runtime->get_arch_major(); if (expr == "bhr,hdr->bhd") { // Permute dims to satisfy the order of (batch_size, m, n, k) // (batch_size, m, n, k): (h, b, d, r) @@ -190,7 +192,7 @@ static void fp8_einsum(const std::string& expr, const auto& perm_d = d.permute({1, 0, 2}); const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; fp8_bmm(perm_a, perm_sfa, b.first, b.second, perm_d, perm_c, recipe, "nk"); - } else if (expr == "bhd,hdr->bhr") { + } else if (expr == "bhd,hdr->bhr" and arch_major == 10) { // (batch_size, m, n, k): (h, b, r, d) const auto& perm_a = a.first.permute({1, 0, 2}); const auto& perm_sfa = a.second.permute({1, 0, 2}); @@ -199,7 +201,7 @@ static void fp8_einsum(const std::string& expr, const auto& perm_d = d.permute({1, 0, 2}); const auto& perm_c = c.has_value() ? std::make_optional(c.value().permute({1, 0, 2})) : std::nullopt; fp8_bmm(perm_a, perm_sfa, perm_b, perm_sfb, perm_d, perm_c, recipe, "nk"); - } else if (expr == "bhd,bhr->hdr") { + } else if (expr == "bhd,bhr->hdr" and arch_major == 10) { // (batch_size, m, n, k): (h, d, r, b) const auto& perm_a = a.first.permute({1, 2, 0}); const auto& perm_sfa = a.second.permute({1, 2, 0}); diff --git a/csrc/apis/gemm.hpp b/csrc/apis/gemm.hpp index fb205de2..6770cf92 100644 --- a/csrc/apis/gemm.hpp +++ b/csrc/apis/gemm.hpp @@ -46,13 +46,16 @@ static bool early_return(const int& m, const int &n, const int& k, } #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE -static void fp8_gemm_nt(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { + +static void fp8_fp4_gemm_nt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { // Shape must be `[M, K] @ [N, K].T` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -65,12 +68,11 @@ static void fp8_gemm_nt(const std::pair& a, check_major_type_cd(d); // Type and shape checks - const auto& [m , k ] = get_shape<2>(a.first); - const auto& [n , k_] = get_shape<2>(b.first); - const auto& [m_, n_] = get_shape<2>(d); + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [n , k_] = check_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16 or d.scalar_type() == torch::kFloat); // Early return for trivial cases @@ -78,88 +80,104 @@ static void fp8_gemm_nt(const std::pair& a, return; // Transform SFA and SFB into compute-required layout - if (not recipe.has_value()) - recipe = get_default_recipe(a.second.scalar_type(), b.second.scalar_type()); - DG_HOST_ASSERT(recipe.value() == std::make_tuple(1, 1, 128) or recipe.value() == std::make_tuple(1, 128, 128)); - const auto& sfa = layout::transform_sf_into_required_layout(a.second, m, k, recipe.value(), std::nullopt, true, disable_ue8m0_cast); - const auto& sfb = layout::transform_sf_into_required_layout(b.second, n, k, recipe.value(), std::nullopt, false, disable_ue8m0_cast); + const auto [sfa, sfb, gran_k_a, gran_k_b] = layout::transform_sf_pair_into_required_layout( + a.second, b.second, m, n, k, recipe, recipe_a, recipe_b, std::nullopt, std::nullopt, disable_ue8m0_cast); // Dispatch into different implements - const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { - if (std::get<1>(recipe.value()) == 1) { + const int gran_n = recipe.has_value() ? std::get<1>(recipe.value()) : std::get<0>(recipe_b.value()); + if (gran_n == 1) { sm90_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); } else { const auto& major_sfb = get_major_type_ab(sfb); sm90_fp8_gemm_1d2d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, major_sfb, compiled_dims); } } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, major_a, major_b, compiled_dims); + sm100_fp8_fp4_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, k, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } } -static void fp8_gemm_nn(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); +static void fp8_fp4_gemm_nn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt(a, {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); } -static void fp8_gemm_tn(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, - {b.first.transpose(0, 1), b.second.transpose(0, 1)}, - d, c, recipe, compiled_dims, disable_ue8m0_cast); +static void fp8_fp4_gemm_tn(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, + {b.first.transpose(0, 1), b.second.transpose(0, 1)}, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); } -static void fp8_gemm_tt(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const std::optional& c, - const std::optional>& recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { - fp8_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, - d, c, recipe, compiled_dims, disable_ue8m0_cast); +static void fp8_fp4_gemm_tt(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const std::optional& c, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast) { + fp8_fp4_gemm_nt({a.first.transpose(0, 1), a.second.transpose(0, 1)}, b, + d, c, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast); } -static void m_grouped_fp8_gemm_nt_contiguous(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const torch::Tensor& m_indices, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { +static void m_grouped_fp8_fp4_gemm_nt_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { // Shape must be `[M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); if (fp8_requires_k_major()) DG_HOST_ASSERT(major_b == cute::UMMA::Major::K); - DG_HOST_ASSERT(m_indices.is_contiguous()); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); // Type and shape checks - const auto& [m, k] = get_shape<2>(a.first); - const auto& [num_groups, n, k_] = get_shape<3>(b.first); - const auto& [m_, n_] = get_shape<2>(d); - const auto& m__ = static_cast(m_indices.numel()); - DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + const auto arch_major = device_runtime->get_arch_major(); + const auto [m , k ] = check_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups, n, k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [m_, n_] = get_shape<2>(d); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + + // Layout checks + if (use_psum_layout) { + const auto& [num_groups_] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(num_groups == num_groups_); + } else { + const auto& [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(m == m__); + DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); + } // D must be N-major check_major_type_cd(d); @@ -169,44 +187,48 @@ static void m_grouped_fp8_gemm_nt_contiguous(const std::pairget_arch_major(); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { const auto& major_sfb = get_major_type_ab(sfb); - sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, m_indices, + DG_HOST_ASSERT(not use_psum_layout); + sm90_m_grouped_fp8_gemm_contiguous_1d2d(a.first, sfa, b.first, sfb, d, grouped_layout, num_groups, m, n, k, major_a, major_b, major_sfb, compiled_dims); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_m_grouped_fp8_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, m_indices, - num_groups, m, n, k, major_a, major_b, compiled_dims); + sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(a.first, sfa, b.first, sfb, d, grouped_layout, + num_groups, m, n, k, gran_k_a, gran_k_b, major_a, major_b, + compiled_dims, use_psum_layout, expected_m_for_psum_layout); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } } -static void m_grouped_fp8_gemm_nn_contiguous(const std::pair& a, +static void m_grouped_fp8_fp4_gemm_nn_contiguous(const std::pair& a, + const std::pair& b, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::string& compiled_dims, + const bool& disable_ue8m0_cast, + const bool& use_psum_layout) { + m_grouped_fp8_fp4_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, + d, grouped_layout, recipe, recipe_a, recipe_b, compiled_dims, disable_ue8m0_cast, use_psum_layout, std::nullopt); +} + +static void m_grouped_fp8_fp4_gemm_nt_masked(const std::pair& a, const std::pair& b, const torch::Tensor& d, - const torch::Tensor& m_indices, - const std::optional>& recipe, + const torch::Tensor& masked_m, + const int& expected_m, + std::optional> recipe, + std::optional> recipe_a, + std::optional> recipe_b, const std::string& compiled_dims, const bool& disable_ue8m0_cast) { - m_grouped_fp8_gemm_nt_contiguous(a, {b.first.transpose(1, 2), b.second.transpose(1, 2)}, - d, m_indices, recipe, compiled_dims, disable_ue8m0_cast); -} - -static void m_grouped_fp8_gemm_nt_masked(const std::pair& a, - const std::pair& b, - const torch::Tensor& d, - const torch::Tensor& masked_m, - const int& expected_m, - std::optional> recipe, - const std::string& compiled_dims, - const bool& disable_ue8m0_cast) { // Shape must be `[G, M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a.first); const auto& major_b = get_major_type_ab(b.first); @@ -214,15 +236,14 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pair(a.first); - const auto& [num_groups_, n, k_] = get_shape<3>(b.first); - const auto& [num_groups__, m_, n_] = get_shape<3>(d); - const auto& num_groups___ = static_cast(masked_m.numel()); + const auto arch_major = device_runtime->get_arch_major(); + const auto [num_groups , m , k ] = check_grouped_ab_fp8_fp4(a.first, major_a, arch_major); + const auto [num_groups_ , n , k_] = check_grouped_ab_fp8_fp4(b.first, major_b, arch_major); + const auto [num_groups__, m_, n_] = get_shape<3>(d); + const auto num_groups___ = static_cast(masked_m.numel()); DG_HOST_ASSERT(num_groups == num_groups_ and num_groups == num_groups__ and num_groups == num_groups___); DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0); - DG_HOST_ASSERT(a.first.scalar_type() == torch::kFloat8_e4m3fn); - DG_HOST_ASSERT(b.first.scalar_type() == torch::kFloat8_e4m3fn); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(masked_m.scalar_type() == torch::kInt); @@ -230,20 +251,18 @@ static void m_grouped_fp8_gemm_nt_masked(const std::pairget_arch_major(); if (arch_major == 9 and sfa.scalar_type() == torch::kFloat) { const auto& major_sfb = get_major_type_ab(sfb); sm90_m_grouped_fp8_gemm_masked_1d2d(a.first, sfa, b.first, sfb, d, masked_m, num_groups, m, n, k, expected_m, major_a, major_b, major_sfb, compiled_dims); } else if (arch_major == 10 and sfa.scalar_type() == torch::kInt) { - sm100_m_grouped_fp8_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, - num_groups, m, n, k, expected_m, major_a, major_b, compiled_dims); + sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(a.first, sfa, b.first, sfb, d, masked_m, + num_groups, m, n, k, expected_m, gran_k_a, gran_k_b, + major_a, major_b, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture or scaling factor types"); } @@ -262,9 +281,10 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pair(d); - const auto& [_, m_] = get_shape<2>(a.first); - const auto& [__, n_] = get_shape<2>(b.first); - DG_HOST_ASSERT(m == m_ and n == n_); + const auto& [sum_k_ , m_] = get_shape<2>(a.first); + const auto& [sum_k__, n_] = get_shape<2>(b.first); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); // Contiguity checks DG_HOST_ASSERT(a.first.is_contiguous()); @@ -283,8 +303,8 @@ static void k_grouped_fp8_gemm_tn_contiguous(const std::pairget_arch_major(); if (arch_major == 10) { - fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, - cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); + sm100_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, + cute::UMMA::Major::MN, cute::UMMA::Major::MN, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } @@ -305,9 +325,7 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pair(d); const auto& sum_mk = a.first.numel(); const auto& sum_nk = b.first.numel(); - int sum_k = 0; - for (const auto& k: ks) - sum_k += k; + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); DG_HOST_ASSERT(sum_mk == static_cast(sum_k) * m); DG_HOST_ASSERT(sum_nk == static_cast(sum_k) * n); @@ -334,7 +352,7 @@ static void k_grouped_fp8_gemm_nt_contiguous(const std::pairget_arch_major(); if (arch_major == 9) { - sm90_fp8_k_grouped_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer, + sm90_k_grouped_fp8_gemm_1d1d(a.first, sfa, b.first, sfb, c, d, m, n, ks, ks_tensor, tensor_map_buffer, cute::UMMA::Major::K, cute::UMMA::Major::K, compiled_dims); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); @@ -404,25 +422,36 @@ static void bf16_gemm_tt(const torch::Tensor& a, } static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& d, const torch::Tensor& m_indices, - const std::string& compiled_dims) { + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { // Shape must be `[M, K] @ [G, N, K].mT` const auto& major_a = get_major_type_ab(a); const auto& major_b = get_major_type_ab(b); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K); - DG_HOST_ASSERT(m_indices.is_contiguous()); + DG_HOST_ASSERT(grouped_layout.is_contiguous()); // Type and shape checks const auto& [m, k] = get_shape<2>(a); const auto& [num_groups, n, k_] = get_shape<3>(b); const auto& [m_, n_] = get_shape<2>(d); - const auto& m__ = static_cast(m_indices.numel()); - DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + DG_HOST_ASSERT(m == m_ and n == n_ and k == k_); DG_HOST_ASSERT(n > 0 and k > 0 and num_groups > 0); DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(b.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); - DG_HOST_ASSERT(m_indices.scalar_type() == torch::kInt); + DG_HOST_ASSERT(grouped_layout.scalar_type() == torch::kInt); + + // Layout checks + if (use_psum_layout) { + const auto& [num_groups_] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(num_groups == num_groups_); + } else { + const auto& [m__] = get_shape<1>(grouped_layout); + DG_HOST_ASSERT(m == m__); + DG_HOST_ASSERT(not expected_m_for_psum_layout.has_value()); + } // D must be N-major check_major_type_cd(d); @@ -434,21 +463,24 @@ static void m_grouped_bf16_gemm_nt_contiguous(const torch::Tensor& a, const torc // Dispatch implementation const auto& arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { - sm90_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices, + DG_HOST_ASSERT(not use_psum_layout); + sm90_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, num_groups, m, n, k, major_a, major_b, compiled_dims); } else if (arch_major == 10) { - sm100_m_grouped_bf16_gemm_contiguous(a, b, d, m_indices, - num_groups, m, n, k, major_a, major_b, compiled_dims); + sm100_m_grouped_bf16_gemm_contiguous(a, b, d, grouped_layout, + num_groups, m, n, k, major_a, major_b, compiled_dims, + use_psum_layout, expected_m_for_psum_layout); } else { DG_HOST_UNREACHABLE("Unsupported architecture"); } } static void m_grouped_bf16_gemm_nn_contiguous(const torch::Tensor& a, const torch::Tensor& b, - const torch::Tensor& d, const torch::Tensor& m_indices, - const std::string& compiled_dims) { + const torch::Tensor& d, const torch::Tensor& grouped_layout, + const std::string& compiled_dims, + const bool& use_psum_layout) { m_grouped_bf16_gemm_nt_contiguous(a, b.transpose(1, 2), - d, m_indices, compiled_dims); + d, grouped_layout, compiled_dims, use_psum_layout, std::nullopt); } static void m_grouped_bf16_gemm_nt_masked(const torch::Tensor& a, const torch::Tensor& b, @@ -498,9 +530,10 @@ static void k_grouped_bf16_gemm_tn_contiguous(const torch::Tensor& a, const std::string& compiled_dims) { // Shape checks const auto& [num_groups, m, n] = get_shape<3>(d); - const auto& [_, m_] = get_shape<2>(a); - const auto& [__, n_] = get_shape<2>(b); - DG_HOST_ASSERT(m == m_ and n == n_); + const auto& [sum_k_ , m_] = get_shape<2>(a); + const auto& [sum_k__, n_] = get_shape<2>(b); + const int sum_k = std::accumulate(ks.begin(), ks.end(), 0); + DG_HOST_ASSERT(m == m_ and n == n_ and sum_k == sum_k_ and sum_k == sum_k__); // Contiguity checks DG_HOST_ASSERT(a.is_contiguous()); @@ -563,38 +596,50 @@ static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b, static void register_apis(pybind11::module_& m) { #if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE - // FP8 GEMMs - m.def("fp8_gemm_nt", &fp8_gemm_nt, + // FP8 FP4 GEMMs + m.def("fp8_fp4_gemm_nt", &fp8_fp4_gemm_nt, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); - m.def("fp8_gemm_nn", &fp8_gemm_nn, + m.def("fp8_fp4_gemm_nn", &fp8_fp4_gemm_nn, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); - m.def("fp8_gemm_tn", &fp8_gemm_tn, + m.def("fp8_fp4_gemm_tn", &fp8_fp4_gemm_tn, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, py::arg("compiled_dims") = "mn", py::arg("disable_ue8m0_cast") = false); - m.def("fp8_gemm_tt", &fp8_gemm_tt, + m.def("fp8_fp4_gemm_tt", &fp8_fp4_gemm_tt, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, py::arg("compiled_dims") = "mn", py::arg("disable_ue8m0_cast") = false); - m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), - py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", - py::arg("disable_ue8m0_cast") = false); - m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), - py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk", - py::arg("disable_ue8m0_cast") = false); - m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked, + m.def("m_grouped_fp8_fp4_gemm_nt_contiguous", &m_grouped_fp8_fp4_gemm_nt_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false, + py::arg("use_psum_layout") = false, + py::arg("expected_m_for_psum_layout") = std::nullopt); + m.def("m_grouped_fp8_fp4_gemm_nn_contiguous", &m_grouped_fp8_fp4_gemm_nn_contiguous, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, + py::arg("compiled_dims") = "nk", + py::arg("disable_ue8m0_cast") = false, + py::arg("use_psum_layout") = false); + m.def("m_grouped_fp8_fp4_gemm_nt_masked", &m_grouped_fp8_fp4_gemm_nt_masked, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("expected_m"), py::arg("recipe") = std::nullopt, + py::arg("recipe_a") = std::nullopt, py::arg("recipe_b") = std::nullopt, py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false); m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"), @@ -606,6 +651,15 @@ static void register_apis(pybind11::module_& m) { py::arg("ks_tensor"), py::arg("c") = std::nullopt, py::arg("recipe") = std::make_tuple(1, 1, 128), py::arg("compiled_dims") = "mn"); + + // FP8 GEMM alias names + m.attr("fp8_gemm_nt") = m.attr("fp8_fp4_gemm_nt"); + m.attr("fp8_gemm_nn") = m.attr("fp8_fp4_gemm_nn"); + m.attr("fp8_gemm_tn") = m.attr("fp8_fp4_gemm_tn"); + m.attr("fp8_gemm_tt") = m.attr("fp8_fp4_gemm_tt"); + m.attr("m_grouped_fp8_gemm_nt_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nt_contiguous"); + m.attr("m_grouped_fp8_gemm_nn_contiguous") = m.attr("m_grouped_fp8_fp4_gemm_nn_contiguous"); + m.attr("m_grouped_fp8_gemm_nt_masked") = m.attr("m_grouped_fp8_fp4_gemm_nt_masked"); #endif #if DG_TENSORMAP_COMPATIBLE @@ -627,11 +681,14 @@ static void register_apis(pybind11::module_& m) { py::arg("c") = std::nullopt, py::arg("compiled_dims") = "mn"); m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), - py::arg("compiled_dims") = "nk"); + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("compiled_dims") = "nk", + py::arg("use_psum_layout") = false, + py::arg("expected_m_for_psum_layout") = std::nullopt); m.def("m_grouped_bf16_gemm_nn_contiguous", &m_grouped_bf16_gemm_nn_contiguous, - py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"), - py::arg("compiled_dims") = "nk"); + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("grouped_layout"), + py::arg("compiled_dims") = "nk", + py::arg("use_psum_layout") = false); m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked, py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"), py::arg("expected_m"), py::arg("compiled_dims") = "nk"); diff --git a/csrc/apis/hyperconnection.hpp b/csrc/apis/hyperconnection.hpp new file mode 100644 index 00000000..0a85b10f --- /dev/null +++ b/csrc/apis/hyperconnection.hpp @@ -0,0 +1,70 @@ +#pragma once + +#include "../utils/compatibility.hpp" + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +#include "../jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp" +#include "../jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp" +#endif + +namespace deep_gemm::hyperconnection { + +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE +static void tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const std::optional& num_splits) { + // A and B must be K-major, D must be N-major + DG_HOST_ASSERT(get_major_type_ab(a) == cute::UMMA::Major::K); + DG_HOST_ASSERT(get_major_type_ab(b) == cute::UMMA::Major::K); + check_major_type_cd(d); + + // S must be contiguous + DG_HOST_ASSERT(sqr_sum.is_contiguous()); + + // Type and shape checks + const auto& [m, k ] = get_shape<2>(a); + const auto& [n, k_] = get_shape<2>(b); + if (num_splits.has_value()) { + const auto& [num_splits_, m_, n_] = get_shape<3>(d); + const auto& [num_splits__, m__] = get_shape<2>(sqr_sum); + DG_HOST_ASSERT(num_splits.value() == num_splits_ and num_splits.value() == num_splits__ and num_splits.value() >= 1); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + } else { + const auto& [m_, n_] = get_shape<2>(d); + const auto& [m__] = get_shape<1>(sqr_sum); + DG_HOST_ASSERT(m == m_ and m == m__ and n == n_ and k == k_); + } + DG_HOST_ASSERT(n > 0 and k > 0); + DG_HOST_ASSERT(a.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(b.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(d.scalar_type() == torch::kFloat); + DG_HOST_ASSERT(sqr_sum.scalar_type() == torch::kFloat); + + // Do nothing if the problem is empty + if (m == 0) + return; + + // Dispatch into different implements + const auto& arch_major = device_runtime->get_arch_major(); + if (arch_major == 9) { + sm90_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); + } else if (arch_major == 10) { + sm100_tf32_hc_prenorm_gemm(a, b, d, sqr_sum, m, n, k, num_splits.has_value() ? num_splits.value() : 1); + } else { + DG_HOST_UNREACHABLE("Unsupported architecture"); + } +} + +#endif + +static void register_apis(pybind11::module_& m) { +#if DG_FP8_COMPATIBLE and DG_TENSORMAP_COMPATIBLE + m.def("tf32_hc_prenorm_gemm", &tf32_hc_prenorm_gemm, + py::arg("a"), py::arg("b"), py::arg("d"), py::arg("sqr_sum"), + py::arg("num_splits") = std::nullopt); +#endif +} + +} // namespace deep_gemm::hyperconnection diff --git a/csrc/apis/layout.hpp b/csrc/apis/layout.hpp index dcc4def0..3a37d36a 100644 --- a/csrc/apis/layout.hpp +++ b/csrc/apis/layout.hpp @@ -1,20 +1,34 @@ #pragma once #include "../utils/layout.hpp" +#include "../utils/compatibility.hpp" + +#if DG_TENSORMAP_COMPATIBLE #include "../jit_kernels/impls/smxx_layout.hpp" +#endif namespace deep_gemm::layout { +#if DG_TENSORMAP_COMPATIBLE static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, const int& mn, const int& k, - const std::tuple& recipe, + const std::optional>& recipe, + const std::optional>& recipe_ab, const std::optional& num_groups, const bool& is_sfa, const bool& disable_ue8m0_cast) { - const auto& gran_mn = is_sfa ? std::get<0>(recipe) : std::get<1>(recipe); - const auto& gran_k = std::get<2>(recipe); const auto& arch_major = device_runtime->get_arch_major(); + int gran_mn, gran_k; + if (recipe.has_value()) { + DG_HOST_ASSERT(not recipe_ab.has_value()); + gran_mn = is_sfa ? std::get<0>(recipe.value()) : std::get<1>(recipe.value()); + gran_k = std::get<2>(recipe.value()); + } else { + DG_HOST_ASSERT(recipe_ab.has_value()); + std::tie(gran_mn, gran_k) = recipe_ab.value(); + } + // Pre-transform checks check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups); @@ -22,30 +36,44 @@ static torch::Tensor transform_sf_into_required_layout(const torch::Tensor& sf, if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) return get_mn_major_tma_aligned_tensor(sf); - // (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if (sf.scalar_type() == torch::kFloat and gran_mn == 1 and gran_k == 128 and arch_major == 10) { - DG_HOST_ASSERT(not disable_ue8m0_cast); - return get_mn_major_tma_aligned_packed_ue8m0_tensor(sf); - } - // (FP32, 128, 128) on SM90: no need to transform, check SFB requirements if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and (arch_major == 9 or disable_ue8m0_cast)) return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, false, true, torch::kFloat); - // (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major - if (sf.scalar_type() == torch::kFloat and gran_mn == 128 and gran_k == 128 and arch_major == 10) { + // (FP32, x, gran_k) on SM100: transform to (INT, 1, gran_k), TMA-aligned and MN-major + if (sf.scalar_type() == torch::kFloat and (gran_k == 32 or gran_k == 128) and arch_major == 10) { DG_HOST_ASSERT(not disable_ue8m0_cast); - const auto& broadcasted = sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(128)); + const auto& broadcasted = gran_mn == 1 ? sf : + sf.index_select(-2, torch::arange(mn, at::TensorOptions().device(sf.device())).floor_divide_(gran_mn)); return get_mn_major_tma_aligned_packed_ue8m0_tensor(broadcasted); } - // (INT, 1, 128) on SM100: transform to TMA-aligned and MN-major - if (sf.scalar_type() == torch::kInt and gran_mn == 1 and gran_k == 128 and arch_major == 10) + // (INT, 1, gran_k) on SM100: transform to TMA-aligned and MN-major + if (sf.scalar_type() == torch::kInt and gran_mn == 1 and (gran_k == 32 or gran_k == 128) and arch_major == 10) return check_sf_layout(sf, mn, k, gran_mn, gran_k, num_groups, true, false, torch::kInt); DG_HOST_UNREACHABLE("Unknown SF transformation"); } +static std::tuple transform_sf_pair_into_required_layout( + const torch::Tensor& sfa, const torch::Tensor& sfb, + const int& m, const int& n, const int& k, + std::optional>& recipe, + const std::optional>& recipe_a, + const std::optional>& recipe_b, + const std::optional& num_groups_a, + const std::optional& num_groups_b, + const bool& disable_ue8m0_cast = false) { + DG_HOST_ASSERT(recipe_a.has_value() == recipe_b.has_value()); + if (not recipe_a.has_value() and not recipe.has_value()) + recipe = get_default_recipe(sfa.scalar_type(), sfb.scalar_type()); + const auto transformed_sfa = transform_sf_into_required_layout(sfa, m, k, recipe, recipe_a, num_groups_a, true, disable_ue8m0_cast); + const auto transformed_sfb = transform_sf_into_required_layout(sfb, n, k, recipe, recipe_b, num_groups_b, false, disable_ue8m0_cast); + const int gran_k_a = recipe_a.has_value() ? std::get<1>(recipe_a.value()) : std::get<2>(recipe.value()); + const int gran_k_b = recipe_b.has_value() ? std::get<1>(recipe_b.value()) : std::get<2>(recipe.value()); + return std::make_tuple(transformed_sfa, transformed_sfb, gran_k_a, gran_k_b); +} + static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Tensor& sf, const std::vector& ks, const torch::Tensor& ks_tensor, @@ -69,17 +97,24 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te DG_HOST_UNREACHABLE("Unknown cases"); } +#endif + static void register_apis(pybind11::module_& m) { + +#if DG_TENSORMAP_COMPATIBLE m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout, - py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"), + py::arg("sf"), py::arg("mn"), py::arg("k"), + py::arg("recipe") = std::nullopt, py::arg("recipe_ab") = std::nullopt, py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false, py::arg("disable_ue8m0_cast") = false); m.def("get_tma_aligned_size", &get_tma_aligned_size); - m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout); m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor); m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor); m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor); +#endif + + m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout); } } // namespace deep_gemm::layout diff --git a/csrc/apis/runtime.hpp b/csrc/apis/runtime.hpp index 9ef42078..a5d313e2 100644 --- a/csrc/apis/runtime.hpp +++ b/csrc/apis/runtime.hpp @@ -1,6 +1,8 @@ #pragma once +#if DG_TENSORMAP_COMPATIBLE #include "../jit/compiler.hpp" +#endif #include "../jit/device_runtime.hpp" namespace deep_gemm::runtime { @@ -18,10 +20,11 @@ static void register_apis(pybind11::module_& m) { m.def("get_tc_util", [&]() { return device_runtime->get_tc_util(); }); - m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) { +#if DG_TENSORMAP_COMPATIBLE Compiler::prepare_init(library_root_path, cuda_home_path_by_python); KernelRuntime::prepare_init(cuda_home_path_by_python); +#endif }); } diff --git a/csrc/indexing/main.cu b/csrc/indexing/main.cu index 4e260f9c..1b96da2f 100644 --- a/csrc/indexing/main.cu +++ b/csrc/indexing/main.cu @@ -15,6 +15,10 @@ #include #include +// Hyperconnection kernels +#include +#include + // Layout kernels #include #include diff --git a/csrc/jit/compiler.hpp b/csrc/jit/compiler.hpp index de717793..3dc0cfbf 100644 --- a/csrc/jit/compiler.hpp +++ b/csrc/jit/compiler.hpp @@ -24,6 +24,7 @@ class Compiler { static std::filesystem::path library_include_path; static std::filesystem::path cuda_home; static std::string library_version; + static std::filesystem::path cuobjdump_path; static std::string get_library_version() { std::vector buffer; @@ -45,6 +46,7 @@ class Compiler { Compiler::library_include_path = Compiler::library_root_path / "include"; Compiler::cuda_home = cuda_home_path_by_python; Compiler::library_version = get_library_version(); + Compiler::cuobjdump_path = Compiler::cuda_home / "bin" / "cuobjdump"; } std::string signature, flags; @@ -56,6 +58,7 @@ class Compiler { DG_HOST_ASSERT(not library_include_path.empty()); DG_HOST_ASSERT(not cuda_home.empty()); DG_HOST_ASSERT(not library_version.empty()); + DG_HOST_ASSERT(not cuobjdump_path.empty()); // Cache settings cache_dir_path = std::filesystem::path(get_env("HOME")) / ".deep_gemm"; @@ -108,25 +111,57 @@ class Compiler { // Compile into a temporary CUBIN const auto tmp_cubin_path = get_tmp_file_path(); - compile(code, dir_path, tmp_cubin_path); + if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_PTX")) { + // Dump PTX if needed + const auto tmp_ptx_path = get_tmp_file_path(); + compile(code, dir_path, tmp_cubin_path, tmp_ptx_path); + + // Replace into the cache directory + std::filesystem::rename(tmp_ptx_path, dir_path / "kernel.ptx"); + } else { + compile(code, dir_path, tmp_cubin_path); + } // Replace into the cache directory - make_dirs(dir_path); - std::filesystem::rename(tmp_cubin_path, dir_path / "kernel.cubin"); + const auto cubin_path = dir_path / "kernel.cubin"; + std::filesystem::rename(tmp_cubin_path, cubin_path); + + // Disassemble if needed + if (get_env("DG_JIT_DUMP_ASM") or get_env("DG_JIT_DUMP_SASS")) { + // Dump into a temporary SASS + const auto tmp_sass_path = get_tmp_file_path(); + disassemble(cubin_path, tmp_sass_path); + + // Replace into the current directory + std::filesystem::rename(tmp_sass_path, dir_path / "kernel.sass"); + } // Put into the runtime cache - const auto& runtime = kernel_runtime_cache->get(dir_path); + const auto runtime = kernel_runtime_cache->get(dir_path); DG_HOST_ASSERT(runtime != nullptr); return runtime; } - virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const = 0; + static void disassemble(const std::filesystem::path &cubin_path, const std::filesystem::path &sass_path) { + // Disassemble the CUBIN file to SASS + const auto command = fmt::format("{} --dump-sass {} > {}", cuobjdump_path.c_str(), cubin_path.c_str(), sass_path.c_str()); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running cuobjdump command: %s\n", command.c_str()); + const auto [return_code, output] = call_external_command(command); + if (return_code != 0) { + printf("cuobjdump failed: %s\n", output.c_str()); + DG_HOST_ASSERT(false and "cuobjdump failed"); + } + } + + virtual void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path, const std::optional &ptx_path = std::nullopt) const = 0; }; DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_root_path); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_include_path); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuda_home); DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, library_version); +DG_DECLARE_STATIC_VAR_IN_CLASS(Compiler, cuobjdump_path); class NVCCCompiler final: public Compiler { std::filesystem::path nvcc_path; @@ -164,17 +199,19 @@ class NVCCCompiler final: public Compiler { const auto& arch = device_runtime->get_arch(false, nvcc_major > 12 or nvcc_minor >= 9); flags = fmt::format("{} -I{} --gpu-architecture=sm_{} " "--compiler-options=-fPIC,-O3,-fconcepts,-Wno-deprecated-declarations,-Wno-abi " - "-cubin -O3 --expt-relaxed-constexpr --expt-extended-lambda", + "-O3 --expt-relaxed-constexpr --expt-extended-lambda", flags, library_include_path.c_str(), arch); } - void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { + void compile(const std::string &code, const std::filesystem::path& dir_path, + const std::filesystem::path &cubin_path, + const std::optional &ptx_path) const override { // Write the code into the cache directory const auto& code_path = dir_path / "kernel.cu"; put(code_path, code); // Compile - const auto& command = fmt::format("{} {} -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); + const auto& command = fmt::format("{} {} -cubin -o {} {}", nvcc_path.c_str(), code_path.c_str(), cubin_path.c_str(), flags); if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) printf("Running NVCC command: %s\n", command.c_str()); const auto& [return_code, output] = call_external_command(command); @@ -183,6 +220,18 @@ class NVCCCompiler final: public Compiler { DG_HOST_ASSERT(false and "NVCC compilation failed"); } + // Compile to PTX if needed + if (ptx_path.has_value()) { + const auto ptx_command = fmt::format("{} {} -ptx -o {} {}", nvcc_path.c_str(), code_path.c_str(), ptx_path->c_str(), flags); + if (get_env("DG_JIT_DEBUG", 0) or get_env("DG_JIT_PRINT_COMPILER_COMMAND", 0)) + printf("Running NVCC PTX command: %s\n", ptx_command.c_str()); + const auto [ptx_return_code, ptx_output] = call_external_command(ptx_command); + if (ptx_return_code != 0) { + printf("NVCC PTX compilation failed: %s\n", ptx_output.c_str()); + DG_HOST_ASSERT(false and "NVCC PTX compilation failed"); + } + } + // Check local memory usage if (get_env("DG_JIT_PTXAS_CHECK", 0)) DG_HOST_ASSERT(not std::regex_search(output, std::regex(R"(Local memory used)"))); @@ -219,11 +268,13 @@ class NVRTCCompiler final: public Compiler { // Override the compiler flags // Only NVRTC >= 12.9 supports arch-specific family suffix const auto& arch = device_runtime->get_arch(false, major > 12 or minor >= 9); - flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {}", + flags = fmt::format("{} {}--gpu-architecture=sm_{} -default-device {} --device-int128", flags, include_dirs, arch, pch_flags); } - void compile(const std::string &code, const std::filesystem::path& dir_path, const std::filesystem::path &cubin_path) const override { + void compile(const std::string &code, const std::filesystem::path& dir_path, + const std::filesystem::path &cubin_path, + const std::optional &ptx_path) const override { // Write the code into the cache directory const auto& code_path = dir_path / "kernel.cu"; put(code_path, code); @@ -266,6 +317,17 @@ class NVRTCCompiler final: public Compiler { } } + if (ptx_path.has_value()) { + // Get PTX size and data if needed + size_t ptx_size; + DG_NVRTC_CHECK(nvrtcGetPTXSize(program, &ptx_size)); + std::string ptx_data(ptx_size, '\0'); + DG_NVRTC_CHECK(nvrtcGetPTX(program, ptx_data.data())); + + // Write into the file system + put(ptx_path.value(), ptx_data); + } + // Get CUBIN size and data size_t cubin_size; DG_NVRTC_CHECK(nvrtcGetCUBINSize(program, &cubin_size)); diff --git a/csrc/jit/device_runtime.hpp b/csrc/jit/device_runtime.hpp index 20ab6935..d33743ef 100644 --- a/csrc/jit/device_runtime.hpp +++ b/csrc/jit/device_runtime.hpp @@ -17,19 +17,7 @@ class DeviceRuntime { static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024; public: -#if TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 3) - // For PyTorch 2.3+, share the PyTorch cuBLASLt handle - DeviceRuntime() = default; - - static cublasLtHandle_t get_cublaslt_handle() { - return at::cuda::getCurrentCUDABlasLtHandle(); - } - - static torch::Tensor get_cublaslt_workspace() { - return torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)); - } -#else - // Otherwise, create the cuBLASLt handle ourselves + // Create the cuBLASLt handle ourselves cublasLtHandle_t cublaslt_handle{}; std::shared_ptr cublaslt_workspace; @@ -49,7 +37,6 @@ class DeviceRuntime { torch::Tensor get_cublaslt_workspace() const { return *cublaslt_workspace; } -#endif std::shared_ptr get_prop() { if (cached_prop == nullptr) { diff --git a/csrc/jit/handle.hpp b/csrc/jit/handle.hpp index 131e1c9f..34447f91 100644 --- a/csrc/jit/handle.hpp +++ b/csrc/jit/handle.hpp @@ -40,10 +40,7 @@ DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleLoad); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleUnload); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuModuleGetFunction); DECL_LAZY_CUDA_DRIVER_FUNCTION(cuLaunchKernelEx); - -#if DG_TENSORMAP_COMPATIBLE DECL_LAZY_CUDA_DRIVER_FUNCTION(cuTensorMapEncodeTiled); -#endif #if CUDART_VERSION >= 12080 and defined(DG_JIT_USE_RUNTIME_API) @@ -166,7 +163,6 @@ static auto launch_kernel(const KernelHandle& kernel, const LaunchConfigHandle& void *ptr_args[] = { &args... }; return lazy_cuLaunchKernelEx(&config, kernel, ptr_args, nullptr); } - #endif } // namespace deep_gemm diff --git a/csrc/jit_kernels/heuristics/common.hpp b/csrc/jit_kernels/heuristics/common.hpp index be66454a..a49584f4 100644 --- a/csrc/jit_kernels/heuristics/common.hpp +++ b/csrc/jit_kernels/heuristics/common.hpp @@ -59,7 +59,8 @@ struct GemmConfig { // Templated configs GemmType gemm_type; KernelType kernel_type; - at::ScalarType ab_dtype, cd_dtype; + MmaKind mma_kind; + at::ScalarType a_dtype, b_dtype, cd_dtype; cute::UMMA::Major major_a; cute::UMMA::Major major_b; bool with_accumulation; @@ -99,9 +100,9 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne const int& m, const int& n, const int& k, const int& block_m, const int& block_n, const int& block_k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, const int& num_stages, const MulticastConfig& multicast_config) { - const int& ab_elem_size = static_cast(c10::elementSize(ab_dtype)); + const int& ab_elem_size = static_cast(get_element_size(mma_kind)); const int& cd_elem_size = static_cast(c10::elementSize(cd_dtype)); const int& load_block_m = ArchSpec::get_ab_load_block_m(multicast_config, block_m); @@ -119,7 +120,7 @@ static SharedMemoryConfig get_smem_config(const GemmType& gemm_type, const Kerne // SF shared memory const auto& [smem_sfa_per_stage, smem_sfb_per_stage] = - ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, ab_dtype, cd_dtype); + ArchSpec::get_sf_smem_size_per_stage(kernel_type, block_m, block_n, block_k, mma_kind, cd_dtype); const int& smem_extra_sfb = ArchSpec::get_extra_sfb_smem_size(m, n, k, block_m, block_n, block_k); // M-barriers and tensor memory pointers @@ -151,21 +152,35 @@ template static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& kernel_type, const int& m, const int& n, const int& k, const int& num_groups, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const at::ScalarType& a_dtype, const at::ScalarType& b_dtype, + const at::ScalarType& cd_dtype, const bool& with_accumulation, const int& num_sms) { - DG_HOST_ASSERT(ab_dtype == torch::kFloat8_e4m3fn or ab_dtype == torch::kBFloat16); + const auto mma_kind = (a_dtype == torch::kBFloat16 ? MmaKind::BF16 : MmaKind::MXFP8FP4); + if (mma_kind == MmaKind::BF16) { + DG_HOST_ASSERT(a_dtype == torch::kBFloat16 and b_dtype == torch::kBFloat16); + } else { + DG_HOST_ASSERT(a_dtype == torch::kFloat8_e4m3fn or a_dtype == kPackedFP4); + DG_HOST_ASSERT(b_dtype == torch::kFloat8_e4m3fn or b_dtype == kPackedFP4); + } DG_HOST_ASSERT(cd_dtype == torch::kBFloat16 or cd_dtype == torch::kFloat); // Select M/N block sizes auto block_ms = ArchSpec::get_block_m_candidates(kernel_type, major_a, m); if (gemm_type == GemmType::MGroupedContiguous) block_ms = std::vector{get_mk_alignment_for_contiguous_layout()}; - if (gemm_type == GemmType::MGroupedMasked) // Exclude 256 for performance - block_ms = std::vector{64, 128}; - const auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype); + if (gemm_type == GemmType::MGroupedMasked or gemm_type == GemmType::MGroupedContiguousWithPsumLayout) + block_ms = std::vector{64, 128}; // Exclude 256 for performance + auto block_ns = ArchSpec::get_block_n_candidates(kernel_type, cd_dtype); + + // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B + // TODO: Optimize it + if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN) + block_ms = std::vector{128}; + if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN) + block_ns = std::vector{128}; // K block size is selected in a fixed manner - const auto& block_k = 128 / static_cast(c10::elementSize(ab_dtype)); + const auto& block_k = (mma_kind == MmaKind::BF16 ? 64 : 128); // Some util functions const auto& get_num_blocks = [=](const int& block_m, const int& block_n) { @@ -186,7 +201,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k for (const auto& block_n: block_ns) { const int& num_waves = get_num_waves(block_m, block_n); const auto& last_util = get_last_wave_util(block_m, block_n); - if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, ab_dtype, cd_dtype, m, n, k, block_m, block_n, block_k)) + if (not ArchSpec::is_block_size_legal(kernel_type, major_a, major_b, mma_kind, cd_dtype, m, n, k, block_m, block_n, block_k)) continue; bool success = false; @@ -218,8 +233,16 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Decide the number of TMA multicasts and whether broadcast on A MulticastConfig best_multicast_config = {1, false}; - const auto& [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( + auto [is_legal_on_a, is_legal_on_b] = ArchSpec::get_multicast_legality( gemm_type, num_groups, m, n, best_block_m, best_block_n, num_sms); + + // NOTES: TMA copy .b4x16_p64 only supports Swizzle 128B + // TODO: Optimize it + if (a_dtype == kPackedFP4 and major_a == cute::UMMA::Major::MN) + is_legal_on_a = false; + if (b_dtype == kPackedFP4 and major_b == cute::UMMA::Major::MN) + is_legal_on_b = false; + const bool is_legal[2] = {is_legal_on_b, is_legal_on_a}; bool order[2] = {false, true}; if (best_block_m > best_block_n) @@ -236,14 +259,14 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k int best_num_stages = 0; SharedMemoryConfig best_smem_config; for (int num_stages = 32; num_stages > 0; -- num_stages) { - if (not ArchSpec::is_num_stages_legal(ab_dtype, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) + if (not ArchSpec::is_num_stages_legal(mma_kind, cd_dtype, num_stages, best_block_m, best_block_n, block_k)) continue; best_smem_config = get_smem_config(gemm_type, kernel_type, m, n, k, best_block_m, best_block_n, block_k, major_a, major_b, - ab_dtype, cd_dtype, + mma_kind, cd_dtype, num_stages, best_multicast_config); if (best_smem_config.smem_size <= smem_capacity) { best_num_stages = num_stages; @@ -255,7 +278,7 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Recompute the minimal number of SMs required // NOTES: less L2 cache usage and less GPU frequency drop int num_min_sms = num_sms; - if (ArchSpec::should_minimize_num_sms()) { + if (get_env("DG_JIT_MINIMIZE_NUM_SMS", 0)) { num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, best_num_waves); num_min_sms = align(num_min_sms, best_multicast_config.num_multicast); DG_HOST_ASSERT(num_min_sms <= num_sms); @@ -264,7 +287,9 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k const auto& config = GemmConfig { .gemm_type = gemm_type, .kernel_type = kernel_type, - .ab_dtype = ab_dtype, + .mma_kind = mma_kind, + .a_dtype = a_dtype, + .b_dtype = b_dtype, .cd_dtype = cd_dtype, .major_a = major_a, .major_b = major_b, @@ -284,21 +309,22 @@ static GemmConfig get_best_config(const GemmType& gemm_type, const KernelType& k // Only SM100 BF16 kernels support tensor core control if (config.tc_util < 100) - DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and ab_dtype == torch::kBFloat16); + DG_HOST_ASSERT(device_runtime->get_arch_major() == 10 and mma_kind == MmaKind::BF16); // Print configs for the first time if (get_env("DG_JIT_DEBUG") or get_env("DG_PRINT_CONFIGS")) { auto key = std::make_tuple(gemm_type, kernel_type, m, n, k, num_groups, major_a, major_b, - ab_dtype, cd_dtype, with_accumulation, num_sms); + mma_kind, a_dtype, b_dtype, cd_dtype, with_accumulation, num_sms); static std::set printed; if (printed.count(key) == 0) { printf("GEMM type: %d, kernel type: %d, M: %d, N: %d, K: %d, groups: %d, " - "A major: %d, B major: %d, AB dtype: %s, CD dtype: %s, accumulation: %d, " + "A major: %d, B major: %d, MMA kind: %d, A dtype: %s, B dtype: %s, CD dtype: %s, accumulation: %d, " "SM limit: %d -> block M: %d, block N: %d, block K: %d, stages: %d, last stages: %d, " "SMs: %d, multicast: %d, multicast on A: %d, shared memory: %d bytes, swizzle A: %d, " "swizzle B: %d, swizzle CD: %d, SMs: %d, threads: %d, TC util: %d%%\n", static_cast(gemm_type), static_cast(kernel_type), m, n, k, num_groups, - static_cast(major_a), static_cast(major_b), c10::toString(ab_dtype), c10::toString(cd_dtype), + static_cast(major_a), static_cast(major_b), static_cast(mma_kind), + c10::toString(a_dtype), c10::toString(b_dtype), c10::toString(cd_dtype), static_cast(with_accumulation), num_sms, best_block_m, best_block_n, block_k, best_num_stages, config.num_last_stages, num_min_sms, best_multicast_config.num_multicast, static_cast(best_multicast_config.is_multicast_on_a), diff --git a/csrc/jit_kernels/heuristics/sm100.hpp b/csrc/jit_kernels/heuristics/sm100.hpp index 0ac4cc28..dd1e6024 100644 --- a/csrc/jit_kernels/heuristics/sm100.hpp +++ b/csrc/jit_kernels/heuristics/sm100.hpp @@ -53,18 +53,18 @@ struct SM100ArchSpec { } static std::pair get_sf_uttcp_aligned_block_sizes( - const int& block_m, const int& block_n, const at::ScalarType& ab_dtype) { + const int& block_m, const int& block_n, const MmaKind& mma_kind) { constexpr int num_utccp_aligned_elems = 128; - switch (ab_dtype) { - case torch::kBFloat16: return {0, 0}; - case torch::kFloat8_e4m3fn: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; + switch (mma_kind) { + case MmaKind::BF16: return {0, 0}; + case MmaKind::MXFP8FP4: return {align(block_m, num_utccp_aligned_elems), align(block_n, num_utccp_aligned_elems)}; default: DG_HOST_UNREACHABLE("Unknown dtype"); } } static bool is_block_size_legal(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, const int& m, const int& n, const int& k, const int& block_m, const int& block_n, const int& block_k) { // Layout A/D does not support `block_n % 16 != 0` @@ -82,7 +82,7 @@ struct SM100ArchSpec { // Check tensor memory validity int sf_block_m = 0, sf_block_n = 0; if (kernel_type == KernelType::Kernel1D1D) { - const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype); + const auto& [sf_block_m_, sf_block_n_] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); sf_block_m = sf_block_m_, sf_block_n = sf_block_n_; } if (((2 * block_n) + (sf_block_m / 32) + (sf_block_n / 32)) > 512) @@ -90,19 +90,15 @@ struct SM100ArchSpec { // NOTES: when B is MN-major, we restrict `block_n` to multiples of 64, // since TMA performance degrades when `swizzle_b <= 32B` (i.e., when `block_ns % 64 != 0`), even with 3D TMA - return major_b == cute::UMMA::Major::K or (block_n * c10::elementSize(ab_dtype)) % 64 == 0; + return major_b == cute::UMMA::Major::K or (block_n * get_element_size(mma_kind)) % 64 == 0; } - static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype, const int& num_stages, const int& block_m, const int& block_n, const int& block_k) { return true; } - static bool should_minimize_num_sms() { - return true; - } - static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, const int& m, const int& n, const int& block_m, const int& block_n, const int& num_sms) { @@ -129,14 +125,14 @@ struct SM100ArchSpec { static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, const int& block_m, const int& block_n, const int& block_k, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) { - if (ab_dtype == torch::kBFloat16) + const MmaKind& mma_kind, const at::ScalarType& cd_dtype) { + if (mma_kind == MmaKind::BF16) return {0, 0}; int smem_sfa_per_stage = 0; int smem_sfb_per_stage = 0; if (kernel_type == KernelType::Kernel1D1D) { - const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, ab_dtype); + const auto [sf_block_m, sf_block_n] = get_sf_uttcp_aligned_block_sizes(block_m, block_n, mma_kind); smem_sfa_per_stage = sf_block_m * 4; smem_sfb_per_stage = sf_block_n * 4; } else { diff --git a/csrc/jit_kernels/heuristics/sm90.hpp b/csrc/jit_kernels/heuristics/sm90.hpp index 8a9b23a0..2fd2e9ec 100644 --- a/csrc/jit_kernels/heuristics/sm90.hpp +++ b/csrc/jit_kernels/heuristics/sm90.hpp @@ -60,7 +60,7 @@ struct SM90ArchSpec { static bool is_block_size_legal(const KernelType& kernel_type, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + const MmaKind& mma_kind, const at::ScalarType& cd_dtype, const int& m, const int& n, const int& k, const int& block_m, const int& block_n, const int& block_k) { // SM90 FP32 output does not support `block_m == 256` @@ -89,19 +89,15 @@ struct SM90ArchSpec { return block_m <= 128 or block_n <= 128; } - static bool is_num_stages_legal(const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype, + static bool is_num_stages_legal(const MmaKind& mma_kind, const at::ScalarType& cd_dtype, const int& num_stages, const int& block_m, const int& block_n, const int& block_k) { // Unrolling both stages and `num_former_iters` will cause large code size - if (ab_dtype == torch::kFloat8_e4m3fn and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4) + if (mma_kind == MmaKind::MXFP8FP4 and block_k % block_n != 0 and block_k / std::gcd(block_n, block_k) <= 4) return num_stages <= 4; return true; } - static bool should_minimize_num_sms() { - return true; - } - static std::pair get_multicast_legality(const GemmType& gemm_type, const int& num_groups, const int& m, const int& n, const int& block_m, const int& block_n, const int& num_sms) { @@ -134,8 +130,8 @@ struct SM90ArchSpec { static std::pair get_sf_smem_size_per_stage(const KernelType& kernel_type, const int& block_m, const int& block_n, const int& block_k, - const at::ScalarType& ab_dtype, const at::ScalarType& cd_dtype) { - if (ab_dtype == torch::kBFloat16) + const MmaKind& mma_kind, const at::ScalarType& cd_dtype) { + if (mma_kind == MmaKind::BF16) return {0, 0}; // NOTES: 128 is for 2D TMA alignment requirement diff --git a/csrc/jit_kernels/impls/runtime_utils.hpp b/csrc/jit_kernels/impls/runtime_utils.hpp index 8f8504d5..b245b94a 100644 --- a/csrc/jit_kernels/impls/runtime_utils.hpp +++ b/csrc/jit_kernels/impls/runtime_utils.hpp @@ -37,11 +37,12 @@ static std::string to_string(const cute::UMMA::Major& major) { static std::string to_string(const GemmType& type) { switch (type) { - case GemmType::Normal: return "GemmType::Normal"; - case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous"; - case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked"; - case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous"; - case GemmType::Batched: return "GemmType::Batched"; + case GemmType::Normal: return "GemmType::Normal"; + case GemmType::MGroupedContiguous: return "GemmType::MGroupedContiguous"; + case GemmType::MGroupedMasked: return "GemmType::MGroupedMasked"; + case GemmType::MGroupedContiguousWithPsumLayout: return "GemmType::MGroupedContiguousWithPsumLayout"; + case GemmType::KGroupedContiguous: return "GemmType::KGroupedContiguous"; + case GemmType::Batched: return "GemmType::Batched"; } DG_HOST_UNREACHABLE("Unknown GEMM type"); } @@ -51,6 +52,8 @@ static std::string to_string(const at::ScalarType& dtype) { case torch::kInt: return "int"; case torch::kFloat: return "float"; case torch::kBFloat16: return "cutlass::bfloat16_t"; + case torch::kFloat8_e4m3fn: return "cutlass::float_e4m3_t"; + case kPackedFP4: return "cutlass::detail::float_e2m1_unpacksmem_t"; default: DG_HOST_UNREACHABLE("Unsupported dtype"); } } @@ -65,6 +68,7 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType& case torch::kFloat: return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; case torch::kBFloat16: return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; case torch::kFloat8_e4m3fn: return CU_TENSOR_MAP_DATA_TYPE_UINT8; + case kPackedFP4: return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; default: DG_HOST_UNREACHABLE("Unsupported dtype"); } } @@ -98,6 +102,10 @@ static CUtensorMap make_tma_2d_desc(const torch::Tensor& t, if (swizzle_mode != 0) smem_inner_dim = swizzle_mode / elem_size; + // Inner dim must be a multiple of 64B for .b4x16_p64 + if (t.scalar_type() == kPackedFP4) + DG_HOST_ASSERT(gmem_inner_dim % 128 == 0); + CUtensorMap tensor_map; const cuuint64_t gmem_dims[2] = {static_cast(gmem_inner_dim), static_cast(gmem_outer_dim)}; const cuuint32_t smem_dims[2] = {static_cast(smem_inner_dim), static_cast(smem_outer_dim)}; @@ -126,6 +134,10 @@ static CUtensorMap make_tma_3d_desc(const torch::Tensor& t, if (swizzle_mode != 0) smem_dim_0 = swizzle_mode / elem_size; + // Inner dim must be a multiple of 64B for .b4x16_p64 + if (t.scalar_type() == kPackedFP4) + DG_HOST_ASSERT(gmem_dim_0 % 128 == 0); + CUtensorMap tensor_map; const cuuint64_t gmem_dims[3] = {static_cast(gmem_dim_0), static_cast(gmem_dim_1), static_cast(gmem_dim_2),}; const cuuint32_t smem_dims[3] = {static_cast(smem_dim_0), static_cast(smem_dim_1), static_cast(smem_dim_2)}; @@ -204,7 +216,7 @@ static CUtensorMap make_tma_cd_desc(const torch::Tensor& t, static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, const torch::Tensor& t, int shape_mn, int shape_k, - const int& block_mn, const int& block_k, + const int& block_mn, const int& gran_k, const int& num_groups, const int& swizzle_mode, const int& swizzle_base = 0, const bool& allow_tf32 = false) { @@ -215,7 +227,7 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major, shape_mn = get_tma_aligned_size(shape_mn, static_cast(t.element_size())); return make_tma_2d_desc(t, - shape_mn, ceil_div(shape_k, block_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, + shape_mn, ceil_div(shape_k, gran_k * (t.scalar_type() == torch::kFloat ? 1 : 4)) * num_groups, block_mn, 1, shape_mn, swizzle_mode, swizzle_base, diff --git a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp index 45810ed5..95f72729 100644 --- a/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm100_bf16_gemm.hpp @@ -79,11 +79,11 @@ static void sm100_bf16_gemm(const torch::Tensor& a, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - const auto& aligned_k = align(k, 64); const auto& config = get_best_config( GemmType::Normal, KernelType::KernelNoSF, m, n, k, 1, major_a, major_b, - torch::kBFloat16, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, @@ -104,7 +104,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a, // Launch const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = 1, .compiled_dims = compiled_dims, .gemm_config = config, @@ -124,16 +124,25 @@ static void sm100_bf16_gemm(const torch::Tensor& a, static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& d, - const torch::Tensor& m_indices, + const torch::Tensor& grouped_layout, const int& num_groups, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { - const auto& aligned_k = align(k, 64); + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m; + const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1; + const auto& config = get_best_config( - GemmType::MGroupedContiguous, KernelType::KernelNoSF, + gemm_type, KernelType::KernelNoSF, // NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole - m, n, k, 1, major_a, major_b, - torch::kBFloat16, d.scalar_type(), false, + m_for_config, n, k, num_groups_for_config, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, @@ -154,14 +163,14 @@ static void sm100_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, // Launch const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = num_groups, .compiled_dims = compiled_dims, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = m_indices.data_ptr(), + .grouped_layout = grouped_layout.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_cd = tensor_map_cd @@ -179,11 +188,11 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - const auto& aligned_k = align(k, 64); const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::KernelNoSF, expected_m, n, k, num_groups, major_a, major_b, - torch::kBFloat16, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); const auto& tensor_map_a = make_tma_a_desc(major_a, a, m, k, @@ -204,7 +213,7 @@ static void sm100_m_grouped_bf16_gemm_masked(const torch::Tensor& a, // Launch const SM100BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = num_groups, .compiled_dims = compiled_dims, .gemm_config = config, @@ -243,7 +252,8 @@ static void sm100_bf16_k_grouped_gemm(const torch::Tensor& a, const auto& config = get_best_config( GemmType::KGroupedContiguous, KernelType::KernelNoSF, m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, - torch::kBFloat16, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Create tensor descriptors @@ -290,7 +300,8 @@ static void sm100_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, const auto& config = get_best_config( GemmType::Batched, KernelType::KernelNoSF, b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K, - torch::kBFloat16, tensor_d.scalar_type(), false, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, device_runtime->get_num_sms()); const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); @@ -337,7 +348,8 @@ static void sm100_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, const auto& config = get_best_config( GemmType::Batched, KernelType::KernelNoSF, b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN, - torch::kBFloat16, tensor_d.scalar_type(), false, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, device_runtime->get_num_sms()); const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); diff --git a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp index 896c2485..07a977d7 100644 --- a/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp @@ -15,10 +15,11 @@ namespace deep_gemm { -class SM100FP8Gemm1D1DRuntime final: public LaunchRuntime { +class SM100FP8FP4Gemm1D1DRuntime final: public LaunchRuntime { public: struct Args { int m, n, k, num_groups; + int gran_k_a, gran_k_b; const std::string& compiled_dims; const std::optional& epilogue_type; @@ -41,6 +42,7 @@ using namespace deep_gemm; static void __instantiate_kernel() {{ auto ptr = reinterpret_cast(&sm100_fp8_gemm_1d1d_impl< + {}, {}, {}, {}, {}, {}, {}, {}, {}, {}, @@ -50,12 +52,14 @@ static void __instantiate_kernel() {{ {}, {}, {}, {}, {}, + {}, {}, {}, {}, {}, {} >); }}; )", to_string(args.gemm_config.major_a), to_string(args.gemm_config.major_b), + args.gran_k_a, args.gran_k_b, get_compiled_dim(args.m, 'm', args.compiled_dims), get_compiled_dim(args.n, 'n', args.compiled_dims), get_compiled_dim(args.k, 'k', args.compiled_dims), args.gemm_config.block_m, args.gemm_config.block_n, args.gemm_config.block_k, args.num_groups, @@ -64,7 +68,8 @@ static void __instantiate_kernel() {{ args.gemm_config.thread_config.num_non_epilogue_threads, args.gemm_config.thread_config.num_epilogue_threads, args.gemm_config.multicast_config.num_multicast, args.gemm_config.multicast_config.is_multicast_on_a, args.gemm_config.num_sms, - to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, to_string(args.gemm_config.cd_dtype), + to_string(args.gemm_config.gemm_type), args.gemm_config.with_accumulation, + to_string(args.gemm_config.a_dtype), to_string(args.gemm_config.b_dtype), to_string(args.gemm_config.cd_dtype), get_default_epilogue_type(args.epilogue_type)); } @@ -78,19 +83,20 @@ static void __instantiate_kernel() {{ } }; -static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const std::optional& c, - const torch::Tensor& d, - const int& m, const int& n, const int& k, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims, - const std::optional& epilogue_type = std::nullopt) { - const auto& aligned_k = align(k, 128); +static void sm100_fp8_fp4_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const std::optional& epilogue_type = std::nullopt) { const auto& config = get_best_config( GemmType::Normal, KernelType::Kernel1D1D, m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); const auto& cd = c.value_or(d); @@ -110,14 +116,16 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa static_cast(d.stride(-2)), 1, config.smem_config.swizzle_cd_mode); const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, 1, 0); + config.block_m, gran_k_a, 1, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, - config.block_n, config.block_k, 1, 0); + config.block_n, gran_k_b, 1, 0); // Launch - const SM100FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, .num_groups = 1, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, .compiled_dims = compiled_dims, .epilogue_type = epilogue_type, .gemm_config = config, @@ -131,24 +139,33 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_fp8_fp4_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); } -static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const torch::Tensor& d, - const torch::Tensor& m_indices, - const int& num_groups, const int& m, const int& n, const int& k, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { - const auto& aligned_k = align(k, 128); +static void sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& grouped_layout, + const int& num_groups, const int& m, const int& n, const int& k, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims, + const bool& use_psum_layout, + const std::optional& expected_m_for_psum_layout) { + const auto& gemm_type = use_psum_layout ? GemmType::MGroupedContiguousWithPsumLayout : GemmType::MGroupedContiguous; + + // NOTES: If actual M is dynamic, estimate config via `num_groups` and `expected_m`. + // Otherwise, treat the contiguous layout as a whole. + const auto& m_for_config = expected_m_for_psum_layout.has_value() ? expected_m_for_psum_layout.value() : m; + const auto& num_groups_for_config = expected_m_for_psum_layout.has_value() ? num_groups : 1; + const auto& config = get_best_config( - GemmType::MGroupedContiguous, KernelType::Kernel1D1D, - // NOTES: `num_groups` is 1, since the contiguous layout is seen as a whole - m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, + gemm_type, KernelType::Kernel1D1D, + m_for_config, n, k, num_groups_for_config, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Create tensor descriptors @@ -168,45 +185,48 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con static_cast(d.stride(-2)), 1, config.smem_config.swizzle_cd_mode); const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, 1, 0); + config.block_m, gran_k_a, 1, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, - config.block_n, config.block_k, num_groups, 0); + config.block_n, gran_k_b, num_groups, 0); // Launch kernel - const SM100FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, config.smem_config.smem_size, config.multicast_config.num_multicast), - .grouped_layout = m_indices.data_ptr(), + .grouped_layout = grouped_layout.data_ptr(), .tensor_map_a = tensor_map_a, .tensor_map_b = tensor_map_b, .tensor_map_sfa = tensor_map_sfa, .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_contiguous_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); } -static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const torch::Tensor& d, - const torch::Tensor& masked_m, - const int& num_groups, const int& m, const int& n, const int& k, - const int& expected_m, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { - const auto& aligned_k = align(k, 128); +static void sm100_m_grouped_fp8_fp4_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const torch::Tensor& d, + const torch::Tensor& masked_m, + const int& num_groups, const int& m, const int& n, const int& k, + const int& expected_m, + const int& gran_k_a, const int& gran_k_b, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::Kernel1D1D, expected_m, n, k, num_groups, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Create tensor descriptors @@ -226,14 +246,16 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t static_cast(d.stride(-2)), num_groups, config.smem_config.swizzle_cd_mode); const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, - config.block_m, config.block_k, num_groups, 0); + config.block_m, gran_k_a, num_groups, 0); const auto& tensor_map_sfb = make_tma_sf_desc(cute::UMMA::Major::MN, sfb, n, k, - config.block_n, config.block_k, num_groups, 0); + config.block_n, gran_k_b, num_groups, 0); // Launch kernel - const SM100FP8Gemm1D1DRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { + .m = m, .n = n, .k = k, .num_groups = num_groups, + .gran_k_a = gran_k_a, + .gran_k_b = gran_k_b, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -247,19 +269,19 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_m_grouped_fp8_fp4_gemm_masked_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); } -static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, - const torch::Tensor& b, const torch::Tensor& sfb, - const std::optional& c, - const torch::Tensor& d, - const int& m, const int& n, - const std::vector& ks, const torch::Tensor& ks_tensor, - const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, - const std::string& compiled_dims) { +static void sm100_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& m, const int& n, + const std::vector& ks, const torch::Tensor& ks_tensor, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, + const std::string& compiled_dims) { DG_HOST_ASSERT(major_a == cute::UMMA::Major::MN and major_b == cute::UMMA::Major::MN); int sum_k = 0, sum_sf_k = 0; @@ -274,7 +296,8 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& const auto& config = get_best_config( GemmType::KGroupedContiguous, KernelType::Kernel1D1D, m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Create tensor descriptors @@ -299,9 +322,11 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& config.block_n, config.block_k, 1, 0); // Launch kernel - const SM100FP8Gemm1D1DRuntime::Args& args = { + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = sum_k, .num_groups = num_groups, + .gran_k_a = 128, + .gran_k_b = 128, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -315,9 +340,9 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); - const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); + const auto& runtime = compiler->build("sm100_k_grouped_fp8_gemm_1d1d", code); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); } static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, @@ -330,7 +355,8 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, const auto& config = get_best_config( GemmType::Batched, KernelType::Kernel1D1D, m, n, k, batch_size, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); const int& load_block_m = SM100ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); @@ -364,9 +390,11 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, config.block_n, config.block_k, batch_size, 0); // Launch - const SM100FP8Gemm1D1DRuntime::Args& args = { + const SM100FP8FP4Gemm1D1DRuntime::Args& args = { .m = m, .n = n, .k = k, .num_groups = batch_size, + .gran_k_a = 128, + .gran_k_b = 128, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, .gemm_config = config, @@ -380,9 +408,9 @@ static void sm100_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, .tensor_map_sfb = tensor_map_sfb, .tensor_map_cd = tensor_map_cd }; - const auto& code = SM100FP8Gemm1D1DRuntime::generate(args); + const auto& code = SM100FP8FP4Gemm1D1DRuntime::generate(args); const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code); - SM100FP8Gemm1D1DRuntime::launch(runtime, args); + SM100FP8FP4Gemm1D1DRuntime::launch(runtime, args); } } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp b/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp new file mode 100644 index 00000000..4f3ce5b1 --- /dev/null +++ b/csrc/jit_kernels/impls/sm100_tf32_hc_prenorm_gemm.hpp @@ -0,0 +1,149 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm100.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM100BF16HCPrenormGemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k; + int block_m, block_n, block_k; + int num_splits; + int swizzle_cd_mode; + int num_stages; + int num_mma_threads, num_cast_and_reduce_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + float* sqr_sum; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm100_tf32_hc_prenorm_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.num_splits, + args.swizzle_cd_mode, + args.num_stages, + args.num_mma_threads, args.num_cast_and_reduce_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum)); + } +}; + +static void sm100_tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const int& m, const int& n, const int& k, + const int& num_splits) { + constexpr int block_m = 64; + constexpr int block_k = 64; + constexpr int num_mma_threads = 128; + constexpr int num_cast_and_reduce_threads = 128; + + const int block_n = align(n, 16); + DG_HOST_ASSERT(n <= block_n); + DG_HOST_ASSERT(n <= 128 and n % 8 == 0); + DG_HOST_ASSERT(k % block_k == 0); + + const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) + : make_tma_3d_desc(d, n, m, num_splits, + block_n, block_m, 1, + static_cast(d.stride(-2)), + static_cast(d.stride(-3)), + swizzle_cd_mode); + + // Calculate stages + int num_stages = 12, smem_size = 0; + while (num_stages > 0) { + const int smem_a_per_stage = block_m * block_k * static_cast(sizeof(nv_bfloat16)); + const int smem_b_per_stage = block_n * block_k * static_cast(sizeof(float)); + const int smem_cd = block_m * swizzle_cd_mode; + const int smem_barriers = (num_stages * 4 + 1) * 8; + const int smem_tmem_ptr = 4; + smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages + + smem_cd + smem_barriers + smem_tmem_ptr; + + if (smem_size <= SM100ArchSpec::smem_capacity) + break; + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split K: %d" + "stages: %d, shared memory: %d, swizzle CD: %d\n", + m, n, k, block_m, block_n, block_k, num_splits, + num_stages, smem_size, swizzle_cd_mode); + } + + // Launch + const SM100BF16HCPrenormGemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .num_splits = num_splits, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_mma_threads = num_mma_threads, + .num_cast_and_reduce_threads = num_cast_and_reduce_threads, + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_mma_threads + num_cast_and_reduce_threads, smem_size, 1), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .sqr_sum = sqr_sum.data_ptr() + }; + const auto& code = SM100BF16HCPrenormGemmRuntime::generate(args); + const auto& runtime = compiler->build("sm100_tf32_hc_prenorm_gemm", code); + SM100BF16HCPrenormGemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp index 97450193..32003f88 100644 --- a/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp +++ b/csrc/jit_kernels/impls/sm90_bf16_gemm.hpp @@ -79,13 +79,11 @@ static void sm90_bf16_gemm(const torch::Tensor& a, const int& m, const int& n, const int& k, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const std::string& compiled_dims) { - DG_HOST_ASSERT(not c.has_value()); - - const auto& aligned_k = align(k, 64); const auto& config = get_best_config( GemmType::Normal, KernelType::KernelNoSF, m, n, k, 1, major_a, major_b, - torch::kBFloat16, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Requires no TMA splits @@ -107,7 +105,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a, // Launch const SM90BF16GemmRuntime::Args& args = { - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = 1, .compiled_dims = compiled_dims, .gemm_config = config, @@ -138,7 +136,8 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a, const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::KernelNoSF, m, n, k, 1, major_a, major_b, - torch::kBFloat16, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Requires no TMA splits @@ -192,7 +191,8 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a, const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::KernelNoSF, expected_m, n, k, num_groups, major_a, major_b, - torch::kBFloat16, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Requires no TMA splits @@ -253,7 +253,8 @@ static void sm90_bf16_k_grouped_gemm(const torch::Tensor& a, const auto& config = get_best_config( GemmType::KGroupedContiguous, KernelType::KernelNoSF, m, n, max_k, num_groups, cute::UMMA::Major::MN, cute::UMMA::Major::MN, - torch::kBFloat16, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Create tensor descriptors @@ -300,7 +301,8 @@ static void sm90_bf16_bhr_hdr_bhd(const torch::Tensor& tensor_a, const auto& config = get_best_config( GemmType::Batched, KernelType::KernelNoSF, b, d, r, h, cute::UMMA::Major::K, cute::UMMA::Major::K, - torch::kBFloat16, tensor_d.scalar_type(), false, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, device_runtime->get_num_sms()); const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); @@ -346,7 +348,8 @@ static void sm90_bf16_bhd_hdr_bhr(const torch::Tensor& tensor_a, const auto& config = get_best_config( GemmType::Batched, KernelType::KernelNoSF, b, r, d, h, cute::UMMA::Major::K, cute::UMMA::Major::MN, - torch::kBFloat16, tensor_d.scalar_type(), false, + tensor_a.scalar_type(), tensor_b.scalar_type(), + tensor_d.scalar_type(), false, device_runtime->get_num_sms()); const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp index ec2b9b97..e61841b3 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d1d.hpp @@ -88,7 +88,8 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const auto& config = get_best_config( GemmType::Normal, KernelType::Kernel1D1D, m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Requires no TMA splits @@ -138,7 +139,7 @@ static void sm90_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, SM90FP8Gemm1D1DRuntime::launch(runtime, args); } -static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, +static void sm90_k_grouped_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa, const torch::Tensor& b, const torch::Tensor& sfb, const std::optional& c, const torch::Tensor& d, @@ -156,7 +157,8 @@ static void sm90_fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Te const auto& config = get_best_config( GemmType::KGroupedContiguous, KernelType::Kernel1D1D, m, n, max_k, num_groups, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Requires no TMA splits diff --git a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp index a1acba50..2696b5a0 100644 --- a/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp +++ b/csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp @@ -87,11 +87,11 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, DG_HOST_ASSERT(not c.has_value() and d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::Normal, KernelType::Kernel1D2D, m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), c.has_value(), + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), device_runtime->get_num_sms()); // Requires no TMA splits @@ -118,7 +118,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa, // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { .major_sfb = major_sfb, - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = 1, .compiled_dims = compiled_dims, .epilogue_type = epilogue_type, @@ -148,11 +148,11 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); - const auto& aligned_k = align(k, 128); const auto& config = get_best_config( GemmType::MGroupedContiguous, KernelType::Kernel1D2D, m, n, k, 1, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Requires no TMA splits @@ -179,7 +179,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { .major_sfb = major_sfb, - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = num_groups, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, @@ -207,14 +207,14 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to const int& expected_m, const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, const std::string& compiled_dims) { - const auto& aligned_k = align(k, 128); DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); const auto& config = get_best_config( GemmType::MGroupedMasked, KernelType::Kernel1D2D, expected_m, n, k, num_groups, major_a, major_b, - torch::kFloat8_e4m3fn, d.scalar_type(), false, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), false, device_runtime->get_num_sms()); // Requires no TMA splits @@ -241,7 +241,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to // Launch const SM90FP8Gemm1D2DRuntime::Args& args = { .major_sfb = major_sfb, - .m = m, .n = n, .k = aligned_k, + .m = m, .n = n, .k = k, .num_groups = num_groups, .compiled_dims = compiled_dims, .epilogue_type = std::nullopt, @@ -261,4 +261,71 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to SM90FP8Gemm1D2DRuntime::launch(runtime, args); } +static void sm90_fp8_bmm(const torch::Tensor& a, const torch::Tensor& sfa, + const torch::Tensor& b, const torch::Tensor& sfb, + const std::optional& c, + const torch::Tensor& d, + const int& batch_size, const int& m, const int& n, const int& k, + const cute::UMMA::Major& major_a, const cute::UMMA::Major& major_b, const cute::UMMA::Major& major_sfb, + const std::string& compiled_dims) { + DG_HOST_ASSERT(d.scalar_type() == torch::kBFloat16); + DG_HOST_ASSERT(major_a == cute::UMMA::Major::K and major_b == cute::UMMA::Major::K); + + const auto& config = get_best_config( + GemmType::Batched, KernelType::Kernel1D2D, + m, n, k, batch_size, major_a, major_b, + a.scalar_type(), b.scalar_type(), + d.scalar_type(), c.has_value(), + device_runtime->get_num_sms()); + + // Requires no TMA splits + DG_HOST_ASSERT(config.smem_config.swizzle_a_mode == config.block_k); + DG_HOST_ASSERT(config.smem_config.swizzle_b_mode == config.block_k); + const int& load_block_m = SM90ArchSpec::get_ab_load_block_m(config.multicast_config, config.block_m); + const auto& tensor_map_a = make_tma_3d_desc(a, k, m, batch_size, + config.block_k, load_block_m, 1, + a.stride(1), + a.stride(0), + config.smem_config.swizzle_a_mode); + + const int& load_block_n = SM90ArchSpec::get_ab_load_block_n(config.multicast_config, config.block_n); + const auto& tensor_map_b = make_tma_3d_desc(b, k, n, batch_size, + config.block_k, load_block_n, 1, + b.stride(1), + b.stride(0), + config.smem_config.swizzle_b_mode); + + const int& store_block_m = SM90ArchSpec::get_cd_store_block_m(config.block_m); + const int& store_block_n = SM90ArchSpec::get_cd_store_block_n(config.block_n); + const auto& tensor_map_d = make_tma_3d_desc(d, n, m, batch_size, + store_block_n, store_block_m, 1, + d.stride(1), d.stride(0), + config.smem_config.swizzle_cd_mode); + + const auto& tensor_map_sfa = make_tma_sf_desc(cute::UMMA::Major::MN, sfa, m, k, + config.block_m, config.block_k, batch_size, 0); + + // Launch + const SM90FP8Gemm1D2DRuntime::Args& args = { + .major_sfb = major_sfb, + .m = m, .n = n, .k = k, + .num_groups = batch_size, + .compiled_dims = compiled_dims, + .epilogue_type = std::nullopt, + .gemm_config = config, + .launch_args = LaunchArgs(config.num_sms, config.thread_config.num_threads, + config.smem_config.smem_size, + config.multicast_config.num_multicast), + .sfb = sfb.data_ptr(), + .grouped_layout = nullptr, + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .tensor_map_sfa = tensor_map_sfa, + }; + const auto& code = SM90FP8Gemm1D2DRuntime::generate(args); + const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code); + SM90FP8Gemm1D2DRuntime::launch(runtime, args); +} + } // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp b/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp new file mode 100644 index 00000000..aeea2623 --- /dev/null +++ b/csrc/jit_kernels/impls/sm90_tf32_hc_prenorm_gemm.hpp @@ -0,0 +1,152 @@ +#pragma once + +#include + +#include "../../jit/compiler.hpp" +#include "../../jit/device_runtime.hpp" +#include "../../jit/kernel_runtime.hpp" +#include "../../utils/exception.hpp" +#include "../../utils/format.hpp" +#include "../../utils/math.hpp" +#include "../heuristics/sm90.hpp" +#include "runtime_utils.hpp" + +namespace deep_gemm { + +class SM90BF16HCPrenormGemmRuntime final: public LaunchRuntime { +public: + struct Args { + int m, n, k; + int block_m, block_n, block_k; + int num_splits; + int swizzle_cd_mode; + int num_stages; + int num_math_threads, num_tma_threads; + + LaunchArgs launch_args; + + CUtensorMap tensor_map_a; + CUtensorMap tensor_map_b; + CUtensorMap tensor_map_d; + float* sqr_sum; + }; + + static std::string generate_impl(const Args& args) { + return fmt::format(R"( +#include + +using namespace deep_gemm; + +static void __instantiate_kernel() {{ + auto ptr = reinterpret_cast(&sm90_tf32_hc_prenorm_gemm_impl< + {}, {}, + {}, {}, {}, + {}, + {}, + {}, + {}, {} + >); +}}; +)", + args.n, args.k, + args.block_m, args.block_n, args.block_k, + args.num_splits, + args.swizzle_cd_mode, + args.num_stages, + args.num_math_threads, args.num_tma_threads); + } + + static void launch_impl(const KernelHandle& kernel, const LaunchConfigHandle& config, Args args) { + // TODO: optimize `args` copy + DG_CUDA_UNIFIED_CHECK(launch_kernel(kernel, config, + args.m, args.tensor_map_a, args.tensor_map_b, args.tensor_map_d, args.sqr_sum)); + } +}; + +static void sm90_tf32_hc_prenorm_gemm(const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& d, + const torch::Tensor& sqr_sum, + const int& m, const int& n, const int& k, + const int& num_splits) { + constexpr int block_m = 64; + constexpr int block_k = 64; + constexpr int num_math_threads = 128; + constexpr int num_tma_threads = 128; + constexpr int num_threads = num_math_threads + num_tma_threads; + + const int block_n = align(n, 16); + DG_HOST_ASSERT(n <= block_n); + // Only support small N for now + DG_HOST_ASSERT(n <= 32 and n % 8 == 0); + DG_HOST_ASSERT(k % block_k == 0); + + const auto& swizzle_cd_mode = get_swizzle_mode(block_n, sizeof(float)); + const auto& tensor_map_a = make_tma_a_desc(cute::UMMA::Major::K, a, m, k, + block_m, block_k, + static_cast(a.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, a.element_size()), 0, + true); + const auto& tensor_map_b = make_tma_b_desc(cute::UMMA::Major::K, b, n, k, + block_n, block_k, + static_cast(b.stride(get_non_contiguous_dim(cute::UMMA::Major::K))), 1, + get_swizzle_mode(block_k, b.element_size()), 0, + true); + const auto& tensor_map_d = num_splits == 1 ? make_tma_cd_desc(d, m, n, + block_m, block_n, + static_cast(d.stride(-2)), 1, + swizzle_cd_mode) + : make_tma_3d_desc(d, n, m, num_splits, + block_n, block_m, 1, + static_cast(d.stride(-2)), + static_cast(d.stride(-3)), + swizzle_cd_mode); + + // Calculate stages + int num_stages = 12, smem_size = 0; + while (num_stages > 0) { + const int smem_a_per_stage = block_m * block_k * static_cast(sizeof(nv_bfloat16)); + const int smem_b_per_stage = block_n * block_k * static_cast(sizeof(float)); + const int smem_cd = block_m * swizzle_cd_mode; + const int smem_barriers = num_stages * 2 * 8; + smem_size = (smem_a_per_stage + smem_b_per_stage) * num_stages + + smem_cd + smem_barriers; + + if (smem_size <= SM90ArchSpec::smem_capacity) + break; + -- num_stages; + } + DG_HOST_ASSERT(num_stages > 0); + + // Print configs + if (get_env("DG_JIT_DEBUG", 0)) { + printf("M: %d, N: %d, K: %d -> " + "block M: %d, block N: %d, block K: %d, split K: %d" + "stages: %d, shared memory: %d, swizzle CD: %d\n", + m, n, k, block_m, block_n, block_k, num_splits, + num_stages, smem_size, swizzle_cd_mode); + } + + smem_size = SM90ArchSpec::smem_capacity; + + // Launch + const SM90BF16HCPrenormGemmRuntime::Args& args = { + .m = m, .n = n, .k = k, + .block_m = block_m, .block_n = block_n, .block_k = block_k, + .num_splits = num_splits, + .swizzle_cd_mode = swizzle_cd_mode, + .num_stages = num_stages, + .num_math_threads = num_math_threads, + .num_tma_threads = num_tma_threads, + .launch_args = LaunchArgs(num_splits * ceil_div(m, block_m), num_threads, smem_size, 1), + .tensor_map_a = tensor_map_a, + .tensor_map_b = tensor_map_b, + .tensor_map_d = tensor_map_d, + .sqr_sum = sqr_sum.data_ptr() + }; + const auto& code = SM90BF16HCPrenormGemmRuntime::generate(args); + const auto& runtime = compiler->build("sm90_tf32_hc_prenorm_gemm", code); + SM90BF16HCPrenormGemmRuntime::launch(runtime, args); +} + +} // namespace deep_gemm diff --git a/csrc/jit_kernels/impls/smxx_cublaslt.hpp b/csrc/jit_kernels/impls/smxx_cublaslt.hpp index 7641dddd..dc20e334 100644 --- a/csrc/jit_kernels/impls/smxx_cublaslt.hpp +++ b/csrc/jit_kernels/impls/smxx_cublaslt.hpp @@ -37,7 +37,6 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a, const bool& accumulate) { cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; cudaDataType_t scale_type = CUDA_R_32F; - const int& math_sms = device_runtime->get_num_sms(); // Operation description cublasLtMatmulDesc_t desc; @@ -45,9 +44,13 @@ static void call_cublaslt_api(const cublasOperation_t& trans_a, DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a))); DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b))); DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + +#if DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE + const int& math_sms = device_runtime->get_num_sms(); DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sms, sizeof(math_sms))); - -#if DG_FP8_COMPATIBLE +#endif + +#if DG_FP8_COMPATIBLE and DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE bool fp8_fast_accumulate = false; if (a.scalar_type() == torch::kFloat8_e4m3fn) DG_CUBLASLT_CHECK(cublasLtMatmulDescSetAttribute(desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fp8_fast_accumulate, sizeof(fp8_fast_accumulate))); diff --git a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp index 82288994..1240aad8 100644 --- a/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp +++ b/csrc/jit_kernels/impls/smxx_fp8_paged_mqa_logits.hpp @@ -174,13 +174,13 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, const int& logits_stride, const int& block_table_stride, const int& num_sms, - const int& num_math_warp_groups) { + const int& split_kv) { const int num_specialized_threads = 128; + const int mma_m = (device_runtime->get_arch_major() == 10 ? 128 : 64); + const int num_math_warp_groups = split_kv / mma_m; const int num_math_threads = num_math_warp_groups * 128; - const int num_extra_threads = device_runtime->get_arch_major() == 10 ? 128 : 0; - const int num_q_stages = 3, num_kv_stages = 3; - const int split_kv = num_math_warp_groups * block_kv; - DG_HOST_ASSERT(logits_stride % (num_math_warp_groups * block_kv) == 0); + const int num_q_stages = 3, num_kv_stages = (device_runtime->get_arch_major() == 10 ? 4 : 3); + DG_HOST_ASSERT(split_kv % mma_m == 0 and logits_stride % split_kv == 0); // Construct TMAs DG_HOST_ASSERT(head_dim == 32 or head_dim == 64 or head_dim == 128); @@ -196,23 +196,39 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, next_n * num_heads, 1, next_n * num_heads, 0); // Calculate shared memory size - const int swizzle_alignment = head_dim * 8; + int smem_size = 0; + if (device_runtime->get_arch_major() == 9) { + const int swizzle_alignment = head_dim * 8; - const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); - const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); - const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int aligned_smem_weight_size_per_stage = align(next_n * num_heads * static_cast(weights.element_size()), swizzle_alignment); + const int smem_q_pipe_size = num_q_stages * (smem_q_size_per_stage + aligned_smem_weight_size_per_stage) + align(num_q_stages * 8 * 2, swizzle_alignment); - const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); - const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); - const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); + const int smem_kv_size_per_stage = block_kv * head_dim * static_cast(kv_cache.element_size()); + const int aligned_smem_kv_scale_size_per_stage = align(block_kv * static_cast(kv_cache_scales.element_size()), swizzle_alignment); + const int smem_kv_pipe_size = num_kv_stages * (smem_kv_size_per_stage + aligned_smem_kv_scale_size_per_stage) + align(num_kv_stages * 8 * 2, swizzle_alignment); - // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 - const int smem_umma_barriers = num_math_warp_groups * 2 * 8; - const int smem_tmem_ptr = 4; + // Allocate some shared memory for UMMA barriers and tensor memory pointer, although it is not used in SM90 + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; - const int smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; - DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); - DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + smem_size = smem_q_pipe_size + num_math_warp_groups * smem_kv_pipe_size + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM90ArchSpec::smem_capacity); + } else { + const int smem_q_size_per_stage = next_n * num_heads * head_dim * static_cast(q.element_size()); + const int smem_kv_size_per_stage = split_kv * head_dim * static_cast(kv_cache.element_size()); + const int smem_kv_scale_size_per_stage = split_kv * static_cast(kv_cache_scales.element_size()); + const int smem_weight_size_per_stage = next_n * num_heads * static_cast(weights.element_size()); + + const int smem_barriers = (num_q_stages + num_kv_stages) * 2 * 8; + const int smem_umma_barriers = num_math_warp_groups * 2 * 8; + const int smem_tmem_ptr = 4; + + smem_size = num_q_stages * (smem_q_size_per_stage + smem_weight_size_per_stage) + + num_kv_stages * (smem_kv_size_per_stage + smem_kv_scale_size_per_stage) + + smem_barriers + smem_umma_barriers + smem_tmem_ptr; + DG_HOST_ASSERT(smem_size <= SM100ArchSpec::smem_capacity); + } // Launch const SMXXFP8PagedMQALogitsRuntime::Args& args = { @@ -238,7 +254,7 @@ static void smxx_fp8_paged_mqa_logits(const torch::Tensor& q, .num_specialized_threads = num_specialized_threads, .num_math_threads = num_math_threads, .launch_args = LaunchArgs(num_sms, - num_specialized_threads + num_math_threads + num_extra_threads, + num_specialized_threads + num_math_threads, smem_size) }; const auto& code = SMXXFP8PagedMQALogitsRuntime::generate(args); diff --git a/csrc/python_api.cpp b/csrc/python_api.cpp index c2201887..0354f1f8 100644 --- a/csrc/python_api.cpp +++ b/csrc/python_api.cpp @@ -3,6 +3,7 @@ #include "apis/attention.hpp" #include "apis/einsum.hpp" +#include "apis/hyperconnection.hpp" #include "apis/gemm.hpp" #include "apis/layout.hpp" #include "apis/runtime.hpp" @@ -15,8 +16,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.doc() = "DeepGEMM C++ library"; + // TODO: make SM80 incompatible issues raise errors deep_gemm::attention::register_apis(m); deep_gemm::einsum::register_apis(m); + deep_gemm::hyperconnection::register_apis(m); deep_gemm::gemm::register_apis(m); deep_gemm::layout::register_apis(m); deep_gemm::runtime::register_apis(m); diff --git a/csrc/utils/compatibility.hpp b/csrc/utils/compatibility.hpp index fb45e3d8..9e2d6720 100644 --- a/csrc/utils/compatibility.hpp +++ b/csrc/utils/compatibility.hpp @@ -2,9 +2,16 @@ #include #include +#include // `torch::kFloat8_e4m3fn` is supported since PyTorch 2.1 #define DG_FP8_COMPATIBLE (TORCH_VERSION_MAJOR > 2 or (TORCH_VERSION_MAJOR == 2 and TORCH_VERSION_MINOR >= 1)) // `cuTensorMapEncodeTiled` is supported since CUDA Driver API 12.1 -#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010) \ No newline at end of file +#define DG_TENSORMAP_COMPATIBLE (CUDA_VERSION >= 12010) + +// `cublasGetErrorString` is supported since CUDA Runtime API 11.4.2 +#define DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE (CUDART_VERSION >= 11042) + +// `CUBLASLT_MATMUL_DESC_FAST_ACCUM` and `CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET` are supported since CUDA Runtime API 11.8 +#define DG_CUBLASLT_ADVANCED_FEATURES_COMPATIBLE (CUDART_VERSION >= 11080) \ No newline at end of file diff --git a/csrc/utils/exception.hpp b/csrc/utils/exception.hpp index 9beb0da3..2aa27066 100644 --- a/csrc/utils/exception.hpp +++ b/csrc/utils/exception.hpp @@ -5,6 +5,8 @@ #include #include +#include "compatibility.hpp" + namespace deep_gemm { class DGException final : public std::exception { @@ -74,6 +76,25 @@ do { \ #endif #ifndef DG_CUBLASLT_CHECK + +#if !DG_CUBLAS_GET_ERROR_STRING_COMPATIBLE +inline const char* cublasGetStatusString(cublasStatus_t status) { + switch(status) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + default: return "Unknown cuBLAS error"; + } +} +#endif + #define DG_CUBLASLT_CHECK(cmd) \ do { \ const auto& e = (cmd); \ diff --git a/csrc/utils/layout.hpp b/csrc/utils/layout.hpp index 8d4d00b2..d67cfcfb 100644 --- a/csrc/utils/layout.hpp +++ b/csrc/utils/layout.hpp @@ -36,15 +36,34 @@ static bool fp8_requires_k_major() { // Tensor utils template static auto get_shape(const torch::Tensor& t) { + DG_HOST_ASSERT(t.dim() == N); return [&t] (std::index_sequence) { return std::make_tuple(static_cast(t.sizes()[Is])...); }(std::make_index_sequence()); } +static std::tuple check_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { + auto [mn, k] = get_shape<2>(ab); + if (ab.scalar_type() != torch::kFloat8_e4m3fn) { + DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); + major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); + } + return std::make_tuple(mn, k); +} + +static std::tuple check_grouped_ab_fp8_fp4(const torch::Tensor& ab, const cute::UMMA::Major& major, const int& arch_major) { + auto [num_groups, mn, k] = get_shape<3>(ab); + if (ab.scalar_type() != torch::kFloat8_e4m3fn) { + DG_HOST_ASSERT(ab.scalar_type() == kPackedFP4 and arch_major == 10); + major == cute::UMMA::Major::K ? (k *= 2) : (mn *= 2); + } + return std::make_tuple(num_groups, mn, k); +} + // Recipe static std::tuple get_default_recipe(const torch::ScalarType& sfa_dtype, const torch::ScalarType& sfb_dtype) { - const auto& arch_major = device_runtime->get_arch_major(); + const auto arch_major = device_runtime->get_arch_major(); if (arch_major == 9) { DG_HOST_ASSERT(sfa_dtype == torch::kFloat and sfb_dtype == torch::kFloat); return {1, 128, 128}; @@ -70,7 +89,7 @@ static torch::Tensor check_sf_layout(const torch::Tensor& sf, DG_HOST_ASSERT(sf.scalar_type() == type_check.value()); // Always do shape checks - const auto& sf_dtype = sf.scalar_type(); + const auto sf_dtype = sf.scalar_type(); DG_HOST_ASSERT(sf_dtype == torch::kFloat or sf_dtype == torch::kInt); DG_HOST_ASSERT(sf.dim() == static_cast(num_groups.has_value()) + 2); if (num_groups.has_value()) diff --git a/csrc/utils/math.hpp b/csrc/utils/math.hpp index 264d2d10..9ece5a3b 100644 --- a/csrc/utils/math.hpp +++ b/csrc/utils/math.hpp @@ -6,6 +6,9 @@ namespace deep_gemm { +// TODO: Use `torch::kFloat4_e2m1fn_x2` +constexpr auto kPackedFP4 = torch::kUInt8; + template static T ceil_div(const T& a, const T& b) { return (a + b - 1) / b; diff --git a/deep_gemm/__init__.py b/deep_gemm/__init__.py index 9f18383f..1c07f5d9 100644 --- a/deep_gemm/__init__.py +++ b/deep_gemm/__init__.py @@ -1,8 +1,6 @@ import os import subprocess import torch -from torch.version import cuda as cuda_version -from packaging import version # Set some default environment provided at setup try: @@ -29,9 +27,15 @@ cublaslt_gemm_tn, cublaslt_gemm_tt, ) -if version.parse(cuda_version) >= version.parse('12.1'): +try: # DeepGEMM Kernels from ._C import ( + # FP8 FP4 GEMMs + fp8_fp4_gemm_nt, fp8_fp4_gemm_nn, + fp8_fp4_gemm_tn, fp8_fp4_gemm_tt, + m_grouped_fp8_fp4_gemm_nt_contiguous, + m_grouped_fp8_fp4_gemm_nn_contiguous, + m_grouped_fp8_fp4_gemm_nt_masked, # FP8 GEMMs fp8_gemm_nt, fp8_gemm_nn, fp8_gemm_tn, fp8_gemm_tt, @@ -55,6 +59,8 @@ fp8_mqa_logits, get_paged_mqa_logits_metadata, fp8_paged_mqa_logits, + # Hyperconnection kernels + tf32_hc_prenorm_gemm, # Layout kernels transform_sf_into_required_layout, get_mk_alignment_for_contiguous_layout @@ -64,6 +70,9 @@ # TODO: remove these later fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_gemm_nt_masked bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass # Some utils from . import testing @@ -71,7 +80,10 @@ from .utils import * # Legacy Triton kernels for A100 -from . import legacy +try: + from . import legacy +except Exception as e: + print(f'Failed to load legacy DeepGEMM A100 Triton kernels: {e}') # Initialize CPP modules def _find_cuda_home() -> str: @@ -97,4 +109,4 @@ def _find_cuda_home() -> str: _find_cuda_home() # CUDA home ) -__version__ = '2.2.0' +__version__ = '2.3.0' diff --git a/deep_gemm/include/deep_gemm/common/scheduler.cuh b/deep_gemm/include/deep_gemm/common/scheduler.cuh index 88101940..f93b96ee 100644 --- a/deep_gemm/include/deep_gemm/common/scheduler.cuh +++ b/deep_gemm/include/deep_gemm/common/scheduler.cuh @@ -51,6 +51,8 @@ struct Scheduler { uint32_t current_group_idx = 0; // Only used for masked layout uint32_t current_m_cumsum = 0; + // Only used for countiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; // Only used for k-grouped layout uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; uint32_t next_group_idx, next_shape_k; @@ -72,12 +74,16 @@ struct Scheduler { current_shape_k = shape_k; if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { num_blocks = num_m_blocks * num_n_blocks; - } else if (kGemmType == GemmType::MGroupedContiguous) { + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { num_blocks = num_m_blocks * num_n_blocks; this->grouped_layout = grouped_layout; - } else if (kGemmType == GemmType::MGroupedMasked) { + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { this->grouped_layout = grouped_layout; - } else if (kGemmType == GemmType::KGroupedContiguous) { + current_psum_m = __ldg(grouped_layout); + num_m_blocks = ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { this->grouped_layout = grouped_layout; get_next_k_group(current_group_idx, current_shape_k); next_group_idx = current_group_idx + 1; @@ -131,7 +137,7 @@ struct Scheduler { } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; return offset * shape_dim + block_idx * block_size; - } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { const auto offset = kWithGroupOffset ? current_group_idx : 0; return offset * shape_dim + block_idx * block_size; } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { @@ -172,6 +178,28 @@ struct Scheduler { } get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = align(current_psum_m, 128u); + current_psum_m = __ldg(grouped_layout + current_group_idx); + current_m_block_cumsum += num_m_blocks; + num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with 128 + m_block_idx += last_psum_m / BLOCK_M; + DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M"); } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { while (true) { // End of the task @@ -248,6 +276,9 @@ struct Scheduler { return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; } else if constexpr (kGemmType == GemmType::MGroupedMasked) { return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); } } }; diff --git a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh index b48b0518..537cbe08 100644 --- a/deep_gemm/include/deep_gemm/common/sm100_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm100_utils.cuh @@ -97,7 +97,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id const auto& layout_type = to_umma_layout_type(); const auto& num_non_contiguous = 128 / get_atom_base(layout_type); if constexpr (kMajorMode == cute::UMMA::Major::K) { - // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); // Atom size: 8 x `kSwizzleMode` (in bytes, on K) @@ -131,8 +132,8 @@ cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_id } __device__ __forceinline__ -uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sf_id) { - desc.a_sf_id_ = sf_id, desc.b_sf_id_ = sf_id; +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; return static_cast(static_cast(desc)) << 32; } @@ -154,6 +155,20 @@ __device__ __forceinline__ void tcgen05_after_thread_sync() { asm volatile("tcgen05.fence::after_thread_sync;"); } +__device__ __forceinline__ +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + // UMMA versions with relaxed assertions struct SM100_MMA_F16BF16_SS { __device__ static void @@ -231,4 +246,21 @@ struct SM100_MMA_MXF8F6F4_2x1SM_SS { } }; +struct SM100_MMA_F16BF16_WS_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + } // namespace `deep_gemm::sm100` diff --git a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh index 14648977..0874b675 100644 --- a/deep_gemm/include/deep_gemm/common/sm90_utils.cuh +++ b/deep_gemm/include/deep_gemm/common/sm90_utils.cuh @@ -152,6 +152,51 @@ struct BF16MMASelector { using type = decltype(select_type()); }; +template +struct TF32MMARS { + + template + __forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; template struct SM90_U32x2_STSM_N { diff --git a/deep_gemm/include/deep_gemm/common/types.hpp b/deep_gemm/include/deep_gemm/common/types.hpp index 2f35c50c..410c5469 100644 --- a/deep_gemm/include/deep_gemm/common/types.hpp +++ b/deep_gemm/include/deep_gemm/common/types.hpp @@ -2,14 +2,36 @@ namespace deep_gemm { +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + enum class GemmType { - Normal = 0, - MGroupedContiguous = 1, - MGroupedMasked = 2, - KGroupedContiguous = 3, - Batched = 4 + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, }; +constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + enum class KernelType { Kernel1D1D = 0, Kernel1D2D = 1, diff --git a/deep_gemm/include/deep_gemm/common/utils.cuh b/deep_gemm/include/deep_gemm/common/utils.cuh index 171a3d3d..8fb6c2fc 100644 --- a/deep_gemm/include/deep_gemm/common/utils.cuh +++ b/deep_gemm/include/deep_gemm/common/utils.cuh @@ -148,6 +148,10 @@ __device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); } +__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + template __device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); diff --git a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh index 293ec94e..0227b3e8 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -388,7 +388,7 @@ sm100_bf16_gemm_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; // Store into shared memory diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh index da7f461c..45a603ad 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -14,6 +14,7 @@ namespace deep_gemm { using namespace deep_gemm::sm100; template __global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) sm100_fp8_gemm_1d1d_impl(int* grouped_layout, @@ -45,16 +47,21 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; constexpr uint32_t kNumTMAStoreStages = 2; - constexpr uint32_t kNumSFStagesPerLoad = sizeof(uint32_t) / sizeof(cutlass::float_ue8m0_t); constexpr uint32_t kNumUTCCPAlignedElems = 128; DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + // Overwrite shape constants if the compiler gives shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; - const uint32_t shape_sf_k = ceil_div(shape_k, BLOCK_K * kNumSFStagesPerLoad); + const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4); + const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4); // Utils bool is_leader_cta = cute::block_rank_in_cluster() == 0; @@ -78,8 +85,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Share memory sizes constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; - constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); - constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); @@ -89,7 +96,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); // NOTES: Make sure we have enough shared memory for UMMA padding - static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size @@ -118,10 +125,10 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); }); auto smem_a = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); }); auto smem_b = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); }); // SFA/SFB shared memory @@ -225,28 +232,31 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); if constexpr (kMajorA == cute::UMMA::Major::K) - tma_copy( + tma_copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); if constexpr (kMajorA == cute::UMMA::Major::MN) - tma_copy( + tma_copy( &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::K) - tma_copy( + tma_copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); if constexpr (kMajorB == cute::UMMA::Major::MN) - tma_copy( + tma_copy( &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); - auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); // Issue SFA and SFB TMAs at certain stages // No swizzling, so one TMA for one SF is enough - const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0) { + if (k_block_idx % kNumSFAStagesPerLoad == 0) { tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, - scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::SF_K>(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad))); + scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad))); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, - scheduler.template get_global_idx(shape_sf_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFStagesPerLoad), m_block_idx)); - num_arrival_bytes += (BLOCK_M + BLOCK_N) * sizeof(uint32_t); + scheduler.template get_global_idx(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx)); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); } // Arrive at full barriers @@ -260,9 +270,8 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // TODO: refactor `UMMA_M` calculation constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); - constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); - auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); auto sf_desc = make_sf_desc(nullptr); @@ -313,19 +322,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Do SF copy at certain stages // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves - const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0 and cute::elect_one_sync()) { - using cute_utccp_t = cute::conditional_t; - - // SFA and SFB copy - // TODO: process shared memory descriptor by addition + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; replace_smem_desc_addr(sf_desc, smem_ptr); cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; @@ -337,17 +347,20 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, // Issue UMMA in the leader CTA using mma_t = cute::conditional_t; - const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sf_stage_in_group_idx); const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); if (cute::elect_one_sync()) { #pragma unroll for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { - b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); #pragma unroll for (uint32_t w = 0; w < kNumMWaves; ++ w) { DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); - a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K); mma_t::fma(a_desc, b_desc, accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, k_block_idx > 0 or k > 0, @@ -391,11 +404,14 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, full_barriers[stage_idx]->wait(phase); // Transpose for UTCCP at certain stages - const uint32_t sf_stage_in_group_idx = k_block_idx % kNumSFStagesPerLoad; - if (sf_stage_in_group_idx == 0) { + if (k_block_idx % kNumSFAStagesPerLoad == 0) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { #pragma unroll for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); @@ -454,7 +470,7 @@ sm100_fp8_gemm_1d1d_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); // The pipeline stage - const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); // Store into shared memory diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh index 96d61732..c51f3600 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -143,7 +143,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const auto& get_next_block_q_idx = [&]() -> cute::tuple { return {block_q_idx + gridDim.x, q_iter_idx + 1}; }; - uint32_t seq_k_start[BLOCK_Q]; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); @@ -152,8 +152,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, for (uint32_t i = 0; i < BLOCK_Q; ++ i) { const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); start = min(start, min(seq_k_start[i], seq_len_kv)); - end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); } start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage @@ -278,9 +279,9 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const auto& v_offset = lane_idx; // Preload weights - constexpr uint32_t kNumWeightsInReg = 52; + constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); float weights[BLOCK_Q][kNumWeightsInReg]; - DG_STATIC_ASSERT(kNumWeightsInReg <= kNumHeads and kNumWeightsInReg % 4 == 0, "Invalid kNumWeightsInReg"); + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); while (block_q_idx < num_q_blocks) { CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); @@ -337,7 +338,7 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, #pragma unroll for (uint32_t i = 0; i < BLOCK_Q; ++ i) { - float* accum = reinterpret_cast(shifted_accum + i * kNumHeads); + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); auto sum_0 = make_float2(0, 0); auto sum_1 = make_float2(0, 0); @@ -367,14 +368,14 @@ void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, sum_1 = transform_smem(j + 2, sum_1); } - float result = sum_0.x + sum_0.y + sum_1.x + sum_1.y; - result *= scale_kv; + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); // Store into the global memory // NOTES: we have redundant writes here, consider more carefully const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; if constexpr (kIsCompressedLogits) { - if (kv_offset + v_offset >= seq_k_start[i]) + if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; } else { logits[q_idx * stride_logits + kv_offset + v_offset] = result; diff --git a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh index 3984accd..049ba746 100644 --- a/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -22,8 +22,9 @@ template -__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads + 128, 1) + uint32_t kNumSpecializedThreads, uint32_t kNumMathThreads, + uint32_t kNumMathWarpGroups = kNumMathThreads / 128> +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, const uint64_t logits_stride, const uint64_t block_table_stride, const uint32_t* context_lens, float* logits, @@ -40,9 +41,7 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, const auto& lane_idx = get_lane_idx(); // Prefetch TMA descriptors - static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); - DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { cute::prefetch_tma_descriptor(&tensor_map_q); cute::prefetch_tma_descriptor(&tensor_map_kv); @@ -54,78 +53,58 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Shared memory configs static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); - static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + - constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); - - static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); - static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); - static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); - static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + - constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); - - static constexpr uint32_t SMEM_UMMA_SIZE = kNumMathWarpGroups * 2 * 8 + static_cast(sizeof(uint32_t)); // Align to swizzling alignment bytes extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); - // Q data and barriers on shared memory + // Q and KV data on shared memory auto smem_q = PatternVisitor([&](const uint32_t& i) { return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); }); - auto smem_weights = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); - }); - auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); - auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); - auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); - - // Separate math warpgroups and tma load warps into KV groups - // Each math warpgroup corresponds to a tma load warp - const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); - - // Per group KV data and barriers on shared memory - const auto& smem_kv_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; auto smem_kv = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * i); + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); }); - auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { - return reinterpret_cast(smem_buffer + smem_kv_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); }); - auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); - auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); - auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); - // UMMA barriers and TMEM pointer on shared memory - auto umma_barrier_ptr = reinterpret_cast(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kNumMathWarpGroups); + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); - auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); + auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); - const bool& is_math_warp = (warp_idx < (kNumMathThreads / 32)); // 0 ~ 16 - const bool& is_tma_load_warp = (warp_idx >= (kNumMathThreads / 32) and warp_idx < (kNumMathThreads / 32 + 4)); // 16 ~ 20 - const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 4)); // 20 + const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); + const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); + const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); // Initialize barriers if (is_tma_load_warp and cute::elect_one_sync()) { - if (kv_group_idx == 0) { - #pragma unroll - for (uint32_t i = 0; i < kNumQStages; ++ i) { - full_q_barriers[i]->init(1); - empty_q_barriers[i]->init(kNumMathThreads); - } + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); } - if (kv_group_idx < kNumMathWarpGroups) { - #pragma unroll - for (uint32_t i = 0; i < kNumKVStages; ++ i) { - full_kv_barriers[i]->init(1); - empty_kv_barriers[i]->init(128); - } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); } cutlass::arch::fence_barrier_init(); } @@ -144,12 +123,13 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, __syncthreads(); // Register reconfigurations - constexpr uint32_t kNumSpecializedRegisters = 32; - constexpr uint32_t kNumMathRegisters = 104; + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; // Scheduler - auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); - DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); // Q and KV pipeline const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { @@ -161,19 +141,18 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, uint32_t q_iter_idx = 0, kv_iter_idx = 0; // UMMA settings - // Construct instruction with layout F - constexpr uint32_t UMMA_M = 64; + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); constexpr uint32_t UMMA_N = kNextN * kNumHeads; + DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); if (is_tma_load_warp) { // TMA warp-group for loading data cutlass::arch::warpgroup_reg_dealloc(); - if (kv_group_idx >= kNumMathWarpGroups) - return; const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { - if (kv_group_idx == 0 and cute::elect_one_sync()) { + if (cute::elect_one_sync()) { tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); @@ -199,6 +178,14 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; num_kv = next_num_kv; + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + } + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); + // Wait Q consumer release and issue TMA Q if (prefetch_q) { CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); @@ -206,25 +193,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, issue_tma_q(q_stage_idx, q_idx + 1); } - // Read KV block index - // TODO: deal with `-1`? - if (kv_idx == 0 or kv_block_idx_ptr == 32) { - kv_block_idx_ptr = 0; - kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? - __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); - } - const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; // Wait KV consumer release CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); - // Issue TMA KV if (cute::elect_one_sync()) { - tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], - smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); - tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], - smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); } @@ -245,32 +233,26 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, uint32_t q_idx = batch_size, kv_idx; uint32_t next_q_idx, next_kv_idx, next_num_kv; uint32_t q_stage_idx, q_phase; - uint32_t umma_phase = 0; - - auto smem_kv = PatternVisitor([&](const uint32_t& stage_idx) { - return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_PIPE_SIZE + SMEM_KV_SIZE_PER_STAGE * stage_idx); - }); + uint32_t umma_phase = 1; while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { - if (q_idx != next_q_idx) { + if (q_idx != next_q_idx) CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); - } q_idx = next_q_idx; kv_idx = next_kv_idx; CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); - DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); - #pragma unroll for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { - empty_umma_barriers[i]->wait(umma_phase & 1); + empty_umma_barriers[i]->wait(umma_phase); #pragma unroll for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { auto a_desc = make_umma_desc( - smem_kv[kv_stage_idx] + i * SMEM_KV_PIPE_SIZE, 0, k * UMMA_K); + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); auto b_desc = make_umma_desc( smem_q[q_stage_idx], 0, k * UMMA_K); cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); @@ -285,10 +267,12 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Offsets const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); - float weights[kNextN][kNumHeads / 4]; - const auto& sub_warp_offset = (warp_idx % 4) * 16; - const auto& v_0_offset = lane_idx / 4 + 0; - const auto& v_1_offset = lane_idx / 4 + 8; + const uint32_t thread_idx = threadIdx.x; + + // Weights + constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); + float weights[kNextN][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none uint32_t q_idx = batch_size, kv_idx; @@ -310,9 +294,8 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, // Read weights #pragma unroll for (uint32_t i = 0; i < kNextN; ++ i) { - #pragma unroll - for (uint32_t j = 0; j < kNumHeads / 4; ++ j) - weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); } } @@ -321,75 +304,80 @@ void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, kv_idx = next_kv_idx; // Calculate KV offset in advance - auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; - // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` // Wait TMA KV arrival CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); full_kv_barriers[kv_stage_idx]->wait(kv_phase); // Read per-KV scales - auto scale_kv = make_float2(ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset), - ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset)); + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); - empty_umma_barriers[warpgroup_idx]->arrive(); // Wait UMMA arrival - full_umma_barriers[warpgroup_idx]->wait(umma_phase & 1); + full_umma_barriers[warpgroup_idx]->wait(umma_phase); umma_phase ^= 1; // Release KV empty empty_kv_barriers[kv_stage_idx]->arrive(); // Reduce over the head dim and store - static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; + uint32_t shifted_accum[kNumLDTMElems]; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + empty_umma_barriers[warpgroup_idx]->arrive(); + #pragma unroll for (uint32_t i = 0; i < kNextN; ++ i) { - // Load from the tensor memory - constexpr uint32_t kNumLDTMElems = UMMA_M * kNumHeads / 128; - uint32_t shifted_accum[kNumLDTMElems]; - DG_STATIC_ASSERT(kNumLDTMElems == 16 or kNumLDTMElems == 32 or kNumLDTMElems == 64, "Invalid LDTM"); - auto tmem_load = [&](auto... Is) { - if constexpr (kNumLDTMElems == 16) { - cute::SM100_TMEM_LOAD_16dp256b4x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 32) { - cute::SM100_TMEM_LOAD_16dp256b8x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } else if constexpr (kNumLDTMElems == 64) { - cute::SM100_TMEM_LOAD_16dp256b16x::copy(tmem_start + i * kNumHeads, shifted_accum[Is]...); - } - }; - [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); - cutlass::arch::fence_view_async_tmem_load(); - - // Transform - const auto& transform_2 = [&](const uint32_t& j, const uint32_t& k, const float2& sum) { - auto a = make_float2(fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k]), 0), - fmaxf(*reinterpret_cast(&shifted_accum[j * 4 + k + 2]), 0)); - auto b = make_float2(weights[i][j * 2 + k], weights[i][j * 2 + k]); - return __ffma2_rn(a, b, sum); - }; + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); - // Intra-thread reduction auto sum_0 = make_float2(0, 0); auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + #pragma unroll - for (uint32_t j = 0; j < kNumHeads / 8; ++ j) { - sum_0 = transform_2(j, 0, sum_0); - sum_1 = transform_2(j, 1, sum_1); + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); } - auto v = __fmul2_rn(__fadd2_rn(sum_0, sum_1), scale_kv); - // Inter-thread reduction + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + #pragma unroll - for (uint32_t j = 0; j < 2; ++ j) { - const auto& offset = 1u << j; - v.x += __shfl_xor_sync(0xffffffffu, v.x, offset); - v.y += __shfl_xor_sync(0xffffffffu, v.y, offset); + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + // Store into the global memory // NOTES: we have redundant writes here, consider more carefully - logits[kv_offset + i * logits_stride + v_0_offset] = v.x; - logits[kv_offset + i * logits_stride + v_1_offset] = v.y; + logits[kv_offset + i * logits_stride + thread_idx] = result; } } } else { diff --git a/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..4e4ff21d --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,345 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__device__ __forceinline__ +uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { + // Calculate the index of the bank group to be written in the atom + const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` + // - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)` + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % (kSwizzleMode / kSwizzleBase); + + return row * 128 + col * kSwizzleBase; +} + +template +__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t kNumCastStages = 2; + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + constexpr auto kMajorA = cute::UMMA::Major::K; + constexpr auto kMajorB = cute::UMMA::Major::K; + DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages"); + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 4 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + full_cast_barriers[i]->init(kNumCastAndReduceThreads); + empty_barriers[i]->init(1); + empty_cast_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + + // Dispatch warps into different roles + if (warp_idx < kNumMMAThreads / 32) { + // TMA load warp + if (warp_idx == 0 and cute::elect_one_sync()) { + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } + + // MMA issue warp + if (warp_idx == 1) { + // Make instruction descriptor + constexpr uint32_t UMMA_M = BLOCK_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(float); + constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float); + using umma_t = cute::SM100_MMA_TF32_TS; + auto instr_desc = cute::UMMA::make_instr_desc(); + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Launch MMAs + // We can not unroll this part + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + const auto& cast_stage_idx = s % kNumCastStages; + full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; + const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; + const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); + } + + // Commit + cutlass::arch::umma_arrive(reinterpret_cast(empty_cast_barriers[cast_stage_idx])); + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + + // Commit to epilogue threads + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Only support layout F (M = 64) and D (M = 128) + DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Source and destination memory address + uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd) + // Base pointer + warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset + get_swizzled_smem_offset(i, lane_idx); // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + if constexpr (BLOCK_M == 64) + __syncwarp(); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0); + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } else { + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32; + + // TODO: make even larger block K + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + + // Launch reductions + float2 sum[2] = {float2{0, 0}, float2{0, 0}}; + #pragma unroll kNumStages + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b) + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + const auto& smem_base_ptr = reinterpret_cast(smem_a[stage_idx]) + // Base pointer + sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset + + // 4 lanes shared a bank group + uint32_t uint32_values[2][kNumLoads]; + DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads"); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; i += 2) { + auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); + sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); + } + + // Wait tensor memory empty + const auto& cast_stage_idx = s % kNumCastStages; + empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1); + + // Cast, reduce and store into tensor memory + float2 fp32x2_values[2][kNumLoads]; + const auto& upper_view = reinterpret_cast(&fp32x2_values[0]); + const auto& lower_view = reinterpret_cast(&fp32x2_values[1]); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast(&uint32_values[u][i])); + sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]); + } + + // Store upper and lower part at the same time + const auto idx_0 = i * 2, idx_1 = i * 2 + 1; + cute::SM100_TMEM_STORE_16dp256b1x::copy( + upper_view[idx_0], upper_view[idx_1], + lower_view[idx_0], lower_view[idx_1], + cast_stage_idx * BLOCK_K + i * 8); + } + cutlass::arch::fence_view_async_tmem_store(); + + // Arrive for issuing MMAs + tcgen05_before_thread_sync(); + full_cast_barriers[cast_stage_idx]->arrive(); + } + + // Intra-warp reduction and write back + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + if (lane_idx % 4 == 0 and m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum; + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh index 34ce31d5..7a77e4e8 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -350,7 +350,7 @@ sm90_bf16_gemm_impl(int* grouped_layout, cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); // Use TMA store to write back to global memory - const auto m_idx = scheduler.template get_global_idx<(kGemmType != GemmType::MGroupedContiguous), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh index 588de44f..9247304c 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -171,20 +171,23 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, empty_barriers[stage_idx]->wait(phase ^ 1); // Issue TMA A + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; auto& full_barrier = *full_barriers[stage_idx]; const uint32_t k_idx = k_block_idx * BLOCK_K; - tma_copy(&tensor_map_a, &full_barrier, + tma_copy(&tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), - num_tma_multicast_a); + num_tma_multicast_a, batch_idx); tma_copy(&tensor_map_sfa, &full_barrier, - smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.get_global_idx(shape_k_scales, 1, k_block_idx), + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), num_tma_multicast_a); // Issue TMA B - tma_copy(&tensor_map_b, &full_barrier, + tma_copy(&tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), - num_tma_multicast_b); + num_tma_multicast_b, batch_idx); full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); } } @@ -222,7 +225,7 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, // Load B scales with math warp-groups // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks if (threadIdx.x >= 32) { - auto previous_group_offset = scheduler.get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; @@ -413,9 +416,14 @@ sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; - cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, - epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset), - scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx)); + auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset); + auto m_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr, + n_idx, m_idx, scheduler.current_group_idx); + } else { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + } cute::tma_store_arrive(); } __syncwarp(); diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh index bc696eb1..d58c7162 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -127,7 +127,7 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const auto& get_next_block_q_idx = [&]() -> cute::tuple { return {block_q_idx + gridDim.x, q_iter_idx + 1}; }; - uint32_t seq_k_start[BLOCK_Q]; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { uint32_t start = cute::numeric_limits::max(); uint32_t end = cute::numeric_limits::min(); @@ -136,8 +136,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, for (uint32_t i = 0; i < BLOCK_Q; ++ i) { const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); start = min(start, min(seq_k_start[i], seq_len_kv)); - end = max(end, min(__ldg(cu_seq_len_k_end + q_idx), seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); } start = start / 4 * 4; return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage @@ -304,9 +305,9 @@ void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, // NOTES: we have redundant writes here, consider more carefully const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; if constexpr (kIsCompressedLogits) { - if (kv_offset + v_0_offset >= seq_k_start[i]) + if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; - if (kv_offset + v_1_offset >= seq_k_start[i]) + if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; } else { logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; diff --git a/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh index 2bcf2eb9..482a85a8 100644 --- a/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh +++ b/deep_gemm/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -58,7 +58,7 @@ void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t ne } template + uint32_t BLOCK_KV, uint32_t kNumBlocksPerSplit> struct PagedMQALogitsScheduler { uint32_t batch_size; const uint32_t* context_lens; @@ -79,8 +79,8 @@ struct PagedMQALogitsScheduler { const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); - current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumMathWarpGroups; - end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumMathWarpGroups; + current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; current_num_kv = get_num_kv(current_q_idx); } @@ -93,7 +93,7 @@ struct PagedMQALogitsScheduler { if (q_idx == end_q_idx and kv_idx == end_kv_idx) return false; - current_kv_idx += kNumMathWarpGroups; + current_kv_idx += kNumBlocksPerSplit; if (current_kv_idx >= current_num_kv) { ++ current_q_idx; current_kv_idx = 0; diff --git a/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 00000000..e3bf9847 --- /dev/null +++ b/deep_gemm/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,287 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ __forceinline__ +uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { + constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; + + const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % kGroupsInSwizzleRange; + + return (row * kNumBankGroups + col) % kGroupsInSwizzleRange; +} + +template +__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // kSwizzleAMode and kSwizzleBMode must be 128 for now + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode"); + + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 256; + + // TMA load warp + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cutlass::arch::warpgroup_reg_dealloc(); + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + + for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + } + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + constexpr uint32_t WGMMA_M = 64; + constexpr uint32_t WGMMA_N = BLOCK_N; + constexpr uint32_t WGMMA_K = 8; + + using WGMMA = typename TF32MMASelector::type; + float accum[WGMMA::kNumAccum] = {0}; + + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + float sqr_sum_acc_0 = 0; + float sqr_sum_acc_1 = 0; + + #pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2 + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128; + constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K; + + float a[kNumRegPerWgmma * kNumWgmmaPerBlockK]; + // Assume swizzle A mode is 128 + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + + // Load BF16 A fragment from shared memory into registers, and transpose to FP32 + uint32_t row = warp_idx * 16 + lane_idx / 4; + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + // Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a + uint32_t bank_group_idx = (row ^ i) % 8; + nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + + uint32_t elem_offset = lane_idx % 4; + nv_bfloat16 a_bf16[kNumRegPerWgmma]; + a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset]; + a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4]; + a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset]; + a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4]; + + auto a_bf16x2_ptr = reinterpret_cast(a_bf16); + auto a_float2_ptr = reinterpret_cast(a); + float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]); + float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]); + a_float2_ptr[i * 2 + 0] = a_float2_0; + a_float2_ptr[i * 2 + 1] = a_float2_1; + sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x; + sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; + } + + warpgroup_wait<0>(); + if (s > 0) + empty_barriers[(s - 1) % kNumStages]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + + constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); + constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; + DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K"); + + #pragma unroll + for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { + #pragma unroll + for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { + auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + } + + const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + + const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); + if (lane_idx % 4 == 0) { + if (m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum_0; + if (m_idx + 8 < shape_m) + sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; + } + warpgroup_wait<0>(); + empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); + + // Write accum to shared memory + // Every 2 threads (one pair) will write to the same bank group (16 bytes). + // Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d + uint32_t is_odd_pair = lane_idx / 2 % 2; + + // Four threads per group; write the data to the same row. + uint32_t row_idx = lane_idx / 4; + + // Even/odd index pairs write to the same column, we need to reorder idx: + // group even pair indices consecutively, and likewise for odd ones. + uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx; + + auto shifted_smem_ptr = reinterpret_cast(smem_cd) + + (warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows + lane_idx % 2 * 8; // One thread of a pair writes 8 bytes + + #pragma unroll + for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) { + // Get the swizzled bank group index (16 bytes per group) + uint32_t bank_group_idx = get_swizzled_bank_group_idx(i + is_odd_pair, reordered_pair_idx); + auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group + + // 0/1 write to the same row, 2/3 write to another row + auto values = reinterpret_cast(accum + i * 2); + st_shared(smem_ptr, values[0], values[1]); + st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, 1); + + // Issue TMA stores + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/deep_gemm/testing/numeric.py b/deep_gemm/testing/numeric.py index d06a03b9..a42c4318 100644 --- a/deep_gemm/testing/numeric.py +++ b/deep_gemm/testing/numeric.py @@ -5,6 +5,8 @@ def calc_diff(x: torch.Tensor, y: torch.Tensor): x, y = x.double(), y.double() denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 sim = 2 * (x * y).sum() / denominator return 1 - sim diff --git a/deep_gemm/utils/layout.py b/deep_gemm/utils/layout.py index b0ef293d..790e0d66 100644 --- a/deep_gemm/utils/layout.py +++ b/deep_gemm/utils/layout.py @@ -1,10 +1,16 @@ -from .._C import ( - get_tma_aligned_size, - get_mk_alignment_for_contiguous_layout, - get_mn_major_tma_aligned_tensor, - get_mn_major_tma_aligned_packed_ue8m0_tensor, - get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor -) +try: + from .._C import ( + get_tma_aligned_size, + get_mn_major_tma_aligned_tensor, + get_mn_major_tma_aligned_packed_ue8m0_tensor, + get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor + ) +except ImportError: + # Expected behavior for CUDA runtime version before 12.1 + pass + +# Valid for all CUDA versions +from .._C import get_mk_alignment_for_contiguous_layout # Some alias get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/deep_gemm/utils/math.py b/deep_gemm/utils/math.py index 1a47e155..c65026e5 100644 --- a/deep_gemm/utils/math.py +++ b/deep_gemm/utils/math.py @@ -15,35 +15,35 @@ def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) -def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - padded_n = align(n, 128) + padded_n = align(n, gran_k) x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) x_padded[:, :n] = x - x_view = x_padded.view(m, -1, 128) + x_view = x_padded.view(m, -1, gran_k) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf -def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: - assert x.dim() == 2 and x.size(0) % 128 == 0 +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 m, n = x.shape - x_view = x.view(-1, 128, n) + x_view = x.view(-1, gran_k, n) x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf -def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape - x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device) + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x - x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) sf = x_amax / 448.0 sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf @@ -58,3 +58,50 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) - sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code # uint8, 0..15 + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + assert n % 2 == 0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 + return packed[:, :n // 2].contiguous(), sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.uint8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() \ No newline at end of file diff --git a/tests/generators.py b/tests/generators.py index 162274c7..ee22e515 100644 --- a/tests/generators.py +++ b/tests/generators.py @@ -1,12 +1,13 @@ import enum import random import torch -from typing import Generator, List +from typing import Generator, List, Optional, Tuple from deep_gemm.testing import get_arch_major from deep_gemm.utils import ( align, ceil_div, per_token_cast_to_fp8, per_channel_cast_to_fp8, per_block_cast_to_fp8, + per_token_cast_to_fp4, transpose_packed_fp4, get_mk_alignment_for_contiguous_layout ) @@ -35,6 +36,51 @@ def is_k_major(self): def is_mn_major(self): return self.value == 1 + + +class QuantConfig: + _legacy_quant_config = (128, 128, False, False) + + def __init__(self, value: Tuple[int, int, bool, bool] = _legacy_quant_config): + self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b = value + + def print(self): + print(f' > Testing with gran_k_a={self.gran_k_a}, gran_k_b={self.gran_k_b}, ' + f'is_fp4_a={self.is_fp4_a}, is_fp4_b={self.is_fp4_b}') + + def is_legacy(self) -> bool: + return (self.gran_k_a, self.gran_k_b, self.is_fp4_a, self.is_fp4_b) == self._legacy_quant_config + + def get_recipes(self, is_wgrad: bool = False) -> Tuple[Tuple, Tuple, Tuple]: + recipe, recipe_a, recipe_b = None, None, None + if self.is_legacy(): + recipe = (1, 1, 128) if is_wgrad else None + else: + recipe_a = (1, self.gran_k_a) + recipe_b = (1, self.gran_k_b) if self.is_fp4_b or is_wgrad else (self.gran_k_b, self.gran_k_b) + return recipe, recipe_a, recipe_b + + def max_diff(self) -> float: + if self.is_fp4_a and self.is_fp4_b: + return 0.02 + if self.is_fp4_a or self.is_fp4_b: + return 0.01 + return 0.001 + + @staticmethod + def get_list_from_dtype(dtype: torch.dtype) -> List: + if dtype == torch.bfloat16: + return [None] + quant_config_list = [QuantConfig()] + if get_arch_major() == 10: + quant_config_list.append(QuantConfig((128, 32, False, True))) + return quant_config_list + + +def reset_seed(seed: int = 0): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) def get_ue8m0_usage(kernel_type: KernelType) -> bool: @@ -60,9 +106,14 @@ def get_major_ab(allow_a_mn_major: bool, allow_b_mn_major: bool) -> Generator: yield major_a, major_b +def get_psum_layout_usage() -> tuple: + return (False, True) if get_arch_major() == 10 else (False, ) + + def enumerate_normal(dtype: torch.dtype) -> Generator: assert dtype in (torch.float8_e4m3fn, torch.bfloat16) + quant_config_list = QuantConfig.get_list_from_dtype(dtype) fp32_output_nk = [(256, 7168), (129280, 7168)] bf16_output_nk = [(2112, 7168), (576, 7168), (24576, 1536), (32768, 512), (7168, 16384), (4096, 7168), (7168, 2048)] m_fwd_list, m_bwd_list = [1, 128, 4096], [4096, ] @@ -73,39 +124,61 @@ def enumerate_normal(dtype: torch.dtype) -> Generator: nk_list += fp32_output_nk for kernel_type in get_kernel_types(dtype): - # Forward - for m in m_fwd_list: - for i in range(len(nk_list)): - n, k = nk_list[i] - out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float - yield kernel_type, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype - - # Backward - for m in m_bwd_list: - for n, k in nk_list: - override_major = MajorTypeAB.MNMajor - override_kernel_type = kernel_type - if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: - override_major = MajorTypeAB.KMajor - override_kernel_type = KernelType.Kernel1D1D - yield kernel_type, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad - yield override_kernel_type, n, m, k, override_major, override_major, True, torch.float # Wgrad - yield override_kernel_type, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + reset_seed() + + # Forward + for m in m_fwd_list: + for i in range(len(nk_list)): + n, k = nk_list[i] + out_dtype = torch.bfloat16 if i < len(bf16_output_nk) else torch.float + yield kernel_type, quant_config, m, n, k, MajorTypeAB.KMajor, MajorTypeAB.KMajor, False, out_dtype + + # Backward + for m in m_bwd_list: + for n, k in nk_list: + override_major = MajorTypeAB.MNMajor + override_kernel_type = kernel_type + if get_arch_major() == 9 and dtype == torch.float8_e4m3fn: + override_major = MajorTypeAB.KMajor + override_kernel_type = KernelType.Kernel1D1D + yield kernel_type, quant_config, m, k, n, MajorTypeAB.KMajor, override_major, False, torch.bfloat16 # Dgrad + yield override_kernel_type, quant_config, n, m, k, override_major, override_major, True, torch.float # Wgrad + yield override_kernel_type, quant_config, n, m, k, override_major, override_major, False, torch.bfloat16 # Wgrad def enumerate_m_grouped_contiguous(dtype: torch.dtype) -> Generator: + quant_config_list = QuantConfig.get_list_from_dtype(dtype) + m_group_list = [(4, 8192), (8, 4096)] + n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] for kernel_type in get_kernel_types(dtype): - for num_groups, expected_m_per_group, n, k in ((4, 8192, 4096, 7168), (4, 8192, 7168, 2048), (8, 4096, 4096, 7168), (8, 4096, 7168, 2048)): - for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn): - yield kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + for use_psum_layout in get_psum_layout_usage(): + reset_seed() + for num_groups, expected_m_per_group in m_group_list: + for n, k in n_k_list: + for major_a, major_b in get_major_ab(False, get_arch_major() != 9 or dtype != torch.float8_e4m3fn): + yield kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout def enumerate_m_grouped_masked(dtype: torch.dtype) -> Generator: + quant_config_list = QuantConfig.get_list_from_dtype(dtype) max_m = 4096 + m_group_list = [(6, 1024), (32, 192), (32, 50)] + n_k_list = [(6144, 7168), (7168, 3072), (4096, 4096), (4096, 2048)] for kernel_type in get_kernel_types(dtype): - for num_groups, m in ((1, 1024), (2, 512), (4, 256)): - for n, k in ((4096, 7168), (7168, 2048), ): - yield kernel_type, num_groups, max_m, m, n, k + for quant_config in quant_config_list: + if len(quant_config_list) > 1: + quant_config.print() + for use_psum_layout in get_psum_layout_usage(): + reset_seed() + for num_groups, m in m_group_list: + for n, k in n_k_list: + yield kernel_type, quant_config, num_groups, max_m, m, n, k, use_psum_layout def enumerate_k_grouped_contiguous(dtype: torch.dtype): @@ -145,11 +218,46 @@ def enumerate_transpose(): yield mn + delta, k +def cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, + use_ue8m0: bool, use_block_cast_for_fp8: bool = False): + if is_fp4: + x_fp4 = per_token_cast_to_fp4(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp4 if major.is_k_major() else (transpose_packed_fp4(x_fp4[0]).T, x_fp4[1]) + else: + x_fp8 = per_block_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ + else per_token_cast_to_fp8(x, use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp8 if major.is_k_major() else (x_fp8[0].T.contiguous().T, x_fp8[1]) + return x + + +def grouped_cast_fp8_fp4_with_major(x: torch.Tensor, major: MajorTypeAB, gran_k: int, is_fp4: bool, + use_ue8m0: bool, use_block_cast_for_fp8: bool = False): + num_groups, mn, k = x.size() + if is_fp4: + x_fp4 = (torch.empty((num_groups, mn, k // 2), device='cuda', dtype=torch.uint8) if major.is_k_major() else \ + torch.empty((num_groups, k, mn // 2), device='cuda', dtype=torch.uint8), + torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_i_fp4 = per_token_cast_to_fp4(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) + x_fp4[0][i], x_fp4[1][i] = x_i_fp4 if major.is_k_major() else (transpose_packed_fp4(x_i_fp4[0]), x_i_fp4[1]) + x = x_fp4 if major.is_k_major() else (x_fp4[0].mT, x_fp4[1]) + else: + x_fp8 = (torch.empty_like(x, dtype=torch.float8_e4m3fn), + torch.empty((num_groups, ceil_div(mn, gran_k), ceil_div(k, gran_k)), device='cuda', dtype=torch.float) if use_block_cast_for_fp8 \ + else torch.empty((num_groups, mn, ceil_div(k, gran_k)), device='cuda', dtype=torch.float)) + for i in range(num_groups): + x_fp8[0][i], x_fp8[1][i] = per_block_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) if use_block_cast_for_fp8 \ + else per_token_cast_to_fp8(x[i], use_ue8m0=use_ue8m0, gran_k=gran_k) + x = x_fp8 if major.is_k_major() else (x_fp8[0].mT.contiguous().mT, x_fp8[1]) + return x + + def generate_normal(m: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, accumulate: bool, out_dtype: torch.dtype, kernel_type: KernelType, - use_ue8m0: bool = False, use_bf16: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, + quant_config: Optional[QuantConfig] = None): a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((n, k), device='cuda', dtype=torch.bfloat16) d = torch.randn((m, n), device='cuda', dtype=out_dtype) * 32 if accumulate else \ @@ -161,25 +269,28 @@ def generate_normal(m: int, n: int, k: int, a = a if major_a.is_k_major() else a.T.contiguous().T b = b if major_b.is_k_major() else b.T.contiguous().T return a, b, c, d, ref_d + + quant_config = QuantConfig() if quant_config is None else quant_config + a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, + use_block_cast_for_fp8=not (kernel_type.is_1d1d() and accumulate)) - a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) - b_fp8 = per_token_cast_to_fp8(b, use_ue8m0=use_ue8m0) if kernel_type.is_1d1d() and accumulate \ - else per_block_cast_to_fp8(b, use_ue8m0=use_ue8m0) - a_fp8 = a_fp8 if major_a.is_k_major() else (a_fp8[0].T.contiguous().T, a_fp8[1]) - b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].T.contiguous().T, b_fp8[1]) - return a_fp8, b_fp8, c, d, ref_d + return a, b, c, d, ref_d def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: int, k: int, major_a: MajorTypeAB, major_b: MajorTypeAB, - use_ue8m0: bool = False, use_bf16: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, + use_psum_layout: bool = False, + quant_config: Optional[QuantConfig] = None): actual_ms = [int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)] aligned_ms = [align(actual_m, get_mk_alignment_for_contiguous_layout()) for actual_m in actual_ms] m = sum(aligned_ms) a = torch.randn((m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) - m_indices = torch.empty(m, device='cuda', dtype=torch.int32) + grouped_layout = torch.empty(num_groups, device='cuda', dtype=torch.int32) if use_psum_layout \ + else torch.empty(m, device='cuda', dtype=torch.int32) d = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) ref_d = torch.randn((m, n), device='cuda', dtype=torch.bfloat16) @@ -187,48 +298,61 @@ def generate_m_grouped_contiguous(num_groups: int, expected_m_per_group: int, n: for i, (actual_m, aligned_m) in enumerate(zip(actual_ms, aligned_ms)): actual_end = start + actual_m aligned_end = start + aligned_m - m_indices[start:actual_end] = i - m_indices[actual_end:aligned_end] = -1 - ref_d[start:aligned_end] = a[start:aligned_end] @ b[i].t() + if use_psum_layout: + grouped_layout[i] = actual_end + else: + grouped_layout[start: actual_end] = i + grouped_layout[actual_end: aligned_end] = -1 + a[actual_end: aligned_end] = 0 + ref_d[start: aligned_end] = a[start: aligned_end] @ b[i].t() start = aligned_end - ref_d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(ref_d), ref_d) if use_bf16: b = b if major_b.is_k_major() else b.mT.contiguous().mT - return m, a, b, m_indices, d, ref_d + return m, a, b, grouped_layout, d, ref_d assert major_a.is_k_major() - a_fp8 = per_token_cast_to_fp8(a, use_ue8m0=use_ue8m0) - b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), - torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) + quant_config = QuantConfig() if quant_config is None else quant_config + a = cast_fp8_fp4_with_major(a, major_a, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = grouped_cast_fp8_fp4_with_major(b, major_b, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) + + return m, a, b, grouped_layout, d, ref_d + + +def layout_masked_to_psum(x: torch.Tensor, psum_m: torch.Tensor): + num_groups, max_m, _ = x.size() + x_psum = torch.empty_like(x).view(num_groups * max_m, -1) + last_psum_m = 0 for i in range(num_groups): - b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) - b_fp8 = b_fp8 if major_b.is_k_major() else (b_fp8[0].mT.contiguous().mT, b_fp8[1]) - return m, a_fp8, b_fp8, m_indices, d, ref_d + x_psum[last_psum_m: psum_m[i]] = x[i, :psum_m[i] - last_psum_m] + last_psum_m = align(psum_m[i], 128) + return x_psum def generate_m_grouped_masked(num_groups: int, max_m: int, expected_m_per_group: int, n: int, k: int, - use_ue8m0: bool = False, use_bf16: bool = False): + use_ue8m0: bool = False, use_bf16: bool = False, + use_psum_layout: bool = False, + quant_config: Optional[QuantConfig] = None): a = torch.randn((num_groups, max_m, k), device='cuda', dtype=torch.bfloat16) b = torch.randn((num_groups, n, k), device='cuda', dtype=torch.bfloat16) d = torch.empty((num_groups, max_m, n), device='cuda', dtype=torch.bfloat16) ref_d = torch.einsum('gmk,gnk->gmn', a, b) masked_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) + psum_m = torch.empty((num_groups, ), device='cuda', dtype=torch.int) for j in range(num_groups): masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3)) + psum_m[j] = (0 if j == 0 else align(psum_m[j - 1], 128)) + masked_m[j] assert masked_m.amax().item() <= max_m if use_bf16: - return a, b, masked_m, d, ref_d + return a, b, masked_m, psum_m, d, ref_d - a_fp8 = (torch.empty_like(a, dtype=torch.float8_e4m3fn), torch.empty((num_groups, max_m, ceil_div(k, 128)), device='cuda', dtype=torch.float)) - b_fp8 = (torch.empty_like(b, dtype=torch.float8_e4m3fn), torch.empty((num_groups, ceil_div(n, 128), ceil_div(k, 128)), device='cuda', dtype=torch.float)) - for i in range(num_groups): - a_fp8[0][i], a_fp8[1][i] = per_token_cast_to_fp8(a[i], use_ue8m0=use_ue8m0) - b_fp8[0][i], b_fp8[1][i] = per_block_cast_to_fp8(b[i], use_ue8m0=use_ue8m0) + quant_config = QuantConfig() if quant_config is None else quant_config + a = grouped_cast_fp8_fp4_with_major(a, MajorTypeAB.KMajor, quant_config.gran_k_a, quant_config.is_fp4_a, use_ue8m0) + b = grouped_cast_fp8_fp4_with_major(b, MajorTypeAB.KMajor, quant_config.gran_k_b, quant_config.is_fp4_b, use_ue8m0, use_block_cast_for_fp8=True) - return a_fp8, b_fp8, masked_m, d, ref_d + return a, b, masked_m, psum_m, d, ref_d def generate_k_grouped_contiguous(num_groups: int, m: int, n: int, major_a: MajorTypeAB, major_b: MajorTypeAB, ks: List[int], diff --git a/tests/test_attention.py b/tests/test_attention.py index e4cd8e5e..b26cf673 100644 --- a/tests/test_attention.py +++ b/tests/test_attention.py @@ -1,12 +1,14 @@ +import dataclasses import random import torch -from typing import Tuple +from typing import Tuple, List import deep_gemm from deep_gemm.testing import ( bench_kineto, calc_diff, count_bytes, - ignore_env, get_arch_major + ignore_env, get_arch_major, + test_filter ) from deep_gemm.utils import ceil_div, per_custom_dims_cast_to_fp8 @@ -154,7 +156,7 @@ def test_mqa_logits(): ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) logits = logits.masked_fill(neginf_mask, 0) diff = calc_diff(logits, ref_logits) - assert diff < 1e-3, f"{diff=}" + assert diff < 1e-3, f'{diff=}' else: ref_cost = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke, cost_only=True) @@ -204,8 +206,6 @@ def ref_fp8_paged_mqa_logits(q: torch.Tensor, kv_cache: torch.Tensor, def test_paged_mqa_logits(): - # TODO: fully refactor with PyTest - print('Testing FP8 Paged MQA Logits:') max_model_len = 111 * 1000 for is_context_lens_2d in (False, True): @@ -264,7 +264,7 @@ def test_paged_mqa_logits(): else: t, clean_t = bench_kineto(lambda: deep_gemm.fp8_paged_mqa_logits(q_fp8, kv_cache_fp8, weights, context_lens, block_tables, schedule_metadata, max_model_len, clean_logits=True), ('fp8_paged_mqa_logits', 'clean_logits')) - clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens) + clean_bytes = (batch_size * next_n * max_model_len - neginf_mask.sum().item()) * 4 + count_bytes(context_lens) print(f' > BSZ={batch_size:3}, NextN={next_n:1}, H={heads:2}, D={index_dim:2}, L={avg_kv:6}: ' f'{tflops / t:4.0f} TFLOPS, {t * 1e6:3.0f} us, ' f'{(input_bytes + output_bytes) / t / 1e9:4.0f} GB/s', end='') @@ -273,6 +273,8 @@ def test_paged_mqa_logits(): print() + + if __name__ == '__main__': torch.manual_seed(0) random.seed(0) diff --git a/tests/test_bf16.py b/tests/test_bf16.py index f2f41c4a..1a3b0467 100644 --- a/tests/test_bf16.py +++ b/tests/test_bf16.py @@ -9,7 +9,7 @@ calc_diff, count_bytes ) from generators import ( - get_arch_major, + get_arch_major, layout_masked_to_psum, align, enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous ) @@ -18,11 +18,7 @@ def test_gemm() -> None: print('Testing GEMM:') scores = [] - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16): - # TODO: support accumulation for SM90 BF16 GEMM - if get_arch_major() == 9 and accumulate: - continue - + for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' @@ -56,29 +52,30 @@ def test_gemm() -> None: def test_m_grouped_gemm_contiguous() -> None: print('Testing m-grouped contiguous GEMM:') - for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16): + for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' for test_alias in (False, True): - m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_bf16=True, use_psum_layout=use_psum_layout) func_name = f"m_grouped_bf16_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" if test_alias: assert major_a.is_k_major() b = b if major_b.is_k_major() else b.mT assert a[0].is_contiguous() and b[0].is_contiguous() - getattr(deep_gemm, func_name)(a, b, d, m_indices) - d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) + getattr(deep_gemm, func_name)(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) diff = calc_diff(d, ref_d) assert diff < 1e-5, f'{m=}, {n=}, {k=}, {major_opt}, {diff:.5f}, alias={test_alias}' - m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_bf16=True) + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_bf16=True, use_psum_layout=use_psum_layout) # noinspection PyShadowingNames def test_func(): - deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, m_indices) + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a, b, d, grouped_layout, use_psum_layout=use_psum_layout) t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}): ' + print(f' > Perf ({num_groups=}, m={m:5}, n={n:5}, k={k:5}, layout={major_opt}, psum={use_psum_layout}): ' f'{t * 1e6:4.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') @@ -89,29 +86,52 @@ def test_m_grouped_gemm_masked() -> None: print('Testing m-grouped masked GEMM:') # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for _, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.bfloat16): - # Test correctness - for i in range(10): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) - deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + for _, _, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.bfloat16): + num_tests = 8 + sum_t, max_t = 0, 0 + sum_ops, sum_bytes = 0, 0 + + for i in range(num_tests): + a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_bf16=True, use_psum_layout=use_psum_layout) + if use_psum_layout: + a_psum = layout_masked_to_psum(a, psum_m) + d_psum = layout_masked_to_psum(d, psum_m) + + # noinspection PyShadowingNames + def test_func(): + if use_psum_layout: + deep_gemm.m_grouped_bf16_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, + use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group) + else: + deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + + test_func() for j in range(num_groups): - diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) + if masked_m[j].item() == 0: + continue + if use_psum_layout: + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]] + else: + d_slice = d[j, :masked_m[j].item()] + diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) assert diff < 1e-5, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}' - # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_bf16=True) - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group) + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) - # Test performance with fixed shapes - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'bf16_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') + sum_t += t + max_t = max(max_t, t) + sum_ops += 2 * valid_m * n * k + sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) + + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' + f'psum={1 if use_psum_layout else 0}): ' + f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' + f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' + f'{sum_bytes / sum_t / 1e9:4.0f} GB/s') print() @@ -148,7 +168,7 @@ def test_func(): def test_cublaslt_gemm() -> None: print('Testing cuBLASLt GEMM:') - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16): + for kernel_type, _, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(dtype=torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N' out_opt = 'FP32' if out_dtype == torch.float else 'BF16' @@ -159,7 +179,8 @@ def test_cublaslt_gemm() -> None: diff = calc_diff(d, ref_d) assert diff < 6e-7, f'{diff=}, ({m=}, {n=}, {k=}, {major_opt=}, {accumulate=}, {out_dtype=})' - t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), 'nvjet', suppress_kineto_output=True,) + t_nvjet, t_gemv, t_gemm = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a, b, d, c=c), ('nvjet', 'gemv', 'gemm'), suppress_kineto_output=True) + t = t_nvjet + t_gemv + t_gemm print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, layout={major_opt}, {out_opt}, {acc_opt}): ' f'{t * 1e6:5.0f} us | ' f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' diff --git a/tests/test_einsum.py b/tests/test_einsum.py index edbe225b..b7979989 100644 --- a/tests/test_einsum.py +++ b/tests/test_einsum.py @@ -80,7 +80,6 @@ def test_bhd_hdr_bhr(): print() -@test_filter(lambda: get_arch_major() >= 10) def test_fp8_bhr_hdr_bhd(use_ue8m0: bool = True): print('Testing FP8 "bhr, hdr -> bhd":') for h, r, d in [(8, 4096, 1024)]: diff --git a/tests/test_fp8.py b/tests/test_fp8.py deleted file mode 100644 index 50d25c7c..00000000 --- a/tests/test_fp8.py +++ /dev/null @@ -1,175 +0,0 @@ -import copy -import numpy as np -import random -import torch - -import deep_gemm -from deep_gemm.testing import ( - bench_kineto, - calc_diff, count_bytes, - ignore_env, get_arch_major -) - -from generators import ( - KernelType, get_ue8m0_usage, - enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, - generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous -) - - -@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) -def test_gemm() -> None: - print('Testing GEMM:') - scores = [] - for kernel_type, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn): - major_opt = 'N' if major_a.is_k_major() else 'T' - major_opt += 'T' if major_b.is_k_major() else 'N' - out_opt = 'FP32' if out_dtype == torch.float else 'BF16' - acc_opt = f'acc={int(accumulate)}' - kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' - use_ue8m0 = get_ue8m0_usage(kernel_type) - disable_ue8m0_cast = not use_ue8m0 - recipe = (1, 1, 128) if kernel_type.is_1d1d() and accumulate else None - - for test_alias in (False, True): - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) - func_name = f'fp8_gemm_{major_opt.lower() if test_alias else "nt"}' - if test_alias: - a = a if major_a.is_k_major() else (a[0].T, a[1].T) - b = b if major_b.is_k_major() else (b[0].T, b[1].T) - assert a[0].is_contiguous() and b[0].is_contiguous() - getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe) - diff = calc_diff(d, ref_d) - assert diff < 0.001, (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' - f'{diff:.5f}, alias={test_alias}') - - a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0) - t = bench_kineto(lambda: deep_gemm.fp8_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe), - 'fp8_gemm', suppress_kineto_output=True) - cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) - print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' - f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' - f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') - if cublas_t > 0: - scores.append((cublas_t + split_k_t) / t) - print(f"Average speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n") - - -def test_m_grouped_gemm_contiguous() -> None: - print('Testing m-grouped contiguous GEMM:') - - for kernel_type, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn): - major_opt = 'N' if major_a.is_k_major() else 'T' - major_opt += 'T' if major_b.is_k_major() else 'N' - kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' - use_ue8m0 = get_ue8m0_usage(kernel_type) - disable_ue8m0_cast = not use_ue8m0 - - for test_alias in (False, True): - m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0) - func_name = f"m_grouped_fp8_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" - if test_alias: - assert major_a.is_k_major() - b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) - assert a[0].is_contiguous() and b[0].is_contiguous() - getattr(deep_gemm, func_name)(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) - d = torch.where((m_indices == -1).unsqueeze(1), torch.zeros_like(d), d) - diff = calc_diff(d, ref_d) - assert diff < 0.001, f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' - m, a, b, m_indices, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, use_ue8m0=use_ue8m0) - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_contiguous(a, b, d, m_indices, disable_ue8m0_cast=disable_ue8m0_cast) - - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') - print() - - -def test_m_grouped_gemm_masked() -> None: - print('Testing m-grouped masked GEMM:') - - # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. - for kernel_type, num_groups, max_m, expected_m_per_group, n, k in enumerate_m_grouped_masked(torch.float8_e4m3fn): - kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' - use_ue8m0 = get_ue8m0_usage(kernel_type) - disable_ue8m0_cast = not use_ue8m0 - - # Test correctness - for i in range(10): - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - for j in range(num_groups): - if masked_m[j].item() == 0: - continue - diff = calc_diff(d[j, :masked_m[j].item()], ref_d[j, :masked_m[j].item()]) - assert diff < 0.001, f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' - - # Construct full cases - a, b, masked_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, use_ue8m0=use_ue8m0) - - # noinspection PyShadowingNames - def test_func(): - deep_gemm.m_grouped_fp8_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast) - - # Test performance with fixed shapes - valid_m = masked_m.sum().item() - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, {kernel_opt}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * valid_m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{(count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b)) / 1e9 / t:4.0f} GB/s') - print() - - -def test_k_grouped_gemm_contiguous() -> None: - print('Testing k-grouped contiguous GEMM:') - - k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \ - else deep_gemm.k_grouped_fp8_gemm_tn_contiguous - for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn): - use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) - - for test_empty_groups in (False, True): - new_ks = copy.deepcopy(ks) - if test_empty_groups and len(ks) > 1: - new_ks[random.randint(0, num_groups - 1)] = 0 - k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0) - new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') - k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c) - - diff = calc_diff(d, ref_d) - assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}' - - # Test performance - k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0) - ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') - - # noinspection PyShadowingNames - def test_func(): - k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c) - - t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) - print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' - f'{t * 1e6:4.0f} us | ' - f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' - f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') - print() - - -if __name__ == '__main__': - torch.manual_seed(0) - random.seed(0) - - print('Library path:') - print(f' > {deep_gemm.__path__}\n') - - test_gemm() - test_m_grouped_gemm_contiguous() - test_m_grouped_gemm_masked() - test_k_grouped_gemm_contiguous() diff --git a/tests/test_fp8_fp4.py b/tests/test_fp8_fp4.py new file mode 100644 index 00000000..f7e3e1c4 --- /dev/null +++ b/tests/test_fp8_fp4.py @@ -0,0 +1,207 @@ +import copy +import numpy as np +import random +import torch + +import deep_gemm +from deep_gemm.testing import ( + bench_kineto, + calc_diff, count_bytes, + ignore_env, get_arch_major +) + +from generators import ( + KernelType, get_ue8m0_usage, layout_masked_to_psum, align, + enumerate_normal, enumerate_m_grouped_contiguous, enumerate_m_grouped_masked, enumerate_k_grouped_contiguous, + generate_normal, generate_m_grouped_contiguous, generate_m_grouped_masked, generate_k_grouped_contiguous +) + + +@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) +def test_gemm() -> None: + print('Testing GEMM:') + scores = [] + for kernel_type, quant_config, m, n, k, major_a, major_b, accumulate, out_dtype in enumerate_normal(torch.float8_e4m3fn): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + out_opt = 'FP32' if out_dtype == torch.float else 'BF16' + acc_opt = f'acc={int(accumulate)}' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes(is_wgrad=(kernel_type.is_1d1d() and accumulate)) + + for test_alias in (False, True): + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) + func_name = f'fp8_fp4_gemm_{major_opt.lower() if test_alias else "nt"}' + if test_alias: + a = a if major_a.is_k_major() else (a[0].T, a[1].T) + b = b if major_b.is_k_major() else (b[0].T, b[1].T) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), (f'{m=}, {n=}, {k=}, {kernel_opt}, {major_opt=}, {accumulate=}, {out_dtype=}, ' + f'{diff:.5f}, alias={test_alias}') + + a, b, c, d, ref_d = generate_normal(m, n, k, major_a, major_b, accumulate, out_dtype, kernel_type, use_ue8m0=use_ue8m0, quant_config=quant_config) + t = bench_kineto(lambda: deep_gemm.fp8_fp4_gemm_nt(a, b, d, c=c, disable_ue8m0_cast=disable_ue8m0_cast, recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b), + 'fp8_gemm', suppress_kineto_output=True) + cublas_t, split_k_t = bench_kineto(lambda: deep_gemm.cublaslt_gemm_nt(a[0], b[0], d, c=c), ('nvjet', 'reduce'), suppress_kineto_output=True) \ + if not quant_config.is_fp4_a and not quant_config.is_fp4_b else (0, 0) + print(f' > Perf (m={m:6}, n={n:6}, k={k:6}, {kernel_opt}, layout={major_opt}, {out_opt}, {acc_opt}): ' + f'{t * 1e6:6.1f} us | {2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{(count_bytes(a, b, d) + count_bytes(c) * int(accumulate)) / 1e9 / t:4.0f} GB/s | ' + f'{(cublas_t + split_k_t) / t:.2f}x cuBLAS') + if cublas_t > 0: + scores.append((cublas_t + split_k_t) / t) + print(f"Average FP8xFP8 GEMM speedup over cuBLASLt: {float(np.prod(scores)) ** (1.0 / len(scores)):.3f}x\n") + + +def test_m_grouped_gemm_contiguous() -> None: + print('Testing m-grouped contiguous GEMM:') + + for kernel_type, quant_config, num_groups, expected_m_per_group, n, k, major_a, major_b, use_psum_layout in enumerate_m_grouped_contiguous(dtype=torch.float8_e4m3fn): + major_opt = 'N' if major_a.is_k_major() else 'T' + major_opt += 'T' if major_b.is_k_major() else 'N' + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes() + + for test_alias in (False, True): + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + func_name = f"m_grouped_fp8_fp4_gemm_{(major_opt.lower() if test_alias else 'nt')}_contiguous" + if test_alias: + assert major_a.is_k_major() + b = b if major_b.is_k_major() else (b[0].mT, b[1].mT) + assert a[0].is_contiguous() and b[0].is_contiguous() + getattr(deep_gemm, func_name)(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + diff = calc_diff(d, ref_d) + assert diff < quant_config.max_diff(), f'{m=}, {n=}, {k=}, {major_opt}, {kernel_opt}, {diff:.5f}, alias={test_alias}' + m, a, b, grouped_layout, d, ref_d = generate_m_grouped_contiguous(num_groups, expected_m_per_group, n, k, major_a, major_b, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + + # noinspection PyShadowingNames + def test_func(): + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a, b, d, grouped_layout, disable_ue8m0_cast=disable_ue8m0_cast, use_psum_layout=use_psum_layout, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=}, m={m:5}, n={n:6}, k={k:5}, {kernel_opt}, layout={major_opt}, psum={use_psum_layout}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d) / 1e9 / t:4.0f} GB/s') + print() + + +def test_m_grouped_gemm_masked() -> None: + print('Testing m-grouped masked GEMM:') + + # TODO: when the actual `m` is greater than `expected_m_per_group`, efficiency may significantly decrease. + for kernel_type, quant_config, num_groups, max_m, expected_m_per_group, n, k, use_psum_layout in enumerate_m_grouped_masked(torch.float8_e4m3fn): + kernel_opt = f'1D1D' if kernel_type.is_1d1d() else '1D2D' + use_ue8m0 = get_ue8m0_usage(kernel_type) + disable_ue8m0_cast = not use_ue8m0 + recipe, recipe_a, recipe_b = quant_config.get_recipes() + + num_tests = 8 + sum_t, max_t = 0, 0 + sum_ops, sum_bytes = 0, 0 + + for i in range(num_tests): + a, b, masked_m, psum_m, d, ref_d = generate_m_grouped_masked(num_groups, max_m, expected_m_per_group, n, k, + use_ue8m0=use_ue8m0, use_psum_layout=use_psum_layout, + quant_config=quant_config) + if use_psum_layout: + a_psum = (layout_masked_to_psum(a[0], psum_m), layout_masked_to_psum(a[1], psum_m)) + d_psum = layout_masked_to_psum(d, psum_m) + + # noinspection PyShadowingNames + def test_func(): + if use_psum_layout: + deep_gemm.m_grouped_fp8_fp4_gemm_nt_contiguous(a_psum, b, d_psum, psum_m, disable_ue8m0_cast=disable_ue8m0_cast, + use_psum_layout=True, expected_m_for_psum_layout=expected_m_per_group, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + else: + deep_gemm.m_grouped_fp8_fp4_gemm_nt_masked(a, b, d, masked_m, expected_m_per_group, disable_ue8m0_cast=disable_ue8m0_cast, + recipe=recipe, recipe_a=recipe_a, recipe_b=recipe_b) + + test_func() + for j in range(num_groups): + if masked_m[j].item() == 0: + continue + if use_psum_layout: + d_slice = d_psum[: psum_m[j]] if j == 0 else d_psum[align(psum_m[j - 1], 128): psum_m[j]] + else: + d_slice = d[j, :masked_m[j].item()] + diff = calc_diff(d_slice, ref_d[j, :masked_m[j].item()]) + assert diff < quant_config.max_diff(), f'{max_m=}, {n=}, {k=}, {j=}, masked_m={masked_m[j]}, {kernel_opt}, {num_groups=}, {diff:.5f}' + + # Test performance with fixed shapes + valid_m = masked_m.sum().item() + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + + sum_t += t + max_t = max(max_t, t) + sum_ops += 2 * valid_m * n * k + sum_bytes += count_bytes(a, d) * valid_m / (max_m * num_groups) + count_bytes(b) + + print(f' > Perf (num_groups={num_groups:2}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}, ' + f'{kernel_opt}, psum={1 if use_psum_layout else 0}): ' + f'{sum_t / num_tests * 1e6:4.0f} us (max: {max_t * 1e6:3.0f} us) | ' + f'{sum_ops / sum_t / 1e12:4.0f} TFLOPS | ' + f'{sum_bytes / sum_t / 1e9:4.0f} GB/s') + print() + + +@ignore_env('DG_JIT_PTXAS_CHECK', lambda: get_arch_major() == 9) +def test_k_grouped_gemm_contiguous() -> None: + print('Testing k-grouped contiguous GEMM:') + + k_grouped_fp8_gemm_contiguous = deep_gemm.k_grouped_fp8_gemm_nt_contiguous if get_arch_major() == 9 \ + else deep_gemm.k_grouped_fp8_gemm_tn_contiguous + for num_groups, m, n, major_a, major_b, ks, expected_k_per_group in enumerate_k_grouped_contiguous(torch.float8_e4m3fn): + use_ue8m0 = get_ue8m0_usage(KernelType.Kernel1D1D) + + for test_empty_groups in (False, True): + new_ks = copy.deepcopy(ks) + if test_empty_groups and len(ks) > 1: + new_ks[random.randint(0, num_groups - 1)] = 0 + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, new_ks, use_ue8m0=use_ue8m0) + new_ks_tensor = torch.tensor(new_ks, dtype=torch.int, device='cuda') + k_grouped_fp8_gemm_contiguous(a, b, d, new_ks, new_ks_tensor, c) + + diff = calc_diff(d, ref_d) + assert diff < 0.001, f'{m=}, {n=}, {k=}, {ks=}, {diff:.5f}' + + # Test performance + k, a, b, c, d, ref_d = generate_k_grouped_contiguous(num_groups, m, n, major_a, major_b, ks, use_ue8m0=use_ue8m0) + ks_tensor = torch.tensor(ks, dtype=torch.int, device='cuda') + + # noinspection PyShadowingNames + def test_func(): + k_grouped_fp8_gemm_contiguous(a, b, d, ks, ks_tensor, c) + + t = bench_kineto(test_func, 'fp8_gemm', suppress_kineto_output=True) + print(f' > Perf ({num_groups=:2}, m={m:5}, n={n:5}, k={k:5}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, c, d) / 1e9 / t:4.0f} GB/s') + print() + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_gemm() + test_m_grouped_gemm_contiguous() + test_m_grouped_gemm_masked() + test_k_grouped_gemm_contiguous() diff --git a/tests/test_hyperconnection.py b/tests/test_hyperconnection.py new file mode 100644 index 00000000..24faf22c --- /dev/null +++ b/tests/test_hyperconnection.py @@ -0,0 +1,57 @@ +import torch +import random + +import deep_gemm +from deep_gemm.testing import ( + test_filter, + bench_kineto, + calc_diff, count_bytes +) +from deep_gemm.utils import align +from generators import get_arch_major + + +@test_filter(lambda: get_arch_major() >= 9) +def test_hc_prenorm_gemm() -> None: + # Needs TF32 precision for PyTorch GEMMs + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + print('Testing hyperconnection prenorm GEMM:') + for m in (13, 137, 4096, 8192): + for n, k in [(24, 28672), (24, 7680), (24, 7168)]: + for num_splits in [None, 16]: + a = torch.randn((m, k), dtype=torch.bfloat16, device='cuda') + b = torch.randn((n, k), dtype=torch.float, device='cuda') + d = torch.empty((m, n), dtype=torch.float, device='cuda') if num_splits is None else \ + torch.empty((num_splits, m, n), dtype=torch.float, device='cuda') + s = torch.empty((m, ), dtype=torch.float, device='cuda') if num_splits is None else \ + torch.empty((num_splits, m), dtype=torch.float, device='cuda') + deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits) + final_d = d if num_splits is None else d.sum(0) + final_s = s if num_splits is None else s.sum(0) + + ref_d = a.float() @ b.T + ref_s = a.float().square().sum(-1) + + diff = max(calc_diff(final_d, ref_d), calc_diff(final_s, ref_s)) + assert diff < 1e-8, f'{m=}, {n=}, {k=}, {diff:.10f}' + + t = bench_kineto(lambda: deep_gemm.tf32_hc_prenorm_gemm(a, b, d, s, num_splits=num_splits), 'tf32_hc_prenorm_gemm', suppress_kineto_output=True) + print(f' > Perf (m={m:5}, n={n:5}, k={k:5}, num_splits={(num_splits or 0):2}): ' + f'{t * 1e6:4.0f} us | ' + f'{2 * m * n * k / t / 1e12:4.0f} TFLOPS | ' + f'{count_bytes(a, b, d, s) / 1e9 / t:4.0f} GB/s') + print() + + + + +if __name__ == '__main__': + torch.manual_seed(0) + random.seed(0) + + print('Library path:') + print(f' > {deep_gemm.__path__}\n') + + test_hc_prenorm_gemm() diff --git a/tests/test_legacy.py b/tests/test_legacy.py index 6559b1cd..4456799f 100644 --- a/tests/test_legacy.py +++ b/tests/test_legacy.py @@ -13,7 +13,7 @@ def test_m_grouped_gemm_contiguous_tl() -> None: print('Testing m-grouped contiguous Triton GEMM:') - for _, num_groups, expected_m_per_group, n, k, major_a, major_b in enumerate_m_grouped_contiguous(torch.bfloat16): + for _, _, num_groups, expected_m_per_group, n, k, major_a, major_b, _ in enumerate_m_grouped_contiguous(torch.bfloat16): major_opt = 'N' if major_a.is_k_major() else 'T' major_opt += 'T' if major_b.is_k_major() else 'N'