Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "conv_util.hpp"
#include "device.hpp"
#include "device_tensor.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "host_tensor.hpp"
Expand Down Expand Up @@ -78,6 +79,52 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::
7, // CThreadTransferSrcDstVectorDim
1>; // CThreadTransferDstScalarPerVector

#if 1
using DeviceConv2DFwdInstance = ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,

256, // block_size
128, // m_per_block
64, // n_per_block
4, // k_per_block
8, // k1
32, // m_per_xdl
32, // n_per_xdl
2, // m_xdl_per_wave
1, // n_xdl_per_wave


S<4,64,1>, // thread_cluster_length
S<1,0,2>, // thread_cluster_arrange_order
S<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
true, // add_extra_dim

S<4,64,1>, // thread_cluster_length
S<1,0,2>, // thread_cluster_arrange_order
S<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
true, // add_extra_dim

1, // m_xdl_per_wave
1, // n_xdl_per_wave
S<1,1,32,1,1,8>, // m_n_block_wave_per_xdl
8 // scalar_per_vector

>;
#endif

template <ck::index_t NumDimSpatial>
using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
Expand All @@ -95,7 +142,11 @@ DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial)
return std::make_unique<DeviceConvNDFwdInstance<3>>();
}
case 2: {
#if 0
return std::make_unique<DeviceConvNDFwdInstance<2>>();
#else
return std::make_unique<DeviceConv2DFwdInstance>();
#endif
}
case 1: {
return std::make_unique<DeviceConvNDFwdInstance<1>>();
Expand Down Expand Up @@ -291,7 +342,7 @@ int main(int argc, char* argv[])

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString()
<< std::endl;

if(do_verification)
Expand Down Expand Up @@ -320,17 +371,17 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvNDFwdInstance<3>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 2: {
auto ref_conv = ReferenceConvNDFwdInstance<2>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
Comment thread
shaojiewang marked this conversation as resolved.
Outdated
}
case 1: {
auto ref_conv = ReferenceConvNDFwdInstance<1>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
Comment thread
shaojiewang marked this conversation as resolved.
Outdated
}
default: {
Expand Down
6 changes: 3 additions & 3 deletions example/09_convnd_fwd/convnd_fwd_xdl_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,17 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvNDFwdInstance<3>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 2: {
auto ref_conv = ReferenceConvNDFwdInstance<2>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 1: {
auto ref_conv = ReferenceConvNDFwdInstance<1>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
default: {
Expand Down
6 changes: 3 additions & 3 deletions example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,17 +322,17 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvNDFwdInstance<3>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 2: {
auto ref_conv = ReferenceConvNDFwdInstance<2>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 1: {
auto ref_conv = ReferenceConvNDFwdInstance<1>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
default: {
Expand Down
6 changes: 3 additions & 3 deletions example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,17 +332,17 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvBwdDataInstance<3>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 2: {
auto ref_conv = ReferenceConvBwdDataInstance<2>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 1: {
auto ref_conv = ReferenceConvBwdDataInstance<1>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
default: {
Expand Down
6 changes: 3 additions & 3 deletions example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,17 +403,17 @@ int main(int argc, char* argv[])
{
case 3: {
auto ref_conv = ReferenceConvBwdWeightInstance<3>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 2: {
auto ref_conv = ReferenceConvBwdWeightInstance<2>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
case 1: {
auto ref_conv = ReferenceConvBwdWeightInstance<1>();
verify_f(ref_conv);
return verify_f(ref_conv);
break;
}
default: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,11 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W

a_grid_desc_k0_m_k1_ = descs[I0];
b_grid_desc_k0_n_k1_ = descs[I1];
c_grid_desc_m_n_ = descs[I2];

block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);

c_grid_desc_m_n_ = descs[I2];

if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
Expand Down