Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 303 additions & 0 deletions example/ck_tile/20_grouped_convolution/gemm_configs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#pragma once

#include <string>
#include <variant>

#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/utility/json_dump.hpp"

#define CK_TILE_PIPELINE_COMPUTE_V3 1
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4

struct GemmConfigBase
{
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;

static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;

static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;

static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
static constexpr bool TiledMMAPermuteN = false;
};

template <typename PrecType>
struct GemmConfigMemoryInterwave : public GemmConfigBase
{
// Memory friendly for Interwave scheduler
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
};

template <typename PrecType>
struct GemmConfigMemoryIntrawave : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 32;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
};

template <typename PrecType>
struct GemmConfigComputeV3 : public GemmConfigBase
{
// Compute V3 only support Intrawave scheduler
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 64;

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};

template <typename PrecType>
struct GemmConfigComputeV3_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
};

template <typename PrecType>
struct GemmConfigComputeV3_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;

static constexpr int kBlockPerCu = 2;
};

template <typename PrecType>
struct GemmConfigComputeV3_WMMA : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;

static constexpr int kBlockPerCu = 2;
};

template <typename PrecType>
struct GemmConfigComputeV4 : public GemmConfigBase
{
// Compute V4 only support Intrawave scheduler
// Using the ping pong reader in the lds level
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};

template <typename PrecType>
struct GemmConfigComputeV4_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 256;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 2;
static constexpr ck_tile::index_t N_Warp = 2;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
};

template <typename PrecType>
struct GemmConfigComputeV5 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 1;
static constexpr ck_tile::index_t K_Warp = 2;

static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = 16;

static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};

template <typename InDataType, typename WeiDataType = InDataType, typename OutDataType = InDataType>
struct ConvTypeConfig;

template <>
struct ConvTypeConfig<ck_tile::half_t>
{
using InDataType = ck_tile::half_t;
using WeiDataType = ck_tile::half_t;
using AccDataType = float;
using OutDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};

template <>
struct ConvTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
{
using InDataType = ck_tile::bf16_t;
using WeiDataType = ck_tile::bf16_t;
using AccDataType = float;
using OutDataType = ck_tile::bf16_t;
};

template <typename T>
struct DataTypeTraits;

template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};

template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};

template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};

template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;

template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
};

template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
};

template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
};

template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "grouped_convolution_backward_data_invoker.hpp"
#include "run_grouped_convolution_bwd_data_example.inc"

template <typename GemmWarpConfig>
template <template <typename PrecType> typename GemmConfig>
int run_grouped_conv_bwd_data_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionBackwardDataInvoker;
Expand All @@ -31,14 +31,14 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
if(data_type == "fp16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
GemmWarpConfig,
GemmConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_bwd_data_example_prec_type<Invoker,
GemmWarpConfig,
GemmConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
Expand All @@ -51,8 +51,8 @@ int run_grouped_conv_bwd_data_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Wmma>(argc, argv);
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_bwd_data_example<GemmWarpConfig_Mfma>(argc, argv);
return !run_grouped_conv_bwd_data_example<GemmConfigComputeV3>(argc, argv);
#endif
}
Loading
Loading