Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
7797f24
initial commit
amd-khushbu Sep 9, 2025
0759f38
remove extra files
amd-khushbu Sep 9, 2025
be11e28
fixing errors
amd-khushbu Sep 9, 2025
00aeb62
updated ReadMe file for mapping of diff quants with diff configs
amd-khushbu Sep 9, 2025
171c71e
Merge branch 'develop' into fix_quant_example
amd-khushbu Sep 9, 2025
c4a5daa
addressing review comments
amd-khushbu Sep 10, 2025
26d2603
addressing review comments
amd-khushbu Sep 10, 2025
cbd5d63
Resolved merge conflicts
amd-khushbu Sep 10, 2025
dc0445e
Resolved merge conflicts
amd-khushbu Sep 10, 2025
b3acd23
[CK TILE GEMM] Replace get_preshuffle_or with is_quantpreshuffle_enabled
CongMa13 Sep 10, 2025
bb3c01c
initial commit
amd-khushbu Sep 11, 2025
ef41cfe
debugging
amd-khushbu Sep 12, 2025
ff27dad
working fp8 for init constant
amd-khushbu Sep 13, 2025
fee560b
fp8 working with all inits
amd-khushbu Sep 13, 2025
1eeb194
updated block level code with comments
amd-khushbu Sep 16, 2025
51d3616
changing the loop iter
amd-khushbu Sep 16, 2025
f1731f5
debugging
amd-khushbu Sep 18, 2025
559238a
debugging
amd-khushbu Sep 18, 2025
d186d0e
debugging
amd-khushbu Sep 18, 2025
abdd61e
code fix
amd-khushbu Sep 19, 2025
13af4d1
merging with develop and resolving merge conflicts
amd-khushbu Sep 19, 2025
a5d8437
code clean up
amd-khushbu Sep 19, 2025
46f1b59
Merge branch 'develop' into preshuffle
amd-khushbu Sep 19, 2025
667b5dd
clang formatted
amd-khushbu Sep 19, 2025
c66907c
Add comment
amd-khushbu Sep 19, 2025
0b9450a
code cleanup
amd-khushbu Sep 19, 2025
e285c76
rebase to develop
amd-khushbu Sep 22, 2025
a63dda8
clang formatted
amd-khushbu Sep 22, 2025
ee78f67
merge conflicts fixes
amd-khushbu Sep 22, 2025
dd7e998
applying the latest int4 changes to the piepline
amd-khushbu Sep 22, 2025
d8c8db7
Merge branch 'develop' into preshuffle
amd-khushbu Sep 23, 2025
d226aed
fixing test code for updated traits
amd-khushbu Sep 23, 2025
845922a
Adding gtest
Sep 23, 2025
bfdceee
review comments addressed
amd-khushbu Sep 24, 2025
e962313
Merge branch 'develop' into preshuffle
amd-khushbu Sep 24, 2025
4c03bf8
addressing review comments
amd-khushbu Sep 25, 2025
7667f6c
Merge branch 'develop' into preshuffle
amd-khushbu Sep 25, 2025
46c50a9
remove c++20 code
amd-khushbu Sep 25, 2025
05c2bde
Merge branch 'develop' into preshuffle
amd-khushbu Sep 26, 2025
b1320e7
added flush cache changes
amd-khushbu Sep 26, 2025
418dc52
Merge branch 'develop' into preshuffle
amd-khushbu Sep 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/ck_tile/38_block_scale_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,6 @@ User need to select correct mapping of config for each quant mode:
| For selecting AQuant | aquant | GemmConfigQuant |
| For selecting Aquant with Preshuffle | aquant | GemmConfigPreshuffleQuant |
| For selecting BQuant | bquant | GemmConfigQuant |
| For selecting PreShuffle Weight matrix with Bquant | bquant | GemmConfigPreshuffleB_Bquant
| For selecting RowCol quant | rowcolquant | GemmConfigRowColQuant |

34 changes: 27 additions & 7 deletions example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ template <typename GemmConfig,
float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
// B datatype is safe to use as compute type as it should be at least fp8
using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
Expand All @@ -41,10 +40,14 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::PreshuffleQuant,
GemmConfig::PreshuffleB,
ALayout,
BLayout,
CLayout,
QuantMode>;
QuantMode,
ALayout, // for AQLayout
BLayout, // for BQLayout
GemmConfig::DoubleSmemBuffer>;

using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<typename TypeConfig::ADataType,
typename TypeConfig::BDataType,
Expand All @@ -53,7 +56,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
GemmTraits,
ComputeDataType>;

using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == false,
Comment thread
amd-khushbu marked this conversation as resolved.
Outdated
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>;

const ck_tile::index_t K_split =
(args.K + GemmConfig::K_Tile - 1) / GemmConfig::K_Tile * GemmConfig::K_Tile;
Expand Down Expand Up @@ -107,9 +113,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant,
ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
std::conditional_t<GemmConfig::PreshuffleB == false,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>>>;

using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename TypeConfig::ADataType,
Expand Down Expand Up @@ -177,6 +186,14 @@ int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int a
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;

if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant) &&
GemmConfig::PreshuffleB)
{
throw std::runtime_error(
"Preshuffling weight matrix is not supported for AQuant or RowColQuant");
}

if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
Expand Down Expand Up @@ -372,4 +389,7 @@ int run_gemm_example(int argc, char* argv[])
}
}

int main(int argc, char* argv[]) { return !run_gemm_example<GemmConfigQuant>(argc, argv); }
int main(int argc, char* argv[])
{
return !run_gemm_example<GemmConfigPreshuffleB_Bquant>(argc, argv);
}
22 changes: 21 additions & 1 deletion example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerM01 = 4;

static constexpr bool PreshuffleQuant = false;
static constexpr bool PreshuffleB = false;
static constexpr bool DoubleSmemBuffer = false;
};

Expand Down Expand Up @@ -145,6 +146,26 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase
static constexpr bool PreshuffleQuant = true;
};

template <typename PrecType>
struct GemmConfigPreshuffleB_Bquant : public GemmConfigBase
Comment thread
amd-khushbu marked this conversation as resolved.
Outdated
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();

static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};

template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
Expand Down Expand Up @@ -222,7 +243,6 @@ auto create_args(int argc, char* argv[])
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("bq_layout", "C", "Bq tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
Expand Down
48 changes: 42 additions & 6 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,22 @@ auto shuffle_aq(const ck_tile::HostTensor<T>* t, int block_aq_k)
return ck_tile::reference_permute(t_view, {1, 0, 2});
}

template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}

template <typename GemmConfig,
typename TypeConfig,
typename ALayout,
Expand Down Expand Up @@ -124,7 +140,8 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
? "AQuantGrouped"
: (QuantMode == ck_tile::QuantType::BQuantGrouped ? "BQuantGrouped"
: "RowColQuant"))
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : "
<< " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false")
<< " PreshuffleB = " << (GemmConfig::PreshuffleB ? "true" : "false") << " : "
<< ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;

Expand Down Expand Up @@ -383,17 +400,36 @@ int run_gemm_example_with_layouts(int argc,
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}

if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
if constexpr(GemmConfig::PreshuffleB)
{
ck_tile::HostTensor<BDataType> b_k_n_dev = shuffle_b<GemmConfig>(b_k_n);
Comment thread
amd-khushbu marked this conversation as resolved.
Outdated
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
ck_tile::permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
if constexpr(GemmConfig::PreshuffleB)
{
ck_tile::HostTensor<BDataType> b_k_n_dev = shuffle_b<GemmConfig>(b_k_n);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
}

c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ struct WarpGemmAttributeMfmaIterateK
static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t kCMLane = Impl::kCMLane;

CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }

Expand Down
3 changes: 3 additions & 0 deletions include/ck_tile/ops/gemm_group_quant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_br_bquant_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp"
#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp"
Expand All @@ -13,6 +14,8 @@
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,9 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>

if constexpr(Traits::PreshuffleQuant)
{
static_assert(false,
"It is not supported yet to enable both Preshuffle and "
"TransposeC.");
// static_assert(false,
// "It is not supported yet to enable both Preshuffle and
// " "TransposeC.");
if constexpr(Traits::TransposeC) // transposed C
{
// TODO:
Expand Down
Loading
Loading