Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "conv_util.hpp"
#include "device.hpp"
#include "device_tensor.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp"
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "host_tensor.hpp"
Expand Down Expand Up @@ -78,6 +79,52 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::
7, // CThreadTransferSrcDstVectorDim
1>; // CThreadTransferDstScalarPerVector

#if 1
using DeviceConv2DFwdInstance = ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,

256, // block_size
128, // m_per_block
64, // n_per_block
4, // k_per_block
8, // k1
32, // m_per_xdl
32, // n_per_xdl
2, // m_xdl_per_wave
1, // n_xdl_per_wave


S<4,64,1>, // thread_cluster_length
S<1,0,2>, // thread_cluster_arrange_order
S<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
true, // add_extra_dim

S<4,64,1>, // thread_cluster_length
S<1,0,2>, // thread_cluster_arrange_order
S<1,0,2>, // src_access_order
2, // src_vector_dim
8, // src_scalar_per_vector
8, // dst_scalar_per_vector
true, // add_extra_dim

1, // m_xdl_per_wave
1, // n_xdl_per_wave
S<1,1,32,1,1,8>, // m_n_block_wave_per_xdl
8 // scalar_per_vector

>;
#endif

template <ck::index_t NumDimSpatial>
using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<InDataType,
WeiDataType,
Expand All @@ -95,7 +142,11 @@ DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial)
return std::make_unique<DeviceConvNDFwdInstance<3>>();
}
case 2: {
#if 0
return std::make_unique<DeviceConvNDFwdInstance<2>>();
#else
return std::make_unique<DeviceConv2DFwdInstance>();
#endif
}
case 1: {
return std::make_unique<DeviceConvNDFwdInstance<1>>();
Expand Down Expand Up @@ -291,7 +342,7 @@ int main(int argc, char* argv[])

float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString()
<< std::endl;

if(do_verification)
Expand Down