Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/ck_tile/03_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ struct PipelineTypeTraits<ck_tile::GemmPipeline::PRESHUFFLE_V2>
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<PipelineProblem>;
};

auto create_args()
inline auto create_args()
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
Expand Down
4 changes: 2 additions & 2 deletions example/ck_tile/03_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,8 @@ bool do_verify(const ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
return pass;
}

std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t>
parse_gemm_size(ck_tile::ArgParser& arg_parser)
std::tuple<ck_tile::index_t, ck_tile::index_t, ck_tile::index_t> inline parse_gemm_size(
ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
Expand Down
2 changes: 2 additions & 0 deletions include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -986,6 +986,8 @@ struct MoeSortingKernel
p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
}
}
__syncthreads();

smem_cumdup(num_experts) = smem_cumsum(num_experts);

// fill the p_sorted_token_ids/p_sorted_weights
Expand Down
1 change: 1 addition & 0 deletions include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,7 @@ struct GroupedGemmKernel
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
block_sync_lds();
block_id = block_id + grid_size; // advance to next block
// NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
if(block_id >= cum_grid_size)
Expand Down
7 changes: 4 additions & 3 deletions include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ struct StreamKKernel
tile_idx += kargs.tile_partitioner.get_grid())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
block_sync_lds();
}

// Stream-K section
Expand Down Expand Up @@ -679,8 +680,8 @@ struct StreamKKernel
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
int num_cu = dev_prop.multiProcessorCount;

return num_cu;
Expand All @@ -700,7 +701,7 @@ struct StreamKKernel
constexpr int min_block_per_cu = 1;
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;

hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));

return max(occupancy, 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ struct UniversalGemmKernel
using Kernel = UniversalGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
const auto kernel = kentry<1, Kernel, KernelArgs>;
int occupancy;
hip_check_error(
ck_tile::hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize().x, 0));

const int grid_size = get_available_compute_units(s) * occupancy;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,35 @@

namespace ck_tile {

template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV1
{
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = false;

CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }

CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }

CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}

template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};

// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
struct GemmPipelineAGmemBGmemCRegV1
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
Expand Down Expand Up @@ -48,14 +72,14 @@ struct GemmPipelineAGmemBGmemCRegV1
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeA()
{
return Problem::VectorSizeA;
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
}
template <bool IsWave32Host = false>
static constexpr index_t GetVectorSizeB()
{
return Problem::VectorSizeB;
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
}
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }

static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,34 @@

namespace ck_tile {

template <typename Problem>
struct BaseGemmPipelineAGmemBGmemCRegV2
{
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;

CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }

CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }

CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
{
return TailNumber::Empty;
}

template <typename RunFunction>
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
{
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
}
};
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV2
struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Problem>
{
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
Expand Down
5 changes: 3 additions & 2 deletions include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ template <bool kPadM_,
bool UseStructuredSparsity_ = false,
bool UsePersistentKernel_ = false,
index_t NumWaveGroups_ = 1,
bool Preshuffle_ = false>
bool Preshuffle_ = false,
int VectorSize_ = 16>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
static constexpr int _VectorSize = VectorSize_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;

using AsLayout = AsLayout_;
Expand Down
2 changes: 1 addition & 1 deletion include/ck_tile/ops/reduce/block/block_reduce2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ struct BlockReduce2dCrossWarpSync

if constexpr(num_reduce_warps == 1)
return;

block_sync_lds();
// Each warp's lane 0 writes its partial results to shared memory
const index_t smem_offset = warp_id;
if(lane_id == 0)
Expand Down
Loading