Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ int main(int argc, char* argv[])

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString()
<< std::endl;

if(do_verification)
Expand Down Expand Up @@ -320,18 +320,15 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvNDFwdInstance<3>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 2: {
auto ref_conv = ReferenceConvNDFwdInstance<2>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 1: {
auto ref_conv = ReferenceConvNDFwdInstance<1>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
Expand Down
9 changes: 3 additions & 6 deletions example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,18 +324,15 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvNDFwdInstance<3>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 2: {
auto ref_conv = ReferenceConvNDFwdInstance<2>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 1: {
auto ref_conv = ReferenceConvNDFwdInstance<1>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
Expand Down
9 changes: 3 additions & 6 deletions example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,18 +322,15 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvNDFwdInstance<3>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 2: {
auto ref_conv = ReferenceConvNDFwdInstance<2>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 1: {
auto ref_conv = ReferenceConvNDFwdInstance<1>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
Expand Down
9 changes: 3 additions & 6 deletions example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,18 +332,15 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvBwdDataInstance<3>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 2: {
auto ref_conv = ReferenceConvBwdDataInstance<2>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 1: {
auto ref_conv = ReferenceConvBwdDataInstance<1>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
Expand Down
9 changes: 3 additions & 6 deletions example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,18 +403,15 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvBwdWeightInstance<3>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 2: {
auto ref_conv = ReferenceConvBwdWeightInstance<2>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
case 1: {
auto ref_conv = ReferenceConvBwdWeightInstance<1>();
verify_f(ref_conv);
break;
return verify_f(ref_conv);
}
default: {
throw std::runtime_error("Unsupported number of spatial dimensions provided!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;

using Block2CTileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, CGridDesc_M_N>;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asroy Is there a performance concern not to use the default BlockToCTileMap_M00_N0_M01Adapt given by the underlying gridwise gemm? I can imagine for small problem sizes the overhead introduced by workgroup mapping can outweigh the benefits.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason, the case (in the PR description) run at 52TFlops with the BlockToCTileMap_M00_N0_M01 with M01=1, and 44TFlops with BlockToCTileMap_M00_N0_M01Adapt


// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1<
BlockSize,
Expand Down Expand Up @@ -477,8 +479,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
ck::index_t M01,
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
Expand All @@ -490,8 +490,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
c_grid_desc_m_n_{},
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op},
Expand Down Expand Up @@ -520,10 +518,9 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W

a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
c_grid_desc_m_n_ = descs[I2];

c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_};

if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
Expand All @@ -546,9 +543,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
typename GridwiseGemm::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
Block2CTileMap block_2_ctile_map_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
Expand Down Expand Up @@ -661,7 +656,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
Block2CTileMap,
true>;

ave_time = launch_and_time_kernel(
Expand Down Expand Up @@ -695,7 +690,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
Block2CTileMap,
false>;

ave_time = launch_and_time_kernel(
Expand Down Expand Up @@ -814,8 +809,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op};
Expand Down Expand Up @@ -854,8 +847,6 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;

using Block2CTileMap = BlockToCTileMap_M00_N0_M01<MPerBlock, NPerBlock, CGridDesc_M_N>;

// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
BlockSize,
Expand Down Expand Up @@ -664,8 +666,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
ck::index_t M01,
ck::index_t N01,
InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
Expand All @@ -677,8 +677,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
c_grid_desc_m_n_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{},
M01_{M01},
N01_{N01},
in_element_op_{in_element_op},
wei_element_op_{wei_element_op},
out_element_op_{out_element_op},
Expand All @@ -705,8 +703,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);

block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_};

if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
Expand All @@ -727,9 +725,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
Block2CTileMap block_2_ctile_map_;
InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_;
OutElementwiseOperation out_element_op_;
Expand Down Expand Up @@ -793,7 +789,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
Block2CTileMap,
true>;

ave_time = launch_and_time_kernel(stream_config,
Expand Down Expand Up @@ -824,7 +820,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
Block2CTileMap,
false>;

ave_time = launch_and_time_kernel(stream_config,
Expand Down Expand Up @@ -955,8 +951,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op};
Expand Down Expand Up @@ -995,8 +989,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
conv_filter_dilations,
input_left_pads,
input_right_pads,
1,
1,
in_element_op,
wei_element_op,
out_element_op);
Expand All @@ -1012,8 +1004,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
auto str = std::stringstream();

// clang-format off
str << "DeviceConv" << std::to_string(NumDimSpatial)
<< "DFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str << "DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;

template <bool HasMainK0BlockLoop, typename Block2CTileMap = DefaultBlock2CTileMap>
template <bool HasMainK0BlockLoop, typename Block2CTileMap>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
Expand Down