From 3e6c3f8c35c0f5d559243079870beefdc2b31982 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 7 Feb 2022 17:46:43 +0100 Subject: [PATCH 01/82] Convolution ND * Code unification across dimensions for generating tensor descriptors. * Example * Instances --- ...nd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp | 111 +++ .../device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp | 859 ++++++++++++++++++ example/8_convnd_fwd_xdl/README.md | 57 ++ example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 283 ++++++ example/CMakeLists.txt | 3 + 5 files changed, 1313 insertions(+) create mode 100644 device_operation/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp create mode 100644 device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp create mode 100644 example/8_convnd_fwd_xdl/README.md create mode 100644 example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp diff --git a/device_operation/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 00000000000..2aa81442c47 --- /dev/null +++ b/device_operation/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,111 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_convnd_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_convnd_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..9d41bd5181e --- /dev/null +++ b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,859 @@ +#ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Inputs with up to 3 spatial dimentions +// @li Input tensor in NHWC data format +// @li Weight tensor in KYXC data format +// @li Output tensor in NHWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template +struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd +{ + using DeviceOp = DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = SpatialDims; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto GetWeightTensorDescriptor(ck::index_t gemm_n, ck::index_t gemm_k) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_n, gemm_k)); + + // wei_gemmk0_gemmn_gemmk1_grid_desc + return transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_n)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto + GetOutputTensorDescriptor(ck::index_t gemm_m, ck::index_t gemm_n, ck::index_t gemm_m_pad) + { + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_n)); + + // out_gemmm_gemmn_grid_desc + return transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(gemm_n)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<2>{}, Sequence<0, 1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t X = filter_spatial_lengths[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + static index_t GetGemmMRaw(ck::index_t N, + const std::vector& output_spatial_lengths) + { + return N * std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + 1, + std::multiplies()); + } + + static index_t GetGemmK(ck::index_t C, const std::vector& filter_spatial_lengths) + { + return C * std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + 1, + std::multiplies()); + } + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t GemmMRaw = GetGemmMRaw(N, output_spatial_lengths); + const index_t GemmN = K; + const index_t GemmK = GetGemmK(C, filter_spatial_lengths); + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + // C = A^T*B + // A: + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + GetInputTensorDescriptor(N, + C, + GemmMRaw, + GemmK, + GemmMPad, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + // B: + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK); + // C: + const auto out_gemmm_gemmn_grid_desc = GetOutputTensorDescriptor(GemmMRaw, GemmN, GemmMPad); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(ck::index_t i = 0; i < SpatialDims; ++i) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(ck::index_t i = 0; i < SpatialDims; ++i) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/example/8_convnd_fwd_xdl/README.md b/example/8_convnd_fwd_xdl/README.md new file mode 100644 index 00000000000..b6fd5711cdf --- /dev/null +++ b/example/8_convnd_fwd_xdl/README.md @@ -0,0 +1,57 @@ +# Instructions for ```convnd_fwd_xdl``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```convnd_fwd_xdl``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j convnd_fwd_xdl +``` + +## Run ```convnd_fwd_xdl``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx +./example/convnd_fwd_xdl 0 1 5 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP32) +``` +in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{216, 165888, 8} +arg.b_grid_desc_k0_n_k1_{216, 256, 8} +arg.c_grid_desc_m_n_{ 165888, 256} +launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.43206 ms, 102.486 TFlops, 232.947 GB/s +``` diff --git a/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp new file mode 100644 index 00000000000..fad4072bca2 --- /dev/null +++ b/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -0,0 +1,283 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "tensor_layout.hpp" +#include "device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +using DeviceConvFwdInstance = + ck::tensor_operation::device::DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + // clang-format off +// | InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| +// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| +// | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| +// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + , S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>; +// clang-format on + +template +void host_verify(const Tensor& in, + const Tensor& wei, + Tensor& out, + const std::vector& conv_strides, + const std::vector& conv_dilations, + const std::vector& in_left_pads, + const std::vector&, + const InElementOp& in_element_op, + const WeiElementOp& wei_element_op, + const OutElementOp& out_element_op) +{ + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + { + int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + { + int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += in_element_op(static_cast(in(n, c, hi, wi))) * + wei_element_op(static_cast(wei(k, c, y, x))); + } + } + } + } + double v2 = out(n, k, ho, wo); + + out_element_op(v2, v); + + out(n, k, ho, wo) = v2; + }; + + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); +} + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // Conv shape + ck::index_t N = 128; + ck::index_t K = 256; + ck::index_t C = 192; + ck::index_t Y = 3; + ck::index_t X = 3; + ck::index_t Hi = 71; + ck::index_t Wi = 71; + ck::index_t conv_stride_h = 2; + ck::index_t conv_stride_w = 2; + ck::index_t conv_dilation_h = 1; + ck::index_t conv_dilation_w = 1; + ck::index_t in_left_pad_h = 1; + ck::index_t in_left_pad_w = 1; + ck::index_t in_right_pad_h = 1; + ck::index_t in_right_pad_w = 1; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 19) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + C = std::stoi(argv[6]); + Y = std::stoi(argv[7]); + X = std::stoi(argv[8]); + Hi = std::stoi(argv[9]); + Wi = std::stoi(argv[10]); + conv_stride_h = std::stoi(argv[11]); + conv_stride_w = std::stoi(argv[12]); + conv_dilation_h = std::stoi(argv[13]); + conv_dilation_w = std::stoi(argv[14]); + in_left_pad_h = std::stoi(argv[15]); + in_left_pad_w = std::stoi(argv[16]); + in_right_pad_h = std::stoi(argv[17]); + in_right_pad_w = std::stoi(argv[18]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " + "RightPx\n"); + exit(0); + } + + const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; + const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; + + const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + + const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; + const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; + const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; + const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; + + // tensor layout + auto f_host_tensor_descriptor = [](std::size_t N_, + std::size_t C_, + std::size_t H, + std::size_t W, + auto layout) { + if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, H * W, W, 1})); + } + else if constexpr(ck::is_same::value || + ck::is_same::value || + ck::is_same::value) + { + return HostTensorDescriptor(std::vector({N_, C_, H, W}), + std::vector({C_ * H * W, 1, W * C_, C_})); + } + }; + + Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); + Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); + Tensor out_n_k_ho_wo_host_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor out_n_k_ho_wo_device_result( + f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + + std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; + std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; + std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * + out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + + // do GEMM + auto conv = DeviceConvFwdInstance{}; + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + N, + K, + C, + std::vector{{Hi, Wi}}, + std::vector{{Y, X}}, + std::vector{{Ho, Wo}}, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + + std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + + sizeof(WeiDataType) * (K * C * Y * X) + + sizeof(OutDataType) * (N * K * Ho * Wo); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + host_verify(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); + + check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + } +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index f9474425bcd..fce6f047cc4 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -19,6 +19,7 @@ set(CONV2D_FWD_XDL_SOURCE 4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_SOURCE 5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) +set(CONVND_FWD_XDL_SOURCE 8_convnd_fwd_xdl/convnd_fwd_xdl.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE}) @@ -27,6 +28,7 @@ add_executable(conv2d_fwd_xdl ${CONV2D_FWD_XDL_SOURCE}) add_executable(conv2d_fwd_xdl_bias_relu ${CONV2D_FWD_XDL_BIAS_RELU_SOURCE}) add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE}) add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE}) +add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) target_link_libraries(gemm_xdl PRIVATE host_tensor) target_link_libraries(gemm_xdl_bias_relu PRIVATE host_tensor) @@ -35,3 +37,4 @@ target_link_libraries(conv2d_fwd_xdl PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor) +target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor) From 65d8da4fa6a0a0e653981219baf7a8b9acb90a9f Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 8 Feb 2022 11:33:55 +0100 Subject: [PATCH 02/82] Move convnd f32 instance file to comply with repo structure. --- .../device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename device_operation/{ => src}/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp (100%) diff --git a/device_operation/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp similarity index 100% rename from device_operation/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp rename to device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp From e2e3d2e075fd35f626af39db9d42625f909f0873 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 8 Feb 2022 11:34:29 +0100 Subject: [PATCH 03/82] Conv 1D tensor layouts. --- device_operation/include/tensor_layout.hpp | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/device_operation/include/tensor_layout.hpp b/device_operation/include/tensor_layout.hpp index b69572d2c08..01b6d6eb6ba 100644 --- a/device_operation/include/tensor_layout.hpp +++ b/device_operation/include/tensor_layout.hpp @@ -21,6 +21,32 @@ struct ColumnMajor : public BaseTensorLayout namespace convolution { +// 1D Conv +struct NWC : public BaseTensorLayout +{ +}; + +struct KXC : public BaseTensorLayout +{ +}; + +struct NWK : public BaseTensorLayout +{ +}; + +struct NCW : public BaseTensorLayout +{ +}; + +struct KCX : public BaseTensorLayout +{ +}; + +struct NKW : public BaseTensorLayout +{ +}; + +// 2D Conv struct NHWC : public BaseTensorLayout { }; From e035c337a37de190358b6b9cb1ca96be71804381 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 8 Feb 2022 12:04:47 +0100 Subject: [PATCH 04/82] Formatting and use ReferenceConv --- example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 122 +++++++++----------- 1 file changed, 55 insertions(+), 67 deletions(-) diff --git a/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp index fad4072bca2..3eff7dd3a0a 100644 --- a/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -4,6 +4,7 @@ #include #include #include + #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -11,8 +12,9 @@ #include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "tensor_layout.hpp" -#include "device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "reference_conv_fwd.hpp" using InDataType = float; using WeiDataType = float; @@ -33,65 +35,46 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; -using DeviceConvFwdInstance = - ck::tensor_operation::device::DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K +using DeviceConvFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< // clang-format off -// | InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| -// | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| -// | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| -// | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - , S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>; + InDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + 2, // SptialDims + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 4, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector // clang-format on -template -void host_verify(const Tensor& in, - const Tensor& wei, - Tensor& out, - const std::vector& conv_strides, - const std::vector& conv_dilations, - const std::vector& in_left_pads, - const std::vector&, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) -{ - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; - if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && - wi < in.mDesc.GetLengths()[3]) - { - v += in_element_op(static_cast(in(n, c, hi, wi))) * - wei_element_op(static_cast(wei(k, c, y, x))); - } - } - } - } - double v2 = out(n, k, ho, wo); - - out_element_op(v2, v); - - out(n, k, ho, wo) = v2; - }; - - make_ParallelTensorFunctor(f_nchw, - out.mDesc.GetLengths()[0], - out.mDesc.GetLengths()[1], - out.mDesc.GetLengths()[2], - out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); -} +using ReferenceConvFwdInstance = ck::tensor_operation::host:: + ReferenceConvFwd; int main(int argc, char* argv[]) { @@ -265,16 +248,21 @@ int main(int argc, char* argv[]) if(do_verification) { - host_verify(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); From ec89868c42afc3354307e58ad11234d2ed3b7dad Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 11 Feb 2022 11:50:36 +0100 Subject: [PATCH 05/82] Reference ConvFwd supporting 1D and 2D convolution. --- .../include/reference_conv_fwd.hpp | 140 +++++++++++++----- 1 file changed, 107 insertions(+), 33 deletions(-) diff --git a/reference_operation/include/reference_conv_fwd.hpp b/reference_operation/include/reference_conv_fwd.hpp index f929f3cda58..c86d6ce5d34 100644 --- a/reference_operation/include/reference_conv_fwd.hpp +++ b/reference_operation/include/reference_conv_fwd.hpp @@ -2,6 +2,7 @@ #define REFERENCE_CONV_FWD_HPP #include +#include #include #include "device_base.hpp" #include "host_tensor.hpp" @@ -10,21 +11,38 @@ namespace ck { namespace tensor_operation { namespace host { -// out[N, K, Ho, Wo] = in[N, C, Hi, Wi] * wei[K, C, Y, X] +// +// @brief Reference implementation for forward convolution. +// +// @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout. +// Weights tensor supports KCYX data layout. Output tensor supports +// NKHoWo data layout. +// +// @tparam InDataType Input tensor data type. +// @tparam WeiDataType Weights tensor data type. +// @tparam OutDataType Output tensor data type. +// @tparam InElementwiseOperation Functor for input tensor elementwise +// operation. +// @tparam WeiElementwiseOperation Functor for weights tensor elementwise +// operation. +// @tparam SpatialNDims Number of spatial dimensions. +// template + typename OutElementwiseOperation, + ck::index_t SpatialNDims = 2, + typename std::enable_if= 1 && SpatialNDims <= 3, bool>::type = false> struct ReferenceConvFwd : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument { - Argument(const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, + Argument(const Tensor& input, + const Tensor& weights, + Tensor& output, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, @@ -32,9 +50,9 @@ struct ReferenceConvFwd : public device::BaseOperator InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) - : in_n_c_hi_wi_{in_n_c_hi_wi}, - wei_k_c_y_x_{wei_k_c_y_x}, - out_n_k_ho_wo_{out_n_k_ho_wo}, + : input_{input}, + weights_{weights}, + output_{output}, conv_strides_{conv_filter_strides}, conv_dilations_{conv_filter_dilations}, in_left_pads_{input_left_pads}, @@ -45,9 +63,9 @@ struct ReferenceConvFwd : public device::BaseOperator { } - const Tensor& in_n_c_hi_wi_; - const Tensor& wei_k_c_y_x_; - Tensor& out_n_k_ho_wo_; + const Tensor& input_; + const Tensor& weights_; + Tensor& output_; std::vector conv_strides_; std::vector conv_dilations_; @@ -60,7 +78,65 @@ struct ReferenceConvFwd : public device::BaseOperator }; // Invoker + template struct Invoker : public device::BaseInvoker + { + }; + + template <> + struct Invoker<1> : public device::BaseInvoker + { + using Argument = ReferenceConvFwd::Argument; + + float Run(const Argument& arg) + { + auto f_ncw = [&](auto n, auto k, auto wo) { + float v_acc = 0; + + for(int c = 0; c < arg.weights_.mDesc.GetLengths()[1]; ++c) + { + for(int x = 0; x < arg.weights_.mDesc.GetLengths()[2]; ++x) + { + int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) + { + float v_in; + float v_wei; + + arg.in_element_op_(v_in, + static_cast(arg.input_(n, c, wi))); + arg.wei_element_op_(v_wei, + static_cast(arg.weights_(k, c, x))); + + v_acc += v_in * v_wei; + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, wo) = v_out; + }; + + make_ParallelTensorFunctor(f_ncw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + template <> + struct Invoker<2> : public device::BaseInvoker { using Argument = ReferenceConvFwd::Argument; @@ -69,27 +145,26 @@ struct ReferenceConvFwd : public device::BaseOperator auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { float v_acc = 0; - for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) + for(int c = 0; c < arg.weights_.mDesc.GetLengths()[1]; ++c) { - for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) + for(int y = 0; y < arg.weights_.mDesc.GetLengths()[2]; ++y) { int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - arg.in_left_pads_[0]; - for(int x = 0; x < arg.wei_k_c_y_x_.mDesc.GetLengths()[3]; ++x) + for(int x = 0; x < arg.weights_.mDesc.GetLengths()[3]; ++x) { int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - arg.in_left_pads_[1]; - if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && - wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) + if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && + wi < arg.input_.mDesc.GetLengths()[3]) { float v_in; float v_wei; arg.in_element_op_( - v_in, - static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + v_in, static_cast(arg.input_(n, c, hi, wi))); arg.wei_element_op_( - v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + v_wei, static_cast(arg.weights_(k, c, y, x))); v_acc += v_in * v_wei; } @@ -100,15 +175,14 @@ struct ReferenceConvFwd : public device::BaseOperator float v_out; arg.out_element_op_(v_out, v_acc); - - arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; + arg.output_(n, k, ho, wo) = v_out; }; make_ParallelTensorFunctor(f_nchw, - arg.out_n_k_ho_wo_.mDesc.GetLengths()[0], - arg.out_n_k_ho_wo_.mDesc.GetLengths()[1], - arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], - arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2], + arg.output_.mDesc.GetLengths()[3])( std::thread::hardware_concurrency()); return 0; @@ -128,9 +202,9 @@ struct ReferenceConvFwd : public device::BaseOperator bool IsSupportedArgument(const device::BaseArgument*) override { return true; } - static auto MakeArgument(const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, + static auto MakeArgument(const Tensor& input, + const Tensor& weights, + Tensor& output, std::vector conv_filter_strides, std::vector conv_filter_dilations, std::vector input_left_pads, @@ -139,9 +213,9 @@ struct ReferenceConvFwd : public device::BaseOperator WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) { - return Argument{in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo, + return Argument{input, + weights, + output, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -151,11 +225,11 @@ struct ReferenceConvFwd : public device::BaseOperator out_element_op}; } - static auto MakeInvoker() { return Invoker{}; } + static auto MakeInvoker() { return Invoker{}; } virtual std::unique_ptr MakeInvokerPointer() { - return std::make_unique(Invoker{}); + return std::make_unique>(Invoker{}); } std::string GetTypeString() const override From 5189c802602b9e11fde4992304f232c60577de10 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 11 Feb 2022 11:51:29 +0100 Subject: [PATCH 06/82] Debug printing TensorLayout name. --- .../device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp | 2 +- device_operation/include/tensor_layout.hpp | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index 9d41bd5181e..d4e7d7725ea 100644 --- a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -851,7 +851,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K return str.str(); } -}; // namespace device +}; } // namespace device } // namespace tensor_operation diff --git a/device_operation/include/tensor_layout.hpp b/device_operation/include/tensor_layout.hpp index 01b6d6eb6ba..1e2f9accb26 100644 --- a/device_operation/include/tensor_layout.hpp +++ b/device_operation/include/tensor_layout.hpp @@ -12,10 +12,12 @@ namespace gemm { struct RowMajor : public BaseTensorLayout { + static constexpr const char* name = "RowMajor"; }; struct ColumnMajor : public BaseTensorLayout { + static constexpr const char* name = "ColumnMajor"; }; } // namespace gemm @@ -24,55 +26,76 @@ namespace convolution { // 1D Conv struct NWC : public BaseTensorLayout { + static constexpr const char* name = "NWC"; }; struct KXC : public BaseTensorLayout { + static constexpr const char* name = "KXC"; }; struct NWK : public BaseTensorLayout { + static constexpr const char* name = "NWK"; }; struct NCW : public BaseTensorLayout { + static constexpr const char* name = "NCW"; }; struct KCX : public BaseTensorLayout { + static constexpr const char* name = "KCX"; }; struct NKW : public BaseTensorLayout { + static constexpr const char* name = "NKW"; }; // 2D Conv struct NHWC : public BaseTensorLayout { + static constexpr const char* name = "NHWC"; }; struct KYXC : public BaseTensorLayout { + static constexpr const char* name = "KYXC"; }; struct NHWK : public BaseTensorLayout { + static constexpr const char* name = "NHWK"; }; struct NCHW : public BaseTensorLayout { + static constexpr const char* name = "NCHW"; }; struct KCYX : public BaseTensorLayout { + static constexpr const char* name = "KCYX"; }; struct NKHW : public BaseTensorLayout { + static constexpr const char* name = "NKHW"; }; } // namespace convolution +template < + typename Layout, + typename std::enable_if::value, bool>::type = false> +std::ostream& operator<<(std::ostream& os, const Layout&) +{ + os << Layout::name; + return os; +} + } // namespace tensor_layout } // namespace ck #endif From a0a6afdf4620b4c5e1dab4ae6286954f6b75508d Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 11 Feb 2022 12:19:15 +0100 Subject: [PATCH 07/82] Conv fwd 1D instance f32 --- ...nd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp | 87 +++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index 2aa81442c47..791facd7a6d 100644 --- a/device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -25,7 +25,94 @@ static constexpr auto ConvFwd1x1P0 = static constexpr auto ConvFwd1x1S1P0 = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; +//------------------------------------------------------------------------------ +// Conv1D +//------------------------------------------------------------------------------ + // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +//------------------------------------------------------------------------------ +// Conv2D +//------------------------------------------------------------------------------ + using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< // clang-format off From 97663eb1fc2341b92b2b702a9959c2e866d13805 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 11 Feb 2022 12:19:44 +0100 Subject: [PATCH 08/82] Refactor conv ND example. Needed to support various conv dimensio. Needed to support various conv dimensions --- example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 517 ++++++++++++++------ 1 file changed, 358 insertions(+), 159 deletions(-) diff --git a/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp index 3eff7dd3a0a..49a2508e61f 100644 --- a/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -1,9 +1,10 @@ +#include +#include +#include #include #include -#include -#include #include -#include +#include #include "config.hpp" #include "print.hpp" @@ -35,17 +36,22 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; -using DeviceConvFwdInstance = ck::tensor_operation::device:: +using DeviceConvFwdBasePtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< - // clang-format off + // clang-format off InDataType, // - OutDataType, // + WeiDataType, // + OutDataType, // AccDataType, // - InElementOp, // Input Elementwise Operation - WeiElementOp, // Weights Elementwise Operation - OutElementOp, // Output Elementwise Operation + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation ConvFwdDefault, // ConvForwardSpecialization - 2, // SptialDims + SpatialDims, // SptialDims 256, // BlockSize 256, // MPerBlock 128, // NPerBlock @@ -73,171 +79,348 @@ using DeviceConvFwdInstance = ck::tensor_operation::device:: 1>; // CThreadTransferDstScalarPerVector // clang-format on -using ReferenceConvFwdInstance = ck::tensor_operation::host:: - ReferenceConvFwd; +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +template +HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dims, + const TensorLayout& layout) +{ + std::size_t N = dims[0]; + std::size_t C = dims[1]; + // 1D + if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor(std::vector({N, C, dims[2]}), + std::vector({C * dims[2], dims[2], 1})); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor(std::vector({N, C, dims[2]}), + std::vector({C * dims[2], 1, C})); + } + // 2D + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor( + std::vector({N, C, dims[2], dims[3]}), + std::vector({C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1})); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor( + std::vector({N, C, dims[2], dims[3]}), + std::vector({C * dims[2] * dims[3], 1, dims[3] * C, C})); + } + + std::stringstream err_msg; + err_msg << "Unsupported data layout provided: " << layout << "!"; + throw std::runtime_error(err_msg.str()); +} + +DeviceConvFwdBasePtr GetConvInstance(int spatial_dims) +{ + switch(spatial_dims) + { + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +std::size_t GetFlops(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // 2 * N * K * * C * + return ck::index_t(2) * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + 1, + std::multiplies()) * + C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + 1, + std::multiplies()); +} + +ck::index_t GetBtype(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // sizeof(InDataType) * (N * C * ) + + // sizeof(WeiDataType) * (K * C * ) + + // sizeof(OutDataType) * (N * K * ); + return sizeof(InDataType) * (N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + 1, + std::multiplies())) + + sizeof(WeiDataType) * (K * C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + 1, + std::multiplies())) + + sizeof(OutDataType) * (N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + 1, + std::multiplies())); +} + +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +struct ConvParams +{ + ConvParams() + : spatial_dims(2), + N(128), + K(256), + C(192), + filter_spatial_lengths(2, 3), + input_spatial_lengths(2, 71), + conv_filter_strides(2, 2), + conv_filter_dilations(2, 1), + input_left_pads(2, 1), + input_right_pads(2, 1) + { + } + + ck::index_t spatial_dims; + ck::index_t N; + ck::index_t K; + ck::index_t C; + + std::vector filter_spatial_lengths; + std::vector input_spatial_lengths; + + std::vector conv_filter_strides; + std::vector conv_filter_dilations; + + std::vector input_left_pads; + std::vector input_right_pads; + + std::vector GetOutputSpatialLengths() const + { + std::vector out_spatial_len(spatial_dims, 0); + for(ck::index_t i = 0; i < spatial_dims; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t idx_eff = + (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; + out_spatial_len[i] = + (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / + conv_filter_strides[i] + + 1; + } + return out_spatial_len; + } +}; + +ConvParams ParseConvParams(int spatial_dims, int argc, char* argv[]) +{ + // (N, K, C) + spatial_dims * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + spatial_dims * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + PrintUseMsg(); + exit(0); + } + + ConvParams params; + int arg_idx = 5; + + params.spatial_dims = spatial_dims; + params.N = std::stoi(argv[arg_idx++]); + params.K = std::stoi(argv[arg_idx++]); + params.C = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths.resize(spatial_dims); + for(int i = 0; i < spatial_dims; ++i) + { + params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths.resize(spatial_dims); + for(int i = 0; i < spatial_dims; ++i) + { + params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides.resize(spatial_dims); + for(int i = 0; i < spatial_dims; ++i) + { + params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations.resize(spatial_dims); + for(int i = 0; i < spatial_dims; ++i) + { + params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads.resize(spatial_dims); + for(int i = 0; i < spatial_dims; ++i) + { + params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads.resize(spatial_dims); + for(int i = 0; i < spatial_dims; ++i) + { + params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} int main(int argc, char* argv[]) { bool do_verification = 0; int init_method = 0; int nrepeat = 5; + int spatial_dims = 2; + + ConvParams params; - // Conv shape - ck::index_t N = 128; - ck::index_t K = 256; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t conv_stride_h = 2; - ck::index_t conv_stride_w = 2; - ck::index_t conv_dilation_h = 1; - ck::index_t conv_dilation_w = 1; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - - if(argc == 4) + if(argc >= 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); nrepeat = std::stoi(argv[3]); } - else if(argc == 19) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - N = std::stoi(argv[4]); - K = std::stoi(argv[5]); - C = std::stoi(argv[6]); - Y = std::stoi(argv[7]); - X = std::stoi(argv[8]); - Hi = std::stoi(argv[9]); - Wi = std::stoi(argv[10]); - conv_stride_h = std::stoi(argv[11]); - conv_stride_w = std::stoi(argv[12]); - conv_dilation_h = std::stoi(argv[13]); - conv_dilation_w = std::stoi(argv[14]); - in_left_pad_h = std::stoi(argv[15]); - in_left_pad_w = std::stoi(argv[16]); - in_right_pad_h = std::stoi(argv[17]); - in_right_pad_w = std::stoi(argv[18]); - } - else + if(argc >= 5) { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(0); + spatial_dims = std::stoi(argv[4]); + params = ParseConvParams(spatial_dims, argc, argv); } - const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; - const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; - - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - - const std::vector conv_filter_strides{{conv_stride_h, conv_stride_w}}; - const std::vector conv_filter_dilations{{conv_dilation_h, conv_dilation_w}}; - const std::vector input_left_pads{{in_left_pad_h, in_left_pad_w}}; - const std::vector input_right_pads{{in_right_pad_h, in_right_pad_w}}; - - // tensor layout - auto f_host_tensor_descriptor = [](std::size_t N_, - std::size_t C_, - std::size_t H, - std::size_t W, - auto layout) { - if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, H * W, W, 1})); - } - else if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - } - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); - Tensor out_n_k_ho_wo_host_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - Tensor out_n_k_ho_wo_device_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; - std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; - std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor host_output(GetHostTensorDescriptor(output_dims, OutLayout{})); + Tensor device_output(GetHostTensorDescriptor(output_dims, OutLayout{})); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; switch(init_method) { case 0: break; case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * - out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); // do GEMM - auto conv = DeviceConvFwdInstance{}; - auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - std::vector{{Hi, Wi}}, - std::vector{{Y, X}}, - std::vector{{Ho, Wo}}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(!conv.IsSupportedArgument(argument)) + auto conv = GetConvInstance(spatial_dims); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) { throw std::runtime_error( "wrong! device_conv with the specified compilation parameters does " "not support this Conv problem"); } - float ave_time = invoker.Run(argument, nrepeat); - - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; + float ave_time = invoker->Run(argument.get(), nrepeat); - std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + - sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo); + ck::index_t flop = GetFlops( + params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + ck::index_t num_btype = GetBtype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -248,24 +431,40 @@ int main(int argc, char* argv[]) if(do_verification) { - auto ref_conv = ReferenceConvFwdInstance{}; - auto ref_invoker = ref_conv.MakeInvoker(); - - auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); - - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output]( + const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + check_error(host_output, device_output); + }; + + switch(spatial_dims) + { + case 2: { + auto ref_conv = ReferenceConvNDFwdInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvNDFwdInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } } } From ed3c16a45f09f681057d41e3d8a834a58b1d6eac Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 11 Feb 2022 12:24:26 +0100 Subject: [PATCH 09/82] Rename conv nd example director to prevent conflicts. --- example/{8_convnd_fwd_xdl => 9_convnd_fwd_xdl}/README.md | 0 .../{8_convnd_fwd_xdl => 9_convnd_fwd_xdl}/convnd_fwd_xdl.cpp | 0 example/CMakeLists.txt | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename example/{8_convnd_fwd_xdl => 9_convnd_fwd_xdl}/README.md (100%) rename example/{8_convnd_fwd_xdl => 9_convnd_fwd_xdl}/convnd_fwd_xdl.cpp (100%) diff --git a/example/8_convnd_fwd_xdl/README.md b/example/9_convnd_fwd_xdl/README.md similarity index 100% rename from example/8_convnd_fwd_xdl/README.md rename to example/9_convnd_fwd_xdl/README.md diff --git a/example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp similarity index 100% rename from example/8_convnd_fwd_xdl/convnd_fwd_xdl.cpp rename to example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index fce6f047cc4..ff5d0ab703b 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -19,7 +19,7 @@ set(CONV2D_FWD_XDL_SOURCE 4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_SOURCE 5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) -set(CONVND_FWD_XDL_SOURCE 8_convnd_fwd_xdl/convnd_fwd_xdl.cpp) +set(CONVND_FWD_XDL_SOURCE 9_convnd_fwd_xdl/convnd_fwd_xdl.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE}) From 7268be86c118f27c37ec4a6bebef936e326b6c68 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 16 Feb 2022 12:06:49 +0100 Subject: [PATCH 10/82] Refactor some common utility to single file. Plus some tests. --- .../include/utility/conv_utils.hpp | 122 ++++++++++++++ example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 144 ++++------------- test/CMakeLists.txt | 3 + test/conv_util/main.cpp | 149 ++++++++++++++++++ 4 files changed, 302 insertions(+), 116 deletions(-) create mode 100644 composable_kernel/include/utility/conv_utils.hpp create mode 100644 test/conv_util/main.cpp diff --git a/composable_kernel/include/utility/conv_utils.hpp b/composable_kernel/include/utility/conv_utils.hpp new file mode 100644 index 00000000000..f7313f786a2 --- /dev/null +++ b/composable_kernel/include/utility/conv_utils.hpp @@ -0,0 +1,122 @@ +#include +#include +#include +#include +#include + +#include "config.hpp" + +namespace ck { +namespace conv_util { + +/** + * @brief Calculate number of FLOPs for Convolution + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Convolution output spatial dimensions + * lengths. + * + * @return The number of flops. + */ +std::size_t GetFlops(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // 2 * N * K * * C * + return ck::index_t(2) * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + 1, + std::multiplies()) * + C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + 1, + std::multiplies()); +} + +template +ck::index_t GetBtype(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // sizeof(InDataType) * (N * C * ) + + // sizeof(WeiDataType) * (K * C * ) + + // sizeof(OutDataType) * (N * K * ); + return sizeof(InDataType) * (N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + 1, + std::multiplies())) + + sizeof(WeiDataType) * (K * C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + 1, + std::multiplies())) + + sizeof(OutDataType) * (N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + 1, + std::multiplies())); +} + +struct ConvParams +{ + ConvParams() + : spatial_dims(2), + N(128), + K(256), + C(192), + filter_spatial_lengths(2, 3), + input_spatial_lengths(2, 71), + conv_filter_strides(2, 2), + conv_filter_dilations(2, 1), + input_left_pads(2, 1), + input_right_pads(2, 1) + { + } + + ck::index_t spatial_dims; + ck::index_t N; + ck::index_t K; + ck::index_t C; + + std::vector filter_spatial_lengths; + std::vector input_spatial_lengths; + + std::vector conv_filter_strides; + std::vector conv_filter_dilations; + + std::vector input_left_pads; + std::vector input_right_pads; + + std::vector GetOutputSpatialLengths() const + { + std::vector out_spatial_len(spatial_dims, 0); + for(ck::index_t i = 0; i < spatial_dims; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t idx_eff = + (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; + out_spatial_len[i] = + (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / + conv_filter_strides[i] + + 1; + } + return out_spatial_len; + } +}; + +} // namespace conv_util +} // namespace ck diff --git a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp index 49a2508e61f..cbefd92fd1c 100644 --- a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -1,21 +1,20 @@ #include -#include #include #include #include -#include #include #include "config.hpp" -#include "print.hpp" +#include "conv_utils.hpp" #include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" #include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "element_wise_operation.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "print.hpp" #include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" using InDataType = float; using WeiDataType = float; @@ -150,52 +149,6 @@ DeviceConvFwdBasePtr GetConvInstance(int spatial_dims) } } -std::size_t GetFlops(ck::index_t N, - ck::index_t C, - ck::index_t K, - const std::vector& filter_spatial_lengths, - const std::vector& output_spatial_lengths) -{ - // 2 * N * K * * C * - return ck::index_t(2) * N * K * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), - 1, - std::multiplies()) * - C * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), - 1, - std::multiplies()); -} - -ck::index_t GetBtype(ck::index_t N, - ck::index_t C, - ck::index_t K, - const std::vector& input_spatial_lengths, - const std::vector& filter_spatial_lengths, - const std::vector& output_spatial_lengths) -{ - // sizeof(InDataType) * (N * C * ) + - // sizeof(WeiDataType) * (K * C * ) + - // sizeof(OutDataType) * (N * K * ); - return sizeof(InDataType) * (N * C * - std::accumulate(std::begin(input_spatial_lengths), - std::end(input_spatial_lengths), - 1, - std::multiplies())) + - sizeof(WeiDataType) * (K * C * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), - 1, - std::multiplies())) + - sizeof(OutDataType) * (N * K * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), - 1, - std::multiplies())); -} - void PrintUseMsg() { std::cout << "arg1: verification (0=no, 1=yes)\n" @@ -213,55 +166,7 @@ void PrintUseMsg() << std::endl; } -struct ConvParams -{ - ConvParams() - : spatial_dims(2), - N(128), - K(256), - C(192), - filter_spatial_lengths(2, 3), - input_spatial_lengths(2, 71), - conv_filter_strides(2, 2), - conv_filter_dilations(2, 1), - input_left_pads(2, 1), - input_right_pads(2, 1) - { - } - - ck::index_t spatial_dims; - ck::index_t N; - ck::index_t K; - ck::index_t C; - - std::vector filter_spatial_lengths; - std::vector input_spatial_lengths; - - std::vector conv_filter_strides; - std::vector conv_filter_dilations; - - std::vector input_left_pads; - std::vector input_right_pads; - - std::vector GetOutputSpatialLengths() const - { - std::vector out_spatial_len(spatial_dims, 0); - for(ck::index_t i = 0; i < spatial_dims; ++i) - { - // XEff = (X - 1) * conv_dilation_w + 1; - // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - const ck::index_t idx_eff = - (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; - out_spatial_len[i] = - (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / - conv_filter_strides[i] + - 1; - } - return out_spatial_len; - } -}; - -ConvParams ParseConvParams(int spatial_dims, int argc, char* argv[]) +ck::conv_util::ConvParams ParseConvParams(int spatial_dims, int argc, char* argv[]) { // (N, K, C) + spatial_dims * 6 (filter, input, strides, dilations, pad left, pad right) int conv_args = 3 + spatial_dims * 6; @@ -272,7 +177,7 @@ ConvParams ParseConvParams(int spatial_dims, int argc, char* argv[]) exit(0); } - ConvParams params; + ck::conv_util::ConvParams params; int arg_idx = 5; params.spatial_dims = spatial_dims; @@ -321,19 +226,19 @@ int main(int argc, char* argv[]) int nrepeat = 5; int spatial_dims = 2; - ConvParams params; + ck::conv_util::ConvParams params; - if(argc >= 4) + if(argc >= 5) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); nrepeat = std::stoi(argv[3]); + spatial_dims = std::stoi(argv[4]); } - if(argc >= 5) + if(argc >= 6) { - spatial_dims = std::stoi(argv[4]); - params = ParseConvParams(spatial_dims, argc, argv); + params = ParseConvParams(spatial_dims, argc, argv); } std::vector input_dims{static_cast(params.N), @@ -364,6 +269,9 @@ int main(int argc, char* argv[]) std::cout << "weights: " << weights.mDesc << std::endl; std::cout << "output: " << host_output.mDesc << std::endl; + // std::iota(input.begin(), input.end(), InDataType(0.f)); + // std::fill(weights.begin(), weights.end(), WeiDataType(0.25f)); + switch(init_method) { case 0: break; @@ -413,14 +321,15 @@ int main(int argc, char* argv[]) float ave_time = invoker->Run(argument.get(), nrepeat); - ck::index_t flop = GetFlops( + ck::index_t flop = ck::conv_util::GetFlops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); - ck::index_t num_btype = GetBtype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths); + ck::index_t num_btype = + ck::conv_util::GetBtype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -447,7 +356,10 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); - check_error(host_output, device_output); + // check_error(host_output, device_output); + + // LogRange(std::cout <<"host_output:\n", host_output, ", "); + // LogRange(std::cout <<"device_output:\n", device_output, ", "); }; switch(spatial_dims) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1b3e1e57e5e..0d5b8dc1619 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -21,3 +21,6 @@ set(SPLIT_K_SOURCE split_k/main.cpp) add_executable(test_split_k ${SPLIT_K_SOURCE}) target_link_libraries(test_split_k PRIVATE host_tensor) target_link_libraries(test_split_k PRIVATE device_gemm_instance) +# test_conv_util +set(CONV_UTIL_SOURCE conv_util/main.cpp) +add_executable(test_conv_util ${CONV_UTIL_SOURCE}) diff --git a/test/conv_util/main.cpp b/test/conv_util/main.cpp new file mode 100644 index 00000000000..1f936c26964 --- /dev/null +++ b/test/conv_util/main.cpp @@ -0,0 +1,149 @@ +#include +#include + +#include "config.hpp" +#include "conv_utils.hpp" + +namespace { + +bool cmp_vec(const std::vector& out, const std::vector& ref) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + for(std::size_t i = 0; i < ref.size(); ++i) + { + if(out[i] != ref[i]) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] + << std::endl; + return false; + } + } + return true; +} + +} // namespace + +static bool TestConvParams_GetOutputSpatialLengths() +{ + bool res{true}; + // -------------------------- default 2D ------------------------------------ + // input NCHW {128,192,71,71}, + // weights KCYX {256,192,3,3}, + // stride {2,2}, + // dilations {1,1}, + // padding {{1,1}, {1,1}} + ck::conv_util::ConvParams conv_params; + std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{36, 36})) + { + std::cout << "Error: ConvParams 2D default constructor." << std::endl; + res = false; + } + + conv_params.conv_filter_strides = std::vector{1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{71, 71})) + { + std::cout << "Error: ConvParams 2D stride {1,1}." << std::endl; + res = false; + } + + conv_params.conv_filter_strides = std::vector{2, 2}; + conv_params.input_left_pads = std::vector{2, 2}; + conv_params.input_right_pads = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{37, 37})) + { + std::cout << "Error: ConvParams 2D padding left/right {2,2}." << std::endl; + res = false; + } + + conv_params.conv_filter_dilations = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{36, 36})) + { + std::cout << "Error: ConvParams 2D dilation {2,2}." << std::endl; + res = false; + } + + conv_params.conv_filter_strides = std::vector{3, 3}; + conv_params.input_left_pads = std::vector{1, 1}; + conv_params.input_right_pads = std::vector{1, 1}; + conv_params.conv_filter_dilations = std::vector{2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{23, 23})) + { + std::cout << "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}." + << std::endl; + res = false; + } + + // -------------------------- 1D ------------------------------------ + conv_params.spatial_dims = 1; + conv_params.filter_spatial_lengths = std::vector{3}; + conv_params.input_spatial_lengths = std::vector{71}; + conv_params.conv_filter_strides = std::vector{2}; + conv_params.conv_filter_dilations = std::vector{1}; + conv_params.input_left_pads = std::vector{1}; + conv_params.input_right_pads = std::vector{1}; + + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{36})) + { + std::cout << "Error: ConvParams 1D default constructor." << std::endl; + res = false; + } + + conv_params.conv_filter_strides = std::vector{1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{71})) + { + std::cout << "Error: ConvParams 1D stride {1}." << std::endl; + res = false; + } + + conv_params.conv_filter_strides = std::vector{2}; + conv_params.input_left_pads = std::vector{2}; + conv_params.input_right_pads = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{37})) + { + std::cout << "Error: ConvParams 1D padding left/right {2}." << std::endl; + res = false; + } + + conv_params.conv_filter_dilations = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{36})) + { + std::cout << "Error: ConvParams 1D dilation {2}." << std::endl; + res = false; + } + + conv_params.conv_filter_strides = std::vector{3}; + conv_params.input_left_pads = std::vector{1}; + conv_params.input_right_pads = std::vector{1}; + conv_params.conv_filter_dilations = std::vector{2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + if(!cmp_vec(out_spatial_len, std::vector{23})) + { + std::cout << "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}." << std::endl; + res = false; + } + + return res; +} + +int main(void) +{ + bool res = TestConvParams_GetOutputSpatialLengths(); + std::cout << "TestConvParams_GetOutputSpatialLengths ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + return 0; +} From 9122de156aae37d981310c1bea7154b10b1bc329 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 16 Feb 2022 15:21:31 +0100 Subject: [PATCH 11/82] Refactor GetHostTensorDescriptor + UT. --- .../include/utility/conv_utils.hpp | 76 ++++++++++++ example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 56 +-------- test/CMakeLists.txt | 2 + test/conv_util/main.cpp | 113 +++++++++--------- 4 files changed, 140 insertions(+), 107 deletions(-) diff --git a/composable_kernel/include/utility/conv_utils.hpp b/composable_kernel/include/utility/conv_utils.hpp index f7313f786a2..a170356996f 100644 --- a/composable_kernel/include/utility/conv_utils.hpp +++ b/composable_kernel/include/utility/conv_utils.hpp @@ -1,10 +1,17 @@ +#ifndef CONV_UTILS_HPP +#define CONV_UTILS_HPP + #include #include #include #include +#include +#include #include #include "config.hpp" +#include "host_tensor.hpp" +#include "tensor_layout.hpp" namespace ck { namespace conv_util { @@ -40,6 +47,22 @@ std::size_t GetFlops(ck::index_t N, std::multiplies()); } +/** + * @brief Calculate number of bytes read/write by convolution algorithm. + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] input_spatial_lengths Input spatial dimensions lengths. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Output spatial dimensions lengths + * + * @tparam InDataType Input tensor data type. + * @tparam WeiDataType Weights tensor data type. + * @tparam OutDataType Output tensor data type. + * + * @return The number of used bytes. + */ template @@ -118,5 +141,58 @@ struct ConvParams } }; +/** + * @brief Gets the host tensor descriptor. + * + * @param[in] dims The tensor dimensions lengths. Always in NCHW format. + * @param[in] layout The tensor data layout. + * + * @tparam TensorLayout Layout type. + * + * @return The host tensor descriptor object. + */ +template +HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dims, + const TensorLayout& layout) +{ + std::size_t C = dims[1]; + // 1D + if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor(dims, std::vector({C * dims[2], dims[2], 1})); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor(dims, std::vector({C * dims[2], 1, C})); + } + // 2D + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1}); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3], 1, dims[3] * C, C}); + } + + std::stringstream err_msg; + err_msg << "Unsupported data layout provided: " << layout << "!"; + throw std::runtime_error(err_msg.str()); +} + } // namespace conv_util } // namespace ck + +#endif diff --git a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp index cbefd92fd1c..10d3c9b534f 100644 --- a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -87,52 +87,6 @@ using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd< OutElementOp, SpatialDims>; -template -HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dims, - const TensorLayout& layout) -{ - std::size_t N = dims[0]; - std::size_t C = dims[1]; - // 1D - if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - - return HostTensorDescriptor(std::vector({N, C, dims[2]}), - std::vector({C * dims[2], dims[2], 1})); - } - else if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - return HostTensorDescriptor(std::vector({N, C, dims[2]}), - std::vector({C * dims[2], 1, C})); - } - // 2D - else if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - - return HostTensorDescriptor( - std::vector({N, C, dims[2], dims[3]}), - std::vector({C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1})); - } - else if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - return HostTensorDescriptor( - std::vector({N, C, dims[2], dims[3]}), - std::vector({C * dims[2] * dims[3], 1, dims[3] * C, C})); - } - - std::stringstream err_msg; - err_msg << "Unsupported data layout provided: " << layout << "!"; - throw std::runtime_error(err_msg.str()); -} - DeviceConvFwdBasePtr GetConvInstance(int spatial_dims) { switch(spatial_dims) @@ -260,10 +214,12 @@ int main(int argc, char* argv[]) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(GetHostTensorDescriptor(input_dims, InLayout{})); - Tensor weights(GetHostTensorDescriptor(filter_dims, WeiLayout{})); - Tensor host_output(GetHostTensorDescriptor(output_dims, OutLayout{})); - Tensor device_output(GetHostTensorDescriptor(output_dims, OutLayout{})); + Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + Tensor device_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0d5b8dc1619..dc5987c0ffb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -21,6 +21,8 @@ set(SPLIT_K_SOURCE split_k/main.cpp) add_executable(test_split_k ${SPLIT_K_SOURCE}) target_link_libraries(test_split_k PRIVATE host_tensor) target_link_libraries(test_split_k PRIVATE device_gemm_instance) + # test_conv_util set(CONV_UTIL_SOURCE conv_util/main.cpp) add_executable(test_conv_util ${CONV_UTIL_SOURCE}) +target_link_libraries(test_conv_util PRIVATE host_tensor) diff --git a/test/conv_util/main.cpp b/test/conv_util/main.cpp index 1f936c26964..994caf00136 100644 --- a/test/conv_util/main.cpp +++ b/test/conv_util/main.cpp @@ -1,17 +1,21 @@ #include +#include #include #include "config.hpp" #include "conv_utils.hpp" +#include "tensor_layout.hpp" namespace { -bool cmp_vec(const std::vector& out, const std::vector& ref) +template +bool cmp_vec(const std::vector& out, const std::vector& ref, const std::string& msg) { if(out.size() != ref.size()) { std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl; + << std::endl + << msg << std::endl; return false; } @@ -20,16 +24,15 @@ bool cmp_vec(const std::vector& out, const std::vector if(out[i] != ref[i]) { std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] - << std::endl; + << std::endl + << msg << std::endl; return false; } } return true; } -} // namespace - -static bool TestConvParams_GetOutputSpatialLengths() +bool TestConvParams_GetOutputSpatialLengths() { bool res{true}; // -------------------------- default 2D ------------------------------------ @@ -40,49 +43,36 @@ static bool TestConvParams_GetOutputSpatialLengths() // padding {{1,1}, {1,1}} ck::conv_util::ConvParams conv_params; std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{36, 36})) - { - std::cout << "Error: ConvParams 2D default constructor." << std::endl; - res = false; - } + res = cmp_vec(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D default constructor."); conv_params.conv_filter_strides = std::vector{1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{71, 71})) - { - std::cout << "Error: ConvParams 2D stride {1,1}." << std::endl; - res = false; - } + res = cmp_vec( + out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}."); conv_params.conv_filter_strides = std::vector{2, 2}; conv_params.input_left_pads = std::vector{2, 2}; conv_params.input_right_pads = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{37, 37})) - { - std::cout << "Error: ConvParams 2D padding left/right {2,2}." << std::endl; - res = false; - } + res = cmp_vec(out_spatial_len, + std::vector{37, 37}, + "Error: ConvParams 2D padding left/right {2,2}."); conv_params.conv_filter_dilations = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{36, 36})) - { - std::cout << "Error: ConvParams 2D dilation {2,2}." << std::endl; - res = false; - } + res = cmp_vec( + out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}."); conv_params.conv_filter_strides = std::vector{3, 3}; conv_params.input_left_pads = std::vector{1, 1}; conv_params.input_right_pads = std::vector{1, 1}; conv_params.conv_filter_dilations = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{23, 23})) - { - std::cout << "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}." - << std::endl; - res = false; - } + res = cmp_vec(out_spatial_len, + std::vector{23, 23}, + "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."); // -------------------------- 1D ------------------------------------ conv_params.spatial_dims = 1; @@ -94,56 +84,65 @@ static bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{36})) - { - std::cout << "Error: ConvParams 1D default constructor." << std::endl; - res = false; - } + res = cmp_vec( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D default constructor."); conv_params.conv_filter_strides = std::vector{1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{71})) - { - std::cout << "Error: ConvParams 1D stride {1}." << std::endl; - res = false; - } + res = + cmp_vec(out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}."); conv_params.conv_filter_strides = std::vector{2}; conv_params.input_left_pads = std::vector{2}; conv_params.input_right_pads = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{37})) - { - std::cout << "Error: ConvParams 1D padding left/right {2}." << std::endl; - res = false; - } + res = cmp_vec(out_spatial_len, + std::vector{37}, + "Error: ConvParams 1D padding left/right {2}."); conv_params.conv_filter_dilations = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{36})) - { - std::cout << "Error: ConvParams 1D dilation {2}." << std::endl; - res = false; - } + res = cmp_vec( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}."); conv_params.conv_filter_strides = std::vector{3}; conv_params.input_left_pads = std::vector{1}; conv_params.input_right_pads = std::vector{1}; conv_params.conv_filter_dilations = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - if(!cmp_vec(out_spatial_len, std::vector{23})) - { - std::cout << "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}." << std::endl; - res = false; - } + res = cmp_vec(out_spatial_len, + std::vector{23}, + "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."); return res; } +bool TestGetHostTensorDescriptor() +{ + bool res{true}; + namespace tl = ck::tensor_layout::convolution; + std::vector dims{2, 3, 4, 5}; + HostTensorDescriptor h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"); + res = + cmp_vec(h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"); + + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCHW{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"); + res = + cmp_vec(h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); + + return res; +} + +} // namespace + int main(void) { bool res = TestConvParams_GetOutputSpatialLengths(); std::cout << "TestConvParams_GetOutputSpatialLengths ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestGetHostTensorDescriptor(); + std::cout << "TestGetHostTensorDescriptor ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return 0; } From 62066ac576a1b7ddcfabde5f66ee005e2246c835 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Feb 2022 12:40:30 +0100 Subject: [PATCH 12/82] Add 1D test case. --- test/conv_util/main.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/conv_util/main.cpp b/test/conv_util/main.cpp index 994caf00136..866545fdd65 100644 --- a/test/conv_util/main.cpp +++ b/test/conv_util/main.cpp @@ -132,6 +132,15 @@ bool TestGetHostTensorDescriptor() res = cmp_vec(h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); + dims = std::vector{2, 3, 4}; + HostTensorDescriptor h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); + res = cmp_vec(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); + + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCW{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); + res = cmp_vec(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); + return res; } From 8b3d2557139692226c5a4a8f4e922ad58ea304e6 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Feb 2022 12:41:03 +0100 Subject: [PATCH 13/82] Test reference convolution 1d/2d --- .../element_wise_operation.hpp | 2 + test/CMakeLists.txt | 6 + test/reference_conv_fwd/main.cpp | 264 ++++++++++++++++++ 3 files changed, 272 insertions(+) create mode 100644 test/reference_conv_fwd/main.cpp diff --git a/composable_kernel/include/tensor_operation/element_wise_operation.hpp b/composable_kernel/include/tensor_operation/element_wise_operation.hpp index c2fe6a9f465..695e807cf7c 100644 --- a/composable_kernel/include/tensor_operation/element_wise_operation.hpp +++ b/composable_kernel/include/tensor_operation/element_wise_operation.hpp @@ -1,6 +1,8 @@ #ifndef CK_ELEMENT_WISE_OPERATION_HPP #define CK_ELEMENT_WISE_OPERATION_HPP +#include "data_type.hpp" + namespace ck { namespace tensor_operation { namespace element_wise { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index dc5987c0ffb..00f9965a6b0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -9,6 +9,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform ${PROJECT_SOURCE_DIR}/external/rocm/include + ${PROJECT_SOURCE_DIR}/reference_operation/include ) # test_magic_number_division @@ -26,3 +27,8 @@ target_link_libraries(test_split_k PRIVATE device_gemm_instance) set(CONV_UTIL_SOURCE conv_util/main.cpp) add_executable(test_conv_util ${CONV_UTIL_SOURCE}) target_link_libraries(test_conv_util PRIVATE host_tensor) + +# test_reference_conv_fwd +set(REFERENCE_CONV_FWD_SOURCE reference_conv_fwd/main.cpp) +add_executable(test_reference_conv_fwd ${REFERENCE_CONV_FWD_SOURCE}) +target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor) diff --git a/test/reference_conv_fwd/main.cpp b/test/reference_conv_fwd/main.cpp new file mode 100644 index 00000000000..2f756f59857 --- /dev/null +++ b/test/reference_conv_fwd/main.cpp @@ -0,0 +1,264 @@ +#include +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "conv_utils.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "reference_conv_fwd.hpp" + +namespace { +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +template +Tensor RunReferenceConv(const ck::conv_util::ConvParams& params) +{ + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + + // init + std::iota(input.begin(), input.end(), InDataType(0.f)); + std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + return host_output; +} + +template ::value, bool>::type = false> +bool check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg, + T rtol = static_cast(1e-5), + T atol = static_cast(1e-8)) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + for(std::size_t i = 0; i < ref.size(); ++i) + { + if(std::abs(out[i] - ref[i]) > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || + !std::isfinite(ref[i])) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] + << std::endl + << msg << std::endl; + return false; + } + } + return true; +} + +template ::value, bool>::type = false> +bool check_err( + const std::vector& out, const std::vector& ref, const std::string& msg, T = 0, T = 0) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + for(std::size_t i = 0; i < ref.size(); ++i) + { + if(out[i] != ref[i]) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] + << std::endl + << msg << std::endl; + return false; + } + } + return true; +} + +bool TestConv2DNHWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.N = 1; + params.K = 1; + params.C = 2; + params.filter_spatial_lengths = std::vector{3, 3}; + params.input_spatial_lengths = std::vector{6, 6}; + params.conv_filter_strides = std::vector{1, 1}; + params.conv_filter_dilations = std::vector{1, 1}; + params.input_left_pads = std::vector{0, 0}; + params.input_right_pads = std::vector{0, 0}; + + auto out_tensor = RunReferenceConv<2>(params); + std::vector ref_dims{1, 1, 4, 4}; + std::vector ref_data{130.5, + 148.5, + 166.5, + 184.5, + 238.5, + 256.5, + 274.5, + 292.5, + 346.5, + 364.5, + 382.5, + 400.5, + 454.5, + 472.5, + 490.5, + 508.5}; + res = res && check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + + params.N = 1; + params.K = 2; + params.C = 2; + params.filter_spatial_lengths = std::vector{3, 3}; + params.input_spatial_lengths = std::vector{12, 12}; + params.conv_filter_strides = std::vector{2, 2}; + params.conv_filter_dilations = std::vector{2, 2}; + params.input_left_pads = std::vector{1, 1}; + params.input_right_pads = std::vector{1, 1}; + + out_tensor = RunReferenceConv<2>(params); + ref_dims = std::vector{1, 2, 5, 5}; + ref_data = std::vector{ + 210., 210., 327., 327., 351., 351., 375., 375., 399., 399., + 459., 459., 706.5, 706.5, 742.5, 742.5, 778.5, 778.5, 814.5, 814.5, + 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, + 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, + 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; + res = res && check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + + return res; +} + +bool TestConv1DNHWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.spatial_dims = 1; + params.N = 1; + params.K = 1; + params.C = 2; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{6}; + params.conv_filter_strides = std::vector{1}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{0}; + params.input_right_pads = std::vector{0}; + + auto out_tensor = RunReferenceConv<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + std::vector ref_dims{1, 1, 4}; + std::vector ref_data{7.5, 13.5, 19.5, 25.5}; + res = res && check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + + params.spatial_dims = 1; + params.N = 1; + params.K = 2; + params.C = 2; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{12}; + params.conv_filter_strides = std::vector{2}; + params.conv_filter_dilations = std::vector{2}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + out_tensor = RunReferenceConv<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); + ref_dims = std::vector{1, 2, 5}; + ref_data = std::vector{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; + res = res && check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + + return res; +} + +} // anonymous namespace + +int main(void) +{ + bool res{true}; + res = TestConv2DNHWC(); + std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv1DNHWC(); + std::cout << "TestConv1DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + return 0; +} From 18d2ab6bcdc458565754c2950671feafc51f561a Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Feb 2022 12:41:27 +0100 Subject: [PATCH 14/82] Remove some leftovers. --- example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp index 10d3c9b534f..c4a15f71166 100644 --- a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -225,9 +225,6 @@ int main(int argc, char* argv[]) std::cout << "weights: " << weights.mDesc << std::endl; std::cout << "output: " << host_output.mDesc << std::endl; - // std::iota(input.begin(), input.end(), InDataType(0.f)); - // std::fill(weights.begin(), weights.end(), WeiDataType(0.25f)); - switch(init_method) { case 0: break; @@ -288,9 +285,7 @@ int main(int argc, char* argv[]) output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; @@ -312,10 +307,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); - // check_error(host_output, device_output); - - // LogRange(std::cout <<"host_output:\n", host_output, ", "); - // LogRange(std::cout <<"device_output:\n", device_output, ", "); }; switch(spatial_dims) From 50e10fa5dbc3e4c1125cfcc123af17ec73d8112d Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Feb 2022 15:05:23 +0100 Subject: [PATCH 15/82] Fix convolution example error for 1D --- example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 75 +++++++++++++++++---- 1 file changed, 63 insertions(+), 12 deletions(-) diff --git a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp index c4a15f71166..50cf402f2bf 100644 --- a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -12,7 +12,6 @@ #include "element_wise_operation.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "print.hpp" #include "reference_conv_fwd.hpp" #include "tensor_layout.hpp" @@ -24,10 +23,6 @@ using AccDataType = float; template using S = ck::Sequence; -using InLayout = ck::tensor_layout::convolution::NHWC; -using WeiLayout = ck::tensor_layout::convolution::KYXC; -using OutLayout = ck::tensor_layout::convolution::NHWK; - using InElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -173,6 +168,63 @@ ck::conv_util::ConvParams ParseConvParams(int spatial_dims, int argc, char* argv return params; } +HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, + int spatial_dims = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(spatial_dims) + { + case 2: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWK{}); + } + case 1: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWK{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector& dims, + int spatial_dims = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(spatial_dims) + { + case 2: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::KYXC{}); + } + case 1: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::KXC{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector& dims, + int spatial_dims = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(spatial_dims) + { + case 2: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); + } + case 1: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + int main(int argc, char* argv[]) { bool do_verification = 0; @@ -214,12 +266,10 @@ int main(int argc, char* argv[]) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); - Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); - Tensor host_output( - ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - Tensor device_output( - ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + Tensor input(GetInputHostTensorDescriptor(input_dims, spatial_dims)); + Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, spatial_dims)); + Tensor host_output(GetOutputHostTensorDescriptor(output_dims, spatial_dims)); + Tensor device_output(GetOutputHostTensorDescriptor(output_dims, spatial_dims)); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl; @@ -284,7 +334,7 @@ int main(int argc, char* argv[]) params.filter_spatial_lengths, output_spatial_lengths); - float tflops = static_cast(flop) / 1.E9 / ave_time; + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; @@ -307,6 +357,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); + check_error(host_output, device_output); }; switch(spatial_dims) From 07b6b6cfadeb1bc501cd467ccfb2fd33fd2b1511 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 18 Feb 2022 17:15:31 +0100 Subject: [PATCH 16/82] Refactor test check errors utility function. --- example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 1 - test/include/test_util.hpp | 84 +++++++++++++++++++ test/reference_conv_fwd/main.cpp | 89 +++++---------------- 3 files changed, 103 insertions(+), 71 deletions(-) create mode 100644 test/include/test_util.hpp diff --git a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp index 50cf402f2bf..46b18060725 100644 --- a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp new file mode 100644 index 00000000000..f779c3dd1d6 --- /dev/null +++ b/test/include/test_util.hpp @@ -0,0 +1,84 @@ +#ifndef TEST_UTIL_HPP +#define TEST_UTIL_HPP + +#include +#include +#include +#include +#include +#include +#include + +namespace test_util { + +template +typename std::enable_if::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg, + T rtol = static_cast(1e-5), + T atol = static_cast(1e-8)) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + T err = 0; + T max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + err = std::abs(out[i] - ref[i]); + if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << out[i] << "!=" << ref[i] << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value, bool>::type check_err( + const std::vector& out, const std::vector& ref, const std::string& msg, T = 0, T = 0) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + for(std::size_t i = 0; i < ref.size(); ++i) + { + if(out[i] != ref[i]) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] + << std::endl + << msg << std::endl; + return false; + } + } + return true; +} + +} // namespace test_util + +#endif diff --git a/test/reference_conv_fwd/main.cpp b/test/reference_conv_fwd/main.cpp index 2f756f59857..79e6c47e399 100644 --- a/test/reference_conv_fwd/main.cpp +++ b/test/reference_conv_fwd/main.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -10,6 +11,8 @@ #include "element_wise_operation.hpp" #include "host_tensor.hpp" #include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" namespace { using InElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -77,60 +80,6 @@ Tensor RunReferenceConv(const ck::conv_util::ConvParams& params) return host_output; } -template ::value, bool>::type = false> -bool check_err(const std::vector& out, - const std::vector& ref, - const std::string& msg, - T rtol = static_cast(1e-5), - T atol = static_cast(1e-8)) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - for(std::size_t i = 0; i < ref.size(); ++i) - { - if(std::abs(out[i] - ref[i]) > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || - !std::isfinite(ref[i])) - { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] - << std::endl - << msg << std::endl; - return false; - } - } - return true; -} - -template ::value, bool>::type = false> -bool check_err( - const std::vector& out, const std::vector& ref, const std::string& msg, T = 0, T = 0) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - for(std::size_t i = 0; i < ref.size(); ++i) - { - if(out[i] != ref[i]) - { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] - << std::endl - << msg << std::endl; - return false; - } - } - return true; -} - bool TestConv2DNHWC() { bool res{true}; @@ -163,10 +112,10 @@ bool TestConv2DNHWC() 472.5, 490.5, 508.5}; - res = res && check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); params.N = 1; params.K = 2; @@ -186,10 +135,10 @@ bool TestConv2DNHWC() 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; - res = res && check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); return res; } @@ -218,10 +167,10 @@ bool TestConv1DNHWC() ck::tensor_layout::convolution::NWK>(params); std::vector ref_dims{1, 1, 4}; std::vector ref_data{7.5, 13.5, 19.5, 25.5}; - res = res && check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); params.spatial_dims = 1; params.N = 1; @@ -243,10 +192,10 @@ bool TestConv1DNHWC() ck::tensor_layout::convolution::NWK>(params); ref_dims = std::vector{1, 2, 5}; ref_data = std::vector{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; - res = res && check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); return res; } From 4bb5783126734a8ba4e7d45168bb86dd3bcc0e5c Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 18 Feb 2022 17:15:58 +0100 Subject: [PATCH 17/82] Test Conv2D Fwd XDL --- test/CMakeLists.txt | 6 + test/convnd_fwd_xdl/main.cpp | 225 +++++++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 test/convnd_fwd_xdl/main.cpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 00f9965a6b0..4cd78432e4f 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -10,6 +10,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform ${PROJECT_SOURCE_DIR}/external/rocm/include ${PROJECT_SOURCE_DIR}/reference_operation/include + ${PROJECT_SOURCE_DIR}/test/include ) # test_magic_number_division @@ -32,3 +33,8 @@ target_link_libraries(test_conv_util PRIVATE host_tensor) set(REFERENCE_CONV_FWD_SOURCE reference_conv_fwd/main.cpp) add_executable(test_reference_conv_fwd ${REFERENCE_CONV_FWD_SOURCE}) target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor) + +# test_convnd_fwd_xdl +set(CONVND_FWD_XDL_SOURCE convnd_fwd_xdl/main.cpp) +add_executable(test_convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) +target_link_libraries(test_convnd_fwd_xdl PRIVATE host_tensor) diff --git a/test/convnd_fwd_xdl/main.cpp b/test/convnd_fwd_xdl/main.cpp new file mode 100644 index 00000000000..3a016cd34cc --- /dev/null +++ b/test/convnd_fwd_xdl/main.cpp @@ -0,0 +1,225 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "conv_utils.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" + +namespace { +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + InDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + SpatialDims, // SptialDims + 64, // BlockSize + 16, // MPerBlock + 16, // NPerBlock + 4, // K0PerBlock + 1, // K1 + 16, // MPerXDL + 16, // NPerXDL + 1, // MXdlPerWave + 1, // NXdlPerWave + S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +template +auto GetHostTensors(const ck::conv_util::ConvParams& params) +{ + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + Tensor device_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + + std::generate(input.begin(), input.end(), [n = 0]() mutable { + return InDataType(n++) * InDataType(0.1f); + }); + std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); + + return std::make_tuple(input, weights, host_output, device_output); +} + +template +void RunReferenceConv(const ck::conv_util::ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); +} + +template +void RunConv(const ck::conv_util::ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + + auto conv = DeviceConvNDFwdInstance(); + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + invoker.Run(argument); + out_device_buf.FromDevice(output.mData.data()); +} + +bool TestConv2DNHWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.N = 2; + params.K = 16; + params.C = 4; + params.input_spatial_lengths = std::vector{16, 16}; + params.conv_filter_strides = std::vector{1, 1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + RunReferenceConv<2>(params, input, weights, host_output); + RunConv<2>(params, input, weights, device_output); + res = res && + test_util::check_err( + device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + + return res; +} + +} // anonymous namespace + +int main() +{ + bool res{true}; + res = TestConv2DNHWC(); + std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; +} From 174e6fa874cb169d63eaa24ca8b64f1fc036ffec Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Feb 2022 12:57:13 +0100 Subject: [PATCH 18/82] More UT for 1D case. * Parameterize input & weight initializers. --- test/convnd_fwd_xdl/main.cpp | 37 ++++++++ test/reference_conv_fwd/main.cpp | 144 ++++++++++++++++++++++++++++--- 2 files changed, 169 insertions(+), 12 deletions(-) diff --git a/test/convnd_fwd_xdl/main.cpp b/test/convnd_fwd_xdl/main.cpp index 3a016cd34cc..6a0d95138b4 100644 --- a/test/convnd_fwd_xdl/main.cpp +++ b/test/convnd_fwd_xdl/main.cpp @@ -215,11 +215,48 @@ bool TestConv2DNHWC() return res; } +bool TestConv1DNWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.spatial_dims = 1; + params.N = 2; + params.K = 16; + params.C = 4; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{16}; + params.conv_filter_strides = std::vector{1}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + RunReferenceConv<1>(params, input, weights, host_output); + RunConv<1>(params, input, weights, device_output); + res = res && + test_util::check_err( + device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + + return res; +} + } // anonymous namespace int main() { bool res{true}; + res = TestConv1DNWC(); + std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv2DNHWC(); std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; } diff --git a/test/reference_conv_fwd/main.cpp b/test/reference_conv_fwd/main.cpp index 79e6c47e399..35590874af7 100644 --- a/test/reference_conv_fwd/main.cpp +++ b/test/reference_conv_fwd/main.cpp @@ -19,14 +19,42 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough; +template +struct FillMonotonicSeq +{ + T m_init_value{0}; + + template + void operator(ForwardIter first, ForwardIter last, T init_value = m_init_value) + { + std::iota(first, last, m_init_value); + } +}; + +template +struct FillConstant +{ + T m_value{0}; + + template + void operator(ForwardIter first, ForwardIter last, T value = m_value) + { + std::fill(first, last, value); + } +} + template -Tensor RunReferenceConv(const ck::conv_util::ConvParams& params) + typename InDataType = float, + typename WeiDataType = float, + typename OutDataType = float, + typename InLayout = ck::tensor_layout::convolution::NHWC, + typename WeiLayout = ck::tensor_layout::convolution::KYXC, + typename OutLayout = ck::tensor_layout::convolution::NHWK, + typename FillInputOp = FillMonotonicSeq, + typename FillWeightsOp = FillConstant> +Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, + const FillInputOp& fill_input_op = FillInputOp(), + const FillWeightsOp& fill_weights_op = FillWeightsOp()) { std::vector input_dims{static_cast(params.N), static_cast(params.C)}; @@ -52,9 +80,8 @@ Tensor RunReferenceConv(const ck::conv_util::ConvParams& params) Tensor host_output( ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - // init - std::iota(input.begin(), input.end(), InDataType(0.f)); - std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); + fill_input_op(input.begin(), input.end()); + fill_weights_op()(weights.begin(), weights.end(), WeiDataType(0.5f)); std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{3}; + params.input_spatial_lengths = std::vector{16}; + params.conv_filter_strides = std::vector{1}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + auto out_tensor = + RunReferenceConv<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params, [](auto first, auto last) { + std::generate(first, last, [n = 0]() mutable { return float(n++) * float(0.1f); }); + }); + + ref_dims = std::vector{2, 16, 16}; + ref_data = std::vector{ + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, 1.4, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, 8.1, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, 10.5, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, 12.900001, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, 15.3, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, 17.7, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, 20.1, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, 22.5, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, 24.900002, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, 27.300001, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, 29.7, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, 32.100002, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, 34.5, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, 23.8, + 27., 27., 27., 27., 27., 27., 27., 27., + 27., 27., 27., 27., 27., 27., 27., 27., + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, 41.7, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, 44.100002, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, 46.5, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, 48.899998, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, 51.3, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, 53.7, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, 56.100002, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, 58.5, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, 60.899998, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, 63.3, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, 65.7, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, 68.1, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, 70.5, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, + 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + return res; } @@ -207,7 +327,7 @@ int main(void) bool res{true}; res = TestConv2DNHWC(); std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNHWC(); + res = TestConv1DNWC(); std::cout << "TestConv1DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return 0; } From 19c85bb8537910a03944124d9442a13c9742f5ae Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Feb 2022 13:00:19 +0100 Subject: [PATCH 19/82] Rename example to prevent conflicts. --- example/{9_convnd_fwd_xdl => 10_convnd_fwd_xdl}/README.md | 0 .../{9_convnd_fwd_xdl => 10_convnd_fwd_xdl}/convnd_fwd_xdl.cpp | 0 example/CMakeLists.txt | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename example/{9_convnd_fwd_xdl => 10_convnd_fwd_xdl}/README.md (100%) rename example/{9_convnd_fwd_xdl => 10_convnd_fwd_xdl}/convnd_fwd_xdl.cpp (100%) diff --git a/example/9_convnd_fwd_xdl/README.md b/example/10_convnd_fwd_xdl/README.md similarity index 100% rename from example/9_convnd_fwd_xdl/README.md rename to example/10_convnd_fwd_xdl/README.md diff --git a/example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp similarity index 100% rename from example/9_convnd_fwd_xdl/convnd_fwd_xdl.cpp rename to example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index d882d49d37a..534933adeab 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -20,7 +20,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_SOURCE 5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bi set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp) -set(CONVND_FWD_XDL_SOURCE 9_convnd_fwd_xdl/convnd_fwd_xdl.cpp) +set(CONVND_FWD_XDL_SOURCE 10_convnd_fwd_xdl/convnd_fwd_xdl.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE}) From f4379e190135a172d717c2392ff1900963bc34d3 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 22 Feb 2022 16:41:14 +0100 Subject: [PATCH 20/82] Split convnd instance into separate files for 1d/2d --- device_operation/CMakeLists.txt | 10 + ...onv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp | 112 ++++++++++ ...2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp | 104 ++++----- ...nd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp | 198 ------------------ 4 files changed, 174 insertions(+), 250 deletions(-) create mode 100644 device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp delete mode 100644 device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt index 31fa455301a..7397d688337 100644 --- a/device_operation/CMakeLists.txt +++ b/device_operation/CMakeLists.txt @@ -64,6 +64,11 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; ) +# device_conv1d_fwd_instance +set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; +) + # device_conv2d_fwd_bias_relu_instance set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; @@ -83,6 +88,7 @@ add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) +add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) @@ -92,6 +98,7 @@ target_include_directories(device_gemm_instance SYSTEM PUBLIC $) target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $) target_include_directories(device_batched_gemm_instance SYSTEM PUBLIC $) +target_include_directories(device_conv1d_fwd_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) @@ -101,6 +108,7 @@ target_compile_features(device_gemm_instance PUBLIC) target_compile_features(device_gemm_bias_relu_instance PUBLIC) target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) target_compile_features(device_batched_gemm_instance PUBLIC) +target_compile_features(device_conv1d_fwd_instance PUBLIC) target_compile_features(device_conv2d_fwd_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) @@ -110,6 +118,7 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -119,6 +128,7 @@ install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) diff --git a/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp new file mode 100644 index 00000000000..8702d18596c --- /dev/null +++ b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +//------------------------------------------------------------------------------ +// Conv1D +//------------------------------------------------------------------------------ + +// Compilation parameters for in[n, wi, c] * wei[k, x, c] = out[n, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index 402d65a6e00..69ff3919685 100644 --- a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -28,67 +28,67 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; diff --git a/device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp deleted file mode 100644 index 791facd7a6d..00000000000 --- a/device_operation/src/device_convnd_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ /dev/null @@ -1,198 +0,0 @@ -#include -#include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_convnd_fwd_instance { - -using F32 = float; - -template -using S = ck::Sequence; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; - -//------------------------------------------------------------------------------ -// Conv1D -//------------------------------------------------------------------------------ - -// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; - -using device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; - -using device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; - -void add_device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); - add_device_operation_instances(instances, - device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); - add_device_operation_instances(instances, - device_conv1d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); -} - -//------------------------------------------------------------------------------ -// Conv2D -//------------------------------------------------------------------------------ - -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; - -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; - -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward|Sptial| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Dims| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> - // clang-format on - >; - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( - std::vector>& instances) -{ - add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); - add_device_operation_instances(instances, - device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); - add_device_operation_instances(instances, - device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); -} - -} // namespace device_convnd_fwd_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck From 0806433f29eb739ac949adfdcd7cd5fbae5c991e Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 22 Feb 2022 16:42:21 +0100 Subject: [PATCH 21/82] Address review comments. --- .../include}/conv_utils.hpp | 8 +- .../device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp | 41 +++-- example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 78 ++++----- .../include/reference_conv_fwd.hpp | 164 ++++++++---------- test/conv_util/main.cpp | 10 +- test/convnd_fwd_xdl/main.cpp | 2 +- test/reference_conv_fwd/main.cpp | 26 +-- 7 files changed, 159 insertions(+), 170 deletions(-) rename {composable_kernel/include/utility => device_operation/include}/conv_utils.hpp (95%) diff --git a/composable_kernel/include/utility/conv_utils.hpp b/device_operation/include/conv_utils.hpp similarity index 95% rename from composable_kernel/include/utility/conv_utils.hpp rename to device_operation/include/conv_utils.hpp index a170356996f..4425a15a828 100644 --- a/composable_kernel/include/utility/conv_utils.hpp +++ b/device_operation/include/conv_utils.hpp @@ -96,7 +96,7 @@ ck::index_t GetBtype(ck::index_t N, struct ConvParams { ConvParams() - : spatial_dims(2), + : num_dim_spatial(2), N(128), K(256), C(192), @@ -109,7 +109,7 @@ struct ConvParams { } - ck::index_t spatial_dims; + ck::index_t num_dim_spatial; ck::index_t N; ck::index_t K; ck::index_t C; @@ -125,8 +125,8 @@ struct ConvParams std::vector GetOutputSpatialLengths() const { - std::vector out_spatial_len(spatial_dims, 0); - for(ck::index_t i = 0; i < spatial_dims; ++i) + std::vector out_spatial_len(num_dim_spatial, 0); + for(ck::index_t i = 0; i < num_dim_spatial; ++i) { // XEff = (X - 1) * conv_dilation_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; diff --git a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index d4e7d7725ea..c16482a1e21 100644 --- a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,8 +1,12 @@ #ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP #define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP +#include #include +#include +#include #include + #include "device.hpp" #include "device_base.hpp" #include "device_conv_fwd.hpp" @@ -41,7 +45,7 @@ template {}; static constexpr auto I1 = Number<1>{}; @@ -405,18 +409,18 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K // C = A^T*B // A: const auto in_gemmk0_gemmm_gemmk1_grid_desc = - GetInputTensorDescriptor(N, - C, - GemmMRaw, - GemmK, - GemmMPad, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); + GetInputTensorDescriptor(N, + C, + GemmMRaw, + GemmK, + GemmMPad, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); // B: const auto wei_gemmk0_gemmn_gemmk1_grid_desc = GetWeightTensorDescriptor(GemmN, GemmK); // C: @@ -441,7 +445,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); } - using ABCGridDescs = decltype(GetABCGridDesc()); + using ABCGridDescs = decltype(GetABCGridDesc()); using AGridDesc_K0_M_K1 = remove_cvref_t; using BGridDesc_K0_N_K1 = remove_cvref_t; @@ -703,7 +707,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) { // check if it's 1x1, stride=1 conv - for(ck::index_t i = 0; i < SpatialDims; ++i) + for(ck::index_t i = 0; i < NumDimSpatial; ++i) { if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) @@ -716,7 +720,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ConvolutionForwardSpecialization_t::Filter1x1Pad0) { // check if it's 1x1 conv - for(ck::index_t i = 0; i < SpatialDims; ++i) + for(ck::index_t i = 0; i < NumDimSpatial; ++i) { if(!(arg.filter_spatial_lengths_[i] == 1 && arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) @@ -840,7 +844,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K auto str = std::stringstream(); // clang-format off - str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + str << "DeviceConv" << std::to_string(NumDimSpatial) + << "DFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp index 46b18060725..5d12ec6b91d 100644 --- a/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -40,9 +40,9 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device:: WeiDataType, // OutDataType, // AccDataType, // - InElementOp, // Input Elementwise Operation - WeiElementOp, // Weights Elementwise Operation - OutElementOp, // Output Elementwise Operation + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation ConvFwdDefault, // ConvForwardSpecialization SpatialDims, // SptialDims 256, // BlockSize @@ -81,9 +81,9 @@ using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd< OutElementOp, SpatialDims>; -DeviceConvFwdBasePtr GetConvInstance(int spatial_dims) +DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) { - switch(spatial_dims) + switch(num_dim_spatial) { case 2: { return std::make_unique>(); @@ -114,10 +114,10 @@ void PrintUseMsg() << std::endl; } -ck::conv_util::ConvParams ParseConvParams(int spatial_dims, int argc, char* argv[]) +ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) { - // (N, K, C) + spatial_dims * 6 (filter, input, strides, dilations, pad left, pad right) - int conv_args = 3 + spatial_dims * 6; + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; int cmdline_nargs = conv_args + 5; if(cmdline_nargs != argc) { @@ -128,38 +128,38 @@ ck::conv_util::ConvParams ParseConvParams(int spatial_dims, int argc, char* argv ck::conv_util::ConvParams params; int arg_idx = 5; - params.spatial_dims = spatial_dims; - params.N = std::stoi(argv[arg_idx++]); - params.K = std::stoi(argv[arg_idx++]); - params.C = std::stoi(argv[arg_idx++]); + params.num_dim_spatial = num_dim_spatial; + params.N = std::stoi(argv[arg_idx++]); + params.K = std::stoi(argv[arg_idx++]); + params.C = std::stoi(argv[arg_idx++]); - params.filter_spatial_lengths.resize(spatial_dims); - for(int i = 0; i < spatial_dims; ++i) + params.filter_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) { params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); } - params.input_spatial_lengths.resize(spatial_dims); - for(int i = 0; i < spatial_dims; ++i) + params.input_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) { params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_strides.resize(spatial_dims); - for(int i = 0; i < spatial_dims; ++i) + params.conv_filter_strides.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) { params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); } - params.conv_filter_dilations.resize(spatial_dims); - for(int i = 0; i < spatial_dims; ++i) + params.conv_filter_dilations.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) { params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); } - params.input_left_pads.resize(spatial_dims); - for(int i = 0; i < spatial_dims; ++i) + params.input_left_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) { params.input_left_pads[i] = std::stoi(argv[arg_idx++]); } - params.input_right_pads.resize(spatial_dims); - for(int i = 0; i < spatial_dims; ++i) + params.input_right_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) { params.input_right_pads[i] = std::stoi(argv[arg_idx++]); } @@ -168,11 +168,11 @@ ck::conv_util::ConvParams ParseConvParams(int spatial_dims, int argc, char* argv } HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, - int spatial_dims = 2) + int num_dim_spatial = 2) { namespace tl = ck::tensor_layout::convolution; - switch(spatial_dims) + switch(num_dim_spatial) { case 2: { return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWK{}); @@ -187,11 +187,11 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, - int spatial_dims = 2) + int num_dim_spatial = 2) { namespace tl = ck::tensor_layout::convolution; - switch(spatial_dims) + switch(num_dim_spatial) { case 2: { return ck::conv_util::GetHostTensorDescriptor(dims, tl::KYXC{}); @@ -206,11 +206,11 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector& dims, - int spatial_dims = 2) + int num_dim_spatial = 2) { namespace tl = ck::tensor_layout::convolution; - switch(spatial_dims) + switch(num_dim_spatial) { case 2: { return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); @@ -229,7 +229,7 @@ int main(int argc, char* argv[]) bool do_verification = 0; int init_method = 0; int nrepeat = 5; - int spatial_dims = 2; + int num_dim_spatial = 2; ck::conv_util::ConvParams params; @@ -238,12 +238,12 @@ int main(int argc, char* argv[]) do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); nrepeat = std::stoi(argv[3]); - spatial_dims = std::stoi(argv[4]); + num_dim_spatial = std::stoi(argv[4]); } if(argc >= 6) { - params = ParseConvParams(spatial_dims, argc, argv); + params = ParseConvParams(num_dim_spatial, argc, argv); } std::vector input_dims{static_cast(params.N), @@ -265,10 +265,10 @@ int main(int argc, char* argv[]) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(GetInputHostTensorDescriptor(input_dims, spatial_dims)); - Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, spatial_dims)); - Tensor host_output(GetOutputHostTensorDescriptor(output_dims, spatial_dims)); - Tensor device_output(GetOutputHostTensorDescriptor(output_dims, spatial_dims)); + Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); + Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl; @@ -294,7 +294,7 @@ int main(int argc, char* argv[]) wei_device_buf.ToDevice(weights.mData.data()); // do GEMM - auto conv = GetConvInstance(spatial_dims); + auto conv = GetConvInstance(num_dim_spatial); auto invoker = conv->MakeInvokerPointer(); auto argument = conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), @@ -359,7 +359,7 @@ int main(int argc, char* argv[]) check_error(host_output, device_output); }; - switch(spatial_dims) + switch(num_dim_spatial) { case 2: { auto ref_conv = ReferenceConvNDFwdInstance<2>(); diff --git a/reference_operation/include/reference_conv_fwd.hpp b/reference_operation/include/reference_conv_fwd.hpp index 4287063c93e..0bba22423fb 100644 --- a/reference_operation/include/reference_conv_fwd.hpp +++ b/reference_operation/include/reference_conv_fwd.hpp @@ -25,7 +25,7 @@ namespace host { // operation. // @tparam WeiElementwiseOperation Functor for weights tensor elementwise // operation. -// @tparam SpatialNDims Number of spatial dimensions. +// @tparam NumDimSpatial Number of spatial dimensions. // template = 1 && SpatialNDims <= 3, bool>::type = false> + ck::index_t NumDimSpatial = 2, + typename std::enable_if= 1 && NumDimSpatial <= 3, bool>::type = false> struct ReferenceConvFwd : public device::BaseOperator { // Argument struct Argument : public device::BaseArgument { Argument(const Tensor& input, - const Tensor& weights, + const Tensor& weight, Tensor& output, std::vector conv_filter_strides, std::vector conv_filter_dilations, @@ -51,7 +51,7 @@ struct ReferenceConvFwd : public device::BaseOperator WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) : input_{input}, - weights_{weights}, + weight_{weight}, output_{output}, conv_strides_{conv_filter_strides}, conv_dilations_{conv_filter_dilations}, @@ -64,7 +64,7 @@ struct ReferenceConvFwd : public device::BaseOperator } const Tensor& input_; - const Tensor& weights_; + const Tensor& weight_; Tensor& output_; std::vector conv_strides_; @@ -77,114 +77,98 @@ struct ReferenceConvFwd : public device::BaseOperator OutElementwiseOperation out_element_op_; }; - // Invoker - template struct Invoker : public device::BaseInvoker - { - }; - - template <> - struct Invoker<1> : public device::BaseInvoker { using Argument = ReferenceConvFwd::Argument; float Run(const Argument& arg) { - auto f_ncw = [&](auto n, auto k, auto wo) { - float v_acc = 0; + if constexpr(NumDimSpatial == 1) + { + auto f_ncw = [&](auto n, auto k, auto wo) { + float v_acc = 0; - for(int c = 0; c < arg.weights_.mDesc.GetLengths()[1]; ++c) - { - for(int x = 0; x < arg.weights_.mDesc.GetLengths()[2]; ++x) + for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) { - int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - - arg.in_left_pads_[0]; - if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) + for(int x = 0; x < arg.weight_.mDesc.GetLengths()[2]; ++x) { - float v_in; - float v_wei; + int wi = wo * arg.conv_strides_[0] + x * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + if(wi >= 0 && wi < arg.input_.mDesc.GetLengths()[2]) + { + float v_in; + float v_wei; - arg.in_element_op_(v_in, - static_cast(arg.input_(n, c, wi))); - arg.wei_element_op_(v_wei, - static_cast(arg.weights_(k, c, x))); + arg.in_element_op_(v_in, + static_cast(arg.input_(n, c, wi))); + arg.wei_element_op_(v_wei, + static_cast(arg.weight_(k, c, x))); - v_acc += v_in * v_wei; + v_acc += v_in * v_wei; + } } } - } - float v_out; + float v_out; - arg.out_element_op_(v_out, v_acc); - arg.output_(n, k, wo) = v_out; - }; + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, wo) = v_out; + }; - make_ParallelTensorFunctor(f_ncw, - arg.output_.mDesc.GetLengths()[0], - arg.output_.mDesc.GetLengths()[1], - arg.output_.mDesc.GetLengths()[2])( - std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_ncw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2])( + std::thread::hardware_concurrency()); - return 0; - } + return 0; + } + else if constexpr(NumDimSpatial == 2) + { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { + float v_acc = 0; - float Run(const device::BaseArgument* p_arg, int) override - { - return Run(*dynamic_cast(p_arg)); - } - }; - - template <> - struct Invoker<2> : public device::BaseInvoker - { - using Argument = ReferenceConvFwd::Argument; - - float Run(const Argument& arg) - { - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - float v_acc = 0; - - for(int c = 0; c < arg.weights_.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < arg.weights_.mDesc.GetLengths()[2]; ++y) + for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) { - int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - - arg.in_left_pads_[0]; - for(int x = 0; x < arg.weights_.mDesc.GetLengths()[3]; ++x) + for(int y = 0; y < arg.weight_.mDesc.GetLengths()[2]; ++y) { - int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - - arg.in_left_pads_[1]; - if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && - wi < arg.input_.mDesc.GetLengths()[3]) + int hi = ho * arg.conv_strides_[0] + y * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + for(int x = 0; x < arg.weight_.mDesc.GetLengths()[3]; ++x) { - float v_in; - float v_wei; - - arg.in_element_op_( - v_in, ck::type_convert(arg.input_(n, c, hi, wi))); - arg.wei_element_op_( - v_wei, ck::type_convert(arg.weights_(k, c, y, x))); - v_acc += v_in * v_wei; + int wi = wo * arg.conv_strides_[1] + x * arg.conv_dilations_[1] - + arg.in_left_pads_[1]; + if(hi >= 0 && hi < arg.input_.mDesc.GetLengths()[2] && wi >= 0 && + wi < arg.input_.mDesc.GetLengths()[3]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, ck::type_convert(arg.input_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, ck::type_convert(arg.weight_(k, c, y, x))); + v_acc += v_in * v_wei; + } } } } - } - float v_out; + float v_out; - arg.out_element_op_(v_out, v_acc); - arg.output_(n, k, ho, wo) = ck::type_convert(v_out); - }; + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, ho, wo) = ck::type_convert(v_out); + }; - make_ParallelTensorFunctor(f_nchw, - arg.output_.mDesc.GetLengths()[0], - arg.output_.mDesc.GetLengths()[1], - arg.output_.mDesc.GetLengths()[2], - arg.output_.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_nchw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2], + arg.output_.mDesc.GetLengths()[3])( + std::thread::hardware_concurrency()); - return 0; + return 0; + } } float Run(const device::BaseArgument* p_arg, int) override @@ -202,7 +186,7 @@ struct ReferenceConvFwd : public device::BaseOperator bool IsSupportedArgument(const device::BaseArgument*) override { return true; } static auto MakeArgument(const Tensor& input, - const Tensor& weights, + const Tensor& weight, Tensor& output, std::vector conv_filter_strides, std::vector conv_filter_dilations, @@ -213,7 +197,7 @@ struct ReferenceConvFwd : public device::BaseOperator OutElementwiseOperation out_element_op) { return Argument{input, - weights, + weight, output, conv_filter_strides, conv_filter_dilations, @@ -224,11 +208,11 @@ struct ReferenceConvFwd : public device::BaseOperator out_element_op}; } - static auto MakeInvoker() { return Invoker{}; } + static auto MakeInvoker() { return Invoker{}; } virtual std::unique_ptr MakeInvokerPointer() { - return std::make_unique>(Invoker{}); + return std::make_unique(Invoker{}); } std::string GetTypeString() const override diff --git a/test/conv_util/main.cpp b/test/conv_util/main.cpp index 866545fdd65..ee194f24629 100644 --- a/test/conv_util/main.cpp +++ b/test/conv_util/main.cpp @@ -75,7 +75,7 @@ bool TestConvParams_GetOutputSpatialLengths() "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."); // -------------------------- 1D ------------------------------------ - conv_params.spatial_dims = 1; + conv_params.num_dim_spatial = 1; conv_params.filter_spatial_lengths = std::vector{3}; conv_params.input_spatial_lengths = std::vector{71}; conv_params.conv_filter_strides = std::vector{2}; @@ -132,10 +132,10 @@ bool TestGetHostTensorDescriptor() res = cmp_vec(h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); - dims = std::vector{2, 3, 4}; - HostTensorDescriptor h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); - res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); - res = cmp_vec(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); + dims = std::vector{2, 3, 4}; + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); + res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); + res = cmp_vec(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCW{}); res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); diff --git a/test/convnd_fwd_xdl/main.cpp b/test/convnd_fwd_xdl/main.cpp index 6a0d95138b4..045becf32fe 100644 --- a/test/convnd_fwd_xdl/main.cpp +++ b/test/convnd_fwd_xdl/main.cpp @@ -219,7 +219,7 @@ bool TestConv1DNWC() { bool res{true}; ck::conv_util::ConvParams params; - params.spatial_dims = 1; + params.num_dim_spatial = 1; params.N = 2; params.K = 16; params.C = 4; diff --git a/test/reference_conv_fwd/main.cpp b/test/reference_conv_fwd/main.cpp index 35590874af7..cc5c113f594 100644 --- a/test/reference_conv_fwd/main.cpp +++ b/test/reference_conv_fwd/main.cpp @@ -25,7 +25,7 @@ struct FillMonotonicSeq T m_init_value{0}; template - void operator(ForwardIter first, ForwardIter last, T init_value = m_init_value) + void operator()(ForwardIter first, ForwardIter last) const { std::iota(first, last, m_init_value); } @@ -37,11 +37,11 @@ struct FillConstant T m_value{0}; template - void operator(ForwardIter first, ForwardIter last, T value = m_value) + void operator()(ForwardIter first, ForwardIter last) const { - std::fill(first, last, value); + std::fill(first, last, m_value); } -} +}; template , typename FillWeightsOp = FillConstant> Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, - const FillInputOp& fill_input_op = FillInputOp(), - const FillWeightsOp& fill_weights_op = FillWeightsOp()) + const FillInputOp& fill_input_op = FillInputOp{0}, + const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) { std::vector input_dims{static_cast(params.N), static_cast(params.C)}; @@ -81,7 +81,7 @@ Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); fill_input_op(input.begin(), input.end()); - fill_weights_op()(weights.begin(), weights.end(), WeiDataType(0.5f)); + fill_weights_op(weights.begin(), weights.end()); std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd{1}; params.input_right_pads = std::vector{1}; - auto out_tensor = + auto out_tensor2 = RunReferenceConv<1, float, float, @@ -312,10 +312,10 @@ bool TestConv1DNWC() 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; - res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + res = res && test_util::check_err(out_tensor2.mDesc.GetLengths(), ref_dims, "Error: wrong output tensor dimensions!"); - res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test_util::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!"); return res; } From f7181893071a2fb6a057191eb63b25fc8c750ea9 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 23 Feb 2022 12:59:40 +0100 Subject: [PATCH 22/82] Fix data type for flops/gbytes calculations. --- device_operation/include/conv_utils.hpp | 24 +++++++++--------- example/10_convnd_fwd_xdl/README.md | 26 +++++++++++++------- example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp | 4 +-- 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/device_operation/include/conv_utils.hpp b/device_operation/include/conv_utils.hpp index deda0712727..9aa616633ee 100644 --- a/device_operation/include/conv_utils.hpp +++ b/device_operation/include/conv_utils.hpp @@ -35,16 +35,16 @@ std::size_t GetFlops(ck::index_t N, const std::vector& output_spatial_lengths) { // 2 * N * K * * C * - return std::size_t(2) * N * K * + return static_cast(2) * N * K * std::accumulate(std::begin(output_spatial_lengths), std::end(output_spatial_lengths), - 1, - std::multiplies()) * + static_cast(1), + std::multiplies()) * C * std::accumulate(std::begin(filter_spatial_lengths), std::end(filter_spatial_lengths), - 1, - std::multiplies()); + static_cast(1), + std::multiplies()); } /** @@ -66,7 +66,7 @@ std::size_t GetFlops(ck::index_t N, template -ck::index_t GetBtype(ck::index_t N, +std::size_t GetBtype(ck::index_t N, ck::index_t C, ck::index_t K, const std::vector& input_spatial_lengths, @@ -79,18 +79,18 @@ ck::index_t GetBtype(ck::index_t N, return sizeof(InDataType) * (N * C * std::accumulate(std::begin(input_spatial_lengths), std::end(input_spatial_lengths), - 1, - std::multiplies())) + + static_cast(1), + std::multiplies())) + sizeof(WeiDataType) * (K * C * std::accumulate(std::begin(filter_spatial_lengths), std::end(filter_spatial_lengths), - 1, - std::multiplies())) + + static_cast(1), + std::multiplies())) + sizeof(OutDataType) * (N * K * std::accumulate(std::begin(output_spatial_lengths), std::end(output_spatial_lengths), - 1, - std::multiplies())); + static_cast(1), + std::multiplies())); } struct ConvParams diff --git a/example/10_convnd_fwd_xdl/README.md b/example/10_convnd_fwd_xdl/README.md index dd6bd721006..d85a4091650 100644 --- a/example/10_convnd_fwd_xdl/README.md +++ b/example/10_convnd_fwd_xdl/README.md @@ -38,20 +38,28 @@ cmake \ #arg1: verification (0=no, 1=yes) #arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg3: run kernel # of times (>1) -#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx -./example/convnd_fwd_xdl 0 1 5 +#arg4: N spatial dimensions (default 2) +#Following arguments (depending on number of spatial dims): +# N, K, C, +# , (ie Y, X for 2D) +# , (ie Hi, Wi for 2D) +# , (ie Sy, Sx for 2D) +# , (ie Dy, Dx for 2D) +# , (ie LeftPy, LeftPx for 2D) +# , (ie RightPy, RightPx for 2D) +./example/convnd_fwd_xdl 0 1 100 ``` Result (MI100 @ 1087Mhz, 33.4TFlops peak FP32) ``` -in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} -wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} -out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} +input: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} +weights: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} +output: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} +arg.a_grid_desc_k0_m_k1_{432, 165888, 4} +arg.b_grid_desc_k0_n_k1_{432, 256, 4} arg.c_grid_desc_m_n_{ 165888, 256} launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} Warm up -Start running 5 times... -Perf: 1.43206 ms, 102.486 TFlops, 232.947 GB/s +Start running 100 times... +Perf: 4.43736 ms, 33.0753 TFlops, 150.357 GB/s ``` diff --git a/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp index 4955ff40ecc..614303a188b 100644 --- a/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ b/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp @@ -323,9 +323,9 @@ int main(int argc, char* argv[]) float ave_time = invoker->Run(argument.get(), nrepeat); - ck::index_t flop = ck::conv_util::GetFlops( + std::size_t flop = ck::conv_util::GetFlops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); - ck::index_t num_btype = + std::size_t num_btype = ck::conv_util::GetBtype(params.N, params.C, params.K, From d5021368911cee354cccc259d91a594bebb7de6b Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 23 Feb 2022 13:01:29 +0100 Subject: [PATCH 23/82] Assign example number 11. --- example/{10_convnd_fwd_xdl => 11_convnd_fwd_xdl}/README.md | 0 .../convnd_fwd_xdl.cpp | 0 example/CMakeLists.txt | 6 +++--- 3 files changed, 3 insertions(+), 3 deletions(-) rename example/{10_convnd_fwd_xdl => 11_convnd_fwd_xdl}/README.md (100%) rename example/{10_convnd_fwd_xdl => 11_convnd_fwd_xdl}/convnd_fwd_xdl.cpp (100%) diff --git a/example/10_convnd_fwd_xdl/README.md b/example/11_convnd_fwd_xdl/README.md similarity index 100% rename from example/10_convnd_fwd_xdl/README.md rename to example/11_convnd_fwd_xdl/README.md diff --git a/example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp b/example/11_convnd_fwd_xdl/convnd_fwd_xdl.cpp similarity index 100% rename from example/10_convnd_fwd_xdl/convnd_fwd_xdl.cpp rename to example/11_convnd_fwd_xdl/convnd_fwd_xdl.cpp diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 2931e93805f..7e6daa7ad6e 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -22,8 +22,8 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fw set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) set(GEMM_XDL_ALPHA_BETA_SOURCE 8_gemm_xdl_alpha_beta/gemm_xdl_alpha_beta.cpp) set(CONV2D_FWD_XDL_INT8_SOURCE 9_conv2d_fwd_xdl_int8/conv2d_fwd_xdl_int8.cpp) -set(CONVND_FWD_XDL_SOURCE 10_convnd_fwd_xdl/convnd_fwd_xdl.cpp) set(CONV3D_FWD_XDL_SOURCE 10_conv3d_fwd_xdl/conv3d_fwd_xdl.cpp) +set(CONVND_FWD_XDL_SOURCE 11_convnd_fwd_xdl/convnd_fwd_xdl.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE}) @@ -34,8 +34,8 @@ add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURC add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE}) add_executable(gemm_xdl_alpha_beta ${GEMM_XDL_ALPHA_BETA_SOURCE}) add_executable(conv2d_fwd_xdl_int8 ${CONV2D_FWD_XDL_INT8_SOURCE}) -add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) add_executable(conv3d_fwd_xdl ${CONV3D_FWD_XDL_SOURCE}) +add_executable(convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) target_link_libraries(gemm_xdl PRIVATE host_tensor) target_link_libraries(gemm_xdl_bias_relu PRIVATE host_tensor) @@ -46,5 +46,5 @@ target_link_libraries(conv2d_fwd_xdl_bias_relu_add PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu_atomic_add PRIVATE host_tensor) target_link_libraries(gemm_xdl_alpha_beta PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_int8 PRIVATE host_tensor) -target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor) target_link_libraries(conv3d_fwd_xdl PRIVATE host_tensor) +target_link_libraries(convnd_fwd_xdl PRIVATE host_tensor) From 79b2ef32fa1916da6351f1b6112d151ad6d3eb54 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 25 Feb 2022 12:01:48 +0100 Subject: [PATCH 24/82] 3D cases for convolution utility functions. --- device_operation/include/conv_utils.hpp | 24 ++++++++ device_operation/include/tensor_layout.hpp | 19 +++++++ test/conv_util/main.cpp | 65 +++++++++++++++++++++- test/include/test_util.hpp | 9 +++ 4 files changed, 115 insertions(+), 2 deletions(-) diff --git a/device_operation/include/conv_utils.hpp b/device_operation/include/conv_utils.hpp index 9aa616633ee..fdfb796754d 100644 --- a/device_operation/include/conv_utils.hpp +++ b/device_operation/include/conv_utils.hpp @@ -186,6 +186,30 @@ HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dim return HostTensorDescriptor( dims, std::vector{C * dims[2] * dims[3], 1, dims[3] * C, C}); } + // 3D + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3] * dims[4], + dims[2] * dims[3] * dims[4], + dims[3] * dims[4], + dims[4], + 1}); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3] * dims[4], + 1, + C * dims[3] * dims[4], + C * dims[4], + C}); + } std::stringstream err_msg; err_msg << "Unsupported data layout provided: " << layout << "!"; diff --git a/device_operation/include/tensor_layout.hpp b/device_operation/include/tensor_layout.hpp index 4904f004a04..06ac439c5f7 100644 --- a/device_operation/include/tensor_layout.hpp +++ b/device_operation/include/tensor_layout.hpp @@ -85,16 +85,35 @@ struct NKHW : public BaseTensorLayout static constexpr const char* name = "NKHW"; }; +// 3D Conv struct NDHWC : public BaseTensorLayout { + static constexpr const char* name = "NDHWC"; }; struct KZYXC : public BaseTensorLayout { + static constexpr const char* name = "KZYXC"; }; struct NDHWK : public BaseTensorLayout { + static constexpr const char* name = "NDHWK"; +}; + +struct NCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCDHW"; +}; + +struct KCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCZYX"; +}; + +struct NKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKDHW"; }; } // namespace convolution diff --git a/test/conv_util/main.cpp b/test/conv_util/main.cpp index ee194f24629..77e1f631aea 100644 --- a/test/conv_util/main.cpp +++ b/test/conv_util/main.cpp @@ -84,8 +84,7 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec( - out_spatial_len, std::vector{36}, "Error: ConvParams 1D default constructor."); + res = cmp_vec(out_spatial_len, std::vector{36}, "Error: ConvParams 1D."); conv_params.conv_filter_strides = std::vector{1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); @@ -114,6 +113,47 @@ bool TestConvParams_GetOutputSpatialLengths() std::vector{23}, "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."); + // -------------------------- 3D ------------------------------------ + conv_params.num_dim_spatial = 3; + conv_params.filter_spatial_lengths = std::vector{3, 3, 3}; + conv_params.input_spatial_lengths = std::vector{71, 71, 71}; + conv_params.conv_filter_strides = std::vector{2, 2, 2}; + conv_params.conv_filter_dilations = std::vector{1, 1, 1}; + conv_params.input_left_pads = std::vector{1, 1, 1}; + conv_params.input_right_pads = std::vector{1, 1, 1}; + + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D."); + + conv_params.conv_filter_strides = std::vector{1, 1, 1}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{71, 71, 71}, + "Error: ConvParams 3D stride {1, 1, 1}."); + + conv_params.conv_filter_strides = std::vector{2, 2, 2}; + conv_params.input_left_pads = std::vector{2, 2, 2}; + conv_params.input_right_pads = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{37, 37, 37}, + "Error: ConvParams 3D padding left/right {2, 2, 2}."); + + conv_params.conv_filter_dilations = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{36, 36, 36}, + "Error: ConvParams 3D dilation {2, 2, 2}."); + + conv_params.conv_filter_strides = std::vector{3, 3, 3}; + conv_params.input_left_pads = std::vector{1, 1, 1}; + conv_params.input_right_pads = std::vector{1, 1, 1}; + conv_params.conv_filter_dilations = std::vector{2, 2, 2}; + out_spatial_len = conv_params.GetOutputSpatialLengths(); + res = cmp_vec(out_spatial_len, + std::vector{23, 23, 23}, + "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."); + return res; } @@ -141,6 +181,27 @@ bool TestGetHostTensorDescriptor() res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); res = cmp_vec(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); + dims = std::vector{2, 3, 4, 5, 6}; + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{}); + res = cmp_vec(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"); + res = cmp_vec(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 1, // C + 3 * 5 * 6, // D + 3 * 6, // H + 3}, // W + "Error: wrong NDHWC dimensions strides!"); + + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCDHW{}); + res = cmp_vec(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"); + res = cmp_vec(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 4 * 5 * 6, // C + 5 * 6, // D + 6, // H + 1}, // W + "Error: wrong NCDHW dimensions strides!"); + return res; } diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index f779c3dd1d6..9ec795cdc3a 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -1,10 +1,12 @@ #ifndef TEST_UTIL_HPP #define TEST_UTIL_HPP +#include #include #include #include #include +#include #include #include #include @@ -81,4 +83,11 @@ typename std::enable_if::value, bool>::type check_err( } // namespace test_util +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} + #endif From 563ad880d16c6ef8cb973ff7c8605fb38e22536d Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 25 Feb 2022 12:14:45 +0100 Subject: [PATCH 25/82] 3D reference convolution. --- .../include/reference_conv_fwd.hpp | 65 +++++++++- test/reference_conv_fwd/main.cpp | 118 ++++++++++++++++-- 2 files changed, 166 insertions(+), 17 deletions(-) diff --git a/reference_operation/include/reference_conv_fwd.hpp b/reference_operation/include/reference_conv_fwd.hpp index 0bba22423fb..5ee5c622b8c 100644 --- a/reference_operation/include/reference_conv_fwd.hpp +++ b/reference_operation/include/reference_conv_fwd.hpp @@ -14,9 +14,9 @@ namespace host { // // @brief Reference implementation for forward convolution. // -// @paragraph Supported tensor layouts. Input tensor supports NCHiWi data layout. -// Weights tensor supports KCYX data layout. Output tensor supports -// NKHoWo data layout. +// @paragraph Supports both NCHW as well as NHWC formats (and their respective +// counterparts for weight and output) as long as tensor descriptor +// lengths is in NCHW. // // @tparam InDataType Input tensor data type. // @tparam WeiDataType Weights tensor data type. @@ -100,9 +100,9 @@ struct ReferenceConvFwd : public device::BaseOperator float v_wei; arg.in_element_op_(v_in, - static_cast(arg.input_(n, c, wi))); + ck::type_convert(arg.input_(n, c, wi))); arg.wei_element_op_(v_wei, - static_cast(arg.weight_(k, c, x))); + ck::type_convert(arg.weight_(k, c, x))); v_acc += v_in * v_wei; } @@ -169,6 +169,61 @@ struct ReferenceConvFwd : public device::BaseOperator return 0; } + else if constexpr(NumDimSpatial == 3) + { + auto f_nchw = [&](auto n, auto k, auto d_o, auto ho, auto wo) { + float v_acc = 0; + + for(int c = 0; c < arg.weight_.mDesc.GetLengths()[1]; ++c) + { + for(int z = 0; z < arg.weight_.mDesc.GetLengths()[2]; ++z) + { + int di = d_o * arg.conv_strides_[0] + z * arg.conv_dilations_[0] - + arg.in_left_pads_[0]; + for(int y = 0; y < arg.weight_.mDesc.GetLengths()[3]; ++y) + { + int hi = ho * arg.conv_strides_[1] + y * arg.conv_dilations_[1] - + arg.in_left_pads_[1]; + for(int x = 0; x < arg.weight_.mDesc.GetLengths()[4]; ++x) + { + int wi = wo * arg.conv_strides_[2] + + x * arg.conv_dilations_[2] - arg.in_left_pads_[2]; + if(di >= 0 && di < arg.input_.mDesc.GetLengths()[2] && + hi >= 0 && hi < arg.input_.mDesc.GetLengths()[3] && + wi >= 0 && wi < arg.input_.mDesc.GetLengths()[4]) + { + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + ck::type_convert(arg.input_(n, c, di, hi, wi))); + arg.wei_element_op_( + v_wei, + ck::type_convert(arg.weight_(k, c, z, y, x))); + v_acc += v_in * v_wei; + } + } + } + } + } + + float v_out; + + arg.out_element_op_(v_out, v_acc); + arg.output_(n, k, d_o, ho, wo) = ck::type_convert(v_out); + }; + + make_ParallelTensorFunctor(f_nchw, + arg.output_.mDesc.GetLengths()[0], + arg.output_.mDesc.GetLengths()[1], + arg.output_.mDesc.GetLengths()[2], + arg.output_.mDesc.GetLengths()[3], + arg.output_.mDesc.GetLengths()[4])( + std::thread::hardware_concurrency()); + + return 0; + } } float Run(const device::BaseArgument* p_arg, int) override diff --git a/test/reference_conv_fwd/main.cpp b/test/reference_conv_fwd/main.cpp index cc5c113f594..29a8e102d5c 100644 --- a/test/reference_conv_fwd/main.cpp +++ b/test/reference_conv_fwd/main.cpp @@ -23,11 +23,16 @@ template struct FillMonotonicSeq { T m_init_value{0}; + T m_step{1}; template void operator()(ForwardIter first, ForwardIter last) const { - std::iota(first, last, m_init_value); + std::generate(first, last, [=, n = m_init_value]() mutable { + auto tmp = n; + n += m_step; + return tmp; + }); } }; @@ -53,7 +58,7 @@ template , typename FillWeightsOp = FillConstant> Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, - const FillInputOp& fill_input_op = FillInputOp{0}, + const FillInputOp& fill_input_op = FillInputOp{}, const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) { std::vector input_dims{static_cast(params.N), @@ -84,6 +89,9 @@ Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, fill_weights_op(weights.begin(), weights.end()); std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + // std::cout <<"input: " << input.mDesc << std::endl << input.mData << std::endl; + // std::cout <<"weight: " << weights.mDesc << std::endl << weights.mData << std::endl; + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd RunReferenceConv(const ck::conv_util::ConvParams& params, OutElementOp{}); ref_invoker.Run(ref_argument); + // std::cout <<"output: " << host_output.mDesc << std::endl << host_output.mData << std::endl; return host_output; } @@ -235,16 +244,14 @@ bool TestConv1DNWC() params.input_left_pads = std::vector{1}; params.input_right_pads = std::vector{1}; - auto out_tensor2 = - RunReferenceConv<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>(params, [](auto first, auto last) { - std::generate(first, last, [n = 0]() mutable { return float(n++) * float(0.1f); }); - }); + auto out_tensor2 = RunReferenceConv<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( + params, FillMonotonicSeq{0.f, 0.1f}); ref_dims = std::vector{2, 16, 16}; ref_data = std::vector{ @@ -320,6 +327,91 @@ bool TestConv1DNWC() return res; } +bool TestConv3DNCDHW() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.N = 1; + params.K = 1; + params.C = 2; + params.filter_spatial_lengths = std::vector{3, 3, 3}; + params.input_spatial_lengths = std::vector{6, 6, 6}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{0, 0, 0}; + params.input_right_pads = std::vector{0, 0, 0}; + + auto out_tensor = RunReferenceConv<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( + params, FillMonotonicSeq{0.f, 0.1f}); + std::vector ref_dims{1, 1, 4, 4, 4}; + std::vector ref_data{ + 407.7, 410.40002, 413.09998, 415.80002, 423.90002, 426.6, 429.30002, 432., + 440.1, 442.80002, 445.5, 448.2, 456.30002, 459., 461.7, 464.40002, + 504.90002, 507.6, 510.30002, 513., 521.1, 523.8, 526.5, 529.2001, + 537.3, 540., 542.7001, 545.4, 553.5, 556.2001, 558.9, 561.6, + 602.10004, 604.8, 607.5, 610.2, 618.3, 621., 623.7, 626.4, + 634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8, + 699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004, + 731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.}; + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 1]: wrong output tensor dimensions!"); + res = res && + test_util::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!"); + + params.N = 1; + params.K = 2; + params.C = 2; + params.filter_spatial_lengths = std::vector{3, 3, 3}; + params.input_spatial_lengths = std::vector{12, 12, 12}; + params.conv_filter_strides = std::vector{3, 3, 3}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{0, 0, 0}; + params.input_right_pads = std::vector{0, 0, 0}; + + out_tensor = RunReferenceConv<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( + params, FillMonotonicSeq{0.f, 0.1f}); + ref_dims = std::vector{1, 2, 4, 4, 4}; + ref_data = std::vector{ + 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, + 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, + 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, + 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., + 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., + 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, + 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, + 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801, + 2756.7002, 2764.7998, 2772.9001, 2781., 2853.9001, 2862., 2870.1, 2878.2002, + 2951.1, 2959.2002, 2967.2998, 2975.4001, 3048.2998, 3056.4001, 3064.5, 3072.6, + 3923.1, 3931.2, 3939.2998, 3947.4, 4020.2998, 4028.4001, 4036.5002, 4044.5999, + 4117.5, 4125.6, 4133.7, 4141.8, 4214.7, 4222.8, 4230.9004, 4239., + 5089.5, 5097.5996, 5105.7, 5113.8, 5186.7, 5194.8, 5202.9, 5211., + 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, + 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, + 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801}; + res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 2]: wrong output tensor dimensions!"); + res = + res && test_util::check_err( + out_tensor.mData, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f); + + return res; +} + } // anonymous namespace int main(void) @@ -329,5 +421,7 @@ int main(void) std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv1DNWC(); std::cout << "TestConv1DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNCDHW(); + std::cout << "TestConv3DNCDHW ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return 0; } From dcb6bac6846f248c566e81b912c20a45fbee1789 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 28 Feb 2022 15:42:54 +0100 Subject: [PATCH 26/82] Add support for 3D convolution. --- .../device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp | 158 ++++++++++++++++++ test/convnd_fwd_xdl/main.cpp | 63 +++++-- test/include/test_util.hpp | 4 +- 3 files changed, 210 insertions(+), 15 deletions(-) diff --git a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index 2997652c82f..2fd466519c3 100644 --- a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -367,6 +367,155 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } } + template ::type = false> + static auto GetInputTensorDescriptor(ck::index_t N, + ck::index_t C, + ck::index_t gemm_m, + ck::index_t gemm_k, + ck::index_t gemm_m_pad, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths, + const std::vector& conv_filter_strides, + const std::vector& conv_filter_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads) + { + const ck::index_t gemm_k0 = gemm_k / GemmK1Number; + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(gemm_m, gemm_k)); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_right_pad_transform(gemm_m, gemm_m_pad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_do_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Do), make_tuple(ConvStrideD)), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_do_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + else + { + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(gemm_k0, GemmK1Number)), + make_pass_through_transform(gemm_m)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // in_gemmk0_gemmm_gemmk1_grid_desc + return transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(gemm_k0), + make_right_pad_transform(gemm_m, gemm_m_pad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + static index_t GetGemmMRaw(ck::index_t N, const std::vector& output_spatial_lengths) { @@ -445,6 +594,13 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}); } + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}); + } + using ABCGridDescs = decltype(GetABCGridDesc()); using AGridDesc_K0_M_K1 = remove_cvref_t; @@ -704,6 +860,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { + // TODO add 2GB check + if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) { diff --git a/test/convnd_fwd_xdl/main.cpp b/test/convnd_fwd_xdl/main.cpp index 045becf32fe..4072618a037 100644 --- a/test/convnd_fwd_xdl/main.cpp +++ b/test/convnd_fwd_xdl/main.cpp @@ -190,6 +190,41 @@ void RunConv(const ck::conv_util::ConvParams& params, out_device_buf.FromDevice(output.mData.data()); } +bool TestConv1DNWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.num_dim_spatial = 1; + params.N = 2; + params.K = 16; + params.C = 4; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{16}; + params.conv_filter_strides = std::vector{1}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + RunReferenceConv<1>(params, input, weights, host_output); + RunConv<1>(params, input, weights, device_output); + res = res && + test_util::check_err( + device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + + return res; +} + bool TestConv2DNHWC() { bool res{true}; @@ -215,34 +250,34 @@ bool TestConv2DNHWC() return res; } -bool TestConv1DNWC() +bool TestConv3DNDHWC() { bool res{true}; ck::conv_util::ConvParams params; - params.num_dim_spatial = 1; + params.num_dim_spatial = 3; params.N = 2; params.K = 16; params.C = 4; - params.filter_spatial_lengths = std::vector{3}; - params.input_spatial_lengths = std::vector{16}; - params.conv_filter_strides = std::vector{1}; - params.conv_filter_dilations = std::vector{1}; - params.input_left_pads = std::vector{1}; - params.input_right_pads = std::vector{1}; + params.filter_spatial_lengths = std::vector{3, 3, 3}; + params.input_spatial_lengths = std::vector{16, 16, 16}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = GetHostTensors(params); + ck::tensor_layout::convolution::NDHWC, + ck::tensor_layout::convolution::KZYXC, + ck::tensor_layout::convolution::NDHWK>(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - RunReferenceConv<1>(params, input, weights, host_output); - RunConv<1>(params, input, weights, device_output); + RunReferenceConv<3>(params, input, weights, host_output); + RunConv<3>(params, input, weights, device_output); res = res && test_util::check_err( device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); @@ -259,4 +294,6 @@ int main() std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv2DNHWC(); std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWC(); + std::cout << "TestConv3DNDHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; } diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index 9ec795cdc3a..1fe8adfeaa3 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -43,7 +43,7 @@ check_err(const std::vector& out, if(err_count < 5) { std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" - << i << "]: " << out[i] << "!=" << ref[i] << std::endl + << i << "]: " << out[i] << " != " << ref[i] << std::endl << msg << std::endl; } res = false; @@ -72,7 +72,7 @@ typename std::enable_if::value, bool>::type check_err( { if(out[i] != ref[i]) { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] + std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl << msg << std::endl; return false; From 25b122021c293c484d7b97ef8fa339bb1506e0c3 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 4 Mar 2022 12:03:28 +0100 Subject: [PATCH 27/82] Check for inputs bigger than 2GB. --- .../device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp | 16 ++- test/convnd_fwd_xdl/main.cpp | 136 ++++++++++++++++++ 2 files changed, 151 insertions(+), 1 deletion(-) diff --git a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index 2fd466519c3..7b9dc0c8ffd 100644 --- a/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/device_operation/include/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -860,7 +860,21 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { - // TODO add 2GB check + // Input tensors can't be bigger than 2GB each. + constexpr std::size_t GB2 = 2 * 1e9; + + if(arg.a_grid_desc_k0_m_k1_.GetElementSpaceSize() > GB2) + { + return false; + } + if(arg.b_grid_desc_k0_n_k1_.GetElementSpaceSize() > GB2) + { + return false; + } + if(arg.c_grid_desc_m_n_.GetElementSpaceSize() > GB2) + { + return false; + } if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) diff --git a/test/convnd_fwd_xdl/main.cpp b/test/convnd_fwd_xdl/main.cpp index 4072618a037..eb38d4924e3 100644 --- a/test/convnd_fwd_xdl/main.cpp +++ b/test/convnd_fwd_xdl/main.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -285,6 +286,135 @@ bool TestConv3DNDHWC() return res; } +bool TestConv3DNDHWC2GBInput() +{ + // >2GB Input + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.N = 2; + params.K = 16; + params.C = 32; + params.filter_spatial_lengths = std::vector{3, 3, 3}; + params.input_spatial_lengths = std::vector{32, 1000, 1000}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + try + { + RunConv<3>(params, input, weights, device_output); + } + catch(const std::runtime_error& err) + { + std::string err_msg{"Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"}; + if(err.what() != err_msg) + { + return false; + } + return true; + } + std::cout << "Error: Failure checking oversized tensor!" << std::endl; + return false; +} + +bool TestConv3DNDHWC2GBFilters() +{ + // >2GB Filters + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.N = 2; + params.K = 16; + params.C = 32; + params.filter_spatial_lengths = std::vector{4, 1000, 1000}; + params.input_spatial_lengths = std::vector{16, 16, 16}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + try + { + RunConv<3>(params, input, weights, device_output); + } + catch(const std::runtime_error& err) + { + std::string err_msg{"Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"}; + if(err.what() != err_msg) + { + return false; + } + return true; + } + std::cout << "Error: Failure checking oversized tensor!" << std::endl; + return false; +} + +bool TestConv3DNDHWC2GBOutput() +{ + // >2GB Output + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.N = 2; + params.K = 16; + params.C = 2; + params.filter_spatial_lengths = std::vector{1, 1, 1}; + params.input_spatial_lengths = std::vector{1000, 1000, 30}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{2, 2, 2}; + params.input_right_pads = std::vector{2, 2, 2}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + try + { + RunConv<3>(params, input, weights, device_output); + } + catch(const std::runtime_error& err) + { + std::string err_msg{"Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"}; + if(err.what() != err_msg) + { + return false; + } + return true; + } + std::cout << "Error: Failure checking oversized tensor!" << std::endl; + return false; +} + } // anonymous namespace int main() @@ -296,4 +426,10 @@ int main() std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWC(); std::cout << "TestConv3DNDHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWC2GBInput(); + std::cout << "TestConv3DNDHWC2GBInput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWC2GBFilters(); + std::cout << "TestConv3DNDHWC2GBFilters ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWC2GBOutput(); + std::cout << "TestConv3DNDHWC2GBOutput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; } From ff6f9d540f904e736cacbecca0433799fe22db5b Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 4 Mar 2022 23:58:22 +0100 Subject: [PATCH 28/82] Formatting --- device_operation/include/conv_utils.hpp | 24 +++++++++++------------- profiler/README.md | 4 ++-- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/device_operation/include/conv_utils.hpp b/device_operation/include/conv_utils.hpp index fdfb796754d..3e4d65311f8 100644 --- a/device_operation/include/conv_utils.hpp +++ b/device_operation/include/conv_utils.hpp @@ -39,12 +39,12 @@ std::size_t GetFlops(ck::index_t N, std::accumulate(std::begin(output_spatial_lengths), std::end(output_spatial_lengths), static_cast(1), - std::multiplies()) * + std::multiplies()) * C * std::accumulate(std::begin(filter_spatial_lengths), std::end(filter_spatial_lengths), static_cast(1), - std::multiplies()); + std::multiplies()); } /** @@ -192,23 +192,21 @@ HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dim std::is_same::value) { - return HostTensorDescriptor( - dims, std::vector{C * dims[2] * dims[3] * dims[4], - dims[2] * dims[3] * dims[4], - dims[3] * dims[4], - dims[4], - 1}); + return HostTensorDescriptor(dims, + std::vector{C * dims[2] * dims[3] * dims[4], + dims[2] * dims[3] * dims[4], + dims[3] * dims[4], + dims[4], + 1}); } else if constexpr(std::is_same::value || std::is_same::value || std::is_same::value) { return HostTensorDescriptor( - dims, std::vector{C * dims[2] * dims[3] * dims[4], - 1, - C * dims[3] * dims[4], - C * dims[4], - C}); + dims, + std::vector{ + C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C}); } std::stringstream err_msg; diff --git a/profiler/README.md b/profiler/README.md index 9aed7e501f1..55942e4834e 100644 --- a/profiler/README.md +++ b/profiler/README.md @@ -67,8 +67,8 @@ Best Perf: 1.1933 ms, 107.977 TFlops, 79.0848 GB/s #arg8: print matrix value (0=no, 1=yes) #arg9: run kernel # of times (>1) #arg10 to 24: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx - ##################### op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads - ./profiler/ckProfiler conv 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 + ##################### op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads + ./profiler/ckProfiler conv_fwd 1 1 1 1 1 1 0 5 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 ``` Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) From 10232bde0b27f6cb28b344c727eacf3b9dd4a785 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Sat, 5 Mar 2022 00:00:20 +0100 Subject: [PATCH 29/82] Support for bf16/f16/f32/i8 - conv instances + UT. --- device_operation/CMakeLists.txt | 27 +- .../device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 720 ------------------ ...nv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp | 111 +++ ...onv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp | 109 +++ ...nv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp | 111 +++ ...d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp | 131 ++-- ...2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp | 104 +-- ...d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp | 131 ++-- ...wd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp | 112 +++ ...fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp | 110 +++ ...fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp | 109 +++ ...wd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp | 112 +++ test/CMakeLists.txt | 3 + test/convnd_fwd_xdl/main.cpp | 249 +++++- test/include/test_util.hpp | 72 +- 15 files changed, 1295 insertions(+), 916 deletions(-) delete mode 100644 device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp create mode 100644 device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp create mode 100644 device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp create mode 100644 device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp create mode 100644 device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp create mode 100644 device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp create mode 100644 device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp create mode 100644 device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt index 5872b69b99d..a7ea175b23c 100644 --- a/device_operation/CMakeLists.txt +++ b/device_operation/CMakeLists.txt @@ -67,6 +67,14 @@ set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp; ) +# device_conv1d_fwd_instance +set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp; +) + # device_conv2d_fwd_instance set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; @@ -76,11 +84,6 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; ) -# device_conv1d_fwd_instance -set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; -) - # device_conv2d_fwd_bias_relu_instance set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; @@ -96,6 +99,15 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; ) +# device_conv3d_fwd_instance +set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; +) + + add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE}) add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) @@ -106,6 +118,7 @@ add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURC add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) +add_library(device_conv3d_fwd_instance SHARED ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE}) target_include_directories(device_gemm_instance SYSTEM PUBLIC $) target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $) @@ -117,6 +130,7 @@ target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) +target_include_directories(device_conv3d_fwd_instance SYSTEM PUBLIC $) target_compile_features(device_gemm_instance PUBLIC) target_compile_features(device_gemm_bias_2d_instance PUBLIC) @@ -128,6 +142,7 @@ target_compile_features(device_conv2d_fwd_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) +target_compile_features(device_conv3d_fwd_instance PUBLIC) set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -139,6 +154,7 @@ set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib) @@ -150,3 +166,4 @@ install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv3d_fwd_instance LIBRARY DESTINATION lib) diff --git a/device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 3888e5e9c8d..00000000000 --- a/device_operation/include/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,720 +0,0 @@ -#ifndef DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP - -#include -#include -#include "device.hpp" -#include "device_base.hpp" -#include "device_conv_fwd.hpp" -#include "convolution_forward_specialization.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r3.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] -template -struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K - : public DeviceConvFwd -{ - using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; - - using ADataType = InDataType; - using BDataType = WeiDataType; - using CDataType = OutDataType; - - // TODO make A/B datatype different - using ABDataType = InDataType; - - static constexpr index_t NDimSpatial = 2; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto K1Number = Number{}; - static constexpr auto GemmK1Number = K1Number; - - static auto - MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads) - { - using namespace ck; - - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = output_spatial_lengths[0]; - const index_t Wo = output_spatial_lengths[1]; - - const index_t Y = filter_spatial_lengths[0]; - const index_t X = filter_spatial_lengths[1]; - - const index_t ConvStrideH = conv_filter_strides[0]; - const index_t ConvStrideW = conv_filter_strides[1]; - - const index_t ConvDilationH = conv_filter_dilations[0]; - const index_t ConvDilationW = conv_filter_dilations[1]; - - const index_t InLeftPadH = input_left_pads[0]; - const index_t InLeftPadW = input_left_pads[1]; - - const index_t InRightPadH = input_right_pads[0]; - const index_t InRightPadW = input_right_pads[1]; - - const index_t GemmMRaw = N * Ho * Wo; - const index_t GemmN = K; - const index_t GemmK = Y * X * C; - - const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; - - assert(GemmK % GemmK1Number == 0); - - const index_t GemmK0 = GemmK / GemmK1Number; - - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) - { - // A: input tensor - const auto in_gemmmraw_gemmk_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmmraw_gemmk_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_right_pad_transform(GemmMRaw, GemmMPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // B: weight tensor - const auto wei_gemmn_gemmk_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, C)); - - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_gemmn_gemmk_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: output tensor - const auto out_gemmmraw_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmm_gemmn_grid_desc = - transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization_t::Filter1x1Pad0) - { - // A: input tensor - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), - make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( - in_n_ho_wo_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmk0_gemmmraw_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - // B: weight tensor - const auto wei_gemmn_gemmk_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, C)); - - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_gemmn_gemmk_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: output tensor - const auto out_gemmmraw_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmm_gemmn_grid_desc = - transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); - } - else - { - // A: input tensor - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple( - make_pass_through_transform(N), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), - make_pass_through_transform(C)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmk_gemmmraw_grid_desc = - transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(Y, X, C)), - make_merge_transform(make_tuple(N, Ho, Wo))), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmk_gemmmraw_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_pass_through_transform(GemmMRaw)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - in_gemmk0_gemmmraw_gemmk1_grid_desc, - make_tuple(make_pass_through_transform(GemmK0), - make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - // B: weight tensor - const auto wei_k_yxc_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - - const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( - wei_k_yxc_grid_desc, - make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0>{})); - - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_gemmk_gemmn_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - // C: output tensor - const auto out_nhowo_k_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - - const auto out_gemmmraw_gemmn_grid_desc = - transform_tensor_descriptor(out_nhowo_k_grid_desc, - make_tuple(make_pass_through_transform(N * Ho * Wo), - make_pass_through_transform(K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto out_gemmm_gemmn_grid_desc = - transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, - make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), - make_pass_through_transform(GemmN)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc); - } - } - - using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); - - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - - // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< - BlockSize, - ABDataType, // TODO: distinguish A/B datatype - AccDataType, - CDataType, - InMemoryDataOperationEnum_t::Set, - AGridDesc_K0_M_K1, - BGridDesc_K0_N_K1, - CGridDesc_M_N, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - MPerBlock, - NPerBlock, - K0PerBlock, - MPerXDL, - NPerXDL, - K1, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_K0_M_K1, - Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, - Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - ABlockLdsAddExtraM, - BBlockTransferThreadClusterLengths_K0_N_K1, - Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, - Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, - 2, // BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1, - false, // BThreadTransferSrcResetCoordinateAfterRun, - BBlockLdsAddExtraN, - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, - 7, // CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector>; - - // Argument - struct Argument : public BaseArgument - { - Argument(const InDataType* p_in_grid, - const WeiDataType* p_wei_grid, - OutDataType* p_out_grid, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - ck::index_t M01, - ck::index_t N01, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) - : p_a_grid_{p_in_grid}, - p_b_grid_{p_wei_grid}, - p_c_grid_{p_out_grid}, - a_grid_desc_k0_m_k1_{}, - b_grid_desc_k0_n_k1_{}, - c_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - block_2_ctile_map_{}, - M01_{M01}, - N01_{N01}, - in_element_op_{in_element_op}, - wei_element_op_{wei_element_op}, - out_element_op_{out_element_op}, - Conv_N_{N}, - Conv_K_{K}, - Conv_C_{C}, - filter_spatial_lengths_{filter_spatial_lengths}, - conv_filter_strides_{conv_filter_strides}, - input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads} - { - const auto descs = - DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - a_grid_desc_k0_m_k1_ = descs[I0]; - b_grid_desc_k0_n_k1_ = descs[I1]; - c_grid_desc_m_n_ = descs[I2]; - - if(GridwiseGemm::CheckValidity( - a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) - { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - - block_2_ctile_map_ = - GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); - } - } - - // private: - const ADataType* p_a_grid_; - const BDataType* p_b_grid_; - CDataType* p_c_grid_; - AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; - BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; - CGridDesc_M_N c_grid_desc_m_n_; - typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; - index_t M01_; - index_t N01_; - InElementwiseOperation in_element_op_; - WeiElementwiseOperation wei_element_op_; - OutElementwiseOperation out_element_op_; - // for checking IsSupportedArgument() - index_t Conv_N_; - index_t Conv_K_; - index_t Conv_C_; - std::vector filter_spatial_lengths_; - std::vector conv_filter_strides_; - std::vector input_left_pads_; - std::vector input_right_pads_; - }; - - // Invoker - struct Invoker : public BaseInvoker - { - using Argument = DeviceOp::Argument; - - float Run(const Argument& arg, int nrepeat = 1) - { - { - std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) - << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " - << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) - << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " - << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; - - std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " - << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - } - - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_)) - { - throw std::runtime_error( - "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); - } - - const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); - - const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); - - const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); - - float ave_time = 0; - - if(has_main_k0_block_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r3< - GridwiseGemm, - ADataType, // TODO: distiguish A/B datatype - CDataType, - remove_reference_t, - remove_reference_t, - remove_reference_t, - InElementwiseOperation, - WeiElementwiseOperation, - OutElementwiseOperation, - remove_reference_t, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.in_element_op_, - arg.wei_element_op_, - arg.out_element_op_, - arg.block_2_ctile_map_); - } - - return ave_time; - } - - float Run(const BaseArgument* p_arg, int nrepeat = 1) override - { - return Run(*dynamic_cast(p_arg), nrepeat); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) - { - // check if it's 1x1, stride=1 conv - if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && - arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && - arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && - arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) - { - return false; - } - } - else if constexpr(ConvForwardSpecialization == - ConvolutionForwardSpecialization_t::Filter1x1Pad0) - { - // check if it's 1x1 conv - if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && - arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && - arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) - { - return false; - } - } - - // vector load A/B matrix from global memory - if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && - arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && - arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) - { - return false; - } - - // vector store C matrix into global memory - if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) - { - return false; - } - - // Gridwise GEMM size - return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m_n_, - arg.M01_, - arg.N01_); - } - - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const InDataType* p_in_grid, - const WeiDataType* p_wei_grid, - OutDataType* p_out_grid, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) - { - return Argument{p_in_grid, - p_wei_grid, - p_out_grid, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - std::unique_ptr - MakeArgumentPointer(const void* p_in_grid, - const void* p_wei_grid, - void* p_out_grid, - ck::index_t N, - ck::index_t K, - ck::index_t C, - std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector output_spatial_lengths, - std::vector conv_filter_strides, - std::vector conv_filter_dilations, - std::vector input_left_pads, - std::vector input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op) override - { - return std::make_unique(static_cast(p_in_grid), - static_cast(p_wei_grid), - static_cast(p_out_grid), - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - 1, - 1, - in_element_op, - wei_element_op, - out_element_op); - } - - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - // clang-format off - str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" - << "<" - << BlockSize << ", " - << MPerBlock << ", " - << NPerBlock << ", " - << K0PerBlock - << ">"; - // clang-format on - - return str.str(); - } -}; // namespace device - -} // namespace device -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp new file mode 100644 index 00000000000..48a13d67734 --- /dev/null +++ b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -0,0 +1,111 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp new file mode 100644 index 00000000000..11301ee8e66 --- /dev/null +++ b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp @@ -0,0 +1,109 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp new file mode 100644 index 00000000000..eeabd008759 --- /dev/null +++ b/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp @@ -0,0 +1,111 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index 575048399bb..817df62ce82 100644 --- a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -26,71 +26,74 @@ static constexpr auto ConvFwd1x1S1P0 = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( std::vector>& instances) diff --git a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index beaad1d3b4e..4222413f91e 100644 --- a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -29,67 +29,67 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; diff --git a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index c9af26ed396..e55f4fe2d7f 100644 --- a/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -26,71 +26,74 @@ static constexpr auto ConvFwd1x1S1P0 = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector>& instances) diff --git a/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp new file mode 100644 index 00000000000..565325d8b7e --- /dev/null +++ b/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp b/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp new file mode 100644 index 00000000000..406c56d2b44 --- /dev/null +++ b/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp @@ -0,0 +1,110 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f16_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp b/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp new file mode 100644 index 00000000000..2bf65ba0783 --- /dev/null +++ b/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp @@ -0,0 +1,109 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_f32_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp b/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp new file mode 100644 index 00000000000..ea0259a3f1f --- /dev/null +++ b/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_int8_instances{}); + add_device_operation_instances( + instances, device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index ff483b81170..1d11e4e363c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -45,3 +45,6 @@ target_link_libraries(test_reference_conv_fwd PRIVATE host_tensor) set(CONVND_FWD_XDL_SOURCE convnd_fwd_xdl/main.cpp) add_executable(test_convnd_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) target_link_libraries(test_convnd_fwd_xdl PRIVATE host_tensor) +target_link_libraries(test_convnd_fwd_xdl PRIVATE device_conv1d_fwd_instance) +target_link_libraries(test_convnd_fwd_xdl PRIVATE device_conv2d_fwd_instance) +target_link_libraries(test_convnd_fwd_xdl PRIVATE device_conv3d_fwd_instance) diff --git a/test/convnd_fwd_xdl/main.cpp b/test/convnd_fwd_xdl/main.cpp index eb38d4924e3..03d9dcc29b3 100644 --- a/test/convnd_fwd_xdl/main.cpp +++ b/test/convnd_fwd_xdl/main.cpp @@ -18,7 +18,48 @@ #include "tensor_layout.hpp" #include "test_util.hpp" +// Forward declarations for conv instances. + +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector&); + +} // namespace device_conv1d_fwd_instance +namespace device_conv2d_fwd_instance { + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); + +} // namespace device_conv2d_fwd_instance +namespace device_conv3d_fwd_instance { + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector&); + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + namespace { + +using bhalf_t = test_util::bhalf_t; + template using S = ck::Sequence; @@ -103,9 +144,7 @@ auto GetHostTensors(const ck::conv_util::ConvParams& params) Tensor device_output( ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - std::generate(input.begin(), input.end(), [n = 0]() mutable { - return InDataType(n++) * InDataType(0.1f); - }); + std::generate(input.begin(), input.end(), [n = 0]() mutable { return InDataType(n++ * 0.1f); }); std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); @@ -191,6 +230,61 @@ void RunConv(const ck::conv_util::ConvParams& params, out_device_buf.FromDevice(output.mData.data()); } +template +bool RunConvInstances(const ck::conv_util::ConvParams& params, + const std::vector& conv_ptrs, + const Tensor& input, + const Tensor& weights, + Tensor& output, + const Tensor& host_output) +{ + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + + bool res{true}; + for(auto& conv_ptr : conv_ptrs) + { + auto invoker = conv_ptr->MakeInvokerPointer(); + auto argument = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(conv_ptr->IsSupportedArgument(argument.get())) + { + invoker->Run(argument.get()); + out_device_buf.FromDevice(output.mData.data()); + res = res && + test_util::check_err( + output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + hipGetErrorString( + hipMemset(out_device_buf.GetDeviceBuffer(), 0, out_device_buf.mMemSize)); + } + } + return res; +} + bool TestConv1DNWC() { bool res{true}; @@ -415,6 +509,125 @@ bool TestConv3DNDHWC2GBOutput() return false; } +template +bool TestConv1DNWCInstances() +{ + ck::conv_util::ConvParams params; + params.num_dim_spatial = 1; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{71}; + params.conv_filter_strides = std::vector{2}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + RunReferenceConv<1>(params, input, weights, host_output); + + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + + return RunConvInstances<1>(params, conv_ptrs, input, weights, device_output, host_output); +} +bool TestConv1DNWCBF16Instances() { return TestConv1DNWCInstances(); } + +bool TestConv1DNWCF16Instances() { return TestConv1DNWCInstances(); } + +bool TestConv1DNWCF32Instances() { return TestConv1DNWCInstances(); } + +bool TestConv1DNWCInt8Instances() { return TestConv1DNWCInstances(); } + +template +bool TestConv2DNHWCInstances() +{ + ck::conv_util::ConvParams params; + params.num_dim_spatial = 2; + params.filter_spatial_lengths = std::vector{3, 3}; + params.input_spatial_lengths = std::vector{71, 71}; + params.conv_filter_strides = std::vector{2, 2}; + params.conv_filter_dilations = std::vector{1, 1}; + params.input_left_pads = std::vector{1, 1}; + params.input_right_pads = std::vector{1, 1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + RunReferenceConv<2>(params, input, weights, host_output); + + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + + return RunConvInstances<2>(params, conv_ptrs, input, weights, device_output, host_output); +} + +bool TestConv2DNHWCBF16Instances() { return TestConv2DNHWCInstances(); } + +bool TestConv2DNHWCF16Instances() { return TestConv2DNHWCInstances(); } + +bool TestConv2DNHWCF32Instances() { return TestConv2DNHWCInstances(); } + +bool TestConv2DNHWCInt8Instances() { return TestConv2DNHWCInstances(); } + +template +bool TestConv3DNDHWCInstances() +{ + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.filter_spatial_lengths = std::vector{3, 3, 3}; + params.input_spatial_lengths = std::vector{71, 71, 71}; + params.conv_filter_strides = std::vector{2, 2, 2}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; + + auto host_tensors = GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + RunReferenceConv<3>(params, input, weights, host_output); + + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + + return RunConvInstances<3>(params, conv_ptrs, input, weights, device_output, host_output); +} + +bool TestConv3DNDHWCBF16Instances() { return TestConv3DNDHWCInstances(); } + +bool TestConv3DNDHWCF16Instances() { return TestConv3DNDHWCInstances(); } + +bool TestConv3DNDHWCF32Instances() { return TestConv3DNDHWCInstances(); } + +bool TestConv3DNDHWCInt8Instances() { return TestConv3DNDHWCInstances(); } + } // anonymous namespace int main() @@ -426,10 +639,40 @@ int main() std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWC(); std::cout << "TestConv3DNDHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWC2GBInput(); std::cout << "TestConv3DNDHWC2GBInput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWC2GBFilters(); std::cout << "TestConv3DNDHWC2GBFilters ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWC2GBOutput(); std::cout << "TestConv3DNDHWC2GBOutput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + res = TestConv1DNWCBF16Instances(); + std::cout << "TestConv1DNWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv1DNWCF16Instances(); + std::cout << "TestConv1DNWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv1DNWCF32Instances(); + std::cout << "TestConv1DNWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv1DNWCInt8Instances(); + std::cout << "TestConv1DNWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + res = TestConv2DNHWCBF16Instances(); + std::cout << "TestConv2DNHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv2DNHWCF16Instances(); + std::cout << "TestConv2DNHWCF16Instances ....." << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv2DNHWCF32Instances(); + std::cout << "TestConv2DNHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv2DNHWCInt8Instances(); + std::cout << "TestConv2DNHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + res = TestConv3DNDHWCBF16Instances(); + std::cout << "TestConv3DNDHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + res = TestConv3DNDHWCF16Instances(); + std::cout << "TestConv3DNDHWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWCF32Instances(); + std::cout << "TestConv3DNDHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWCInt8Instances(); + std::cout << "TestConv3DNDHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; } diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index 1fe8adfeaa3..f61f5fc2584 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -11,15 +11,21 @@ #include #include +#include "data_type.hpp" + namespace test_util { +// This will be removed when bf16 will be properly integrated to CK. +using bhalf_t = ushort; + template -typename std::enable_if::value, bool>::type +typename std::enable_if::value && !std::is_same::value, + bool>::type check_err(const std::vector& out, const std::vector& ref, const std::string& msg, - T rtol = static_cast(1e-5), - T atol = static_cast(1e-8)) + double rtol = 1e-5, + double atol = 1e-8) { if(out.size() != ref.size()) { @@ -30,9 +36,9 @@ check_err(const std::vector& out, } bool res{true}; - int err_count = 0; - T err = 0; - T max_err = std::numeric_limits::min(); + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { err = std::abs(out[i] - ref[i]); @@ -57,8 +63,58 @@ check_err(const std::vector& out, } template -typename std::enable_if::value, bool>::type check_err( - const std::vector& out, const std::vector& ref, const std::string& msg, T = 0, T = 0) +typename std::enable_if::value || std::is_same::value, + bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg, + double rtol = 1e-5, + double atol = 1e-8) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = ck::type_convert(ck::NumericLimits::Min()); + for(std::size_t i = 0; i < ref.size(); ++i) + { + float o = ck::type_convert(out[i]); + float r = ck::type_convert(ref[i]); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << o << " != " << r << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value && !std::is_same::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg, + double = 0, + double = 0) { if(out.size() != ref.size()) { From bde66c26588e2554a747cae47905096a2386f852 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 9 Mar 2022 15:17:45 +0100 Subject: [PATCH 30/82] Use check_err from test_util.hpp. --- test/conv2d_fwd/conv2d_fwd.cpp | 26 ++---- test/conv_util/conv_util.cpp | 147 ++++++++++++++------------------- test/include/test_util.hpp | 12 ++- 3 files changed, 75 insertions(+), 110 deletions(-) diff --git a/test/conv2d_fwd/conv2d_fwd.cpp b/test/conv2d_fwd/conv2d_fwd.cpp index 164d4a1cc10..bba4aa7ce09 100644 --- a/test/conv2d_fwd/conv2d_fwd.cpp +++ b/test/conv2d_fwd/conv2d_fwd.cpp @@ -8,6 +8,7 @@ #include "device_conv_fwd.hpp" #include "element_wise_operation.hpp" #include "reference_conv_fwd.hpp" +#include "test_util.hpp" namespace ck { namespace tensor_operation { @@ -37,23 +38,6 @@ using InElementOp = ck::tensor_operation::element_wise::PassThrough; using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough; -template -static bool check_out(const Tensor& ref, const Tensor& result) -{ - float max_diff = 1e-6; - - for(int i = 0; i < ref.mData.size(); ++i) - { - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) - { - return false; - } - } - - return true; -} - int main(int argc, char* argv[]) { int data_type = 0; @@ -264,9 +248,13 @@ int main(int argc, char* argv[]) if(conv_ptr->IsSupportedArgument(argument_ptr.get())) { invoker_ptr->Run(argument_ptr.get(), 0); - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - if(!check_out(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result)) + bool res = test::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData, + "Error: incorrect results!", + 1e-5f, + 1e-4f); + if(!res) { success = false; break; diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 77e1f631aea..1dff3f28a20 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -5,33 +5,10 @@ #include "config.hpp" #include "conv_utils.hpp" #include "tensor_layout.hpp" +#include "test_util.hpp" namespace { -template -bool cmp_vec(const std::vector& out, const std::vector& ref, const std::string& msg) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - for(std::size_t i = 0; i < ref.size(); ++i) - { - if(out[i] != ref[i]) - { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << "!=" << ref[i] - << std::endl - << msg << std::endl; - return false; - } - } - return true; -} - bool TestConvParams_GetOutputSpatialLengths() { bool res{true}; @@ -43,26 +20,26 @@ bool TestConvParams_GetOutputSpatialLengths() // padding {{1,1}, {1,1}} ck::conv_util::ConvParams conv_params; std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{36, 36}, - "Error: ConvParams 2D default constructor."); + res = test::check_err(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D default constructor."); conv_params.conv_filter_strides = std::vector{1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec( + res = test::check_err( out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}."); conv_params.conv_filter_strides = std::vector{2, 2}; conv_params.input_left_pads = std::vector{2, 2}; conv_params.input_right_pads = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{37, 37}, - "Error: ConvParams 2D padding left/right {2,2}."); + res = test::check_err(out_spatial_len, + std::vector{37, 37}, + "Error: ConvParams 2D padding left/right {2,2}."); conv_params.conv_filter_dilations = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec( + res = test::check_err( out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}."); conv_params.conv_filter_strides = std::vector{3, 3}; @@ -70,9 +47,9 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1, 1}; conv_params.conv_filter_dilations = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{23, 23}, - "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."); + res = test::check_err(out_spatial_len, + std::vector{23, 23}, + "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."); // -------------------------- 1D ------------------------------------ conv_params.num_dim_spatial = 1; @@ -84,24 +61,24 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, std::vector{36}, "Error: ConvParams 1D."); + res = test::check_err(out_spatial_len, std::vector{36}, "Error: ConvParams 1D."); conv_params.conv_filter_strides = std::vector{1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = - cmp_vec(out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}."); + res = test::check_err( + out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}."); conv_params.conv_filter_strides = std::vector{2}; conv_params.input_left_pads = std::vector{2}; conv_params.input_right_pads = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{37}, - "Error: ConvParams 1D padding left/right {2}."); + res = test::check_err(out_spatial_len, + std::vector{37}, + "Error: ConvParams 1D padding left/right {2}."); conv_params.conv_filter_dilations = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec( + res = test::check_err( out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}."); conv_params.conv_filter_strides = std::vector{3}; @@ -109,9 +86,9 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1}; conv_params.conv_filter_dilations = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{23}, - "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."); + res = test::check_err(out_spatial_len, + std::vector{23}, + "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."); // -------------------------- 3D ------------------------------------ conv_params.num_dim_spatial = 3; @@ -123,36 +100,38 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1, 1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D."); + res = test::check_err( + out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D."); conv_params.conv_filter_strides = std::vector{1, 1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{71, 71, 71}, - "Error: ConvParams 3D stride {1, 1, 1}."); + res = test::check_err(out_spatial_len, + std::vector{71, 71, 71}, + "Error: ConvParams 3D stride {1, 1, 1}."); conv_params.conv_filter_strides = std::vector{2, 2, 2}; conv_params.input_left_pads = std::vector{2, 2, 2}; conv_params.input_right_pads = std::vector{2, 2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{37, 37, 37}, - "Error: ConvParams 3D padding left/right {2, 2, 2}."); + res = test::check_err(out_spatial_len, + std::vector{37, 37, 37}, + "Error: ConvParams 3D padding left/right {2, 2, 2}."); conv_params.conv_filter_dilations = std::vector{2, 2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{36, 36, 36}, - "Error: ConvParams 3D dilation {2, 2, 2}."); + res = test::check_err(out_spatial_len, + std::vector{36, 36, 36}, + "Error: ConvParams 3D dilation {2, 2, 2}."); conv_params.conv_filter_strides = std::vector{3, 3, 3}; conv_params.input_left_pads = std::vector{1, 1, 1}; conv_params.input_right_pads = std::vector{1, 1, 1}; conv_params.conv_filter_dilations = std::vector{2, 2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = cmp_vec(out_spatial_len, - std::vector{23, 23, 23}, - "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."); + res = test::check_err( + out_spatial_len, + std::vector{23, 23, 23}, + "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."); return res; } @@ -163,44 +142,44 @@ bool TestGetHostTensorDescriptor() namespace tl = ck::tensor_layout::convolution; std::vector dims{2, 3, 4, 5}; HostTensorDescriptor h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); - res = cmp_vec(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"); - res = - cmp_vec(h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"); + res = test::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"); + res = test::check_err( + h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"); h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCHW{}); - res = cmp_vec(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"); - res = - cmp_vec(h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); + res = test::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"); + res = test::check_err( + h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); dims = std::vector{2, 3, 4}; h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); - res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); - res = cmp_vec(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); + res = test::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); + res = test::check_err(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCW{}); - res = cmp_vec(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); - res = cmp_vec(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); + res = test::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); + res = test::check_err(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); dims = std::vector{2, 3, 4, 5, 6}; h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{}); - res = cmp_vec(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"); - res = cmp_vec(h.GetStrides(), - {3 * 4 * 5 * 6, // N - 1, // C - 3 * 5 * 6, // D - 3 * 6, // H - 3}, // W - "Error: wrong NDHWC dimensions strides!"); + res = test::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"); + res = test::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 1, // C + 3 * 5 * 6, // D + 3 * 6, // H + 3}, // W + "Error: wrong NDHWC dimensions strides!"); h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCDHW{}); - res = cmp_vec(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"); - res = cmp_vec(h.GetStrides(), - {3 * 4 * 5 * 6, // N - 4 * 5 * 6, // C - 5 * 6, // D - 6, // H - 1}, // W - "Error: wrong NCDHW dimensions strides!"); + res = test::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"); + res = test::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 4 * 5 * 6, // C + 5 * 6, // D + 6, // H + 1}, // W + "Error: wrong NCDHW dimensions strides!"); return res; } diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index f61f5fc2584..f9d7027f712 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -13,10 +13,7 @@ #include "data_type.hpp" -namespace test_util { - -// This will be removed when bf16 will be properly integrated to CK. -using bhalf_t = ushort; +namespace test { template typename std::enable_if::value && !std::is_same::value, @@ -63,7 +60,7 @@ check_err(const std::vector& out, } template -typename std::enable_if::value || std::is_same::value, +typename std::enable_if::value || std::is_same::value, bool>::type check_err(const std::vector& out, const std::vector& ref, @@ -109,7 +106,8 @@ check_err(const std::vector& out, } template -typename std::enable_if::value && !std::is_same::value, bool>::type +typename std::enable_if::value && !std::is_same::value, + bool>::type check_err(const std::vector& out, const std::vector& ref, const std::string& msg, @@ -137,7 +135,7 @@ check_err(const std::vector& out, return true; } -} // namespace test_util +} // namespace test template std::ostream& operator<<(std::ostream& os, const std::vector& v) From 3188b535fc38070f167b29013704a91464f8ad82 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 9 Mar 2022 15:21:31 +0100 Subject: [PATCH 31/82] Split convnd test into separate files for each dim. --- test/CMakeLists.txt | 17 + test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp | 128 +++++ test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp | 121 +++++ test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp | 269 ++++++++++ test/convnd_fwd_xdl/convnd_fwd_xdl.cpp | 678 ------------------------- test/include/conv_test_util.hpp | 262 ++++++++++ 6 files changed, 797 insertions(+), 678 deletions(-) create mode 100644 test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp create mode 100644 test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp create mode 100644 test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp delete mode 100644 test/convnd_fwd_xdl/convnd_fwd_xdl.cpp create mode 100644 test/include/conv_test_util.hpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4de43065cc0..f9635115408 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -34,3 +34,20 @@ foreach(TEST ${TESTS}) message("adding test ${BASE_NAME}") add_test_executeable(test_${BASE_NAME} ${TEST}) endforeach(TEST ${TESTS}) + +# test_convnd_fwd_xdl +set(CONVND_FWD_XDL_SOURCE convnd_fwd_xdl/conv1d_fwd_xdl.cpp) +add_executable(ut_conv1d_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) +target_link_libraries(ut_conv1d_fwd_xdl PRIVATE host_tensor) +target_link_libraries(ut_conv1d_fwd_xdl PRIVATE device_conv1d_fwd_instance) + + +set(CONVND_FWD_XDL_SOURCE convnd_fwd_xdl/conv2d_fwd_xdl.cpp) +add_executable(ut_conv2d_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) +target_link_libraries(ut_conv2d_fwd_xdl PRIVATE host_tensor) +target_link_libraries(ut_conv2d_fwd_xdl PRIVATE device_conv2d_fwd_instance) + +set(CONVND_FWD_XDL_SOURCE convnd_fwd_xdl/conv3d_fwd_xdl.cpp) +add_executable(ut_conv3d_fwd_xdl ${CONVND_FWD_XDL_SOURCE}) +target_link_libraries(ut_conv3d_fwd_xdl PRIVATE host_tensor) +target_link_libraries(ut_conv3d_fwd_xdl PRIVATE device_conv3d_fwd_instance) \ No newline at end of file diff --git a/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp new file mode 100644 index 00000000000..69872dddb27 --- /dev/null +++ b/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp @@ -0,0 +1,128 @@ +#include +#include +#include +#include + +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "conv_test_util.hpp" +#include "host_tensor.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" + +// Forward declarations for conv instances. + +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv1d_fwd_instance { + +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector&); +void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector&); + +} // namespace device_conv1d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace { + +bool TestConv1DNWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.num_dim_spatial = 1; + params.N = 2; + params.K = 16; + params.C = 4; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{16}; + params.conv_filter_strides = std::vector{1}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + test::conv::RunReferenceConv<1>(params, input, weights, host_output); + test::conv::RunConv<1>(params, input, weights, device_output); + res = res && + test::check_err( + device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + + return res; +} + +template +bool TestConv1DNWCInstances() +{ + ck::conv_util::ConvParams params; + params.num_dim_spatial = 1; + params.filter_spatial_lengths = std::vector{3}; + params.input_spatial_lengths = std::vector{71}; + params.conv_filter_strides = std::vector{2}; + params.conv_filter_dilations = std::vector{1}; + params.input_left_pads = std::vector{1}; + params.input_right_pads = std::vector{1}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + test::conv::RunReferenceConv<1>(params, input, weights, host_output); + + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); + + return test::conv::RunConvInstances<1>( + params, conv_ptrs, input, weights, device_output, host_output); +} +bool TestConv1DNWCBF16Instances() { return TestConv1DNWCInstances(); } + +bool TestConv1DNWCF16Instances() { return TestConv1DNWCInstances(); } + +bool TestConv1DNWCF32Instances() { return TestConv1DNWCInstances(); } + +bool TestConv1DNWCInt8Instances() { return TestConv1DNWCInstances(); } + +} // anonymous namespace + +int main() +{ + bool res{true}; + res = TestConv1DNWC(); + std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + res = TestConv1DNWCBF16Instances(); + std::cout << "TestConv1DNWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv1DNWCF16Instances(); + std::cout << "TestConv1DNWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv1DNWCF32Instances(); + std::cout << "TestConv1DNWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv1DNWCInt8Instances(); + std::cout << "TestConv1DNWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; +} diff --git a/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp new file mode 100644 index 00000000000..c34c9a46776 --- /dev/null +++ b/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp @@ -0,0 +1,121 @@ +#include +#include +#include +#include +#include + +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "conv_test_util.hpp" +#include "host_tensor.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" + +// Forward declarations for conv instances. +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); +void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace { + +bool TestConv2DNHWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.N = 2; + params.K = 16; + params.C = 4; + params.input_spatial_lengths = std::vector{16, 16}; + params.conv_filter_strides = std::vector{1, 1}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + test::conv::RunReferenceConv<2>(params, input, weights, host_output); + test::conv::RunConv<2>(params, input, weights, device_output); + res = res && + test::check_err( + device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + + return res; +} + +template +bool TestConv2DNHWCInstances() +{ + ck::conv_util::ConvParams params; + params.num_dim_spatial = 2; + params.filter_spatial_lengths = std::vector{3, 3}; + params.input_spatial_lengths = std::vector{71, 71}; + params.conv_filter_strides = std::vector{2, 2}; + params.conv_filter_dilations = std::vector{1, 1}; + params.input_left_pads = std::vector{1, 1}; + params.input_right_pads = std::vector{1, 1}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + test::conv::RunReferenceConv<2>(params, input, weights, host_output); + + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + + return test::conv::RunConvInstances<2>( + params, conv_ptrs, input, weights, device_output, host_output); +} + +bool TestConv2DNHWCBF16Instances() { return TestConv2DNHWCInstances(); } + +bool TestConv2DNHWCF16Instances() { return TestConv2DNHWCInstances(); } + +bool TestConv2DNHWCF32Instances() { return TestConv2DNHWCInstances(); } + +bool TestConv2DNHWCInt8Instances() { return TestConv2DNHWCInstances(); } + +} // anonymous namespace + +int main() +{ + bool res{true}; + res = TestConv2DNHWC(); + std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + res = TestConv2DNHWCBF16Instances(); + std::cout << "TestConv2DNHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv2DNHWCF16Instances(); + std::cout << "TestConv2DNHWCF16Instances ....." << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv2DNHWCF32Instances(); + std::cout << "TestConv2DNHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv2DNHWCInt8Instances(); + std::cout << "TestConv2DNHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + return 0; +} diff --git a/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp new file mode 100644 index 00000000000..237a9001f80 --- /dev/null +++ b/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp @@ -0,0 +1,269 @@ +#include +#include +#include +#include +#include + +#include "data_type.hpp" +#include "element_wise_operation.hpp" +#include "conv_test_util.hpp" +#include "host_tensor.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" + +// Forward declarations for conv instances. +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv3d_fwd_instance { + +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector&); +void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector&); + +} // namespace device_conv3d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace { + +bool TestConv3DNDHWC() +{ + bool res{true}; + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.N = 2; + params.K = 16; + params.C = 4; + params.filter_spatial_lengths = std::vector{3, 3, 3}; + params.input_spatial_lengths = std::vector{16, 16, 16}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + test::conv::RunReferenceConv<3>(params, input, weights, host_output); + test::conv::RunConv<3>(params, input, weights, device_output); + res = res && + test::check_err( + device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + + return res; +} + +bool TestConv3DNDHWC2GBInput() +{ + // >2GB Input + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.N = 2; + params.K = 16; + params.C = 32; + params.filter_spatial_lengths = std::vector{3, 3, 3}; + params.input_spatial_lengths = std::vector{32, 1000, 1000}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + try + { + test::conv::RunConv<3>(params, input, weights, device_output); + } + catch(const std::runtime_error& err) + { + std::string err_msg{"Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"}; + if(err.what() != err_msg) + { + return false; + } + return true; + } + std::cout << "Error: Failure checking oversized tensor!" << std::endl; + return false; +} + +bool TestConv3DNDHWC2GBFilters() +{ + // >2GB Filters + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.N = 2; + params.K = 16; + params.C = 32; + params.filter_spatial_lengths = std::vector{4, 1000, 1000}; + params.input_spatial_lengths = std::vector{16, 16, 16}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + try + { + test::conv::RunConv<3>(params, input, weights, device_output); + } + catch(const std::runtime_error& err) + { + std::string err_msg{"Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"}; + if(err.what() != err_msg) + { + return false; + } + return true; + } + std::cout << "Error: Failure checking oversized tensor!" << std::endl; + return false; +} + +bool TestConv3DNDHWC2GBOutput() +{ + // >2GB Output + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.N = 2; + params.K = 16; + params.C = 2; + params.filter_spatial_lengths = std::vector{1, 1, 1}; + params.input_spatial_lengths = std::vector{1000, 1000, 30}; + params.conv_filter_strides = std::vector{1, 1, 1}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{2, 2, 2}; + params.input_right_pads = std::vector{2, 2, 2}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + try + { + test::conv::RunConv<3>(params, input, weights, device_output); + } + catch(const std::runtime_error& err) + { + std::string err_msg{"Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"}; + if(err.what() != err_msg) + { + return false; + } + return true; + } + std::cout << "Error: Failure checking oversized tensor!" << std::endl; + return false; +} + +template +bool TestConv3DNDHWCInstances() +{ + ck::conv_util::ConvParams params; + params.num_dim_spatial = 3; + params.filter_spatial_lengths = std::vector{3, 3, 3}; + params.input_spatial_lengths = std::vector{71, 71, 71}; + params.conv_filter_strides = std::vector{2, 2, 2}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; + + auto host_tensors = test::conv::GetHostTensors(params); + const Tensor& input = std::get<0>(host_tensors); + const Tensor& weights = std::get<1>(host_tensors); + Tensor& host_output = std::get<2>(host_tensors); + Tensor& device_output = std::get<3>(host_tensors); + + test::conv::RunReferenceConv<3>(params, input, weights, host_output); + + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); + + return test::conv::RunConvInstances<3>( + params, conv_ptrs, input, weights, device_output, host_output); +} + +bool TestConv3DNDHWCBF16Instances() { return TestConv3DNDHWCInstances(); } + +bool TestConv3DNDHWCF16Instances() { return TestConv3DNDHWCInstances(); } + +bool TestConv3DNDHWCF32Instances() { return TestConv3DNDHWCInstances(); } + +bool TestConv3DNDHWCInt8Instances() { return TestConv3DNDHWCInstances(); } + +} // anonymous namespace + +int main() +{ + bool res{true}; + res = TestConv3DNDHWC(); + std::cout << "TestConv3DNDHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + res = TestConv3DNDHWC2GBInput(); + std::cout << "TestConv3DNDHWC2GBInput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWC2GBFilters(); + std::cout << "TestConv3DNDHWC2GBFilters ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWC2GBOutput(); + std::cout << "TestConv3DNDHWC2GBOutput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + + res = TestConv3DNDHWCBF16Instances(); + std::cout << "TestConv3DNDHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + res = TestConv3DNDHWCF16Instances(); + std::cout << "TestConv3DNDHWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWCF32Instances(); + std::cout << "TestConv3DNDHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = TestConv3DNDHWCInt8Instances(); + std::cout << "TestConv3DNDHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + + return 0; +} diff --git a/test/convnd_fwd_xdl/convnd_fwd_xdl.cpp b/test/convnd_fwd_xdl/convnd_fwd_xdl.cpp deleted file mode 100644 index 03d9dcc29b3..00000000000 --- a/test/convnd_fwd_xdl/convnd_fwd_xdl.cpp +++ /dev/null @@ -1,678 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include - -#include "config.hpp" -#include "conv_utils.hpp" -#include "device.hpp" -#include "device_tensor.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "reference_conv_fwd.hpp" -#include "tensor_layout.hpp" -#include "test_util.hpp" - -// Forward declarations for conv instances. - -using DeviceConvFwdNoOpPtr = - ck::tensor_operation::device::DeviceConvFwdPtr; - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv1d_fwd_instance { - -void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(std::vector&); -void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(std::vector&); -void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(std::vector&); -void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector&); - -} // namespace device_conv1d_fwd_instance -namespace device_conv2d_fwd_instance { - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); - -} // namespace device_conv2d_fwd_instance -namespace device_conv3d_fwd_instance { - -void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(std::vector&); -void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(std::vector&); -void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(std::vector&); -void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector&); - -} // namespace device_conv3d_fwd_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -namespace { - -using bhalf_t = test_util::bhalf_t; - -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; - -template -using DeviceConvNDFwdInstance = ck::tensor_operation::device:: - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< - // clang-format off - InDataType, // - WeiDataType, // - OutDataType, // - InDataType, // - InElementOp, // Input Elementwise Operation - WeiElementOp, // Weights Elementwise Operation - OutElementOp, // Output Elementwise Operation - ConvFwdDefault, // ConvForwardSpecialization - SpatialDims, // SptialDims - 64, // BlockSize - 16, // MPerBlock - 16, // NPerBlock - 4, // K0PerBlock - 1, // K1 - 16, // MPerXDL - 16, // NPerXDL - 1, // MXdlPerWave - 1, // NXdlPerWave - S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 1, // ABlockTransferSrcScalarPerVector - 1, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 1, // BBlockTransferSrcScalarPerVector - 1, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockTransferAddExtraN - 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector -// clang-format on - -template -auto GetHostTensors(const ck::conv_util::ConvParams& params) -{ - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; - input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); - - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; - filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); - - const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; - output_dims.insert(std::end(output_dims), - std::begin(output_spatial_lengths), - std::end(output_spatial_lengths)); - - Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); - Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); - Tensor host_output( - ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - Tensor device_output( - ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - - std::generate(input.begin(), input.end(), [n = 0]() mutable { return InDataType(n++ * 0.1f); }); - std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); - std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); - std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); - - return std::make_tuple(input, weights, host_output, device_output); -} - -template -void RunReferenceConv(const ck::conv_util::ConvParams& params, - const Tensor& input, - const Tensor& weights, - Tensor& output) -{ - auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(input, - weights, - output, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); -} - -template -void RunConv(const ck::conv_util::ConvParams& params, - const Tensor& input, - const Tensor& weights, - Tensor& output) -{ - DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(input.mData.data()); - wei_device_buf.ToDevice(weights.mData.data()); - const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - - auto conv = DeviceConvNDFwdInstance(); - auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(!conv.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "Error! device_conv with the specified compilation parameters does " - "not support this Conv problem"); - } - - invoker.Run(argument); - out_device_buf.FromDevice(output.mData.data()); -} - -template -bool RunConvInstances(const ck::conv_util::ConvParams& params, - const std::vector& conv_ptrs, - const Tensor& input, - const Tensor& weights, - Tensor& output, - const Tensor& host_output) -{ - DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(input.mData.data()); - wei_device_buf.ToDevice(weights.mData.data()); - const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - - bool res{true}; - for(auto& conv_ptr : conv_ptrs) - { - auto invoker = conv_ptr->MakeInvokerPointer(); - auto argument = conv_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(conv_ptr->IsSupportedArgument(argument.get())) - { - invoker->Run(argument.get()); - out_device_buf.FromDevice(output.mData.data()); - res = res && - test_util::check_err( - output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); - hipGetErrorString( - hipMemset(out_device_buf.GetDeviceBuffer(), 0, out_device_buf.mMemSize)); - } - } - return res; -} - -bool TestConv1DNWC() -{ - bool res{true}; - ck::conv_util::ConvParams params; - params.num_dim_spatial = 1; - params.N = 2; - params.K = 16; - params.C = 4; - params.filter_spatial_lengths = std::vector{3}; - params.input_spatial_lengths = std::vector{16}; - params.conv_filter_strides = std::vector{1}; - params.conv_filter_dilations = std::vector{1}; - params.input_left_pads = std::vector{1}; - params.input_right_pads = std::vector{1}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& host_output = std::get<2>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - RunReferenceConv<1>(params, input, weights, host_output); - RunConv<1>(params, input, weights, device_output); - res = res && - test_util::check_err( - device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); - - return res; -} - -bool TestConv2DNHWC() -{ - bool res{true}; - ck::conv_util::ConvParams params; - params.N = 2; - params.K = 16; - params.C = 4; - params.input_spatial_lengths = std::vector{16, 16}; - params.conv_filter_strides = std::vector{1, 1}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& host_output = std::get<2>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - RunReferenceConv<2>(params, input, weights, host_output); - RunConv<2>(params, input, weights, device_output); - res = res && - test_util::check_err( - device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); - - return res; -} - -bool TestConv3DNDHWC() -{ - bool res{true}; - ck::conv_util::ConvParams params; - params.num_dim_spatial = 3; - params.N = 2; - params.K = 16; - params.C = 4; - params.filter_spatial_lengths = std::vector{3, 3, 3}; - params.input_spatial_lengths = std::vector{16, 16, 16}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& host_output = std::get<2>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - RunReferenceConv<3>(params, input, weights, host_output); - RunConv<3>(params, input, weights, device_output); - res = res && - test_util::check_err( - device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); - - return res; -} - -bool TestConv3DNDHWC2GBInput() -{ - // >2GB Input - ck::conv_util::ConvParams params; - params.num_dim_spatial = 3; - params.N = 2; - params.K = 16; - params.C = 32; - params.filter_spatial_lengths = std::vector{3, 3, 3}; - params.input_spatial_lengths = std::vector{32, 1000, 1000}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - try - { - RunConv<3>(params, input, weights, device_output); - } - catch(const std::runtime_error& err) - { - std::string err_msg{"Error! device_conv with the specified compilation parameters does " - "not support this Conv problem"}; - if(err.what() != err_msg) - { - return false; - } - return true; - } - std::cout << "Error: Failure checking oversized tensor!" << std::endl; - return false; -} - -bool TestConv3DNDHWC2GBFilters() -{ - // >2GB Filters - ck::conv_util::ConvParams params; - params.num_dim_spatial = 3; - params.N = 2; - params.K = 16; - params.C = 32; - params.filter_spatial_lengths = std::vector{4, 1000, 1000}; - params.input_spatial_lengths = std::vector{16, 16, 16}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - try - { - RunConv<3>(params, input, weights, device_output); - } - catch(const std::runtime_error& err) - { - std::string err_msg{"Error! device_conv with the specified compilation parameters does " - "not support this Conv problem"}; - if(err.what() != err_msg) - { - return false; - } - return true; - } - std::cout << "Error: Failure checking oversized tensor!" << std::endl; - return false; -} - -bool TestConv3DNDHWC2GBOutput() -{ - // >2GB Output - ck::conv_util::ConvParams params; - params.num_dim_spatial = 3; - params.N = 2; - params.K = 16; - params.C = 2; - params.filter_spatial_lengths = std::vector{1, 1, 1}; - params.input_spatial_lengths = std::vector{1000, 1000, 30}; - params.conv_filter_strides = std::vector{1, 1, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{2, 2, 2}; - params.input_right_pads = std::vector{2, 2, 2}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - try - { - RunConv<3>(params, input, weights, device_output); - } - catch(const std::runtime_error& err) - { - std::string err_msg{"Error! device_conv with the specified compilation parameters does " - "not support this Conv problem"}; - if(err.what() != err_msg) - { - return false; - } - return true; - } - std::cout << "Error: Failure checking oversized tensor!" << std::endl; - return false; -} - -template -bool TestConv1DNWCInstances() -{ - ck::conv_util::ConvParams params; - params.num_dim_spatial = 1; - params.filter_spatial_lengths = std::vector{3}; - params.input_spatial_lengths = std::vector{71}; - params.conv_filter_strides = std::vector{2}; - params.conv_filter_dilations = std::vector{1}; - params.input_left_pads = std::vector{1}; - params.input_right_pads = std::vector{1}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& host_output = std::get<2>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - RunReferenceConv<1>(params, input, weights, host_output); - - std::vector conv_ptrs; - ck::tensor_operation::device::device_conv1d_fwd_instance:: - add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); - - return RunConvInstances<1>(params, conv_ptrs, input, weights, device_output, host_output); -} -bool TestConv1DNWCBF16Instances() { return TestConv1DNWCInstances(); } - -bool TestConv1DNWCF16Instances() { return TestConv1DNWCInstances(); } - -bool TestConv1DNWCF32Instances() { return TestConv1DNWCInstances(); } - -bool TestConv1DNWCInt8Instances() { return TestConv1DNWCInstances(); } - -template -bool TestConv2DNHWCInstances() -{ - ck::conv_util::ConvParams params; - params.num_dim_spatial = 2; - params.filter_spatial_lengths = std::vector{3, 3}; - params.input_spatial_lengths = std::vector{71, 71}; - params.conv_filter_strides = std::vector{2, 2}; - params.conv_filter_dilations = std::vector{1, 1}; - params.input_left_pads = std::vector{1, 1}; - params.input_right_pads = std::vector{1, 1}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& host_output = std::get<2>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - RunReferenceConv<2>(params, input, weights, host_output); - - std::vector conv_ptrs; - ck::tensor_operation::device::device_conv2d_fwd_instance:: - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); - - return RunConvInstances<2>(params, conv_ptrs, input, weights, device_output, host_output); -} - -bool TestConv2DNHWCBF16Instances() { return TestConv2DNHWCInstances(); } - -bool TestConv2DNHWCF16Instances() { return TestConv2DNHWCInstances(); } - -bool TestConv2DNHWCF32Instances() { return TestConv2DNHWCInstances(); } - -bool TestConv2DNHWCInt8Instances() { return TestConv2DNHWCInstances(); } - -template -bool TestConv3DNDHWCInstances() -{ - ck::conv_util::ConvParams params; - params.num_dim_spatial = 3; - params.filter_spatial_lengths = std::vector{3, 3, 3}; - params.input_spatial_lengths = std::vector{71, 71, 71}; - params.conv_filter_strides = std::vector{2, 2, 2}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; - - auto host_tensors = GetHostTensors(params); - const Tensor& input = std::get<0>(host_tensors); - const Tensor& weights = std::get<1>(host_tensors); - Tensor& host_output = std::get<2>(host_tensors); - Tensor& device_output = std::get<3>(host_tensors); - - RunReferenceConv<3>(params, input, weights, host_output); - - std::vector conv_ptrs; - ck::tensor_operation::device::device_conv3d_fwd_instance:: - add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); - - return RunConvInstances<3>(params, conv_ptrs, input, weights, device_output, host_output); -} - -bool TestConv3DNDHWCBF16Instances() { return TestConv3DNDHWCInstances(); } - -bool TestConv3DNDHWCF16Instances() { return TestConv3DNDHWCInstances(); } - -bool TestConv3DNDHWCF32Instances() { return TestConv3DNDHWCInstances(); } - -bool TestConv3DNDHWCInt8Instances() { return TestConv3DNDHWCInstances(); } - -} // anonymous namespace - -int main() -{ - bool res{true}; - res = TestConv1DNWC(); - std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv2DNHWC(); - std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWC(); - std::cout << "TestConv3DNDHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - - res = TestConv3DNDHWC2GBInput(); - std::cout << "TestConv3DNDHWC2GBInput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWC2GBFilters(); - std::cout << "TestConv3DNDHWC2GBFilters ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWC2GBOutput(); - std::cout << "TestConv3DNDHWC2GBOutput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - - res = TestConv1DNWCBF16Instances(); - std::cout << "TestConv1DNWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNWCF16Instances(); - std::cout << "TestConv1DNWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNWCF32Instances(); - std::cout << "TestConv1DNWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNWCInt8Instances(); - std::cout << "TestConv1DNWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - - res = TestConv2DNHWCBF16Instances(); - std::cout << "TestConv2DNHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv2DNHWCF16Instances(); - std::cout << "TestConv2DNHWCF16Instances ....." << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv2DNHWCF32Instances(); - std::cout << "TestConv2DNHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv2DNHWCInt8Instances(); - std::cout << "TestConv2DNHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - - res = TestConv3DNDHWCBF16Instances(); - std::cout << "TestConv3DNDHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") - << std::endl; - res = TestConv3DNDHWCF16Instances(); - std::cout << "TestConv3DNDHWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWCF32Instances(); - std::cout << "TestConv3DNDHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWCInt8Instances(); - std::cout << "TestConv3DNDHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") - << std::endl; -} diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp new file mode 100644 index 00000000000..747e19c292b --- /dev/null +++ b/test/include/conv_test_util.hpp @@ -0,0 +1,262 @@ +#ifndef TEST_CONV_UTIL_HPP +#define TEST_CONV_UTIL_HPP + +#include +#include +#include +#include +#include +#include + +#include "config.hpp" +#include "conv_utils.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" +#include "test_util.hpp" + +namespace { + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + InDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + SpatialDims, // SptialDims + 64, // BlockSize + 16, // MPerBlock + 16, // NPerBlock + 4, // K0PerBlock + 1, // K1 + 16, // MPerXDL + 16, // NPerXDL + 1, // MXdlPerWave + 1, // NXdlPerWave + S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +} // namespace + +namespace test { +namespace conv { + +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +auto GetHostTensors(const ck::conv_util::ConvParams& params) +{ + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + Tensor device_output( + ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + + std::generate(input.begin(), input.end(), [n = 0]() mutable { return InDataType(n++ * 0.1f); }); + std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); + + return std::make_tuple(input, weights, host_output, device_output); +} + +template +void RunReferenceConv(const ck::conv_util::ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); +} + +template +void RunConv(const ck::conv_util::ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + + auto conv = DeviceConvNDFwdInstance(); + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + invoker.Run(argument); + out_device_buf.FromDevice(output.mData.data()); +} + +template +bool RunConvInstances(const ck::conv_util::ConvParams& params, + const std::vector& conv_ptrs, + const Tensor& input, + const Tensor& weights, + Tensor& output, + const Tensor& host_output) +{ + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + + bool res{true}; + for(auto& conv_ptr : conv_ptrs) + { + auto invoker = conv_ptr->MakeInvokerPointer(); + auto argument = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(conv_ptr->IsSupportedArgument(argument.get())) + { + invoker->Run(argument.get()); + out_device_buf.FromDevice(output.mData.data()); + res = res && + test::check_err( + output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + hipGetErrorString( + hipMemset(out_device_buf.GetDeviceBuffer(), 0, out_device_buf.mMemSize)); + } + } + return res; +} + +} // namespace conv +} // namespace test + +#endif From e38a182e9b32885c75796a8ddd904fc8c20600aa Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 10 Mar 2022 14:10:29 +0100 Subject: [PATCH 32/82] Fix data generation and use proper instances. --- test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp | 45 ++++++++++++++++------- test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp | 45 ++++++++++++++++------- test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp | 51 ++++++++++++++++++-------- test/include/conv_test_util.hpp | 13 ++++++- 4 files changed, 111 insertions(+), 43 deletions(-) diff --git a/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp index 69872dddb27..4533bc7353b 100644 --- a/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp +++ b/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp @@ -70,7 +70,7 @@ bool TestConv1DNWC() } template -bool TestConv1DNWCInstances() +bool TestConv1DNWCInstances(const std::vector& conv_ptrs) { ck::conv_util::ConvParams params; params.num_dim_spatial = 1; @@ -93,21 +93,40 @@ bool TestConv1DNWCInstances() Tensor& device_output = std::get<3>(host_tensors); test::conv::RunReferenceConv<1>(params, input, weights, host_output); - + return test::conv::RunConvInstances<1>( + params, conv_ptrs, input, weights, device_output, host_output); +} +bool TestConv1DNWCBF16Instances() +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv1d_fwd_instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); - - return test::conv::RunConvInstances<1>( - params, conv_ptrs, input, weights, device_output, host_output); + return TestConv1DNWCInstances(conv_ptrs); } -bool TestConv1DNWCBF16Instances() { return TestConv1DNWCInstances(); } -bool TestConv1DNWCF16Instances() { return TestConv1DNWCInstances(); } +bool TestConv1DNWCF16Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); + return TestConv1DNWCInstances(conv_ptrs); +} -bool TestConv1DNWCF32Instances() { return TestConv1DNWCInstances(); } +bool TestConv1DNWCF32Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); + return TestConv1DNWCInstances(conv_ptrs); +} -bool TestConv1DNWCInt8Instances() { return TestConv1DNWCInstances(); } +bool TestConv1DNWCInt8Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv1d_fwd_instance:: + add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); + return TestConv1DNWCInstances(conv_ptrs); +} } // anonymous namespace @@ -118,11 +137,11 @@ int main() std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv1DNWCBF16Instances(); - std::cout << "TestConv1DNWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv1DNWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv1DNWCF16Instances(); - std::cout << "TestConv1DNWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv1DNWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv1DNWCF32Instances(); - std::cout << "TestConv1DNWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv1DNWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv1DNWCInt8Instances(); - std::cout << "TestConv1DNWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv1DNWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; } diff --git a/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp index c34c9a46776..2ff6e64b4a5 100644 --- a/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp +++ b/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp @@ -60,7 +60,7 @@ bool TestConv2DNHWC() } template -bool TestConv2DNHWCInstances() +bool TestConv2DNHWCInstances(const std::vector& conv_ptrs) { ck::conv_util::ConvParams params; params.num_dim_spatial = 2; @@ -83,22 +83,41 @@ bool TestConv2DNHWCInstances() Tensor& device_output = std::get<3>(host_tensors); test::conv::RunReferenceConv<2>(params, input, weights, host_output); + return test::conv::RunConvInstances<2>( + params, conv_ptrs, input, weights, device_output, host_output); +} +bool TestConv2DNHWCBF16Instances() +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); - - return test::conv::RunConvInstances<2>( - params, conv_ptrs, input, weights, device_output, host_output); + return TestConv2DNHWCInstances(conv_ptrs); } -bool TestConv2DNHWCBF16Instances() { return TestConv2DNHWCInstances(); } - -bool TestConv2DNHWCF16Instances() { return TestConv2DNHWCInstances(); } +bool TestConv2DNHWCF16Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + return TestConv2DNHWCInstances(conv_ptrs); +} -bool TestConv2DNHWCF32Instances() { return TestConv2DNHWCInstances(); } +bool TestConv2DNHWCF32Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + return TestConv2DNHWCInstances(conv_ptrs); +} -bool TestConv2DNHWCInt8Instances() { return TestConv2DNHWCInstances(); } +bool TestConv2DNHWCInt8Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + return TestConv2DNHWCInstances(conv_ptrs); +} } // anonymous namespace @@ -109,13 +128,13 @@ int main() std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv2DNHWCBF16Instances(); - std::cout << "TestConv2DNHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv2DNHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv2DNHWCF16Instances(); - std::cout << "TestConv2DNHWCF16Instances ....." << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv2DNHWCF16Instances ....." << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv2DNHWCF32Instances(); - std::cout << "TestConv2DNHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv2DNHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv2DNHWCInt8Instances(); - std::cout << "TestConv2DNHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv2DNHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return 0; } diff --git a/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp index 237a9001f80..85ed49eb693 100644 --- a/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp +++ b/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp @@ -199,7 +199,7 @@ bool TestConv3DNDHWC2GBOutput() } template -bool TestConv3DNDHWCInstances() +bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs) { ck::conv_util::ConvParams params; params.num_dim_spatial = 3; @@ -222,22 +222,41 @@ bool TestConv3DNDHWCInstances() Tensor& device_output = std::get<3>(host_tensors); test::conv::RunReferenceConv<3>(params, input, weights, host_output); + return test::conv::RunConvInstances<3>( + params, conv_ptrs, input, weights, device_output, host_output); +} +bool TestConv3DNDHWCBF16Instances() +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); - - return test::conv::RunConvInstances<3>( - params, conv_ptrs, input, weights, device_output, host_output); + return TestConv3DNDHWCInstances(conv_ptrs); } -bool TestConv3DNDHWCBF16Instances() { return TestConv3DNDHWCInstances(); } - -bool TestConv3DNDHWCF16Instances() { return TestConv3DNDHWCInstances(); } +bool TestConv3DNDHWCF16Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); + return TestConv3DNDHWCInstances(conv_ptrs); +} -bool TestConv3DNDHWCF32Instances() { return TestConv3DNDHWCInstances(); } +bool TestConv3DNDHWCF32Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); + return TestConv3DNDHWCInstances(conv_ptrs); +} -bool TestConv3DNDHWCInt8Instances() { return TestConv3DNDHWCInstances(); } +bool TestConv3DNDHWCInt8Instances() +{ + std::vector conv_ptrs; + ck::tensor_operation::device::device_conv3d_fwd_instance:: + add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); + return TestConv3DNDHWCInstances(conv_ptrs); +} } // anonymous namespace @@ -248,21 +267,21 @@ int main() std::cout << "TestConv3DNDHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWC2GBInput(); - std::cout << "TestConv3DNDHWC2GBInput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv3DNDHWC2GBInput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWC2GBFilters(); - std::cout << "TestConv3DNDHWC2GBFilters ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv3DNDHWC2GBFilters ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWC2GBOutput(); - std::cout << "TestConv3DNDHWC2GBOutput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv3DNDHWC2GBOutput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWCBF16Instances(); - std::cout << "TestConv3DNDHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + std::cout << "\nTestConv3DNDHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWCF16Instances(); - std::cout << "TestConv3DNDHWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv3DNDHWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWCF32Instances(); - std::cout << "TestConv3DNDHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv3DNDHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWCInt8Instances(); - std::cout << "TestConv3DNDHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") + std::cout << "\nTestConv3DNDHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return 0; diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp index 747e19c292b..2bf758ba590 100644 --- a/test/include/conv_test_util.hpp +++ b/test/include/conv_test_util.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -115,7 +116,17 @@ auto GetHostTensors(const ck::conv_util::ConvParams& params) Tensor device_output( ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - std::generate(input.begin(), input.end(), [n = 0]() mutable { return InDataType(n++ * 0.1f); }); + std::mt19937 gen(11939); + if constexpr (std::is_same::value) + { + std::uniform_int_distribution<> dis(-5, 5); + std::generate(input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + } + else + { + std::uniform_real_distribution<> dis(0.f, 1.f); + std::generate(input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + } std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); From ef5296c72cf679cc4c8e06e3da79616d08531840 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 10 Mar 2022 15:20:52 +0100 Subject: [PATCH 33/82] Formatting --- test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp | 6 ++++-- test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp | 17 ++++++++++------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp index 4533bc7353b..7da85cbf4e6 100644 --- a/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp +++ b/test/convnd_fwd_xdl/conv1d_fwd_xdl.cpp @@ -137,11 +137,13 @@ int main() std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv1DNWCBF16Instances(); - std::cout << "\nTestConv1DNWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv1DNWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; res = TestConv1DNWCF16Instances(); std::cout << "\nTestConv1DNWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv1DNWCF32Instances(); std::cout << "\nTestConv1DNWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv1DNWCInt8Instances(); - std::cout << "\nTestConv1DNWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv1DNWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; } diff --git a/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp index 2ff6e64b4a5..008c39f1db7 100644 --- a/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp +++ b/test/convnd_fwd_xdl/conv2d_fwd_xdl.cpp @@ -88,7 +88,7 @@ bool TestConv2DNHWCInstances(const std::vector& conv_ptrs) } bool TestConv2DNHWCBF16Instances() -{ +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); @@ -96,7 +96,7 @@ bool TestConv2DNHWCBF16Instances() } bool TestConv2DNHWCF16Instances() -{ +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); @@ -104,7 +104,7 @@ bool TestConv2DNHWCF16Instances() } bool TestConv2DNHWCF32Instances() -{ +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); @@ -112,7 +112,7 @@ bool TestConv2DNHWCF32Instances() } bool TestConv2DNHWCInt8Instances() -{ +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); @@ -128,13 +128,16 @@ int main() std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv2DNHWCBF16Instances(); - std::cout << "\nTestConv2DNHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv2DNHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; res = TestConv2DNHWCF16Instances(); std::cout << "\nTestConv2DNHWCF16Instances ....." << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv2DNHWCF32Instances(); - std::cout << "\nTestConv2DNHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv2DNHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; res = TestConv2DNHWCInt8Instances(); - std::cout << "\nTestConv2DNHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv2DNHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; return 0; } From f8524e173ec6e6cac8e26e439bb8da1fb0e33817 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 10 Mar 2022 15:21:14 +0100 Subject: [PATCH 34/82] Skip tensor initialization if not necessary. --- test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp | 55 ++++++++++++++------------ test/include/conv_test_util.hpp | 31 +++++++++------ 2 files changed, 48 insertions(+), 38 deletions(-) diff --git a/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp b/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp index 85ed49eb693..e9478df7d9e 100644 --- a/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp +++ b/test/convnd_fwd_xdl/conv3d_fwd_xdl.cpp @@ -84,12 +84,13 @@ bool TestConv3DNDHWC2GBInput() params.input_left_pads = std::vector{1, 1, 1}; params.input_right_pads = std::vector{1, 1, 1}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = + test::conv::GetHostTensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -127,12 +128,13 @@ bool TestConv3DNDHWC2GBFilters() params.input_left_pads = std::vector{1, 1, 1}; params.input_right_pads = std::vector{1, 1, 1}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = + test::conv::GetHostTensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -170,12 +172,13 @@ bool TestConv3DNDHWC2GBOutput() params.input_left_pads = std::vector{2, 2, 2}; params.input_right_pads = std::vector{2, 2, 2}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = + test::conv::GetHostTensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -204,7 +207,7 @@ bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs ck::conv_util::ConvParams params; params.num_dim_spatial = 3; params.filter_spatial_lengths = std::vector{3, 3, 3}; - params.input_spatial_lengths = std::vector{71, 71, 71}; + params.input_spatial_lengths = std::vector{32, 32, 32}; params.conv_filter_strides = std::vector{2, 2, 2}; params.conv_filter_dilations = std::vector{1, 1, 1}; params.input_left_pads = std::vector{1, 1, 1}; @@ -227,7 +230,7 @@ bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs } bool TestConv3DNDHWCBF16Instances() -{ +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); @@ -235,7 +238,7 @@ bool TestConv3DNDHWCBF16Instances() } bool TestConv3DNDHWCF16Instances() -{ +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); @@ -243,7 +246,7 @@ bool TestConv3DNDHWCF16Instances() } bool TestConv3DNDHWCF32Instances() -{ +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); @@ -251,7 +254,7 @@ bool TestConv3DNDHWCF32Instances() } bool TestConv3DNDHWCInt8Instances() -{ +{ std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); @@ -277,9 +280,11 @@ int main() std::cout << "\nTestConv3DNDHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; res = TestConv3DNDHWCF16Instances(); - std::cout << "\nTestConv3DNDHWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv3DNDHWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; res = TestConv3DNDHWCF32Instances(); - std::cout << "\nTestConv3DNDHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + std::cout << "\nTestConv3DNDHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; res = TestConv3DNDHWCInt8Instances(); std::cout << "\nTestConv3DNDHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp index 2bf758ba590..c67e7eb412b 100644 --- a/test/include/conv_test_util.hpp +++ b/test/include/conv_test_util.hpp @@ -88,7 +88,7 @@ template -auto GetHostTensors(const ck::conv_util::ConvParams& params) +auto GetHostTensors(const ck::conv_util::ConvParams& params, bool init = true) { std::vector input_dims{static_cast(params.N), static_cast(params.C)}; @@ -116,20 +116,25 @@ auto GetHostTensors(const ck::conv_util::ConvParams& params) Tensor device_output( ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - std::mt19937 gen(11939); - if constexpr (std::is_same::value) + if(init) { - std::uniform_int_distribution<> dis(-5, 5); - std::generate(input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); - } - else - { - std::uniform_real_distribution<> dis(0.f, 1.f); - std::generate(input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + std::mt19937 gen(11939); + if constexpr(std::is_same::value) + { + std::uniform_int_distribution<> dis(-5, 5); + std::generate( + input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + } + else + { + std::uniform_real_distribution<> dis(0.f, 1.f); + std::generate( + input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + } + std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); } - std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); - std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); - std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); return std::make_tuple(input, weights, host_output, device_output); } From 30aabf38060b9ce6bb359c1f2fbfe76c752b8581 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 10 Mar 2022 23:00:14 +0100 Subject: [PATCH 35/82] Fix CMakefiles. --- .../gpu/CMakeLists.txt | 1 + .../gpu/conv1d_fwd/CMakeLists.txt | 3 +++ .../gpu/conv3d_fwd/CMakeLists.txt | 4 ++-- test/convnd_fwd/CMakeLists.txt | 18 ++++++++++++------ 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 52277f0ee3d..70dd62a9af4 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -23,6 +23,7 @@ add_subdirectory(gemm_bias_relu_add) add_subdirectory(batched_gemm) add_subdirectory(conv1d_fwd) add_subdirectory(conv2d_fwd) +add_subdirectory(conv3d_fwd) add_subdirectory(conv2d_fwd_bias_relu) add_subdirectory(conv2d_fwd_bias_relu_add) add_subdirectory(conv2d_fwd_bias_relu_atomic_add) diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt index cadc374d831..6c7c3e4f788 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/CMakeLists.txt @@ -1,6 +1,9 @@ # device_conv1d_fwd_instance set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE + device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp; + device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp; device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; + device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp; ) add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt index 44ef8c06b15..f6849a7bb20 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/CMakeLists.txt @@ -1,11 +1,11 @@ # device_conv3d_fwd_instance -set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE +set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; ) -add_library(device_conv3d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +add_library(device_conv3d_fwd_instance SHARED ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE}) target_compile_features(device_conv3d_fwd_instance PUBLIC) set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) install(TARGETS device_conv3d_fwd_instance LIBRARY DESTINATION lib) diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 6afb10ed826..4608cdbe86a 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -1,11 +1,17 @@ +add_custom_target(test_convnd_fwd) + add_test_executable(test_conv1d_fwd conv1d_fwd.cpp) -target_link_libraries(conv1d_fwd PRIVATE host_tensor) -target_link_libraries(conv1d_fwd PRIVATE device_conv1d_fwd_instance) +target_link_libraries(test_conv1d_fwd PRIVATE host_tensor) +target_link_libraries(test_conv1d_fwd PRIVATE device_conv1d_fwd_instance) +add_dependencies(test_convnd_fwd test_conv1d_fwd) add_test_executable(test_conv2d_fwd conv2d_fwd.cpp) -target_link_libraries(conv2d_fwd PRIVATE host_tensor) -target_link_libraries(conv2d_fwd PRIVATE device_conv2d_fwd_instance) +target_link_libraries(test_conv2d_fwd PRIVATE host_tensor) +target_link_libraries(test_conv2d_fwd PRIVATE device_conv2d_fwd_instance) +add_dependencies(test_convnd_fwd test_conv2d_fwd) add_test_executable(test_conv3d_fwd conv3d_fwd.cpp) -target_link_libraries(conv3d_fwd PRIVATE host_tensor) -target_link_libraries(conv3d_fwd PRIVATE device_conv3d_fwd_instance) \ No newline at end of file +target_link_libraries(test_conv3d_fwd PRIVATE host_tensor) +target_link_libraries(test_conv3d_fwd PRIVATE device_conv3d_fwd_instance) +add_dependencies(test_convnd_fwd test_conv3d_fwd) + From bd39c81f09af15dea79913097722964c14455b20 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 10 Mar 2022 23:00:49 +0100 Subject: [PATCH 36/82] Remove redundant conv2d_fwd test. --- test/CMakeLists.txt | 1 - test/conv2d_fwd/CMakeLists.txt | 3 - test/conv2d_fwd/conv2d_fwd.cpp | 296 --------------------------------- test/convnd_fwd/conv2d_fwd.cpp | 4 + 4 files changed, 4 insertions(+), 300 deletions(-) delete mode 100644 test/conv2d_fwd/CMakeLists.txt delete mode 100644 test/conv2d_fwd/conv2d_fwd.cpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index eec8b5b852e..2d6e28f54c8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -36,6 +36,5 @@ add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_split_k) -add_subdirectory(conv2d_fwd) add_subdirectory(convnd_fwd) add_subdirectory(conv2d_bwd_data) diff --git a/test/conv2d_fwd/CMakeLists.txt b/test/conv2d_fwd/CMakeLists.txt deleted file mode 100644 index b0e55797e5d..00000000000 --- a/test/conv2d_fwd/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_test_executable(test_conv2d_fwd conv2d_fwd.cpp) -target_link_libraries(test_conv2d_fwd PRIVATE host_tensor) -target_link_libraries(test_conv2d_fwd PRIVATE device_conv2d_fwd_instance) diff --git a/test/conv2d_fwd/conv2d_fwd.cpp b/test/conv2d_fwd/conv2d_fwd.cpp deleted file mode 100644 index bba4aa7ce09..00000000000 --- a/test/conv2d_fwd/conv2d_fwd.cpp +++ /dev/null @@ -1,296 +0,0 @@ -#include "config.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_conv.hpp" -#include "tensor_layout.hpp" -#include "device_tensor.hpp" -#include "device_conv_fwd.hpp" -#include "element_wise_operation.hpp" -#include "reference_conv_fwd.hpp" -#include "test_util.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace device_conv2d_fwd_instance { - -using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr; - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); - -void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( - std::vector&); - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); - -void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); -} // namespace device_conv2d_fwd_instance -} // namespace device -} // namespace tensor_operation -} // namespace ck - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -int main(int argc, char* argv[]) -{ - int data_type = 0; - int init_method = 0; - - // Conv shape - ck::index_t N = 128; - ck::index_t K = 256; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t conv_stride_h = 2; - ck::index_t conv_stride_w = 2; - ck::index_t conv_dilation_h = 1; - ck::index_t conv_dilation_w = 1; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - if(argc == 1) - { - data_type = 1; - init_method = 1; - } - else if(argc == 3) - { - data_type = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - } - else if(argc == 18) - { - data_type = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - - N = std::stoi(argv[3]); - K = std::stoi(argv[4]); - C = std::stoi(argv[5]); - Y = std::stoi(argv[6]); - X = std::stoi(argv[7]); - Hi = std::stoi(argv[8]); - Wi = std::stoi(argv[9]); - conv_stride_h = std::stoi(argv[10]); - conv_stride_w = std::stoi(argv[11]); - conv_dilation_h = std::stoi(argv[12]); - conv_dilation_w = std::stoi(argv[13]); - in_left_pad_h = std::stoi(argv[14]); - in_left_pad_w = std::stoi(argv[15]); - in_right_pad_h = std::stoi(argv[16]); - in_right_pad_w = std::stoi(argv[17]); - } - else - { - printf("arg1: data type (0=fp32, 1=fp16, 2= bfp16, 3= int8_t )\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3 to 17: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(1); - } - - auto Run = [&](auto input_type, auto wei_type, auto out_type) { - using InDataType = decltype(input_type); - using WeiDataType = decltype(wei_type); - using OutDataType = decltype(out_type); - - using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; - - const ck::index_t YEff = (Y - 1) * conv_dilation_h + 1; - const ck::index_t XEff = (X - 1) * conv_dilation_w + 1; - - const ck::index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const ck::index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - - const std::vector input_spatial_lengths{Hi, Wi}; - const std::vector filter_spatial_lengths{Y, X}; - const std::vector output_spatial_lengths{Ho, Wo}; - const std::vector conv_filter_strides{conv_stride_h, conv_stride_w}; - const std::vector conv_filter_dilations{conv_dilation_h, conv_dilation_w}; - const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; - const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; - - auto f_host_tensor_descriptor = - [](std::size_t N_, std::size_t C_, std::size_t H, std::size_t W) { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi)); - Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X)); - Tensor out_n_k_ho_wo_host_result(f_host_tensor_descriptor(N, K, Ho, Wo)); - Tensor out_n_k_ho_wo_device_result(f_host_tensor_descriptor(N, K, Ho, Wo)); - - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; - std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; - std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0, 1}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-1, 1}); - } - - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * - out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - - using PassThrough = ck::tensor_operation::element_wise::PassThrough; - - using DeviceConvFwdNoOpPtr = - ck::tensor_operation::device::DeviceConvFwdPtr; - - // add device Conv instances - std::vector conv_ptrs; - - if constexpr(ck::is_same_v, float> && - ck::is_same_v, float> && - ck::is_same_v, float>) - { - ck::tensor_operation::device::device_conv2d_fwd_instance:: - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); - } - else if constexpr(ck::is_same_v, ck::half_t> && - ck::is_same_v, ck::half_t> && - ck::is_same_v, ck::half_t>) - { - ck::tensor_operation::device::device_conv2d_fwd_instance:: - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); - - ck::tensor_operation::device::device_conv2d_fwd_instance:: - add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); - } - else if constexpr(ck::is_same_v, ck::bhalf_t> && - ck::is_same_v, ck::bhalf_t> && - ck::is_same_v, ck::bhalf_t>) - { - ck::tensor_operation::device::device_conv2d_fwd_instance:: - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); - } - else if constexpr(ck::is_same_v, int8_t> && - ck::is_same_v, int8_t> && - ck::is_same_v, int8_t>) - { - ck::tensor_operation::device::device_conv2d_fwd_instance:: - add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); - } - - if(conv_ptrs.size() <= 0) - { - throw std::runtime_error("wrong! no device Conv instance found"); - } - - auto ref_conv = ReferenceConvFwdInstance{}; - auto ref_invoker = ref_conv.MakeInvoker(); - - auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); - - // profile device Conv instances - bool success = false; - for(auto& conv_ptr : conv_ptrs) - { - auto argument_ptr = conv_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - auto invoker_ptr = conv_ptr->MakeInvokerPointer(); - - if(conv_ptr->IsSupportedArgument(argument_ptr.get())) - { - invoker_ptr->Run(argument_ptr.get(), 0); - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - bool res = test::check_err(out_n_k_ho_wo_device_result.mData, - out_n_k_ho_wo_host_result.mData, - "Error: incorrect results!", - 1e-5f, - 1e-4f); - if(!res) - { - success = false; - break; - } - success = true; - } - } - - if(success) - { - std::cout << "test conv2d fwd : Pass" << std::endl; - return 0; - } - else - { - std::cout << "test conv2d fwd: Fail " << std::endl; - return -1; - } - }; - int res = -1; - if(data_type == 0) - { - res = Run(float(), float(), float()); - } - else if(data_type == 1) - { - res = Run(ck::half_t(), ck::half_t(), ck::half_t()); - } - else if(data_type == 2) - { - Run(ck::bhalf_t(), ck::bhalf_t(), ck::bhalf_t()); - } - else if(data_type == 3) - { - res = Run(int8_t(), int8_t(), int8_t()); - } - - return res; -} diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp index 008c39f1db7..624db66b9e1 100644 --- a/test/convnd_fwd/conv2d_fwd.cpp +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -24,6 +24,8 @@ namespace device_conv2d_fwd_instance { void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); +void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances( + std::vector&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); @@ -100,6 +102,8 @@ bool TestConv2DNHWCF16Instances() std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); return TestConv2DNHWCInstances(conv_ptrs); } From 413839c42267c6001b5184eb6046b81831f363d9 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 10 Mar 2022 23:34:02 +0100 Subject: [PATCH 37/82] Lower problem size for conv3D UT. --- test/convnd_fwd/conv3d_fwd.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index 4f0d8bbfa56..6aa05c41d28 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -205,13 +205,13 @@ template bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs) { ck::conv_util::ConvParams params; - params.num_dim_spatial = 3; - params.filter_spatial_lengths = std::vector{3, 3, 2}; - params.input_spatial_lengths = std::vector{71, 71, 2}; - params.conv_filter_strides = std::vector{2, 2, 1}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; + params.N = 64 params.num_dim_spatial = 3; + params.filter_spatial_lengths = std::vector{3, 3, 2}; + params.input_spatial_lengths = std::vector{32, 32, 2}; + params.conv_filter_strides = std::vector{2, 2, 2}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = test::conv::GetHostTensors Date: Thu, 10 Mar 2022 23:34:30 +0100 Subject: [PATCH 38/82] 3D case for convnd example. --- example/09_convnd_fwd/convnd_fwd_xdl.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl.cpp b/example/09_convnd_fwd/convnd_fwd_xdl.cpp index 6342e8f6200..d26a52b2fdb 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl.cpp @@ -84,6 +84,9 @@ DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) { switch(num_dim_spatial) { + case 3: { + return std::make_unique>(); + } case 2: { return std::make_unique>(); } @@ -173,6 +176,9 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector switch(num_dim_spatial) { + case 3: { + return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{}); + } case 2: { return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); } @@ -360,6 +372,11 @@ int main(int argc, char* argv[]) switch(num_dim_spatial) { + case 3: { + auto ref_conv = ReferenceConvNDFwdInstance<3>(); + verify_f(ref_conv); + break; + } case 2: { auto ref_conv = ReferenceConvNDFwdInstance<2>(); verify_f(ref_conv); From b7e49f9c42f3615cf01a6bceeb478d96d98f32bc Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 11 Mar 2022 09:25:12 +0100 Subject: [PATCH 39/82] Remove leftovers after merge. --- device_operation/CMakeLists.txt | 220 -------------------------------- 1 file changed, 220 deletions(-) delete mode 100644 device_operation/CMakeLists.txt diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt deleted file mode 100644 index 1cc3b8d4039..00000000000 --- a/device_operation/CMakeLists.txt +++ /dev/null @@ -1,220 +0,0 @@ -include_directories(BEFORE - include - ${PROJECT_SOURCE_DIR}/host/host_tensor/include - ${PROJECT_SOURCE_DIR}/device/include - ${PROJECT_SOURCE_DIR}/device_operation/include - ${PROJECT_SOURCE_DIR}/profiler/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation - ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform - ${PROJECT_SOURCE_DIR}/external/rocm/include -) - -# device_gemm_instance -set(DEVICE_GEMM_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp; -) - -# device_gemm_bias_2d_instance -set(DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instance.cpp; -) - -# device_gemm_bias_relu_instance -set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; -) - -# device_gemm_bias_relu_add_instance -set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; -) - -set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instance.cpp; -) - -# device_conv1d_fwd_instance -set(DEVICE_CONV1D_FWD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instance.cpp; -) - -# device_conv2d_fwd_instance -set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -# device_conv2d_fwd_bias_relu_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -# device_conv2d_fwd_bias_relu_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -# device_conv2d_fwd_bias_relu_atomic_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -# device_conv2d_bwd_data_instance -set(DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; -) - -# device_conv3d_fwd_instance -set(DEVICE_CONV3D_FWD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instance.cpp; -) - -# device_reduce_instance -set(DEVICE_REDUCE_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f16_f16_f16.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f16_f32_f16.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f32_f32_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f32_f64_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_f64_f64_f64.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f16_f16_f16.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f16_f32_f16.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f32_f32_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f32_f64_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_threadwise_f64_f64_f64.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f16_f16_f16.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f32_f32_f16.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f32_f32_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f64_f64_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_blockwise_second_call_f64_f64_f64.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f16_f32_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f32_f32_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_atomic_add_f32_f64_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/src/device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.cpp; -) - -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) -add_library(device_gemm_bias_2d_instance SHARED ${DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE}) -add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) -add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) -add_library(device_batched_gemm_instance SHARED ${DEVICE_BATCHED_GEMM_INSTANCE_SOURCE}) -add_library(device_conv1d_fwd_instance SHARED ${DEVICE_CONV1D_FWD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) -add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) -add_library(device_conv2d_bwd_data_instance SHARED ${DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE}) -add_library(device_conv3d_fwd_instance SHARED ${DEVICE_CONV3D_FWD_INSTANCE_SOURCE}) -add_library(device_reduce_instance SHARED ${DEVICE_REDUCE_INSTANCE_SOURCE}) - -target_include_directories(device_gemm_instance SYSTEM PUBLIC $) -target_include_directories(device_gemm_bias_2d_instance SYSTEM PUBLIC $) -target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $) -target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $) -target_include_directories(device_batched_gemm_instance SYSTEM PUBLIC $) -target_include_directories(device_conv1d_fwd_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) -target_include_directories(device_conv2d_bwd_data_instance SYSTEM PUBLIC $) -target_include_directories(device_conv3d_fwd_instance SYSTEM PUBLIC $) -target_include_directories(device_reduce_instance SYSTEM PUBLIC $) - -target_compile_features(device_gemm_instance PUBLIC) -target_compile_features(device_gemm_bias_2d_instance PUBLIC) -target_compile_features(device_gemm_bias_relu_instance PUBLIC) -target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) -target_compile_features(device_batched_gemm_instance PUBLIC) -target_compile_features(device_conv1d_fwd_instance PUBLIC) -target_compile_features(device_conv2d_fwd_instance PUBLIC) -target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) -target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) -target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) -target_compile_features(device_conv2d_bwd_data_instance PUBLIC) -target_compile_features(device_conv3d_fwd_instance PUBLIC) -target_compile_features(device_reduce_instance PUBLIC) - -set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_batched_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv1d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_conv3d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -set_target_properties(device_reduce_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) - -install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) -install(TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib) -install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) -install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) -install(TARGETS device_batched_gemm_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv1d_fwd_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv2d_bwd_data_instance LIBRARY DESTINATION lib) -install(TARGETS device_conv3d_fwd_instance LIBRARY DESTINATION lib) -install(TARGETS device_reduce_instance LIBRARY DESTINATION lib) From 4e6dfda459a416efe05899d44e541db3fd7bc82a Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 11 Mar 2022 15:20:26 +0100 Subject: [PATCH 40/82] Add Conv Specialization string to GetTypeString --- .../device/convolution_forward_specialization.hpp | 14 ++++++++++++++ ...ice_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 3 ++- .../device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp | 3 ++- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp index e047acee76f..d1c0eb8cca2 100644 --- a/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp +++ b/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -1,6 +1,8 @@ #ifndef CONVOLUTION_FORWARD_SPECIALIZATION #define CONVOLUTION_FORWARD_SPECIALIZATION +#include + namespace ck { namespace tensor_operation { namespace device { @@ -13,6 +15,18 @@ enum ConvolutionForwardSpecialization_t OddC, }; +inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization_t& s) +{ + switch(s) + { + case Default: return "Default"; + case Filter1x1Pad0: return "Filter1x1Pad0"; + case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case OddC: return "OddC"; + default: return "Unrecognized specialization!"; + } +} + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 3280b9ea30a..219f76062a5 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -875,7 +875,8 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << getConvFwdSpecializationStr(ConvForwardSpecialization) << ">"; // clang-format on diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index 7b9dc0c8ffd..b600d7f78d4 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1023,7 +1023,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " - << K0PerBlock + << K0PerBlock << ", " + << getConvFwdSpecializationStr(ConvForwardSpecialization) << ">"; // clang-format on From 37f92a37ebb0741c1660ea4b92030bd16f0b4409 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 16 Mar 2022 16:56:12 +0100 Subject: [PATCH 41/82] Skip instance causing numerical errors. --- .../device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp | 3 ++- .../device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp | 3 ++- .../device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp index 48a13d67734..bcb3a776a1a 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -33,7 +33,8 @@ using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + // FIXME: this instance causes numerical errors. + // DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index 826c2fbdce6..c0f4a874df8 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -34,7 +34,8 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + // FIXME: this instance causes numerical errors. + // DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index 565325d8b7e..a59b1513f86 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -33,7 +33,8 @@ using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + // FIXME: this instance causes numerical errors. + // DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, From ca9b5377a9963ec235c4ea5618c1e3b958f24e23 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 16 Mar 2022 18:06:23 +0100 Subject: [PATCH 42/82] Small fixes. --- .../device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp | 3 ++- test/convnd_fwd/conv3d_fwd.cpp | 3 ++- test/include/conv_test_util.hpp | 12 ++++++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index b600d7f78d4..4612e92de95 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -749,6 +749,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float Run(const Argument& arg, int nrepeat = 1) { +#if 0 { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " @@ -761,7 +762,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - +#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index 6aa05c41d28..45438616bc6 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -205,7 +205,8 @@ template bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs) { ck::conv_util::ConvParams params; - params.N = 64 params.num_dim_spatial = 3; + params.N = 64; + params.num_dim_spatial = 3; params.filter_spatial_lengths = std::vector{3, 3, 2}; params.input_spatial_lengths = std::vector{32, 32, 2}; params.conv_filter_strides = std::vector{2, 2, 2}; diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp index c67e7eb412b..406d6f3e840 100644 --- a/test/include/conv_test_util.hpp +++ b/test/include/conv_test_util.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "config.hpp" @@ -131,7 +132,7 @@ auto GetHostTensors(const ck::conv_util::ConvParams& params, bool init = true) std::generate( input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); } - std::fill(weights.begin(), weights.end(), WeiDataType(0.5f)); + std::fill(weights.begin(), weights.end(), WeiDataType(1.5f)); std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); } @@ -260,11 +261,18 @@ bool RunConvInstances(const ck::conv_util::ConvParams& params, if(conv_ptr->IsSupportedArgument(argument.get())) { + float atol{1e-5f}; + float rtol{1e-4f}; + if constexpr (std::is_same_v) + { + atol = 1e-4f; + rtol = 2.5e-3f; + } invoker->Run(argument.get()); out_device_buf.FromDevice(output.mData.data()); res = res && test::check_err( - output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + output.mData, host_output.mData, "Error: incorrect results!", atol, rtol); hipGetErrorString( hipMemset(out_device_buf.GetDeviceBuffer(), 0, out_device_buf.mMemSize)); } From 67d845e93375baa713553b075d8e2f43d9eb6f1d Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 16 Mar 2022 18:13:16 +0100 Subject: [PATCH 43/82] Remove redundant includes. --- .../tensor_operation/gpu/element/element_wise_operation.hpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 72d3c16ae6e..b2513fb4d3e 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -2,10 +2,6 @@ #define CK_ELEMENT_WISE_OPERATION_HPP #include "data_type.hpp" -#include "data_type.hpp" - -#include "data_type.hpp" - namespace ck { namespace tensor_operation { namespace element_wise { From 66cb8504049b336ede8c9ebe4746e9673c75ec1e Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Mar 2022 11:42:57 +0100 Subject: [PATCH 44/82] Fix namespace name error. --- test/gemm/gemm_bf16.cpp | 2 +- test/gemm/gemm_fp32.cpp | 4 +- test/gemm/gemm_int8.cpp | 2 +- .../reference_conv_fwd/reference_conv_fwd.cpp | 57 +++++++++---------- 4 files changed, 32 insertions(+), 33 deletions(-) diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index b6d54fcae80..7dd6b706f49 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -136,7 +136,7 @@ bool TestGemm(DeviceGemmPtr_& gemmPtr) bf16_to_f32_(c_device_bf16, c_device_fp32); // Assert - bool res = test_util::check_err( + bool res = test::check_err( c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index a4cae6db2bc..8e14a6f31a3 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -111,8 +111,8 @@ bool TestGemm(DeviceGemmPtr_& gemmPtr) gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); // Assert - bool res = test_util::check_err( - c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + bool res = + test::check_err(c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 464689bf160..5de4acea690 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -111,7 +111,7 @@ bool TestGemm(DeviceGemmPtr_& gemmPtr) gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); // Assert - bool res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); + bool res = test::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index 29a8e102d5c..aaf3cb4763a 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -148,10 +148,10 @@ bool TestConv2DNHWC() 472.5, 490.5, 508.5}; - res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); params.N = 1; params.K = 2; @@ -171,10 +171,10 @@ bool TestConv2DNHWC() 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; - res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); return res; } @@ -203,10 +203,10 @@ bool TestConv1DNWC() ck::tensor_layout::convolution::NWK>(params); std::vector ref_dims{1, 1, 4}; std::vector ref_data{7.5, 13.5, 19.5, 25.5}; - res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); params.num_dim_spatial = 1; params.N = 1; @@ -228,10 +228,10 @@ bool TestConv1DNWC() ck::tensor_layout::convolution::NWK>(params); ref_dims = std::vector{1, 2, 5}; ref_data = std::vector{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; - res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test_util::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && test::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); params.num_dim_spatial = 1; params.N = 2; @@ -319,10 +319,10 @@ bool TestConv1DNWC() 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; - res = res && test_util::check_err(out_tensor2.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test_util::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!"); + res = res && test::check_err(out_tensor2.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && test::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!"); return res; } @@ -360,11 +360,10 @@ bool TestConv3DNCDHW() 634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8, 699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004, 731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.}; - res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error [case 1]: wrong output tensor dimensions!"); - res = res && - test_util::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!"); + res = res && test::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 1]: wrong output tensor dimensions!"); + res = res && test::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!"); params.N = 1; params.K = 2; @@ -402,11 +401,11 @@ bool TestConv3DNCDHW() 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801}; - res = res && test_util::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error [case 2]: wrong output tensor dimensions!"); + res = res && test::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 2]: wrong output tensor dimensions!"); res = - res && test_util::check_err( + res && test::check_err( out_tensor.mData, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f); return res; From 2868f2e0929ddd6db49e7cc7ea3330bc054304c6 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Mar 2022 11:44:01 +0100 Subject: [PATCH 45/82] Script for automatic testing and logging convolution fwd UTs --- script/test_convnd_fwd.sh | 110 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 script/test_convnd_fwd.sh diff --git a/script/test_convnd_fwd.sh b/script/test_convnd_fwd.sh new file mode 100644 index 00000000000..1311f3c9ae1 --- /dev/null +++ b/script/test_convnd_fwd.sh @@ -0,0 +1,110 @@ +#!/usr/bin/env bash + +# set -e + +DIM1=False +DIM2=True +DIM3=False +DATE=220317 +GIT_HASH=4e6dfda +LOG_DIR=${DATE}_${GIT_HASH} +SUFFIX=${GIT_HASH} + + +#-------------------------------------------------------------------------- +# Commandline arguments parsing +# like: cmd -key[--key] value +#-------------------------------------------------------------------------- + +POSITIONAL=() +while [[ $# -gt 0 ]] +do +key="$1" + +case $key in + -d1|--d1) + DIM1=True + echo DIM1: "${DIM1}" + shift # past argument + ;; + -d2|--d2) + DIM2=True + echo DIM2: "${DIM2}" + shift # past argument + ;; + -d3|--d3) + DIM3=True + echo DIM3: "${DIM3}" + shift # past argument + ;; + -all|--all) + DIM1=True + DIM2=True + DIM3=True + echo DIM1: "${DIM1}" + echo DIM2: "${DIM2}" + echo DIM3: "${DIM3}" + shift # past argument + ;; + -s|--suffix) + SUFFIX=${SUFFIX}_"$2" + echo SUFFIX: "${SUFFIX}" + shift # past argument + shift # past value + ;; + *) # unknown option + POSITIONAL+=("$1") # save it in an array for later + shift # past argument + ;; +esac +done +set -- "${POSITIONAL[@]}" # restore positional parameters + +#-------------------------------------------------------------------------- + +NUMACTL=numactl --cpunodebind=1 --membind=1 +NUMACTL= +# ENV_CONF= +GPU=mi100 +PROF_ITER_COUNT=10000 +LOG_DIR_PATH=../log/${LOG_DIR} +set -x + +#------------------------------------------------------------------------------- +# 1D +#------------------------------------------------------------------------------- + +if [[ "${DIM1}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv1d nwc <<<<<<<<<<" + CMD="./../build/bin/test_conv1d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv1d_fwd_nwc_${SUFFIX}_${GPU}.log + +fi + +#------------------------------------------------------------------------------- +# 2D +#------------------------------------------------------------------------------- + +if [[ "${DIM2}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv2d nhwc <<<<<<<<<<" + CMD="./../build/bin/test_conv2d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv2d_fwd_nhwc_${SUFFIX}_${GPU}.log + +fi + +#------------------------------------------------------------------------------- +# 3D +#------------------------------------------------------------------------------- + +if [[ "${DIM3}" == "True" ]]; then + mkdir -p ${LOG_DIR_PATH} + echo ">>>>>>>> RUN test conv3d ndhwc <<<<<<<<<<" + CMD="./../build/bin/test_conv3d_fwd" + ${NUMACTL} ${CMD} 2>&1 \ + | tee ${LOG_DIR_PATH}/test_conv3d_fwd_ndhwc_${SUFFIX}_${GPU}.log + +fi From e6f2d58eb79faf3640d3fd48facf230e231fdcc9 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Mar 2022 14:32:47 +0100 Subject: [PATCH 46/82] Comment out numactl cmd. --- script/test_convnd_fwd.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/script/test_convnd_fwd.sh b/script/test_convnd_fwd.sh index 1311f3c9ae1..1bd7a6b5d71 100644 --- a/script/test_convnd_fwd.sh +++ b/script/test_convnd_fwd.sh @@ -62,7 +62,7 @@ set -- "${POSITIONAL[@]}" # restore positional parameters #-------------------------------------------------------------------------- -NUMACTL=numactl --cpunodebind=1 --membind=1 +# NUMACTL="numactl --cpunodebind=1 --membind=1" NUMACTL= # ENV_CONF= GPU=mi100 From a5d3b9256a4bb0cb159eb1d1e86c1c814071cd07 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Mar 2022 15:42:16 +0100 Subject: [PATCH 47/82] Refine weights initalization and relax rtol for fp16 --- test/include/conv_test_util.hpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp index 406d6f3e840..5056f221632 100644 --- a/test/include/conv_test_util.hpp +++ b/test/include/conv_test_util.hpp @@ -131,8 +131,9 @@ auto GetHostTensors(const ck::conv_util::ConvParams& params, bool init = true) std::uniform_real_distribution<> dis(0.f, 1.f); std::generate( input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + std::generate( + weights.begin(), weights.end(), [&dis, &gen]() { return WeiDataType(dis(gen)); }); } - std::fill(weights.begin(), weights.end(), WeiDataType(1.5f)); std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); } @@ -263,10 +264,10 @@ bool RunConvInstances(const ck::conv_util::ConvParams& params, { float atol{1e-5f}; float rtol{1e-4f}; - if constexpr (std::is_same_v) + if constexpr(std::is_same_v) { atol = 1e-4f; - rtol = 2.5e-3f; + rtol = 2,5e-3f; } invoker->Run(argument.get()); out_device_buf.FromDevice(output.mData.data()); From 4164c0a4e42e7e630544a6e8c5e960cf531fdaa6 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 18 Mar 2022 10:15:41 +0100 Subject: [PATCH 48/82] Move test_util.hpp to check_err.hpp --- example/CMakeLists.txt | 1 + .../include/ck/library/utility/check_err.hpp | 295 +++++++++--------- profiler/CMakeLists.txt | 1 + test/CMakeLists.txt | 1 + test/conv_util/conv_util.cpp | 120 +++---- test/convnd_fwd/conv1d_fwd.cpp | 4 +- test/convnd_fwd/conv2d_fwd.cpp | 4 +- test/convnd_fwd/conv3d_fwd.cpp | 20 +- test/gemm/gemm_bf16.cpp | 4 +- test/gemm/gemm_fp32.cpp | 6 +- test/gemm/gemm_int8.cpp | 4 +- test/include/conv_test_util.hpp | 6 +- .../reference_conv_fwd/reference_conv_fwd.cpp | 59 ++-- 13 files changed, 268 insertions(+), 257 deletions(-) rename test/include/test_util.hpp => library/include/ck/library/utility/check_err.hpp (86%) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 6f9201d8351..afc2f8c19a8 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -13,6 +13,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/library/include/ck/library/host_tensor ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility ${PROJECT_SOURCE_DIR}/external/include/half ) diff --git a/test/include/test_util.hpp b/library/include/ck/library/utility/check_err.hpp similarity index 86% rename from test/include/test_util.hpp rename to library/include/ck/library/utility/check_err.hpp index f9d7027f712..3ff7fcdfbcd 100644 --- a/test/include/test_util.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -1,147 +1,148 @@ -#ifndef TEST_UTIL_HPP -#define TEST_UTIL_HPP - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "data_type.hpp" - -namespace test { - -template -typename std::enable_if::value && !std::is_same::value, - bool>::type -check_err(const std::vector& out, - const std::vector& ref, - const std::string& msg, - double rtol = 1e-5, - double atol = 1e-8) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - bool res{true}; - int err_count = 0; - double err = 0; - double max_err = std::numeric_limits::min(); - for(std::size_t i = 0; i < ref.size(); ++i) - { - err = std::abs(out[i] - ref[i]); - if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) - { - max_err = err > max_err ? err : max_err; - err_count++; - if(err_count < 5) - { - std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" - << i << "]: " << out[i] << " != " << ref[i] << std::endl - << msg << std::endl; - } - res = false; - } - } - if(!res) - { - std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; - } - return res; -} - -template -typename std::enable_if::value || std::is_same::value, - bool>::type -check_err(const std::vector& out, - const std::vector& ref, - const std::string& msg, - double rtol = 1e-5, - double atol = 1e-8) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - bool res{true}; - int err_count = 0; - double err = 0; - double max_err = ck::type_convert(ck::NumericLimits::Min()); - for(std::size_t i = 0; i < ref.size(); ++i) - { - float o = ck::type_convert(out[i]); - float r = ck::type_convert(ref[i]); - err = std::abs(o - r); - if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) - { - max_err = err > max_err ? err : max_err; - err_count++; - if(err_count < 5) - { - std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" - << i << "]: " << o << " != " << r << std::endl - << msg << std::endl; - } - res = false; - } - } - if(!res) - { - std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; - } - return res; -} - -template -typename std::enable_if::value && !std::is_same::value, - bool>::type -check_err(const std::vector& out, - const std::vector& ref, - const std::string& msg, - double = 0, - double = 0) -{ - if(out.size() != ref.size()) - { - std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() - << std::endl - << msg << std::endl; - return false; - } - - for(std::size_t i = 0; i < ref.size(); ++i) - { - if(out[i] != ref[i]) - { - std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] - << std::endl - << msg << std::endl; - return false; - } - } - return true; -} - -} // namespace test - -template -std::ostream& operator<<(std::ostream& os, const std::vector& v) -{ - std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); - return os; -} - -#endif +#ifndef CHECK_ERR_HPP +#define CHECK_ERR_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "data_type.hpp" + +namespace ck { +namespace utils { + +template +typename std::enable_if::value && !std::is_same::value, + bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg, + double rtol = 1e-5, + double atol = 1e-8) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + err = std::abs(out[i] - ref[i]); + if(err > atol + rtol * std::abs(ref[i]) || !std::isfinite(out[i]) || !std::isfinite(ref[i])) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << out[i] << " != " << ref[i] << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value || std::is_same::value, + bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg, + double rtol = 1e-5, + double atol = 1e-8) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = type_convert(NumericLimits::Min()); + for(std::size_t i = 0; i < ref.size(); ++i) + { + float o = type_convert(out[i]); + float r = type_convert(ref[i]); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << std::setw(12) << std::setprecision(7) << "out[" << i << "] != ref[" + << i << "]: " << o << " != " << r << std::endl + << msg << std::endl; + } + res = false; + } + } + if(!res) + { + std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +typename std::enable_if::value && !std::is_same::value, bool>::type +check_err(const std::vector& out, + const std::vector& ref, + const std::string& msg, + double = 0, + double = 0) +{ + if(out.size() != ref.size()) + { + std::cout << "out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl + << msg << std::endl; + return false; + } + + for(std::size_t i = 0; i < ref.size(); ++i) + { + if(out[i] != ref[i]) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] + << std::endl + << msg << std::endl; + return false; + } + } + return true; +} + +} // namespace utils +} // namespace ck + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} + +#endif diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 5e7156a3996..4a213506f12 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -15,6 +15,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility ${PROJECT_SOURCE_DIR}/profiler/include ${PROJECT_SOURCE_DIR}/external/include/half ) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 2d6e28f54c8..e11b4ded579 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -15,6 +15,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/library/include/ck/library/tensor_operation_instance/gpu/reduce ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/cpu ${PROJECT_SOURCE_DIR}/library/include/ck/library/reference_tensor_operation/gpu + ${PROJECT_SOURCE_DIR}/library/include/ck/library/utility ${PROJECT_SOURCE_DIR}/test/include ${PROJECT_SOURCE_DIR}/external/include/half ) diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 1dff3f28a20..8d952511ebd 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -5,7 +5,7 @@ #include "config.hpp" #include "conv_utils.hpp" #include "tensor_layout.hpp" -#include "test_util.hpp" +#include "check_err.hpp" namespace { @@ -20,26 +20,26 @@ bool TestConvParams_GetOutputSpatialLengths() // padding {{1,1}, {1,1}} ck::conv_util::ConvParams conv_params; std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, - std::vector{36, 36}, - "Error: ConvParams 2D default constructor."); + res = ck::utils::check_err(out_spatial_len, + std::vector{36, 36}, + "Error: ConvParams 2D default constructor."); conv_params.conv_filter_strides = std::vector{1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err( + res = ck::utils::check_err( out_spatial_len, std::vector{71, 71}, "Error: ConvParams 2D stride {1,1}."); conv_params.conv_filter_strides = std::vector{2, 2}; conv_params.input_left_pads = std::vector{2, 2}; conv_params.input_right_pads = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, - std::vector{37, 37}, - "Error: ConvParams 2D padding left/right {2,2}."); + res = ck::utils::check_err(out_spatial_len, + std::vector{37, 37}, + "Error: ConvParams 2D padding left/right {2,2}."); conv_params.conv_filter_dilations = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err( + res = ck::utils::check_err( out_spatial_len, std::vector{36, 36}, "Error: ConvParams 2D dilation {2,2}."); conv_params.conv_filter_strides = std::vector{3, 3}; @@ -47,9 +47,10 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1, 1}; conv_params.conv_filter_dilations = std::vector{2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, - std::vector{23, 23}, - "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."); + res = + ck::utils::check_err(out_spatial_len, + std::vector{23, 23}, + "Error: ConvParams 2D strides{3,3}, padding {1,1}, dilations {2,2}."); // -------------------------- 1D ------------------------------------ conv_params.num_dim_spatial = 1; @@ -61,24 +62,25 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, std::vector{36}, "Error: ConvParams 1D."); + res = ck::utils::check_err( + out_spatial_len, std::vector{36}, "Error: ConvParams 1D."); conv_params.conv_filter_strides = std::vector{1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err( + res = ck::utils::check_err( out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}."); conv_params.conv_filter_strides = std::vector{2}; conv_params.input_left_pads = std::vector{2}; conv_params.input_right_pads = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, - std::vector{37}, - "Error: ConvParams 1D padding left/right {2}."); + res = ck::utils::check_err(out_spatial_len, + std::vector{37}, + "Error: ConvParams 1D padding left/right {2}."); conv_params.conv_filter_dilations = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err( + res = ck::utils::check_err( out_spatial_len, std::vector{36}, "Error: ConvParams 1D dilation {2}."); conv_params.conv_filter_strides = std::vector{3}; @@ -86,9 +88,9 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1}; conv_params.conv_filter_dilations = std::vector{2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, - std::vector{23}, - "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."); + res = ck::utils::check_err(out_spatial_len, + std::vector{23}, + "Error: ConvParams 1D strides{3}, padding {1}, dilations {2}."); // -------------------------- 3D ------------------------------------ conv_params.num_dim_spatial = 3; @@ -100,35 +102,35 @@ bool TestConvParams_GetOutputSpatialLengths() conv_params.input_right_pads = std::vector{1, 1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err( + res = ck::utils::check_err( out_spatial_len, std::vector{36, 36, 36}, "Error: ConvParams 3D."); conv_params.conv_filter_strides = std::vector{1, 1, 1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, - std::vector{71, 71, 71}, - "Error: ConvParams 3D stride {1, 1, 1}."); + res = ck::utils::check_err(out_spatial_len, + std::vector{71, 71, 71}, + "Error: ConvParams 3D stride {1, 1, 1}."); conv_params.conv_filter_strides = std::vector{2, 2, 2}; conv_params.input_left_pads = std::vector{2, 2, 2}; conv_params.input_right_pads = std::vector{2, 2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, - std::vector{37, 37, 37}, - "Error: ConvParams 3D padding left/right {2, 2, 2}."); + res = ck::utils::check_err(out_spatial_len, + std::vector{37, 37, 37}, + "Error: ConvParams 3D padding left/right {2, 2, 2}."); conv_params.conv_filter_dilations = std::vector{2, 2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err(out_spatial_len, - std::vector{36, 36, 36}, - "Error: ConvParams 3D dilation {2, 2, 2}."); + res = ck::utils::check_err(out_spatial_len, + std::vector{36, 36, 36}, + "Error: ConvParams 3D dilation {2, 2, 2}."); conv_params.conv_filter_strides = std::vector{3, 3, 3}; conv_params.input_left_pads = std::vector{1, 1, 1}; conv_params.input_right_pads = std::vector{1, 1, 1}; conv_params.conv_filter_dilations = std::vector{2, 2, 2}; out_spatial_len = conv_params.GetOutputSpatialLengths(); - res = test::check_err( + res = ck::utils::check_err( out_spatial_len, std::vector{23, 23, 23}, "Error: ConvParams 3D strides{3, 3, 3}, padding {1, 1, 1}, dilations {2, 2, 2}."); @@ -142,44 +144,48 @@ bool TestGetHostTensorDescriptor() namespace tl = ck::tensor_layout::convolution; std::vector dims{2, 3, 4, 5}; HostTensorDescriptor h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); - res = test::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"); - res = test::check_err( + res = + ck::utils::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"); + res = ck::utils::check_err( h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"); - h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCHW{}); - res = test::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"); - res = test::check_err( + h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCHW{}); + res = + ck::utils::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"); + res = ck::utils::check_err( h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); dims = std::vector{2, 3, 4}; h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); - res = test::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); - res = test::check_err(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); + res = ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); + res = + ck::utils::check_err(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCW{}); - res = test::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); - res = test::check_err(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); + res = ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); + res = + ck::utils::check_err(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); dims = std::vector{2, 3, 4, 5, 6}; h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{}); - res = test::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"); - res = test::check_err(h.GetStrides(), - {3 * 4 * 5 * 6, // N - 1, // C - 3 * 5 * 6, // D - 3 * 6, // H - 3}, // W - "Error: wrong NDHWC dimensions strides!"); + res = ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"); + res = ck::utils::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 1, // C + 3 * 5 * 6, // D + 3 * 6, // H + 3}, // W + "Error: wrong NDHWC dimensions strides!"); h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCDHW{}); - res = test::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"); - res = test::check_err(h.GetStrides(), - {3 * 4 * 5 * 6, // N - 4 * 5 * 6, // C - 5 * 6, // D - 6, // H - 1}, // W - "Error: wrong NCDHW dimensions strides!"); + res = ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"); + res = ck::utils::check_err(h.GetStrides(), + {3 * 4 * 5 * 6, // N + 4 * 5 * 6, // C + 5 * 6, // D + 6, // H + 1}, // W + "Error: wrong NCDHW dimensions strides!"); return res; } diff --git a/test/convnd_fwd/conv1d_fwd.cpp b/test/convnd_fwd/conv1d_fwd.cpp index 7da85cbf4e6..1c4437d78fd 100644 --- a/test/convnd_fwd/conv1d_fwd.cpp +++ b/test/convnd_fwd/conv1d_fwd.cpp @@ -8,7 +8,7 @@ #include "conv_test_util.hpp" #include "host_tensor.hpp" #include "tensor_layout.hpp" -#include "test_util.hpp" +#include "check_err.hpp" // Forward declarations for conv instances. @@ -63,7 +63,7 @@ bool TestConv1DNWC() test::conv::RunReferenceConv<1>(params, input, weights, host_output); test::conv::RunConv<1>(params, input, weights, device_output); res = res && - test::check_err( + ck::utils::check_err( device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); return res; diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp index 624db66b9e1..8855e7f250f 100644 --- a/test/convnd_fwd/conv2d_fwd.cpp +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -9,7 +9,7 @@ #include "conv_test_util.hpp" #include "host_tensor.hpp" #include "tensor_layout.hpp" -#include "test_util.hpp" +#include "check_err.hpp" // Forward declarations for conv instances. using DeviceConvFwdNoOpPtr = @@ -55,7 +55,7 @@ bool TestConv2DNHWC() test::conv::RunReferenceConv<2>(params, input, weights, host_output); test::conv::RunConv<2>(params, input, weights, device_output); res = res && - test::check_err( + ck::utils::check_err( device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); return res; diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index 45438616bc6..632a890d490 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -9,7 +9,7 @@ #include "conv_test_util.hpp" #include "host_tensor.hpp" #include "tensor_layout.hpp" -#include "test_util.hpp" +#include "check_err.hpp" // Forward declarations for conv instances. using DeviceConvFwdNoOpPtr = @@ -63,7 +63,7 @@ bool TestConv3DNDHWC() test::conv::RunReferenceConv<3>(params, input, weights, host_output); test::conv::RunConv<3>(params, input, weights, device_output); res = res && - test::check_err( + ck::utils::check_err( device_output.mData, host_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); return res; @@ -205,14 +205,14 @@ template bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs) { ck::conv_util::ConvParams params; - params.N = 64; - params.num_dim_spatial = 3; - params.filter_spatial_lengths = std::vector{3, 3, 2}; - params.input_spatial_lengths = std::vector{32, 32, 2}; - params.conv_filter_strides = std::vector{2, 2, 2}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; + params.N = 64; + params.num_dim_spatial = 3; + params.filter_spatial_lengths = std::vector{3, 3, 2}; + params.input_spatial_lengths = std::vector{32, 32, 2}; + params.conv_filter_strides = std::vector{2, 2, 2}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = test::conv::GetHostTensors #include +#include "check_err.hpp" #include "gemm_util.hpp" #include "config.hpp" #include "print.hpp" @@ -19,7 +20,6 @@ #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" -#include "test_util.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -136,7 +136,7 @@ bool TestGemm(DeviceGemmPtr_& gemmPtr) bf16_to_f32_(c_device_bf16, c_device_fp32); // Assert - bool res = test::check_err( + bool res = ck::utils::check_err( c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index 8e14a6f31a3..d0ccaac3167 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -19,7 +19,7 @@ #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" -#include "test_util.hpp" +#include "check_err.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -111,8 +111,8 @@ bool TestGemm(DeviceGemmPtr_& gemmPtr) gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); // Assert - bool res = - test::check_err(c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + bool res = ck::utils::check_err( + c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 5de4acea690..c12755b4120 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -19,7 +19,7 @@ #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" -#include "test_util.hpp" +#include "check_err.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -111,7 +111,7 @@ bool TestGemm(DeviceGemmPtr_& gemmPtr) gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); // Assert - bool res = test::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); + bool res = ck::utils::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp index 5056f221632..12fbc041d09 100644 --- a/test/include/conv_test_util.hpp +++ b/test/include/conv_test_util.hpp @@ -10,6 +10,7 @@ #include #include +#include "check_err.hpp" #include "config.hpp" #include "conv_utils.hpp" #include "device.hpp" @@ -19,7 +20,6 @@ #include "host_tensor.hpp" #include "reference_conv_fwd.hpp" #include "tensor_layout.hpp" -#include "test_util.hpp" namespace { @@ -267,12 +267,12 @@ bool RunConvInstances(const ck::conv_util::ConvParams& params, if constexpr(std::is_same_v) { atol = 1e-4f; - rtol = 2,5e-3f; + rtol = 2.5e-3f; } invoker->Run(argument.get()); out_device_buf.FromDevice(output.mData.data()); res = res && - test::check_err( + ck::utils::check_err( output.mData, host_output.mData, "Error: incorrect results!", atol, rtol); hipGetErrorString( hipMemset(out_device_buf.GetDeviceBuffer(), 0, out_device_buf.mMemSize)); diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index aaf3cb4763a..d02c2f175f2 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -6,13 +6,13 @@ #include #include +#include "check_err.hpp" #include "config.hpp" #include "conv_utils.hpp" #include "element_wise_operation.hpp" #include "host_tensor.hpp" #include "reference_conv_fwd.hpp" #include "tensor_layout.hpp" -#include "test_util.hpp" namespace { using InElementOp = ck::tensor_operation::element_wise::PassThrough; @@ -148,10 +148,10 @@ bool TestConv2DNHWC() 472.5, 490.5, 508.5}; - res = res && test::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); params.N = 1; params.K = 2; @@ -171,10 +171,10 @@ bool TestConv2DNHWC() 747., 747., 1138.5, 1138.5, 1174.5, 1174.5, 1210.5, 1210.5, 1246.5, 1246.5, 1035., 1035., 1570.5, 1570.5, 1606.5, 1606.5, 1642.5, 1642.5, 1678.5, 1678.5, 1323., 1323., 2002.5, 2002.5, 2038.5, 2038.5, 2074.5, 2074.5, 2110.5, 2110.5}; - res = res && test::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); return res; } @@ -203,10 +203,10 @@ bool TestConv1DNWC() ck::tensor_layout::convolution::NWK>(params); std::vector ref_dims{1, 1, 4}; std::vector ref_data{7.5, 13.5, 19.5, 25.5}; - res = res && test::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); params.num_dim_spatial = 1; params.N = 1; @@ -228,10 +228,10 @@ bool TestConv1DNWC() ck::tensor_layout::convolution::NWK>(params); ref_dims = std::vector{1, 2, 5}; ref_data = std::vector{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; - res = res && test::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); + res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && ck::utils::check_err(out_tensor.mData, ref_data, "Error: incorrect results!"); params.num_dim_spatial = 1; params.N = 2; @@ -319,10 +319,10 @@ bool TestConv1DNWC() 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 72.9, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4, 49.4}; - res = res && test::check_err(out_tensor2.mDesc.GetLengths(), - ref_dims, - "Error: wrong output tensor dimensions!"); - res = res && test::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!"); + res = res && ck::utils::check_err(out_tensor2.mDesc.GetLengths(), + ref_dims, + "Error: wrong output tensor dimensions!"); + res = res && ck::utils::check_err(out_tensor2.mData, ref_data, "Error: incorrect results!"); return res; } @@ -360,10 +360,11 @@ bool TestConv3DNCDHW() 634.5, 637.2, 639.9, 642.60004, 650.7, 653.4, 656.10004, 658.8, 699.3, 702., 704.7, 707.4, 715.5, 718.2, 720.9, 723.60004, 731.7, 734.4001, 737.10004, 739.8, 747.9001, 750.60004, 753.3, 756.}; - res = res && test::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error [case 1]: wrong output tensor dimensions!"); - res = res && test::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!"); + res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 1]: wrong output tensor dimensions!"); + res = res && + ck::utils::check_err(out_tensor.mData, ref_data, "Error [case 1]: incorrect results!"); params.N = 1; params.K = 2; @@ -401,11 +402,11 @@ bool TestConv3DNCDHW() 5283.9004, 5292., 5300.0996, 5308.2, 5381.0996, 5389.2, 5397.3, 5405.4004, 6255.9004, 6264.0005, 6272.1, 6280.2, 6353.1, 6361.2, 6369.301, 6377.4, 6450.301, 6458.4, 6466.5, 6474.6, 6547.5, 6555.6, 6563.699, 6571.801}; - res = res && test::check_err(out_tensor.mDesc.GetLengths(), - ref_dims, - "Error [case 2]: wrong output tensor dimensions!"); + res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), + ref_dims, + "Error [case 2]: wrong output tensor dimensions!"); res = - res && test::check_err( + res && ck::utils::check_err( out_tensor.mData, ref_data, "Error [case 2]: incorrect results!", 1e-4f, 1e-6f); return res; From 2f5549629ce7c5ec0b0a241febbc3c1874d5119a Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 17 Mar 2022 15:42:16 +0100 Subject: [PATCH 49/82] Refine weights initalization and relax rtol for fp16 --- test/include/conv_test_util.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp index 406d6f3e840..9228158c9fd 100644 --- a/test/include/conv_test_util.hpp +++ b/test/include/conv_test_util.hpp @@ -131,8 +131,9 @@ auto GetHostTensors(const ck::conv_util::ConvParams& params, bool init = true) std::uniform_real_distribution<> dis(0.f, 1.f); std::generate( input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + std::generate( + weights.begin(), weights.end(), [&dis, &gen]() { return WeiDataType(dis(gen)); }); } - std::fill(weights.begin(), weights.end(), WeiDataType(1.5f)); std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); } @@ -263,7 +264,7 @@ bool RunConvInstances(const ck::conv_util::ConvParams& params, { float atol{1e-5f}; float rtol{1e-4f}; - if constexpr (std::is_same_v) + if constexpr(std::is_same_v) { atol = 1e-4f; rtol = 2.5e-3f; From 8098df965c58e3e217dfaa3f254350a4020c7c50 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 18 Mar 2022 11:49:58 +0100 Subject: [PATCH 50/82] Refactor common part of test conv utils. * Move utility function to single common place. --- example/09_convnd_fwd/convnd_fwd_xdl.cpp | 42 +- .../gpu/device/conv_utils.hpp | 220 --------- .../ck/library/utility/conv_fwd_util.hpp | 440 ++++++++++++++++++ test/conv_util/conv_util.cpp | 16 +- test/convnd_fwd/conv1d_fwd.cpp | 39 +- test/convnd_fwd/conv2d_fwd.cpp | 28 +- test/convnd_fwd/conv3d_fwd.cpp | 81 ++-- test/convnd_fwd/conv_util.hpp | 87 ++++ test/include/conv_test_util.hpp | 287 ------------ .../reference_conv_fwd/reference_conv_fwd.cpp | 94 ++-- 10 files changed, 681 insertions(+), 653 deletions(-) delete mode 100644 include/ck/tensor_operation/gpu/device/conv_utils.hpp create mode 100644 library/include/ck/library/utility/conv_fwd_util.hpp create mode 100644 test/convnd_fwd/conv_util.hpp delete mode 100644 test/include/conv_test_util.hpp diff --git a/example/09_convnd_fwd/convnd_fwd_xdl.cpp b/example/09_convnd_fwd/convnd_fwd_xdl.cpp index d26a52b2fdb..188431cd9e1 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl.cpp @@ -3,7 +3,7 @@ #include #include #include "config.hpp" -#include "conv_utils.hpp" +#include "conv_fwd_util.hpp" #include "device.hpp" #include "device_tensor.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" @@ -116,7 +116,7 @@ void PrintUseMsg() << std::endl; } -ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) +ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) { // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) int conv_args = 3 + num_dim_spatial * 6; @@ -127,7 +127,7 @@ ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* a exit(0); } - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; int arg_idx = 5; params.num_dim_spatial = num_dim_spatial; @@ -177,13 +177,13 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector switch(num_dim_spatial) { case 3: { - return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{}); + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWC{}); } case 2: { - return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWC{}); } case 1: { - return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWC{}); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); @@ -242,7 +242,7 @@ int main(int argc, char* argv[]) int nrepeat = 5; int num_dim_spatial = 2; - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; if(argc >= 5) { @@ -334,15 +334,15 @@ int main(int argc, char* argv[]) float ave_time = invoker->Run(argument.get(), nrepeat); - std::size_t flop = ck::conv_util::GetFlops( + std::size_t flop = ck::utils::conv::GetFlops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = - ck::conv_util::GetBtype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::GetBtype( + params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; diff --git a/include/ck/tensor_operation/gpu/device/conv_utils.hpp b/include/ck/tensor_operation/gpu/device/conv_utils.hpp deleted file mode 100644 index 3e4d65311f8..00000000000 --- a/include/ck/tensor_operation/gpu/device/conv_utils.hpp +++ /dev/null @@ -1,220 +0,0 @@ -#ifndef CONV_UTILS_HPP -#define CONV_UTILS_HPP - -#include -#include -#include -#include -#include -#include -#include - -#include "config.hpp" -#include "host_tensor.hpp" -#include "tensor_layout.hpp" - -namespace ck { -namespace conv_util { - -/** - * @brief Calculate number of FLOPs for Convolution - * - * @param[in] N Batch size. - * @param[in] C Number of input channels. - * @param[in] K Number of output channels. - * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. - * @param[in] output_spatial_lengths Convolution output spatial dimensions - * lengths. - * - * @return The number of flops. - */ -std::size_t GetFlops(ck::index_t N, - ck::index_t C, - ck::index_t K, - const std::vector& filter_spatial_lengths, - const std::vector& output_spatial_lengths) -{ - // 2 * N * K * * C * - return static_cast(2) * N * K * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), - static_cast(1), - std::multiplies()) * - C * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), - static_cast(1), - std::multiplies()); -} - -/** - * @brief Calculate number of bytes read/write by convolution algorithm. - * - * @param[in] N Batch size. - * @param[in] C Number of input channels. - * @param[in] K Number of output channels. - * @param[in] input_spatial_lengths Input spatial dimensions lengths. - * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. - * @param[in] output_spatial_lengths Output spatial dimensions lengths - * - * @tparam InDataType Input tensor data type. - * @tparam WeiDataType Weights tensor data type. - * @tparam OutDataType Output tensor data type. - * - * @return The number of used bytes. - */ -template -std::size_t GetBtype(ck::index_t N, - ck::index_t C, - ck::index_t K, - const std::vector& input_spatial_lengths, - const std::vector& filter_spatial_lengths, - const std::vector& output_spatial_lengths) -{ - // sizeof(InDataType) * (N * C * ) + - // sizeof(WeiDataType) * (K * C * ) + - // sizeof(OutDataType) * (N * K * ); - return sizeof(InDataType) * (N * C * - std::accumulate(std::begin(input_spatial_lengths), - std::end(input_spatial_lengths), - static_cast(1), - std::multiplies())) + - sizeof(WeiDataType) * (K * C * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), - static_cast(1), - std::multiplies())) + - sizeof(OutDataType) * (N * K * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), - static_cast(1), - std::multiplies())); -} - -struct ConvParams -{ - ConvParams() - : num_dim_spatial(2), - N(128), - K(256), - C(192), - filter_spatial_lengths(2, 3), - input_spatial_lengths(2, 71), - conv_filter_strides(2, 2), - conv_filter_dilations(2, 1), - input_left_pads(2, 1), - input_right_pads(2, 1) - { - } - - ck::index_t num_dim_spatial; - ck::index_t N; - ck::index_t K; - ck::index_t C; - - std::vector filter_spatial_lengths; - std::vector input_spatial_lengths; - - std::vector conv_filter_strides; - std::vector conv_filter_dilations; - - std::vector input_left_pads; - std::vector input_right_pads; - - std::vector GetOutputSpatialLengths() const - { - std::vector out_spatial_len(num_dim_spatial, 0); - for(ck::index_t i = 0; i < num_dim_spatial; ++i) - { - // XEff = (X - 1) * conv_dilation_w + 1; - // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; - const ck::index_t idx_eff = - (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; - out_spatial_len[i] = - (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / - conv_filter_strides[i] + - 1; - } - return out_spatial_len; - } -}; - -/** - * @brief Gets the host tensor descriptor. - * - * @param[in] dims The tensor dimensions lengths. Always in NCHW format. - * @param[in] layout The tensor data layout. - * - * @tparam TensorLayout Layout type. - * - * @return The host tensor descriptor object. - */ -template -HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dims, - const TensorLayout& layout) -{ - std::size_t C = dims[1]; - // 1D - if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - - return HostTensorDescriptor(dims, std::vector({C * dims[2], dims[2], 1})); - } - else if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - return HostTensorDescriptor(dims, std::vector({C * dims[2], 1, C})); - } - // 2D - else if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - - return HostTensorDescriptor( - dims, std::vector{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1}); - } - else if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - return HostTensorDescriptor( - dims, std::vector{C * dims[2] * dims[3], 1, dims[3] * C, C}); - } - // 3D - else if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - - return HostTensorDescriptor(dims, - std::vector{C * dims[2] * dims[3] * dims[4], - dims[2] * dims[3] * dims[4], - dims[3] * dims[4], - dims[4], - 1}); - } - else if constexpr(std::is_same::value || - std::is_same::value || - std::is_same::value) - { - return HostTensorDescriptor( - dims, - std::vector{ - C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C}); - } - - std::stringstream err_msg; - err_msg << "Unsupported data layout provided: " << layout << "!"; - throw std::runtime_error(err_msg.str()); -} - -} // namespace conv_util -} // namespace ck - -#endif diff --git a/library/include/ck/library/utility/conv_fwd_util.hpp b/library/include/ck/library/utility/conv_fwd_util.hpp new file mode 100644 index 00000000000..bba28638852 --- /dev/null +++ b/library/include/ck/library/utility/conv_fwd_util.hpp @@ -0,0 +1,440 @@ +#ifndef CONV_FWD_UTIL_HPP +#define CONV_FWD_UTIL_HPP + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace ck { +namespace utils { +namespace conv { + +using DeviceConvFwdNoOpPtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +/** + * @brief Calculate number of FLOPs for Convolution + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Convolution output spatial dimensions + * lengths. + * + * @return The number of flops. + */ +std::size_t GetFlops(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // 2 * N * K * * C * + return static_cast(2) * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies()) * + C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies()); +} + +/** + * @brief Calculate number of bytes read/write by convolution algorithm. + * + * @param[in] N Batch size. + * @param[in] C Number of input channels. + * @param[in] K Number of output channels. + * @param[in] input_spatial_lengths Input spatial dimensions lengths. + * @param[in] filter_spatial_lengths Filter spatial dimensions lengths. + * @param[in] output_spatial_lengths Output spatial dimensions lengths + * + * @tparam InDataType Input tensor data type. + * @tparam WeiDataType Weights tensor data type. + * @tparam OutDataType Output tensor data type. + * + * @return The number of used bytes. + */ +template +std::size_t GetBtype(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) +{ + // sizeof(InDataType) * (N * C * ) + + // sizeof(WeiDataType) * (K * C * ) + + // sizeof(OutDataType) * (N * K * ); + return sizeof(InDataType) * (N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + static_cast(1), + std::multiplies())) + + sizeof(WeiDataType) * (K * C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies())) + + sizeof(OutDataType) * (N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies())); +} + +struct ConvParams +{ + ConvParams() + : num_dim_spatial(2), + N(128), + K(256), + C(192), + filter_spatial_lengths(2, 3), + input_spatial_lengths(2, 71), + conv_filter_strides(2, 2), + conv_filter_dilations(2, 1), + input_left_pads(2, 1), + input_right_pads(2, 1) + { + } + + ck::index_t num_dim_spatial; + ck::index_t N; + ck::index_t K; + ck::index_t C; + + std::vector filter_spatial_lengths; + std::vector input_spatial_lengths; + + std::vector conv_filter_strides; + std::vector conv_filter_dilations; + + std::vector input_left_pads; + std::vector input_right_pads; + + std::vector GetOutputSpatialLengths() const + { + std::vector out_spatial_len(num_dim_spatial, 0); + for(ck::index_t i = 0; i < num_dim_spatial; ++i) + { + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::index_t idx_eff = + (filter_spatial_lengths[i] - 1) * conv_filter_dilations[i] + 1; + out_spatial_len[i] = + (input_spatial_lengths[i] + input_left_pads[i] + input_right_pads[i] - idx_eff) / + conv_filter_strides[i] + + 1; + } + return out_spatial_len; + } +}; + +/** + * @brief Gets the host tensor descriptor. + * + * @param[in] dims The tensor dimensions lengths. Always in NCHW format. + * @param[in] layout The tensor data layout. + * + * @tparam TensorLayout Layout type. + * + * @return The host tensor descriptor object. + */ +template +HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dims, + const TensorLayout& layout) +{ + std::size_t C = dims[1]; + // 1D + if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor(dims, std::vector({C * dims[2], dims[2], 1})); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor(dims, std::vector({C * dims[2], 1, C})); + } + // 2D + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1}); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor( + dims, std::vector{C * dims[2] * dims[3], 1, dims[3] * C, C}); + } + // 3D + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + + return HostTensorDescriptor(dims, + std::vector{C * dims[2] * dims[3] * dims[4], + dims[2] * dims[3] * dims[4], + dims[3] * dims[4], + dims[4], + 1}); + } + else if constexpr(std::is_same::value || + std::is_same::value || + std::is_same::value) + { + return HostTensorDescriptor( + dims, + std::vector{ + C * dims[2] * dims[3] * dims[4], 1, C * dims[3] * dims[4], C * dims[4], C}); + } + + std::stringstream err_msg; + err_msg << "Unsupported data layout provided: " << layout << "!"; + throw std::runtime_error(err_msg.str()); +} + +template +auto GetHostTensors(const ConvParams& params, bool init = true) +{ + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(ck::utils::conv::GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(ck::utils::conv::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor host_output( + ck::utils::conv::GetHostTensorDescriptor(output_dims, OutLayout{})); + Tensor device_output( + ck::utils::conv::GetHostTensorDescriptor(output_dims, OutLayout{})); + + if(init) + { + std::mt19937 gen(11939); + if constexpr(std::is_same::value) + { + std::uniform_int_distribution<> dis(-5, 5); + std::generate( + input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + } + else + { + std::uniform_real_distribution<> dis(0.f, 1.f); + std::generate( + input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + std::generate( + weights.begin(), weights.end(), [&dis, &gen]() { return WeiDataType(dis(gen)); }); + } + std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); + std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); + } + + return std::make_tuple(input, weights, host_output, device_output); +} + +template +void RunReferenceConvFwd(const ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); +} + +template + class DeviceConvNDFwdInstance> +void RunConvFwd(const ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + + auto conv = DeviceConvNDFwdInstance(); + auto invoker = conv.MakeInvoker(); + auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "Error! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + invoker.Run(argument); + out_device_buf.FromDevice(output.mData.data()); +} + +template +bool RunConvInstances(const ConvParams& params, + const std::vector& conv_ptrs, + const Tensor& input, + const Tensor& weights, + Tensor& output, + const Tensor& host_output) +{ + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + + bool res{true}; + for(auto& conv_ptr : conv_ptrs) + { + auto invoker = conv_ptr->MakeInvokerPointer(); + auto argument = conv_ptr->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(conv_ptr->IsSupportedArgument(argument.get())) + { + float atol{1e-5f}; + float rtol{1e-4f}; + if constexpr(std::is_same_v) + { + atol = 1e-4f; + rtol = 2.5e-3f; + } + invoker->Run(argument.get()); + out_device_buf.FromDevice(output.mData.data()); + res = res && + ck::utils::check_err( + output.mData, host_output.mData, "Error: incorrect results!", atol, rtol); + hipGetErrorString( + hipMemset(out_device_buf.GetDeviceBuffer(), 0, out_device_buf.mMemSize)); + } + } + return res; +} + +} // namespace conv +} // namespace utils +} // namespace ck + +#endif diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 8d952511ebd..349477b493a 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -3,7 +3,7 @@ #include #include "config.hpp" -#include "conv_utils.hpp" +#include "conv_fwd_util.hpp" #include "tensor_layout.hpp" #include "check_err.hpp" @@ -18,7 +18,7 @@ bool TestConvParams_GetOutputSpatialLengths() // stride {2,2}, // dilations {1,1}, // padding {{1,1}, {1,1}} - ck::conv_util::ConvParams conv_params; + ck::utils::conv::ConvParams conv_params; std::vector out_spatial_len = conv_params.GetOutputSpatialLengths(); res = ck::utils::check_err(out_spatial_len, std::vector{36, 36}, @@ -143,31 +143,31 @@ bool TestGetHostTensorDescriptor() bool res{true}; namespace tl = ck::tensor_layout::convolution; std::vector dims{2, 3, 4, 5}; - HostTensorDescriptor h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); + HostTensorDescriptor h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWC{}); res = ck::utils::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"); res = ck::utils::check_err( h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"); - h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCHW{}); + h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NCHW{}); res = ck::utils::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"); res = ck::utils::check_err( h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); dims = std::vector{2, 3, 4}; - h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); + h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWC{}); res = ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); res = ck::utils::check_err(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); - h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCW{}); + h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NCW{}); res = ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); res = ck::utils::check_err(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); dims = std::vector{2, 3, 4, 5, 6}; - h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{}); + h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWC{}); res = ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"); res = ck::utils::check_err(h.GetStrides(), {3 * 4 * 5 * 6, // N @@ -177,7 +177,7 @@ bool TestGetHostTensorDescriptor() 3}, // W "Error: wrong NDHWC dimensions strides!"); - h = ck::conv_util::GetHostTensorDescriptor(dims, tl::NCDHW{}); + h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NCDHW{}); res = ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"); res = ck::utils::check_err(h.GetStrides(), {3 * 4 * 5 * 6, // N diff --git a/test/convnd_fwd/conv1d_fwd.cpp b/test/convnd_fwd/conv1d_fwd.cpp index 1c4437d78fd..ef49871e35f 100644 --- a/test/convnd_fwd/conv1d_fwd.cpp +++ b/test/convnd_fwd/conv1d_fwd.cpp @@ -5,7 +5,8 @@ #include "data_type.hpp" #include "element_wise_operation.hpp" -#include "conv_test_util.hpp" +#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "host_tensor.hpp" #include "tensor_layout.hpp" #include "check_err.hpp" @@ -37,7 +38,7 @@ namespace { bool TestConv1DNWC() { bool res{true}; - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 1; params.N = 2; params.K = 16; @@ -49,18 +50,19 @@ bool TestConv1DNWC() params.input_left_pads = std::vector{1}; params.input_right_pads = std::vector{1}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = + ck::utils::conv::GetHostTensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - test::conv::RunReferenceConv<1>(params, input, weights, host_output); + ck::utils::conv::RunReferenceConvFwd<1>(params, input, weights, host_output); test::conv::RunConv<1>(params, input, weights, device_output); res = res && ck::utils::check_err( @@ -72,7 +74,7 @@ bool TestConv1DNWC() template bool TestConv1DNWCInstances(const std::vector& conv_ptrs) { - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 1; params.filter_spatial_lengths = std::vector{3}; params.input_spatial_lengths = std::vector{71}; @@ -81,19 +83,20 @@ bool TestConv1DNWCInstances(const std::vector& conv_ptrs) params.input_left_pads = std::vector{1}; params.input_right_pads = std::vector{1}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = + ck::utils::conv::GetHostTensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - test::conv::RunReferenceConv<1>(params, input, weights, host_output); - return test::conv::RunConvInstances<1>( + ck::utils::conv::RunReferenceConvFwd<1>(params, input, weights, host_output); + return ck::utils::conv::RunConvInstances<1>( params, conv_ptrs, input, weights, device_output, host_output); } bool TestConv1DNWCBF16Instances() diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp index 8855e7f250f..b76bec37261 100644 --- a/test/convnd_fwd/conv2d_fwd.cpp +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -6,7 +6,8 @@ #include "data_type.hpp" #include "element_wise_operation.hpp" -#include "conv_test_util.hpp" +#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "host_tensor.hpp" #include "tensor_layout.hpp" #include "check_err.hpp" @@ -39,20 +40,20 @@ namespace { bool TestConv2DNHWC() { bool res{true}; - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.N = 2; params.K = 16; params.C = 4; params.input_spatial_lengths = std::vector{16, 16}; params.conv_filter_strides = std::vector{1, 1}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = ck::utils::conv::GetHostTensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - test::conv::RunReferenceConv<2>(params, input, weights, host_output); + ck::utils::conv::RunReferenceConvFwd<2>(params, input, weights, host_output); test::conv::RunConv<2>(params, input, weights, device_output); res = res && ck::utils::check_err( @@ -64,7 +65,7 @@ bool TestConv2DNHWC() template bool TestConv2DNHWCInstances(const std::vector& conv_ptrs) { - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 2; params.filter_spatial_lengths = std::vector{3, 3}; params.input_spatial_lengths = std::vector{71, 71}; @@ -73,19 +74,20 @@ bool TestConv2DNHWCInstances(const std::vector& conv_ptrs) params.input_left_pads = std::vector{1, 1}; params.input_right_pads = std::vector{1, 1}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = + ck::utils::conv::GetHostTensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - test::conv::RunReferenceConv<2>(params, input, weights, host_output); - return test::conv::RunConvInstances<2>( + ck::utils::conv::RunReferenceConvFwd<2>(params, input, weights, host_output); + return ck::utils::conv::RunConvInstances<2>( params, conv_ptrs, input, weights, device_output, host_output); } diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index 632a890d490..9532a1f9990 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -6,7 +6,8 @@ #include "data_type.hpp" #include "element_wise_operation.hpp" -#include "conv_test_util.hpp" +#include "conv_fwd_util.hpp" +#include "conv_util.hpp" #include "host_tensor.hpp" #include "tensor_layout.hpp" #include "check_err.hpp" @@ -37,7 +38,7 @@ namespace { bool TestConv3DNDHWC() { bool res{true}; - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 3; params.N = 2; params.K = 16; @@ -49,18 +50,19 @@ bool TestConv3DNDHWC() params.input_left_pads = std::vector{1, 1, 1}; params.input_right_pads = std::vector{1, 1, 1}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = + ck::utils::conv::GetHostTensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - test::conv::RunReferenceConv<3>(params, input, weights, host_output); + ck::utils::conv::RunReferenceConvFwd<3>(params, input, weights, host_output); test::conv::RunConv<3>(params, input, weights, device_output); res = res && ck::utils::check_err( @@ -72,7 +74,7 @@ bool TestConv3DNDHWC() bool TestConv3DNDHWC2GBInput() { // >2GB Input - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 3; params.N = 2; params.K = 16; @@ -85,12 +87,12 @@ bool TestConv3DNDHWC2GBInput() params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = - test::conv::GetHostTensors(params, false); + ck::utils::conv::GetHostTensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -116,7 +118,7 @@ bool TestConv3DNDHWC2GBInput() bool TestConv3DNDHWC2GBFilters() { // >2GB Filters - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 3; params.N = 2; params.K = 16; @@ -129,12 +131,12 @@ bool TestConv3DNDHWC2GBFilters() params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = - test::conv::GetHostTensors(params, false); + ck::utils::conv::GetHostTensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -160,7 +162,7 @@ bool TestConv3DNDHWC2GBFilters() bool TestConv3DNDHWC2GBOutput() { // >2GB Output - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 3; params.N = 2; params.K = 16; @@ -173,12 +175,12 @@ bool TestConv3DNDHWC2GBOutput() params.input_right_pads = std::vector{2, 2, 2}; auto host_tensors = - test::conv::GetHostTensors(params, false); + ck::utils::conv::GetHostTensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -204,7 +206,7 @@ bool TestConv3DNDHWC2GBOutput() template bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs) { - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.N = 64; params.num_dim_spatial = 3; params.filter_spatial_lengths = std::vector{3, 3, 2}; @@ -214,19 +216,20 @@ bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs params.input_left_pads = std::vector{1, 1, 1}; params.input_right_pads = std::vector{1, 1, 1}; - auto host_tensors = test::conv::GetHostTensors(params); + auto host_tensors = + ck::utils::conv::GetHostTensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - test::conv::RunReferenceConv<3>(params, input, weights, host_output); - return test::conv::RunConvInstances<3>( + ck::utils::conv::RunReferenceConvFwd<3>(params, input, weights, host_output); + return ck::utils::conv::RunConvInstances<3>( params, conv_ptrs, input, weights, device_output, host_output); } diff --git a/test/convnd_fwd/conv_util.hpp b/test/convnd_fwd/conv_util.hpp new file mode 100644 index 00000000000..299ce4afaa1 --- /dev/null +++ b/test/convnd_fwd/conv_util.hpp @@ -0,0 +1,87 @@ +#ifndef TEST_CONV_UTIL_HPP +#define TEST_CONV_UTIL_HPP + +#include + +#include "config.hpp" +#include "conv_fwd_util.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "sequence.hpp" + +namespace { + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + InDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + SpatialDims, // SptialDims + 64, // BlockSize + 16, // MPerBlock + 16, // NPerBlock + 4, // K0PerBlock + 1, // K1 + 16, // MPerXDL + 16, // NPerXDL + 1, // MXdlPerWave + 1, // NXdlPerWave + S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector + 1, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 1, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockTransferAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector +// clang-format on + +} // namespace + +namespace test { +namespace conv { + +template +void RunConv(const ck::utils::conv::ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) +{ + ck::utils::conv:: + RunConvFwd( + params, input, weights, output); +} + +} // namespace conv +} // namespace test + +#endif diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp deleted file mode 100644 index 12fbc041d09..00000000000 --- a/test/include/conv_test_util.hpp +++ /dev/null @@ -1,287 +0,0 @@ -#ifndef TEST_CONV_UTIL_HPP -#define TEST_CONV_UTIL_HPP - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "check_err.hpp" -#include "config.hpp" -#include "conv_utils.hpp" -#include "device.hpp" -#include "device_tensor.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "host_tensor.hpp" -#include "reference_conv_fwd.hpp" -#include "tensor_layout.hpp" - -namespace { - -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; - -template -using DeviceConvNDFwdInstance = ck::tensor_operation::device:: - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< - // clang-format off - InDataType, // - WeiDataType, // - OutDataType, // - InDataType, // - InElementOp, // Input Elementwise Operation - WeiElementOp, // Weights Elementwise Operation - OutElementOp, // Output Elementwise Operation - ConvFwdDefault, // ConvForwardSpecialization - SpatialDims, // SptialDims - 64, // BlockSize - 16, // MPerBlock - 16, // NPerBlock - 4, // K0PerBlock - 1, // K1 - 16, // MPerXDL - 16, // NPerXDL - 1, // MXdlPerWave - 1, // NXdlPerWave - S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 1, // ABlockTransferSrcScalarPerVector - 1, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 1, // BBlockTransferSrcScalarPerVector - 1, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockTransferAddExtraN - 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector -// clang-format on - -} // namespace - -namespace test { -namespace conv { - -using DeviceConvFwdNoOpPtr = - ck::tensor_operation::device::DeviceConvFwdPtr; - -template -auto GetHostTensors(const ck::conv_util::ConvParams& params, bool init = true) -{ - std::vector input_dims{static_cast(params.N), - static_cast(params.C)}; - input_dims.insert(std::end(input_dims), - std::begin(params.input_spatial_lengths), - std::end(params.input_spatial_lengths)); - - std::vector filter_dims{static_cast(params.K), - static_cast(params.C)}; - filter_dims.insert(std::end(filter_dims), - std::begin(params.filter_spatial_lengths), - std::end(params.filter_spatial_lengths)); - - const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - std::vector output_dims{static_cast(params.N), - static_cast(params.K)}; - output_dims.insert(std::end(output_dims), - std::begin(output_spatial_lengths), - std::end(output_spatial_lengths)); - - Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); - Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); - Tensor host_output( - ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - Tensor device_output( - ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); - - if(init) - { - std::mt19937 gen(11939); - if constexpr(std::is_same::value) - { - std::uniform_int_distribution<> dis(-5, 5); - std::generate( - input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); - } - else - { - std::uniform_real_distribution<> dis(0.f, 1.f); - std::generate( - input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); - std::generate( - weights.begin(), weights.end(), [&dis, &gen]() { return WeiDataType(dis(gen)); }); - } - std::fill(host_output.begin(), host_output.end(), OutDataType(0.f)); - std::fill(device_output.begin(), device_output.end(), OutDataType(0.f)); - } - - return std::make_tuple(input, weights, host_output, device_output); -} - -template -void RunReferenceConv(const ck::conv_util::ConvParams& params, - const Tensor& input, - const Tensor& weights, - Tensor& output) -{ - auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); - auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(input, - weights, - output, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); -} - -template -void RunConv(const ck::conv_util::ConvParams& params, - const Tensor& input, - const Tensor& weights, - Tensor& output) -{ - DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(input.mData.data()); - wei_device_buf.ToDevice(weights.mData.data()); - const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - - auto conv = DeviceConvNDFwdInstance(); - auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(!conv.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "Error! device_conv with the specified compilation parameters does " - "not support this Conv problem"); - } - - invoker.Run(argument); - out_device_buf.FromDevice(output.mData.data()); -} - -template -bool RunConvInstances(const ck::conv_util::ConvParams& params, - const std::vector& conv_ptrs, - const Tensor& input, - const Tensor& weights, - Tensor& output, - const Tensor& host_output) -{ - DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(input.mData.data()); - wei_device_buf.ToDevice(weights.mData.data()); - const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); - - bool res{true}; - for(auto& conv_ptr : conv_ptrs) - { - auto invoker = conv_ptr->MakeInvokerPointer(); - auto argument = conv_ptr->MakeArgumentPointer( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - params.N, - params.K, - params.C, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths, - params.conv_filter_strides, - params.conv_filter_dilations, - params.input_left_pads, - params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(conv_ptr->IsSupportedArgument(argument.get())) - { - float atol{1e-5f}; - float rtol{1e-4f}; - if constexpr(std::is_same_v) - { - atol = 1e-4f; - rtol = 2.5e-3f; - } - invoker->Run(argument.get()); - out_device_buf.FromDevice(output.mData.data()); - res = res && - ck::utils::check_err( - output.mData, host_output.mData, "Error: incorrect results!", atol, rtol); - hipGetErrorString( - hipMemset(out_device_buf.GetDeviceBuffer(), 0, out_device_buf.mMemSize)); - } - } - return res; -} - -} // namespace conv -} // namespace test - -#endif diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index d02c2f175f2..2810ff9eba8 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -8,7 +8,7 @@ #include "check_err.hpp" #include "config.hpp" -#include "conv_utils.hpp" +#include "conv_fwd_util.hpp" #include "element_wise_operation.hpp" #include "host_tensor.hpp" #include "reference_conv_fwd.hpp" @@ -57,9 +57,9 @@ template , typename FillWeightsOp = FillConstant> -Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, - const FillInputOp& fill_input_op = FillInputOp{}, - const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) +Tensor RunReferenceConvFwd(const ck::utils::conv::ConvParams& params, + const FillInputOp& fill_input_op = FillInputOp{}, + const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) { std::vector input_dims{static_cast(params.N), static_cast(params.C)}; @@ -80,10 +80,10 @@ Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(ck::conv_util::GetHostTensorDescriptor(input_dims, InLayout{})); - Tensor weights(ck::conv_util::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor input(ck::utils::conv::GetHostTensorDescriptor(input_dims, InLayout{})); + Tensor weights(ck::utils::conv::GetHostTensorDescriptor(filter_dims, WeiLayout{})); Tensor host_output( - ck::conv_util::GetHostTensorDescriptor(output_dims, OutLayout{})); + ck::utils::conv::GetHostTensorDescriptor(output_dims, OutLayout{})); fill_input_op(input.begin(), input.end()); fill_weights_op(weights.begin(), weights.end()); @@ -119,7 +119,7 @@ Tensor RunReferenceConv(const ck::conv_util::ConvParams& params, bool TestConv2DNHWC() { bool res{true}; - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.N = 1; params.K = 1; params.C = 2; @@ -130,7 +130,7 @@ bool TestConv2DNHWC() params.input_left_pads = std::vector{0, 0}; params.input_right_pads = std::vector{0, 0}; - auto out_tensor = RunReferenceConv<2>(params); + auto out_tensor = RunReferenceConvFwd<2>(params); std::vector ref_dims{1, 1, 4, 4}; std::vector ref_data{130.5, 148.5, @@ -163,7 +163,7 @@ bool TestConv2DNHWC() params.input_left_pads = std::vector{1, 1}; params.input_right_pads = std::vector{1, 1}; - out_tensor = RunReferenceConv<2>(params); + out_tensor = RunReferenceConvFwd<2>(params); ref_dims = std::vector{1, 2, 5, 5}; ref_data = std::vector{ 210., 210., 327., 327., 351., 351., 375., 375., 399., 399., @@ -182,7 +182,7 @@ bool TestConv2DNHWC() bool TestConv1DNWC() { bool res{true}; - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 1; params.N = 1; params.K = 1; @@ -194,13 +194,13 @@ bool TestConv1DNWC() params.input_left_pads = std::vector{0}; params.input_right_pads = std::vector{0}; - auto out_tensor = RunReferenceConv<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>(params); + auto out_tensor = RunReferenceConvFwd<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); std::vector ref_dims{1, 1, 4}; std::vector ref_data{7.5, 13.5, 19.5, 25.5}; res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), @@ -219,13 +219,13 @@ bool TestConv1DNWC() params.input_left_pads = std::vector{1}; params.input_right_pads = std::vector{1}; - out_tensor = RunReferenceConv<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>(params); + out_tensor = RunReferenceConvFwd<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); ref_dims = std::vector{1, 2, 5}; ref_data = std::vector{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), @@ -244,13 +244,13 @@ bool TestConv1DNWC() params.input_left_pads = std::vector{1}; params.input_right_pads = std::vector{1}; - auto out_tensor2 = RunReferenceConv<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>( + auto out_tensor2 = RunReferenceConvFwd<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( params, FillMonotonicSeq{0.f, 0.1f}); ref_dims = std::vector{2, 16, 16}; @@ -330,7 +330,7 @@ bool TestConv1DNWC() bool TestConv3DNCDHW() { bool res{true}; - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = 3; params.N = 1; params.K = 1; @@ -342,13 +342,13 @@ bool TestConv3DNCDHW() params.input_left_pads = std::vector{0, 0, 0}; params.input_right_pads = std::vector{0, 0, 0}; - auto out_tensor = RunReferenceConv<3, - float, - float, - float, - ck::tensor_layout::convolution::NCDHW, - ck::tensor_layout::convolution::KCZYX, - ck::tensor_layout::convolution::NKDHW>( + auto out_tensor = RunReferenceConvFwd<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( params, FillMonotonicSeq{0.f, 0.1f}); std::vector ref_dims{1, 1, 4, 4, 4}; std::vector ref_data{ @@ -376,13 +376,13 @@ bool TestConv3DNCDHW() params.input_left_pads = std::vector{0, 0, 0}; params.input_right_pads = std::vector{0, 0, 0}; - out_tensor = RunReferenceConv<3, - float, - float, - float, - ck::tensor_layout::convolution::NCDHW, - ck::tensor_layout::convolution::KCZYX, - ck::tensor_layout::convolution::NKDHW>( + out_tensor = RunReferenceConvFwd<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( params, FillMonotonicSeq{0.f, 0.1f}); ref_dims = std::vector{1, 2, 4, 4, 4}; ref_data = std::vector{ From 0bdc51e24cef649aa722d738c19f4f908743eb3d Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Mar 2022 09:58:31 +0100 Subject: [PATCH 51/82] Add additional common functions to utility. --- .../ck/library/utility/conv_fwd_util.hpp | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/library/include/ck/library/utility/conv_fwd_util.hpp b/library/include/ck/library/utility/conv_fwd_util.hpp index bba28638852..4c882e72942 100644 --- a/library/include/ck/library/utility/conv_fwd_util.hpp +++ b/library/include/ck/library/utility/conv_fwd_util.hpp @@ -291,6 +291,72 @@ auto GetHostTensors(const ConvParams& params, bool init = true) return std::make_tuple(input, weights, host_output, device_output); } +HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWK{}); + } + case 2: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWK{}); + } + case 1: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWK{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KZYXC{}); + } + case 2: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KYXC{}); + } + case 1: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KXC{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector& dims, + int num_dim_spatial = 2) +{ + namespace tl = ck::tensor_layout::convolution; + + switch(num_dim_spatial) + { + case 3: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWC{}); + } + case 2: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWC{}); + } + case 1: { + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWC{}); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + template Date: Mon, 21 Mar 2022 09:59:24 +0100 Subject: [PATCH 52/82] Refactor convnd_fwd_xdl examples. * Remove redundant files. * Unify structure. --- example/05_conv2d_fwd/CMakeLists.txt | 2 - example/05_conv2d_fwd/README.md | 57 --- example/05_conv2d_fwd/conv2d_fwd_xdl_fp16.cpp | 274 -------------- example/05_conv2d_fwd/conv2d_fwd_xdl_int8.cpp | 275 -------------- example/08_conv3d_fwd/CMakeLists.txt | 1 - example/08_conv3d_fwd/README.md | 57 --- example/08_conv3d_fwd/conv3d_fwd_xdl.cpp | 281 -------------- example/09_convnd_fwd/CMakeLists.txt | 2 + example/09_convnd_fwd/convnd_fwd_xdl.cpp | 93 +---- example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp | 341 +++++++++++++++++ example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp | 343 ++++++++++++++++++ example/CMakeLists.txt | 2 - 12 files changed, 704 insertions(+), 1024 deletions(-) delete mode 100644 example/05_conv2d_fwd/CMakeLists.txt delete mode 100644 example/05_conv2d_fwd/README.md delete mode 100644 example/05_conv2d_fwd/conv2d_fwd_xdl_fp16.cpp delete mode 100644 example/05_conv2d_fwd/conv2d_fwd_xdl_int8.cpp delete mode 100644 example/08_conv3d_fwd/CMakeLists.txt delete mode 100644 example/08_conv3d_fwd/README.md delete mode 100644 example/08_conv3d_fwd/conv3d_fwd_xdl.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp create mode 100644 example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp diff --git a/example/05_conv2d_fwd/CMakeLists.txt b/example/05_conv2d_fwd/CMakeLists.txt deleted file mode 100644 index 5f0e118fd6e..00000000000 --- a/example/05_conv2d_fwd/CMakeLists.txt +++ /dev/null @@ -1,2 +0,0 @@ -add_example_executable(example_conv2d_fwd_xdl_fp16 conv2d_fwd_xdl_fp16.cpp) -add_example_executable(example_conv2d_fwd_xdl_int8 conv2d_fwd_xdl_int8.cpp) diff --git a/example/05_conv2d_fwd/README.md b/example/05_conv2d_fwd/README.md deleted file mode 100644 index 4114571afe4..00000000000 --- a/example/05_conv2d_fwd/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# Instructions for ```conv2d_fwd_xdl``` Example - -## Docker script -```bash -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/tensorflow:rocm4.3.1-tf2.6-dev \ -/bin/bash -``` - -## Build ```conv2d_fwd_xdl``` -```bash -mkdir build && cd build -``` - -```bash -# Need to specify target ID, example below is gfx908 -cmake \ --D BUILD_DEV=OFF \ --D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_PREFIX_PATH=/opt/rocm \ -.. -``` - -```bash - make -j conv2d_fwd_xdl -``` - -## Run ```conv2d_fwd_xdl``` -```bash -#arg1: verification (0=no, 1=yes) -#arg2: initialization (0=no init, 1=integer value, 2=decimal value) -#arg3: run kernel # of times (>1) -#arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx -./example/conv2d_fwd_xdl 0 1 5 -``` - -Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) -``` -in_n_c_hi_wi: dim 4, lengths {128, 192, 71, 71}, strides {967872, 1, 13632, 192} -wei_k_c_y_x: dim 4, lengths {256, 192, 3, 3}, strides {1728, 1, 576, 192} -out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256} -arg.a_grid_desc_k0_m_k1_{216, 165888, 8} -arg.b_grid_desc_k0_n_k1_{216, 256, 8} -arg.c_grid_desc_m_n_{ 165888, 256} -launch_and_time_kernel: grid_dim {1296, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 1.43206 ms, 102.486 TFlops, 232.947 GB/s -``` diff --git a/example/05_conv2d_fwd/conv2d_fwd_xdl_fp16.cpp b/example/05_conv2d_fwd/conv2d_fwd_xdl_fp16.cpp deleted file mode 100644 index 4f255fda9d5..00000000000 --- a/example/05_conv2d_fwd/conv2d_fwd_xdl_fp16.cpp +++ /dev/null @@ -1,274 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "element_wise_operation.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp" -#include "reference_conv_fwd.hpp" -#include "convolution_utility.hpp" - -using InDataType = ck::half_t; -using WeiDataType = ck::half_t; -using OutDataType = ck::half_t; -using AccDataType = float; - -template -using S = ck::Sequence; - -using InLayout = ck::tensor_layout::convolution::NHWC; -using WeiLayout = ck::tensor_layout::convolution::KYXC; -using OutLayout = ck::tensor_layout::convolution::NHWK; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; - -using DeviceConvFwdInstance = ck::tensor_operation::device:: - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< - InDataType, // InDataType - WeiDataType, // WeiDataType - OutDataType, // OutDataType - AccDataType, // AccDataType - InElementOp, // InElementwiseOperation - WeiElementOp, // WeiElementwiseOperation - OutElementOp, // OutElementwiseOperation - ConvFwdDefault, // ConvForwardSpecialization - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 4, // K0PerBlock - 8, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector - -using ReferenceConvFwdInstance = ck::tensor_operation::host:: - ReferenceConvFwd; - -int main(int argc, char* argv[]) -{ - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; - - // Conv shape - ck::index_t N = 128; - ck::index_t K = 256; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t conv_stride_h = 2; - ck::index_t conv_stride_w = 2; - ck::index_t conv_dilation_h = 1; - ck::index_t conv_dilation_w = 1; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - } - else if(argc == 19) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - - N = std::stoi(argv[4]); - K = std::stoi(argv[5]); - C = std::stoi(argv[6]); - Y = std::stoi(argv[7]); - X = std::stoi(argv[8]); - Hi = std::stoi(argv[9]); - Wi = std::stoi(argv[10]); - conv_stride_h = std::stoi(argv[11]); - conv_stride_w = std::stoi(argv[12]); - conv_dilation_h = std::stoi(argv[13]); - conv_dilation_w = std::stoi(argv[14]); - in_left_pad_h = std::stoi(argv[15]); - in_left_pad_w = std::stoi(argv[16]); - in_right_pad_h = std::stoi(argv[17]); - in_right_pad_w = std::stoi(argv[18]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(0); - } - - const std::vector conv_filter_strides{conv_stride_h, conv_stride_w}; - const std::vector conv_filter_dilations{conv_dilation_h, conv_dilation_w}; - const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; - const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; - const auto output_spatial_lengths = - ck::tensor_operation::ConvolutionUtility::ComputeOutputSpatialLengths({Hi, Wi}, - {Y, X}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - const ck::index_t Ho = output_spatial_lengths[0]; - const ck::index_t Wo = output_spatial_lengths[1]; - - // tensor layout - auto f_host_tensor_descriptor = [](std::size_t N_, - std::size_t C_, - std::size_t H, - std::size_t W, - auto layout) { - if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, H * W, W, 1})); - } - else if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - } - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); - Tensor out_n_k_ho_wo_host_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - Tensor out_n_k_ho_wo_device_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; - std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; - std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * - out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - - // do GEMM - auto conv = DeviceConvFwdInstance{}; - auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(!conv.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem"); - } - - float ave_time = invoker.Run(argument, nrepeat); - - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; - - std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + - sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - if(do_verification) - { - auto ref_conv = ReferenceConvFwdInstance{}; - auto ref_invoker = ref_conv.MakeInvoker(); - - auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); - - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); - } -} diff --git a/example/05_conv2d_fwd/conv2d_fwd_xdl_int8.cpp b/example/05_conv2d_fwd/conv2d_fwd_xdl_int8.cpp deleted file mode 100644 index 8614f534728..00000000000 --- a/example/05_conv2d_fwd/conv2d_fwd_xdl_int8.cpp +++ /dev/null @@ -1,275 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "tensor_layout.hpp" -#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" -#include "element_wise_operation.hpp" -#include "reference_conv_fwd.hpp" -#include "convolution_utility.hpp" - -using InDataType = int8_t; -using WeiDataType = int8_t; -using OutDataType = int8_t; -using AccDataType = int32_t; - -template -using S = ck::Sequence; - -using InLayout = ck::tensor_layout::convolution::NHWC; -using WeiLayout = ck::tensor_layout::convolution::KYXC; -using OutLayout = ck::tensor_layout::convolution::NHWK; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; - -using DeviceConvFwdInstance = ck::tensor_operation::device:: - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< - int8_t, // InDataType - int8_t, // WeiDataType - int8_t, // OutDataType - int32_t, // AccDataType - PassThrough, // InElementwiseOperation - PassThrough, // WeiElementwiseOperation - PassThrough, // OutElementwiseOperation - ConvFwdDefault, // ConvForwardSpecialization - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 4, // K0PerBlock - 16, // K1 - 32, // MPerXdl - 32, // NPerXdl - 2, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector - -using ReferenceConvFwdInstance = ck::tensor_operation::host:: - ReferenceConvFwd; - -int main(int argc, char* argv[]) -{ - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; - - // Conv shape - ck::index_t N = 128; - ck::index_t K = 256; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t conv_stride_h = 2; - ck::index_t conv_stride_w = 2; - ck::index_t conv_dilation_h = 1; - ck::index_t conv_dilation_w = 1; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - } - else if(argc == 19) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - - N = std::stoi(argv[4]); - K = std::stoi(argv[5]); - C = std::stoi(argv[6]); - Y = std::stoi(argv[7]); - X = std::stoi(argv[8]); - Hi = std::stoi(argv[9]); - Wi = std::stoi(argv[10]); - conv_stride_h = std::stoi(argv[11]); - conv_stride_w = std::stoi(argv[12]); - conv_dilation_h = std::stoi(argv[13]); - conv_dilation_w = std::stoi(argv[14]); - in_left_pad_h = std::stoi(argv[15]); - in_left_pad_w = std::stoi(argv[16]); - in_right_pad_h = std::stoi(argv[17]); - in_right_pad_w = std::stoi(argv[18]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(0); - } - - const std::vector conv_filter_strides{conv_stride_h, conv_stride_w}; - const std::vector conv_filter_dilations{conv_dilation_h, conv_dilation_w}; - const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; - const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; - const auto output_spatial_lengths = - ck::tensor_operation::ConvolutionUtility::ComputeOutputSpatialLengths({Hi, Wi}, - {Y, X}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - const ck::index_t Ho = output_spatial_lengths[0]; - const ck::index_t Wo = output_spatial_lengths[1]; - - // tensor layout - auto f_host_tensor_descriptor = [](std::size_t N_, - std::size_t C_, - std::size_t H, - std::size_t W, - auto layout) { - if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, H * W, W, 1})); - } - else if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - } - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); - Tensor out_n_k_ho_wo_host_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - Tensor out_n_k_ho_wo_device_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; - std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; - std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-1, 1}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-1, 1}); - break; - default: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0, 1}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-1, 1}); - } - - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * - out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - - // do GEMM - auto conv = DeviceConvFwdInstance{}; - auto invoker = conv.MakeInvoker(); - auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(!conv.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem"); - } - - float ave_time = invoker.Run(argument, nrepeat); - - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; - - std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + - sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - if(do_verification) - { - auto ref_conv = ReferenceConvFwdInstance{}; - auto ref_invoker = ref_conv.MakeInvoker(); - - auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - ref_invoker.Run(ref_argument); - - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); - } -} diff --git a/example/08_conv3d_fwd/CMakeLists.txt b/example/08_conv3d_fwd/CMakeLists.txt deleted file mode 100644 index 49fb1fe1ce5..00000000000 --- a/example/08_conv3d_fwd/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_example_executable(example_conv3d_fwd_xdl conv3d_fwd_xdl.cpp) diff --git a/example/08_conv3d_fwd/README.md b/example/08_conv3d_fwd/README.md deleted file mode 100644 index 06339b74e52..00000000000 --- a/example/08_conv3d_fwd/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# Instructions for ```conv3d_fwd_xdl``` Example - -## Docker script -```bash -docker run \ --it \ ---rm \ ---privileged \ ---group-add sudo \ --w /root/workspace \ --v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ -rocm/tensorflow:rocm4.3.1-tf2.6-dev \ -/bin/bash -``` - -## Build ```conv3d_fwd_xdl``` -```bash -mkdir build && cd build -``` - -```bash -# Need to specify target ID, example below is gfx908 -cmake \ --D BUILD_DEV=OFF \ --D CMAKE_BUILD_TYPE=Release \ --D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ --D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_PREFIX_PATH=/opt/rocm \ -.. -``` - -```bash - make -j conv3d_fwd_xdl -``` - -## Run ```conv3d_fwd_xdl``` -```bash -#arg1: verification (0=no, 1=yes) -#arg2: initialization (0=no init, 1=integer value, 2=decimal value) -#arg3: run kernel # of times (>1) -#arg4 to 24: N, K, C, Z, Y, X, Di, Hi, Wi, Sz, Sy, Sx, Dz, Dy, Dx, leftPz, LeftPy, LeftPx, RightPz, RightPy, RightPx -./example/conv3d_fwd_xdl 0 1 5 -``` - -Result (MI100 dynamic frequency) -``` -in: dim 5, lengths {4, 71, 71, 71, 192}, strides {68718912, 967872, 13632, 192, 1} -wei: dim 5, lengths {256, 3, 3, 3, 192}, strides {5184, 1728, 576, 192, 1} -out: dim 5, lengths {4, 36, 36, 36, 256}, strides {11943936, 331776, 9216, 256, 1} -a_grid_desc_b_k0_m_k1{1, 648, 186624, 8} -b_grid_desc_b_k0_n_k1{1, 648, 256, 8} -launch_and_time_kernel: grid_dim {1458, 1, 1}, block_dim {256, 1, 1} -Warm up -Start running 5 times... -Perf: 4.49466 ms, 110.206 TFlops, 144.161 GB/s -``` - diff --git a/example/08_conv3d_fwd/conv3d_fwd_xdl.cpp b/example/08_conv3d_fwd/conv3d_fwd_xdl.cpp deleted file mode 100644 index 89d29336196..00000000000 --- a/example/08_conv3d_fwd/conv3d_fwd_xdl.cpp +++ /dev/null @@ -1,281 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" -#include "device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp" -#include "device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp" -#include "convolution_utility.hpp" - -// convolution data type -using InDataType = ck::half_t; -using WeiDataType = ck::half_t; -using OutDataType = ck::half_t; -using AccDataType = float; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -using F16 = ck::half_t; -using F32 = float; - -template -using S = ck::Sequence; - -using InLayout = ck::tensor_layout::convolution::NDHWC; -using WeiLayout = ck::tensor_layout::convolution::KZYXC; -using OutLayout = ck::tensor_layout::convolution::NDHWK; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; - -using DeviceConv3dFwdInstance = ck::tensor_operation::device:: - DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< - InDataType, // InData - WeiDataType, // WeiData - OutDataType, // OutData - AccDataType, // AccData - InElementOp, // InElementwise Operation - WeiElementOp, // WeiElementwise Operation - OutElementOp, // OutElementwise Operation - ConvFwdDefault, // ConvForwardSpecialization - 256, // BlockSize - 128, // MPerBlock - 256, // NPerBlock - 4, // K0PerBlock - 8, // K1. K0PerBlock * K1 = KPerBlock - 32, // MPerXDL - 32, // NPerXDL. Each XDL computes a matrix of size (MPerXDL, NPerBlock) - 2, // MXdlPerWave - 4, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 8, // ABlockTransferSrcScalarPerVector - 8, // ABlockTransferDstScalarPerVector_K1 - true, // ABlockLdsAddExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 8, // BBlockTransferSrcScalarPerVector - 8, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockLdsAddExtraN - 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector - -int main(int argc, char* argv[]) -{ - bool do_verification = false; - int init_method = 0; - int nrepeat = 5; - - // convolution shape - ck::index_t N = 4; - ck::index_t K = 256; - ck::index_t C = 192; - std::vector in_spatial_lengths = {71, 71, 71}; - std::vector filter_spatial_lengths = {3, 3, 3}; - std::vector conv_filter_strides = {2, 2, 2}; - std::vector conv_filter_dilations = {1, 1, 1}; - std::vector in_left_pads = {1, 1, 1}; - std::vector in_right_pads = {1, 1, 1}; - - if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - } - else if(argc == 25) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); - - N = std::stoi(argv[4]); - K = std::stoi(argv[5]); - C = std::stoi(argv[6]); - filter_spatial_lengths[0] = std::stoi(argv[7]); - filter_spatial_lengths[1] = std::stoi(argv[8]); - filter_spatial_lengths[2] = std::stoi(argv[9]); - in_spatial_lengths[0] = std::stoi(argv[10]); - in_spatial_lengths[1] = std::stoi(argv[11]); - in_spatial_lengths[2] = std::stoi(argv[12]); - conv_filter_strides[0] = std::stoi(argv[13]); - conv_filter_strides[1] = std::stoi(argv[14]); - conv_filter_strides[2] = std::stoi(argv[15]); - conv_filter_dilations[0] = std::stoi(argv[16]); - conv_filter_dilations[1] = std::stoi(argv[17]); - conv_filter_dilations[2] = std::stoi(argv[18]); - in_left_pads[0] = std::stoi(argv[19]); - in_left_pads[1] = std::stoi(argv[20]); - in_left_pads[2] = std::stoi(argv[21]); - in_right_pads[0] = std::stoi(argv[22]); - in_right_pads[1] = std::stoi(argv[23]); - in_right_pads[2] = std::stoi(argv[24]); - } - else - { - printf("Usage: 3 or 24 input arguments\n"); - printf(" arg1: verification (0=no, 1=yes)\n"); - printf(" arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf(" arg3: run kernel # of times (>1)\n"); - printf(" arg4 to 24: N, K, C, Z, Y, X, Di, Hi, Wi, Sz, Sy, Sz, Dz, Dy, Dx, LeftPz, LeftPy, " - "LeftPz, RightPz, RightPy, RightPx\n"); - exit(0); - } - - auto conv3d = DeviceConv3dFwdInstance{}; - - const auto out_spatial_lengths = - ck::tensor_operation::ConvolutionUtility::ComputeOutputSpatialLengths( - in_spatial_lengths, - filter_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - in_left_pads, - in_right_pads); - Tensor in( - {N, in_spatial_lengths[0], in_spatial_lengths[1], in_spatial_lengths[2], C}); - Tensor wei( - {K, filter_spatial_lengths[0], filter_spatial_lengths[1], filter_spatial_lengths[2], C}); - Tensor out( - {N, out_spatial_lengths[0], out_spatial_lengths[1], out_spatial_lengths[2], K}); - - std::cout << "in: " << in.mDesc << std::endl; - std::cout << "wei: " << wei.mDesc << std::endl; - std::cout << "out: " << out.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - default: - in.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - } - - DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * out.mDesc.GetElementSpace()); - - in_device_buf.ToDevice(in.mData.data()); - wei_device_buf.ToDevice(wei.mData.data()); - - // do Convolution - auto invoker = conv3d.MakeInvoker(); - auto argument = conv3d.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer()), - N, - K, - C, - in_spatial_lengths, - filter_spatial_lengths, - out_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - in_left_pads, - in_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(!conv3d.IsSupportedArgument(argument)) - { - throw std::runtime_error( - "wrong! device_conv3d with the specified compilation parameters does " - "not support this GEMM problem"); - } - - float ave_time = invoker.Run(argument, nrepeat); - - const auto Di = in_spatial_lengths[0]; - const auto Hi = in_spatial_lengths[1]; - const auto Wi = in_spatial_lengths[2]; - const auto Do = out_spatial_lengths[0]; - const auto Ho = out_spatial_lengths[1]; - const auto Wo = out_spatial_lengths[2]; - const auto Z = filter_spatial_lengths[0]; - const auto Y = filter_spatial_lengths[1]; - const auto X = filter_spatial_lengths[2]; - - std::size_t flop = std::size_t(2) * N * K * Do * Ho * Wo * C * Z * Y * X; - std::size_t num_btype = sizeof(InDataType) * N * Di * Hi * Wi * C + - sizeof(WeiDataType) * K * Z * Y * X * C + - sizeof(OutDataType) * N * Do * Ho * Wo * K; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - out_device_buf.FromDevice(out.mData.data()); - - if(do_verification) - { - DeviceMem out_ref_device_buf(sizeof(OutDataType) * N * Do * Ho * Wo * K); - - using DeviceConv3dFwdNaive = ck::tensor_operation::device:: - DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K< - InDataType, - WeiDataType, - OutDataType, - AccDataType, - InElementOp, - WeiElementOp, - OutElementOp>; - auto conv3d_naive = DeviceConv3dFwdNaive{}; - auto invoker_naive = conv3d_naive.MakeInvoker(); - auto argument_naive = conv3d_naive.MakeArgument( - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_ref_device_buf.GetDeviceBuffer()), - N, - K, - C, - in_spatial_lengths, - filter_spatial_lengths, - out_spatial_lengths, - conv_filter_strides, - conv_filter_dilations, - in_left_pads, - in_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - - if(!conv3d_naive.IsSupportedArgument(argument_naive)) - { - throw std::runtime_error( - "wrong! device_conv3d_naive does NOT support the specified compilation parameters"); - } - invoker_naive.Run(argument_naive); - - Tensor out_ref( - {N, out_spatial_lengths[0], out_spatial_lengths[1], out_spatial_lengths[2], K}); - - out_ref_device_buf.FromDevice(out_ref.mData.data()); - - check_error(out_ref, out); - } - - return 0; -} diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 61299b521e7..fd6d11d9ff2 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1 +1,3 @@ add_example_executable(example_convnd_fwd_xdl convnd_fwd_xdl.cpp) +add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) +add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl.cpp b/example/09_convnd_fwd/convnd_fwd_xdl.cpp index 188431cd9e1..322182ad7d7 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl.cpp @@ -2,6 +2,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "conv_fwd_util.hpp" #include "device.hpp" @@ -13,6 +15,8 @@ #include "reference_conv_fwd.hpp" #include "tensor_layout.hpp" +namespace { + using InDataType = float; using WeiDataType = float; using OutDataType = float; @@ -169,80 +173,18 @@ ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* return params; } -HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWK{}); - } - case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWK{}); - } - case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWK{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} - -HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KZYXC{}); - } - case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KYXC{}); - } - case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KXC{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} - -HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWC{}); - } - case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWC{}); - } - case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWC{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} +} // anonymous namespace int main(int argc, char* argv[]) { + using namespace ck::utils::conv; + bool do_verification = 0; int init_method = 0; int nrepeat = 5; int num_dim_spatial = 2; - ck::utils::conv::ConvParams params; + ConvParams params; if(argc >= 5) { @@ -334,15 +276,15 @@ int main(int argc, char* argv[]) float ave_time = invoker->Run(argument.get(), nrepeat); - std::size_t flop = ck::utils::conv::GetFlops( + std::size_t flop = GetFlops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = ck::utils::conv::GetBtype( - params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths); + std::size_t num_btype = + GetBtype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -367,7 +309,8 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); out_device_buf.FromDevice(device_output.mData.data()); - check_error(host_output, device_output); + ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); }; switch(num_dim_spatial) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp new file mode 100644 index 00000000000..3fdc133314c --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -0,0 +1,341 @@ +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_fwd_util.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace { + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; +using AccDataType = float; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +using DeviceConvFwdBasePtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + PrintUseMsg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial = num_dim_spatial; + params.N = std::stoi(argv[arg_idx++]); + params.K = std::stoi(argv[arg_idx++]); + params.C = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + int num_dim_spatial = 2; + + ConvParams params; + + if(argc >= 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + } + + if(argc >= 6) + { + params = ParseConvParams(num_dim_spatial, argc, argv); + } + + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); + Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // do GEMM + auto conv = GetConvInstance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), nrepeat); + + std::size_t flop = GetFlops( + params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = GetBtype( + params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output]( + const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvNDFwdInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvNDFwdInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvNDFwdInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } +} diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp new file mode 100644 index 00000000000..0e2bca9625c --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -0,0 +1,343 @@ +#include +#include +#include +#include + +#include "check_err.hpp" +#include "config.hpp" +#include "conv_fwd_util.hpp" +#include "device.hpp" +#include "device_tensor.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "reference_conv_fwd.hpp" +#include "tensor_layout.hpp" + +namespace { + +using InDataType = int8_t; +using WeiDataType = int8_t; +using OutDataType = int8_t; +using AccDataType = int32_t; + +template +using S = ck::Sequence; + +using InLayout = ck::tensor_layout::convolution::NHWC; +using WeiLayout = ck::tensor_layout::convolution::KYXC; +using OutLayout = ck::tensor_layout::convolution::NHWK; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + +using DeviceConvFwdBasePtr = + ck::tensor_operation::device::DeviceConvFwdPtr; + +template +using DeviceConvNDFwdInstance = ck::tensor_operation::device:: + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< + // clang-format off + InDataType, // + WeiDataType, // + OutDataType, // + AccDataType, // + InElementOp, // Input Elementwise Operation + WeiElementOp, // Weights Elementwise Operation + OutElementOp, // Output Elementwise Operation + ConvFwdDefault, // ConvForwardSpecialization + NumDimSpatial, // NumDimSpatial + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock + 4, // K0PerBlock + 16, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 7, // CThreadTransferSrcDstVectorDim + 1>; // CThreadTransferDstScalarPerVector + +template +using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + +DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "arg4: N spatial dimensions (default 2)\n" + << "Following arguments (depending on number of spatial dims):\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 5; + if(cmdline_nargs != argc) + { + PrintUseMsg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 5; + + params.num_dim_spatial = num_dim_spatial; + params.N = std::stoi(argv[arg_idx++]); + params.K = std::stoi(argv[arg_idx++]); + params.C = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + int num_dim_spatial = 2; + + ConvParams params; + + if(argc >= 5) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + num_dim_spatial = std::stoi(argv[4]); + } + + if(argc >= 6) + { + params = ParseConvParams(num_dim_spatial, argc, argv); + } + + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); + Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + } + + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + + // do GEMM + auto conv = GetConvInstance(num_dim_spatial); + auto invoker = conv->MakeInvokerPointer(); + auto argument = + conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + if(!conv->IsSupportedArgument(argument.get())) + { + throw std::runtime_error( + "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem"); + } + + float ave_time = invoker->Run(argument.get(), nrepeat); + + std::size_t flop = GetFlops( + params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = GetBtype( + params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + if(do_verification) + { + auto verify_f = [&input, &weights, &host_output, ¶ms, &out_device_buf, &device_output]( + const auto& ref_conv) { + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); + out_device_buf.FromDevice(device_output.mData.data()); + ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); + }; + + switch(num_dim_spatial) + { + case 3: { + auto ref_conv = ReferenceConvNDFwdInstance<3>(); + verify_f(ref_conv); + break; + } + case 2: { + auto ref_conv = ReferenceConvNDFwdInstance<2>(); + verify_f(ref_conv); + break; + } + case 1: { + auto ref_conv = ReferenceConvNDFwdInstance<1>(); + verify_f(ref_conv); + break; + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } + } +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index afc2f8c19a8..929dd6282d2 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -30,10 +30,8 @@ add_subdirectory(01_gemm) add_subdirectory(02_gemm_alpha_beta) add_subdirectory(03_gemm_bias_relu) add_subdirectory(04_gemm_bias_relu_add) -add_subdirectory(05_conv2d_fwd) add_subdirectory(06_conv2d_fwd_bias_relu) add_subdirectory(07_conv2d_fwd_bias_relu_add) -add_subdirectory(08_conv3d_fwd) add_subdirectory(09_convnd_fwd) add_subdirectory(10_conv2d_bwd_data) add_subdirectory(11_conv2d_bwd_wgt) From 873d918c2f9d14d49f409b53d2cbcc01096f296c Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Mar 2022 16:23:08 +0100 Subject: [PATCH 53/82] Add constructor to ConvParams. * And add input parameters validation. --- .../ck/library/utility/conv_fwd_util.hpp | 77 +++++++++++++++---- 1 file changed, 61 insertions(+), 16 deletions(-) diff --git a/library/include/ck/library/utility/conv_fwd_util.hpp b/library/include/ck/library/utility/conv_fwd_util.hpp index 4c882e72942..f6b1b8107b4 100644 --- a/library/include/ck/library/utility/conv_fwd_util.hpp +++ b/library/include/ck/library/utility/conv_fwd_util.hpp @@ -31,10 +31,6 @@ using DeviceConvFwdNoOpPtr = ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough>; -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - /** * @brief Calculate number of FLOPs for Convolution * @@ -128,6 +124,39 @@ struct ConvParams { } + ConvParams(ck::index_t n_dim, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial(n_dim), + N(n_batch), + K(n_out_channels), + C(n_in_channels), + filter_spatial_lengths(filters_len), + input_spatial_lengths(input_len), + conv_filter_strides(strides), + conv_filter_dilations(dilations), + input_left_pads(left_pads), + input_right_pads(right_pads) + { + if(filter_spatial_lengths.size() != num_dim_spatial || + input_spatial_lengths.size() != num_dim_spatial || + conv_filter_strides.size() != num_dim_spatial || + conv_filter_dilations.size() != num_dim_spatial || + input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial) + { + throw(std::runtime_error( + "ConvParams::GetOutputSpatialLengths: " + "parameter size is different from number of declared dimensions!")); + } + } + ck::index_t num_dim_spatial; ck::index_t N; ck::index_t K; @@ -144,6 +173,17 @@ struct ConvParams std::vector GetOutputSpatialLengths() const { + if(filter_spatial_lengths.size() != num_dim_spatial || + input_spatial_lengths.size() != num_dim_spatial || + conv_filter_strides.size() != num_dim_spatial || + conv_filter_dilations.size() != num_dim_spatial || + input_left_pads.size() != num_dim_spatial || input_right_pads.size() != num_dim_spatial) + { + throw(std::runtime_error( + "ConvParams::GetOutputSpatialLengths: " + "parameter size is different from number of declared dimensions!")); + } + std::vector out_spatial_len(num_dim_spatial, 0); for(ck::index_t i = 0; i < num_dim_spatial; ++i) { @@ -366,12 +406,13 @@ void RunReferenceConvFwd(const ConvParams& params, const Tensor& weights, Tensor& output) { + using PassThrough = ck::tensor_operation::element_wise::PassThrough; auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(input, @@ -381,9 +422,9 @@ void RunReferenceConvFwd(const ConvParams& params, params.conv_filter_dilations, params.input_left_pads, params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + PassThrough{}, + PassThrough{}, + PassThrough{}); ref_invoker.Run(ref_argument); } @@ -399,6 +440,8 @@ void RunConvFwd(const ConvParams& params, const Tensor& weights, Tensor& output) { + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); @@ -422,9 +465,9 @@ void RunConvFwd(const ConvParams& params, params.conv_filter_dilations, params.input_left_pads, params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + PassThrough{}, + PassThrough{}, + PassThrough{}); if(!conv.IsSupportedArgument(argument)) { @@ -448,6 +491,8 @@ bool RunConvInstances(const ConvParams& params, Tensor& output, const Tensor& host_output) { + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); DeviceMem out_device_buf(sizeof(OutDataType) * output.mDesc.GetElementSpace()); @@ -474,9 +519,9 @@ bool RunConvInstances(const ConvParams& params, params.conv_filter_dilations, params.input_left_pads, params.input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + PassThrough{}, + PassThrough{}, + PassThrough{}); if(conv_ptr->IsSupportedArgument(argument.get())) { From 557a958566825bb4d3710237a2b549afc7d50c49 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Mar 2022 16:24:30 +0100 Subject: [PATCH 54/82] Modify conv examples to use single utility file. --- .../conv2d_fwd_xdl_bias_relu.cpp | 322 +++++++++-------- .../conv2d_fwd_xdl_bias_relu_add.cpp | 338 +++++++++--------- example/09_convnd_fwd/convnd_fwd_xdl.cpp | 2 +- example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp | 2 +- example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp | 2 +- .../gpu/device/convolution_utility.hpp | 73 ---- ...ice_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp | 40 +-- 7 files changed, 362 insertions(+), 417 deletions(-) delete mode 100644 include/ck/tensor_operation/gpu/device/convolution_utility.hpp diff --git a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp index d251aa35e12..11f1ed85ca1 100644 --- a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp +++ b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -4,17 +4,20 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" -#include "print.hpp" +#include "conv_fwd_util.hpp" #include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" #include "device_tensor.hpp" -#include "tensor_layout.hpp" #include "element_wise_operation.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" #include "reference_conv_fwd_bias_activation.hpp" -#include "convolution_utility.hpp" +#include "tensor_layout.hpp" + +namespace { using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -86,146 +89,155 @@ using ReferenceConvFwdInstance = WeiElementOp, OutElementOp>; -int main(int argc, char* argv[]) +void PrintUseMsg() { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; - - // Conv shape - ck::index_t N = 128; - ck::index_t K = 256; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t conv_stride_h = 2; - ck::index_t conv_stride_w = 2; - ck::index_t conv_dilation_h = 1; - ck::index_t conv_dilation_w = 1; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - - if(argc == 4) + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "Following arguments:\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) +{ + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int num_dim_spatial = 2; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 4; + if(cmdline_nargs != argc) { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + PrintUseMsg(); + exit(0); } - else if(argc == 19) + + ck::utils::conv::ConvParams params; + int arg_idx = 4; + + params.num_dim_spatial = num_dim_spatial; + params.N = std::stoi(argv[arg_idx++]); + params.K = std::stoi(argv[arg_idx++]); + params.C = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_strides.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + const int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); nrepeat = std::stoi(argv[3]); - - N = std::stoi(argv[4]); - K = std::stoi(argv[5]); - C = std::stoi(argv[6]); - Y = std::stoi(argv[7]); - X = std::stoi(argv[8]); - Hi = std::stoi(argv[9]); - Wi = std::stoi(argv[10]); - conv_stride_h = std::stoi(argv[11]); - conv_stride_w = std::stoi(argv[12]); - conv_dilation_h = std::stoi(argv[13]); - conv_dilation_w = std::stoi(argv[14]); - in_left_pad_h = std::stoi(argv[15]); - in_left_pad_w = std::stoi(argv[16]); - in_right_pad_h = std::stoi(argv[17]); - in_right_pad_w = std::stoi(argv[18]); } - else + + if(argc >= 5) { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(0); + params = ParseConvParams(argc, argv); } - const std::vector conv_filter_strides{conv_stride_h, conv_stride_w}; - const std::vector conv_filter_dilations{conv_dilation_h, conv_dilation_w}; - const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; - const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; - const auto output_spatial_lengths = - ck::tensor_operation::ConvolutionUtility::ComputeOutputSpatialLengths({Hi, Wi}, - {Y, X}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - const ck::index_t Ho = output_spatial_lengths[0]; - const ck::index_t Wo = output_spatial_lengths[1]; - - // tensor layout - auto f_host_tensor_descriptor = [](std::size_t N_, - std::size_t C_, - std::size_t H, - std::size_t W, - auto layout) { - if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, H * W, W, 1})); - } - else if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - } - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); - Tensor out_n_k_ho_wo_host_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - Tensor out_n_k_ho_wo_device_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); + Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); // bias: assume contiguous 1d vector - Tensor bias_k( - HostTensorDescriptor(std::vector({static_cast(K)}))); + Tensor bias( + HostTensorDescriptor(std::vector({static_cast(params.K)}))); - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; - std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; - std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; - std::cout << "bias_k: " << bias_k.mDesc << std::endl; + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; switch(init_method) { case 0: break; case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * - out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); - DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace()); - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - bias_device_buf.ToDevice(bias_k.mData.data()); + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); auto conv = DeviceConvFwdInstance{}; auto invoker = conv.MakeInvoker(); @@ -234,16 +246,16 @@ int main(int argc, char* argv[]) static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), static_cast(bias_device_buf.GetDeviceBuffer()), - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, InElementOp{}, WeiElementOp{}, OutElementOp{}); @@ -257,16 +269,19 @@ int main(int argc, char* argv[]) float ave_time = invoker.Run(argument, nrepeat); - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; - - std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + - sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - + std::size_t flop = GetFlops( + params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = + GetBtype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths) + + sizeof(OutDataType) * (params.K); + + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; @@ -275,21 +290,20 @@ int main(int argc, char* argv[]) auto ref_conv = ReferenceConvFwdInstance{}; auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - bias_k, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + bias, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, InElementOp{}, WeiElementOp{}, OutElementOp{}); ref_invoker.Run(ref_argument); - - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + out_device_buf.FromDevice(device_output.mData.data()); + ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); } } diff --git a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp index d6011b98a90..b45b2de7432 100644 --- a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp +++ b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -4,17 +4,20 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" -#include "print.hpp" +#include "conv_fwd_util.hpp" #include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" +#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" #include "device_tensor.hpp" -#include "tensor_layout.hpp" #include "element_wise_operation.hpp" -#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" #include "reference_conv_fwd_bias_activation_add.hpp" -#include "convolution_utility.hpp" +#include "tensor_layout.hpp" + +namespace { using InDataType = ck::half_t; using WeiDataType = ck::half_t; @@ -83,154 +86,164 @@ using ReferenceConvFwdInstance = WeiElementOp, OutElementOp>; -int main(int argc, char* argv[]) +void PrintUseMsg() +{ + std::cout << "arg1: verification (0=no, 1=yes)\n" + << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" + << "arg3: run kernel # of times (>1)\n" + << "Following arguments:\n" + << " N, K, C, \n" + << " , (ie Y, X for 2D)\n" + << " , (ie Hi, Wi for 2D)\n" + << " , (ie Sy, Sx for 2D)\n" + << " , (ie Dy, Dx for 2D)\n" + << " , (ie LeftPy, LeftPx for 2D)\n" + << " , (ie RightPy, RightPx for 2D)\n" + << std::endl; +} + +ck::utils::conv::ConvParams ParseConvParams(int argc, char* argv[]) { - bool do_verification = 0; - int init_method = 0; - int nrepeat = 5; - - // Conv shape - ck::index_t N = 128; - ck::index_t K = 256; - ck::index_t C = 192; - ck::index_t Y = 3; - ck::index_t X = 3; - ck::index_t Hi = 71; - ck::index_t Wi = 71; - ck::index_t conv_stride_h = 2; - ck::index_t conv_stride_w = 2; - ck::index_t conv_dilation_h = 1; - ck::index_t conv_dilation_w = 1; - ck::index_t in_left_pad_h = 1; - ck::index_t in_left_pad_w = 1; - ck::index_t in_right_pad_h = 1; - ck::index_t in_right_pad_w = 1; - - if(argc == 4) + // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) + int num_dim_spatial = 2; + int conv_args = 3 + num_dim_spatial * 6; + int cmdline_nargs = conv_args + 4; + if(cmdline_nargs != argc) { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - nrepeat = std::stoi(argv[3]); + PrintUseMsg(); + exit(0); + } + + ck::utils::conv::ConvParams params; + int arg_idx = 4; + + params.num_dim_spatial = num_dim_spatial; + params.N = std::stoi(argv[arg_idx++]); + params.K = std::stoi(argv[arg_idx++]); + params.C = std::stoi(argv[arg_idx++]); + + params.filter_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.filter_spatial_lengths[i] = std::stoi(argv[arg_idx++]); + } + params.input_spatial_lengths.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_spatial_lengths[i] = std::stoi(argv[arg_idx++]); } - else if(argc == 19) + params.conv_filter_strides.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_strides[i] = std::stoi(argv[arg_idx++]); + } + params.conv_filter_dilations.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.conv_filter_dilations[i] = std::stoi(argv[arg_idx++]); + } + params.input_left_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_left_pads[i] = std::stoi(argv[arg_idx++]); + } + params.input_right_pads.resize(num_dim_spatial); + for(int i = 0; i < num_dim_spatial; ++i) + { + params.input_right_pads[i] = std::stoi(argv[arg_idx++]); + } + + return params; +} + +} // anonymous namespace + +int main(int argc, char* argv[]) +{ + using namespace ck::utils::conv; + + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + const int num_dim_spatial = 2; + + ck::utils::conv::ConvParams params; + + if(argc >= 4) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); nrepeat = std::stoi(argv[3]); - - N = std::stoi(argv[4]); - K = std::stoi(argv[5]); - C = std::stoi(argv[6]); - Y = std::stoi(argv[7]); - X = std::stoi(argv[8]); - Hi = std::stoi(argv[9]); - Wi = std::stoi(argv[10]); - conv_stride_h = std::stoi(argv[11]); - conv_stride_w = std::stoi(argv[12]); - conv_dilation_h = std::stoi(argv[13]); - conv_dilation_w = std::stoi(argv[14]); - in_left_pad_h = std::stoi(argv[15]); - in_left_pad_w = std::stoi(argv[16]); - in_right_pad_h = std::stoi(argv[17]); - in_right_pad_w = std::stoi(argv[18]); } - else + + if(argc >= 5) { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 18: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, " - "RightPx\n"); - exit(0); + params = ParseConvParams(argc, argv); } - const std::vector conv_filter_strides{conv_stride_h, conv_stride_w}; - const std::vector conv_filter_dilations{conv_dilation_h, conv_dilation_w}; - const std::vector input_left_pads{in_left_pad_h, in_left_pad_w}; - const std::vector input_right_pads{in_right_pad_h, in_right_pad_w}; - const auto output_spatial_lengths = - ck::tensor_operation::ConvolutionUtility::ComputeOutputSpatialLengths({Hi, Wi}, - {Y, X}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); - - const ck::index_t Ho = output_spatial_lengths[0]; - const ck::index_t Wo = output_spatial_lengths[1]; - - // tensor layout - auto f_host_tensor_descriptor = [](std::size_t N_, - std::size_t C_, - std::size_t H, - std::size_t W, - auto layout) { - if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, H * W, W, 1})); - } - else if constexpr(ck::is_same::value || - ck::is_same::value || - ck::is_same::value) - { - return HostTensorDescriptor(std::vector({N_, C_, H, W}), - std::vector({C_ * H * W, 1, W * C_, C_})); - } - }; - - Tensor in_n_c_hi_wi(f_host_tensor_descriptor(N, C, Hi, Wi, InLayout{})); - Tensor wei_k_c_y_x(f_host_tensor_descriptor(K, C, Y, X, WeiLayout{})); - Tensor out_n_k_ho_wo_host_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); - Tensor out_n_k_ho_wo_device_result( - f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + std::vector input_dims{static_cast(params.N), + static_cast(params.C)}; + input_dims.insert(std::end(input_dims), + std::begin(params.input_spatial_lengths), + std::end(params.input_spatial_lengths)); + + std::vector filter_dims{static_cast(params.K), + static_cast(params.C)}; + filter_dims.insert(std::end(filter_dims), + std::begin(params.filter_spatial_lengths), + std::end(params.filter_spatial_lengths)); + + const std::vector& output_spatial_lengths = params.GetOutputSpatialLengths(); + std::vector output_dims{static_cast(params.N), + static_cast(params.K)}; + output_dims.insert(std::end(output_dims), + std::begin(output_spatial_lengths), + std::end(output_spatial_lengths)); + + Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); + Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); // bias: assume contiguous 1d vector - Tensor bias_k( - HostTensorDescriptor(std::vector({static_cast(K)}))); + Tensor bias( + HostTensorDescriptor(std::vector({static_cast(params.K)}))); // residual: assume same layout as output tensor - Tensor resi_n_k_ho_wo(f_host_tensor_descriptor(N, K, Ho, Wo, OutLayout{})); + Tensor residual(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); - std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi.mDesc << std::endl; - std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; - std::cout << "out_n_k_ho_wo: " << out_n_k_ho_wo_host_result.mDesc << std::endl; - std::cout << "bias_k: " << bias_k.mDesc << std::endl; - std::cout << "resi_n_k_ho_wo: " << resi_n_k_ho_wo.mDesc << std::endl; + std::cout << "input: " << input.mDesc << std::endl; + std::cout << "weights: " << weights.mDesc << std::endl; + std::cout << "output: " << host_output.mDesc << std::endl; + std::cout << "bias: " << bias.mDesc << std::endl; + std::cout << "residual: " << residual.mDesc << std::endl; switch(init_method) { case 0: break; case 1: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - bias_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + input.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + weights.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + bias.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + residual.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - bias_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - resi_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + input.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + weights.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + bias.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + residual.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } - DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_device_buf(sizeof(OutDataType) * - out_n_k_ho_wo_device_result.mDesc.GetElementSpace()); - DeviceMem bias_device_buf(sizeof(OutDataType) * bias_k.mDesc.GetElementSpace()); - DeviceMem resi_device_buf(sizeof(OutDataType) * resi_n_k_ho_wo.mDesc.GetElementSpace()); + DeviceMem in_device_buf(sizeof(InDataType) * input.mDesc.GetElementSpace()); + DeviceMem wei_device_buf(sizeof(WeiDataType) * weights.mDesc.GetElementSpace()); + DeviceMem out_device_buf(sizeof(OutDataType) * device_output.mDesc.GetElementSpace()); + DeviceMem bias_device_buf(sizeof(OutDataType) * bias.mDesc.GetElementSpace()); + DeviceMem resi_device_buf(sizeof(OutDataType) * residual.mDesc.GetElementSpace()); - in_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - bias_device_buf.ToDevice(bias_k.mData.data()); - resi_device_buf.ToDevice(resi_n_k_ho_wo.mData.data()); + in_device_buf.ToDevice(input.mData.data()); + wei_device_buf.ToDevice(weights.mData.data()); + bias_device_buf.ToDevice(bias.mData.data()); + resi_device_buf.ToDevice(residual.mData.data()); const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; @@ -244,16 +257,16 @@ int main(int argc, char* argv[]) static_cast(out_device_buf.GetDeviceBuffer()), static_cast(bias_device_buf.GetDeviceBuffer()), static_cast(resi_device_buf.GetDeviceBuffer()), - N, - K, - C, - std::vector{Hi, Wi}, - std::vector{Y, X}, - std::vector{Ho, Wo}, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, + params.N, + params.K, + params.C, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, in_element_op, wei_element_op, out_element_op); @@ -267,17 +280,21 @@ int main(int argc, char* argv[]) float ave_time = invoker.Run(argument, nrepeat); - std::size_t flop = std::size_t(2) * N * K * Ho * Wo * C * Y * X; - - std::size_t num_btype = sizeof(InDataType) * (N * C * Hi * Wi) + - sizeof(WeiDataType) * (K * C * Y * X) + - sizeof(OutDataType) * (N * K * Ho * Wo) + sizeof(OutDataType) * (K) + - sizeof(OutDataType) * (N * K * Ho * Wo); - - float tflops = static_cast(flop) / 1.E9 / ave_time; - + std::size_t flop = GetFlops( + params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = + GetBtype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths) + + sizeof(OutDataType) * (params.K) + + sizeof(OutDataType) * + (params.N * params.K * output_spatial_lengths[0] * output_spatial_lengths[1]); + + float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; @@ -286,23 +303,22 @@ int main(int argc, char* argv[]) auto ref_conv = ReferenceConvFwdInstance{}; auto ref_invoker = ref_conv.MakeInvoker(); - auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - bias_k, - resi_n_k_ho_wo, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, + auto ref_argument = ref_conv.MakeArgument(input, + weights, + host_output, + bias, + residual, + params.conv_filter_strides, + params.conv_filter_dilations, + params.input_left_pads, + params.input_right_pads, in_element_op, wei_element_op, out_element_op); ref_invoker.Run(ref_argument); - - out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + out_device_buf.FromDevice(device_output.mData.data()); + ck::utils::check_err( + host_output.mData, device_output.mData, "Error: incorrect results!", 1e-5f, 1e-4f); } } diff --git a/example/09_convnd_fwd/convnd_fwd_xdl.cpp b/example/09_convnd_fwd/convnd_fwd_xdl.cpp index 322182ad7d7..bb07386e801 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl.cpp @@ -184,7 +184,7 @@ int main(int argc, char* argv[]) int nrepeat = 5; int num_dim_spatial = 2; - ConvParams params; + ck::utils::conv::ConvParams params; if(argc >= 5) { diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index 3fdc133314c..f280cdbfa7e 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -187,7 +187,7 @@ int main(int argc, char* argv[]) int nrepeat = 5; int num_dim_spatial = 2; - ConvParams params; + ck::utils::conv::ConvParams params; if(argc >= 5) { diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 0e2bca9625c..0c8c9baf7cc 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -189,7 +189,7 @@ int main(int argc, char* argv[]) int nrepeat = 5; int num_dim_spatial = 2; - ConvParams params; + ck::utils::conv::ConvParams params; if(argc >= 5) { diff --git a/include/ck/tensor_operation/gpu/device/convolution_utility.hpp b/include/ck/tensor_operation/gpu/device/convolution_utility.hpp deleted file mode 100644 index a6b891dab29..00000000000 --- a/include/ck/tensor_operation/gpu/device/convolution_utility.hpp +++ /dev/null @@ -1,73 +0,0 @@ -#ifndef CONVOLUTION_UTILITY_HPP -#define CONVOLUTION_UTILITY_HPP - -#include - -namespace ck { -namespace tensor_operation { - -struct ConvolutionUtility -{ - static std::vector - ComputeOutputSpatialLengths(std::vector input_spatial_lengths, - std::vector filter_spatial_lengths, - std::vector conv_strides, - std::vector conv_dilations, - std::vector in_left_pads, - std::vector in_right_pads) - { - if(input_spatial_lengths.size() == 2) - { - assert(filter_spatial_lengths.size() == 2); - assert(conv_strides.size() == 2); - assert(conv_dilations.size() == 2); - assert(in_left_pads.size() == 2); - assert(in_right_pads.size() == 2); - - const index_t YEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1; - const index_t XEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1; - - const index_t Hi = input_spatial_lengths[0]; - const index_t Wi = input_spatial_lengths[1]; - - const index_t Ho = - (Hi + in_left_pads[0] + in_right_pads[0] - YEff) / conv_strides[0] + 1; - const index_t Wo = - (Wi + in_left_pads[1] + in_right_pads[1] - XEff) / conv_strides[1] + 1; - - return {Ho, Wo}; - } - else if(input_spatial_lengths.size() == 3) - { - assert(filter_spatial_lengths.size() == 3); - assert(conv_strides.size() == 3); - assert(conv_dilations.size() == 3); - assert(in_left_pads.size() == 3); - assert(in_right_pads.size() == 3); - - const index_t ZEff = (filter_spatial_lengths[0] - 1) * conv_dilations[0] + 1; - const index_t YEff = (filter_spatial_lengths[1] - 1) * conv_dilations[1] + 1; - const index_t XEff = (filter_spatial_lengths[2] - 1) * conv_dilations[2] + 1; - - const index_t Di = input_spatial_lengths[0]; - const index_t Hi = input_spatial_lengths[1]; - const index_t Wi = input_spatial_lengths[2]; - - const index_t Do = - (Di + in_left_pads[0] + in_right_pads[0] - ZEff) / conv_strides[0] + 1; - const index_t Ho = - (Hi + in_left_pads[1] + in_right_pads[1] - YEff) / conv_strides[1] + 1; - const index_t Wo = - (Wi + in_left_pads[2] + in_right_pads[2] - XEff) / conv_strides[2] + 1; - return {Do, Ho, Wo}; - } - else - { - return {}; - } - } -}; - -} // namespace tensor_operation -} // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp index 0371c4ab0d5..c3ebe588657 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -4,7 +4,7 @@ #include #include #include -#include "convolution_utility.hpp" +#include "conv_fwd_util.hpp" #include "device.hpp" #include "device_conv_fwd.hpp" #include "common_header.hpp" @@ -53,36 +53,30 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) - : N_{N}, - K_{K}, - C_{C}, - in_spatial_lengths_{input_spatial_lengths}, - filter_spatial_lengths_{filter_spatial_lengths}, + : params_{3, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, out_spatial_lengths_{output_spatial_lengths}, - conv_filter_strides_{conv_filter_strides}, - conv_filter_dilations_{conv_filter_dilations}, - in_left_pads_{input_left_pads}, - in_right_pads_{input_right_pads}, p_in_{p_in}, p_wei_{p_wei}, p_out_{p_out}, in_element_op_{in_element_op}, wei_element_op_{wei_element_op}, out_element_op_{out_element_op} + { } // private: - index_t N_; - index_t K_; - index_t C_; - std::vector in_spatial_lengths_; - std::vector filter_spatial_lengths_; + utils::conv::ConvParams params_; std::vector out_spatial_lengths_; - std::vector conv_filter_strides_; - std::vector conv_filter_dilations_; - std::vector in_left_pads_; - std::vector in_right_pads_; const InDataType* p_in_; const WeiDataType* p_wei_; @@ -157,13 +151,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W static bool IsSupportedArgument(const Argument& arg) { - std::vector out_spatial_lengths = - ConvolutionUtility::ComputeOutputSpatialLengths(arg.in_spatial_lengths_, - arg.filter_spatial_lengths_, - arg.conv_filter_strides_, - arg.conv_filter_dilations_, - arg.in_left_pads_, - arg.in_right_pads_); + std::vector out_spatial_lengths = arg.params_.GetOutputSpatialLengths(); bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] && out_spatial_lengths[1] == arg.out_spatial_lengths_[1] && From a1919c00ac19c474b3d32da751fc2ade935ccf31 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Mar 2022 21:48:36 +0100 Subject: [PATCH 55/82] Remove check_error from host_tensor.hpp --- example/01_gemm/gemm_xdl_bf16.cpp | 4 +- example/01_gemm/gemm_xdl_fp16.cpp | 4 +- example/01_gemm/gemm_xdl_int8.cpp | 4 +- .../gemm_xdl_alpha_beta.cpp | 4 +- .../03_gemm_bias_relu/gemm_xdl_bias_relu.cpp | 4 +- .../gemm_xdl_bias_relu_add.cpp | 4 +- .../conv2d_bwd_data_xdl.cpp | 5 ++- .../11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp | 4 +- example/12_reduce/reduce_blockwise.cpp | 4 +- example/13_pool2d_fwd/pool2d_fwd.cpp | 4 +- .../ck/library/host_tensor/host_tensor.hpp | 40 ------------------- .../include/ck/library/utility/check_err.hpp | 6 +-- .../conv_add_fwd_driver_offline_nchwc.cpp | 4 +- .../conv_bwd_driver_offline.cpp | 4 +- .../conv_fwd_driver_offline.cpp | 4 +- .../conv_fwd_driver_offline_nchwc.cpp | 4 +- .../conv_maxpool_fwd_driver_offline_nchwc.cpp | 6 ++- .../conv_wrw_driver_offline.cpp | 4 +- .../gemm_driver_offline.cpp | 4 +- .../include/profile_batched_gemm_impl.hpp | 4 +- .../include/profile_conv_bwd_data_impl.hpp | 5 ++- .../profile_conv_fwd_bias_relu_add_impl.hpp | 5 ++- ...ile_conv_fwd_bias_relu_atomic_add_impl.hpp | 4 +- .../profile_conv_fwd_bias_relu_impl.hpp | 4 +- profiler/include/profile_conv_fwd_impl.hpp | 5 ++- .../include/profile_gemm_bias_2d_impl.hpp | 4 +- .../profile_gemm_bias_relu_add_impl.hpp | 4 +- .../include/profile_gemm_bias_relu_impl.hpp | 4 +- profiler/include/profile_gemm_impl.hpp | 7 +++- profiler/include/profile_reduce_impl.hpp | 6 ++- .../magic_number_division.cpp | 33 +++------------ 31 files changed, 101 insertions(+), 101 deletions(-) diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 5a9091a2361..7ce14143c40 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -229,7 +231,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_f32_result); + ck::utils::conv::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); } return 0; diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index ad369e774d4..478f26d9f47 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -236,7 +238,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } return 0; diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index ba24aa4e85e..26ab7cc51dc 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -225,7 +227,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } return 0; diff --git a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp index 51a31bcfb76..c8f5a72cb87 100644 --- a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp +++ b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -244,6 +246,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } } diff --git a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp index 4dc8d0b7883..a01f13669c4 100644 --- a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp +++ b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -230,6 +232,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } } diff --git a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp index 3ce7e9848b3..ea184efa44f 100644 --- a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp +++ b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -248,6 +250,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } } diff --git a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp index 7f289c19383..36325f703ba 100644 --- a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp +++ b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -242,6 +244,7 @@ int main(int argc, char* argv[]) in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); - check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); + ck::utils::conv::check_err(in_n_c_hi_wi_device_result.mData, + in_n_c_hi_wi_host_result.mData); } } diff --git a/example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp b/example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp index 41415875836..0a0d9d1f8cf 100644 --- a/example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp +++ b/example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -284,6 +286,6 @@ int main(int argc, char* argv[]) LogRangeAsType(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") << std::endl; } - check_error(wei_k_c_y_x_host_result, wei_k_c_y_x_device_result); + ck::utils::conv::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData); } } diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index 6a5864ede07..0b0b1660545 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -356,7 +358,7 @@ int main(int argc, char* argv[]) if(args.do_verification) { out_dev.FromDevice(out.mData.data()); - check_error(out_ref, out); + ck::utils::conv::check_err(out.mData, out_ref.mData); if(NeedIndices) { diff --git a/example/13_pool2d_fwd/pool2d_fwd.cpp b/example/13_pool2d_fwd/pool2d_fwd.cpp index a0cb61136f6..b18e329003b 100644 --- a/example/13_pool2d_fwd/pool2d_fwd.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd.cpp @@ -3,6 +3,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -299,7 +301,7 @@ int main(int argc, char* argv[]) out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); - check_error(out_n_c_ho_wo_host, out_n_c_ho_wo_device); + ck::utils::conv::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); if constexpr(NeedIndices) { diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index f9f462d7fd8..9582fba4f56 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -316,46 +316,6 @@ float bf16_to_f32_(ck::bhalf_t src_val); void bf16_to_f32_(const Tensor& src, Tensor& dst); -template -void check_error(const Tensor& ref, const Tensor& result) -{ - float error = 0; - float max_diff = -1; - float ref_value = 0, result_value = 0; - - if constexpr(std::is_same::value) - { - for(int i = 0; i < ref.mData.size(); ++i) - { - error += std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i])); - float diff = std::abs(bf16_to_f32_(ref.mData[i]) - bf16_to_f32_(result.mData[i])); - if(max_diff < diff) - { - max_diff = diff; - ref_value = bf16_to_f32_(ref.mData[i]); - result_value = bf16_to_f32_(result.mData[i]); - } - } - } - else - { - for(int i = 0; i < ref.mData.size(); ++i) - { - error += std::abs(double(ref.mData[i]) - double(result.mData[i])); - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) - { - max_diff = diff; - ref_value = ref.mData[i]; - result_value = result.mData[i]; - } - } - } - - std::cout << "error: " << error << std::endl; - std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl; -} - template void check_indices(const Tensor& ref, const Tensor& result) { diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 3ff7fcdfbcd..e23aa8cd000 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -21,9 +21,9 @@ typename std::enable_if::value && !std::is_same::type check_err(const std::vector& out, const std::vector& ref, - const std::string& msg, - double rtol = 1e-5, - double atol = 1e-8) + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 1e-8) { if(out.size() != ref.size()) { diff --git a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp index d818f3c950e..7b70a5578b0 100644 --- a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -401,7 +403,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), activ_type); - check_error(add_host, add_device); + ck::utils::conv::check_err(add_device.mData, add_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp index 7082f1050c9..9d5c06a0e8f 100644 --- a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -473,7 +475,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), layout); - check_error(in_host, in_device); + ck::utils::conv::check_err(in_device.mData, in_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp index a6f47c5de5a..1f20fed44a8 100644 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -534,7 +536,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), layout); - check_error(out_host, out_device); + ck::utils::conv::check_err(out_device.mData, out_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp index 6b34254c74f..20fee073116 100644 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -377,7 +379,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), activ_type); - check_error(out_host, out_device); + ck::utils::conv::check_err(out_device.mData, out_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp index d8a22bda337..3f116fa168f 100644 --- a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -397,8 +399,8 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), activ_type); - check_error(out_host, out_device); - check_error(max_host, max_device); + ck::utils::conv::check_err(out_device.mData, out_host.mData); + ck::utils::conv::check_err(max_device.mData, max_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp index 0151fea9e50..caa898dd720 100644 --- a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -517,7 +519,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), layout); - check_error(wei_host, wei_device); + ck::utils::conv::check_err(wei_device.mData, wei_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/gemm_driver_offline.cpp b/library/src/obselete_driver_offline/gemm_driver_offline.cpp index bd8cb00390c..58d83cad651 100644 --- a/library/src/obselete_driver_offline/gemm_driver_offline.cpp +++ b/library/src/obselete_driver_offline/gemm_driver_offline.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "debug.hpp" #include "print.hpp" @@ -441,7 +443,7 @@ int main(int argc, char* argv[]) { host_gemm(a, b, c_host, layout); - check_error(c_host, c_device); + ck::utils::conv::check_err(c_device.mData, c_host.mData); if(do_log) { diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp index aaab0aa355c..8f4d63b0050 100644 --- a/profiler/include/profile_batched_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "check_err.hpp" #include "reference_batched_gemm.hpp" namespace ck { @@ -218,7 +220,7 @@ void profile_batched_gemm_impl(int do_verification, { c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); - check_error(c_g_m_n_host_result, c_g_m_n_device_result); + ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_conv_bwd_data_impl.hpp b/profiler/include/profile_conv_bwd_data_impl.hpp index 019020c2ace..6cd7cc4d695 100644 --- a/profiler/include/profile_conv_bwd_data_impl.hpp +++ b/profiler/include/profile_conv_bwd_data_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -251,7 +253,8 @@ void profile_conv_bwd_data_impl(int do_verification, { in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); - check_error(in_n_c_hi_wi_host_result, in_n_c_hi_wi_device_result); + ck::utils::check_err(in_n_c_hi_wi_device_result.mData, + in_n_c_hi_wi_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp index 286323c629d..d0de7307d25 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -245,7 +247,8 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, { out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp index c17d184e848..9bdfa612832 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_atomic_add_impl.hpp @@ -1,4 +1,5 @@ #pragma once +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -301,7 +302,8 @@ void profile_conv_fwd_bias_relu_atomic_add_impl(int do_verification, { out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp index cd68f992e90..f34e52048e9 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp @@ -1,4 +1,5 @@ #pragma once +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -233,7 +234,8 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, { out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_conv_fwd_impl.hpp b/profiler/include/profile_conv_fwd_impl.hpp index 95d65354856..6038cd4612f 100644 --- a/profiler/include/profile_conv_fwd_impl.hpp +++ b/profiler/include/profile_conv_fwd_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -253,7 +255,8 @@ void profile_conv_fwd_impl(int do_verification, { out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); - check_error(out_n_k_ho_wo_host_result, out_n_k_ho_wo_device_result); + ck::utils::check_err(out_n_k_ho_wo_device_result.mData, + out_n_k_ho_wo_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_gemm_bias_2d_impl.hpp b/profiler/include/profile_gemm_bias_2d_impl.hpp index 94223c4f7a9..227d858c89f 100644 --- a/profiler/include/profile_gemm_bias_2d_impl.hpp +++ b/profiler/include/profile_gemm_bias_2d_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -283,7 +285,7 @@ void profile_gemm_bias_2d_impl(int do_verification, { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_gemm_bias_relu_add_impl.hpp b/profiler/include/profile_gemm_bias_relu_add_impl.hpp index f6625a8b22e..75ed78075ba 100644 --- a/profiler/include/profile_gemm_bias_relu_add_impl.hpp +++ b/profiler/include/profile_gemm_bias_relu_add_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -257,7 +259,7 @@ void profile_gemm_bias_relu_add_impl(int do_verification, { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_gemm_bias_relu_impl.hpp b/profiler/include/profile_gemm_bias_relu_impl.hpp index e403a88d586..70688f5184c 100644 --- a/profiler/include/profile_gemm_bias_relu_impl.hpp +++ b/profiler/include/profile_gemm_bias_relu_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -236,7 +238,7 @@ void profile_gemm_bias_relu_impl(int do_verification, { c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 30778351fa2..26d0de6406e 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -1,5 +1,7 @@ #pragma once #include + +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -400,7 +402,8 @@ void profile_gemm_impl(int do_verification, ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_f32_result); + ck::utils::conv::check_err(c_m_n_device_f32_result.mData, + c_m_n_host_result.mData); if(do_log) { @@ -429,7 +432,7 @@ void profile_gemm_impl(int do_verification, a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_reduce_impl.hpp b/profiler/include/profile_reduce_impl.hpp index 8ed93b94ebe..eaaea1e9824 100644 --- a/profiler/include/profile_reduce_impl.hpp +++ b/profiler/include/profile_reduce_impl.hpp @@ -1,4 +1,6 @@ #pragma once + +#include "check_err.hpp" #include "device_reduce.hpp" #include "device_reduce_instance.hpp" #include "reduction_enums.hpp" @@ -409,7 +411,7 @@ void profile_reduce_impl_impl(bool do_verification, if(do_verification) { out_dev.FromDevice(out.mData.data()); - check_error(out_ref, out); + ck::utils::conv::check_err(out.mData, out_ref.mData); if(NeedIndices) { @@ -523,7 +525,7 @@ void profile_reduce_impl_impl(bool do_verification, if(do_verification) { out_dev.FromDevice(out.mData.data()); - check_error(out_ref, out); + ck::utils::conv::check_err(out.mData, out_ref.mData); if(NeedIndices) { diff --git a/test/magic_number_division/magic_number_division.cpp b/test/magic_number_division/magic_number_division.cpp index ec53996349a..ccb4425277a 100644 --- a/test/magic_number_division/magic_number_division.cpp +++ b/test/magic_number_division/magic_number_division.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -54,29 +56,6 @@ __host__ void cpu_magic_number_division(uint32_t magic_multiplier, } } -template -T check_error(const std::vector& ref, const std::vector& result) -{ - T error = 0; - T max_diff = 0; - T ref_value = 0, result_value = 0; - - for(std::size_t i = 0; i < ref.size(); ++i) - { - T diff = std::abs(ref[i] - result[i]); - error += diff; - - if(max_diff < diff) - { - max_diff = diff; - ref_value = ref[i]; - result_value = result[i]; - } - } - - return max_diff; -} - int main(int, char*[]) { uint64_t num_divisor = 4096; @@ -135,9 +114,9 @@ int main(int, char*[]) naive_result_dev_buf.FromDevice(naive_result_host.data()); magic_result_dev_buf.FromDevice(magic_result_host.data()); - int32_t max_diff = check_error(naive_result_host, magic_result_host); + bool res = ck::utils::check_err(magic_result_host, naive_result_host); - if(max_diff != 0) + if(!res) { pass = false; continue; @@ -149,9 +128,9 @@ int main(int, char*[]) magic_result_host2.data(), num_dividend); - max_diff = check_error(naive_result_host, magic_result_host2); + res = ck::utils::check_err(magic_result_host2, naive_result_host); - if(max_diff != 0) + if(!res) { pass = false; continue; From 5a664e7b797a46f5e21f76f870ba698182b80fee Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Mar 2022 21:53:56 +0100 Subject: [PATCH 56/82] Get rid of check_indices function. --- example/12_reduce/reduce_blockwise.cpp | 3 ++- example/13_pool2d_fwd/pool2d_fwd.cpp | 3 ++- .../ck/library/host_tensor/host_tensor.hpp | 24 ------------------- profiler/include/profile_reduce_impl.hpp | 6 +++-- 4 files changed, 8 insertions(+), 28 deletions(-) diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index 0b0b1660545..7625fa71f00 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -363,7 +363,8 @@ int main(int argc, char* argv[]) if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - check_indices(out_indices_ref, out_indices); + ck::utils::conv::check_err(out_indices.mData, out_indices_ref.mData); + ; }; }; } diff --git a/example/13_pool2d_fwd/pool2d_fwd.cpp b/example/13_pool2d_fwd/pool2d_fwd.cpp index b18e329003b..761a93a9d68 100644 --- a/example/13_pool2d_fwd/pool2d_fwd.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd.cpp @@ -307,7 +307,8 @@ int main(int argc, char* argv[]) { out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); - // check_indices(out_indices_n_c_ho_wo_host, out_indices_n_c_ho_wo_device); + // ck::utils::conv::check_err(out_indices_n_c_ho_wo_device.mData, + // out_indices_n_c_ho_wo_host.mData);; }; } } diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index 9582fba4f56..9aeba6df5e2 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -316,28 +316,4 @@ float bf16_to_f32_(ck::bhalf_t src_val); void bf16_to_f32_(const Tensor& src, Tensor& dst); -template -void check_indices(const Tensor& ref, const Tensor& result) -{ - bool has_error = false; - int error_count = 0; - - for(int i = 0; i < ref.mData.size(); ++i) - { - if(ref.mData[i] != result.mData[i]) - { - std::cerr << std::endl - << "Indices different at position " << i << " (ref: " << ref.mData[i] - << ", result: " << result.mData[i] << ")" << std::endl; - has_error = true; - error_count++; - if(error_count == 20) - break; - }; - } - - if(!has_error) - std::cout << std::endl << "Indices result is completely acccurate!" << std::endl; -} - #endif diff --git a/profiler/include/profile_reduce_impl.hpp b/profiler/include/profile_reduce_impl.hpp index eaaea1e9824..391fe2ccf74 100644 --- a/profiler/include/profile_reduce_impl.hpp +++ b/profiler/include/profile_reduce_impl.hpp @@ -416,7 +416,8 @@ void profile_reduce_impl_impl(bool do_verification, if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - check_indices(out_indices_ref, out_indices); + ck::utils::conv::check_err(out_indices.mData, out_indices_ref.mData); + ; }; if(do_log) @@ -530,7 +531,8 @@ void profile_reduce_impl_impl(bool do_verification, if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - check_indices(out_indices_ref, out_indices); + ck::utils::conv::check_err(out_indices.mData, out_indices_ref.mData); + ; }; if(do_log) From 0e8fb6d7f17661406f846cafc7007c96fd7a21cc Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Mar 2022 23:14:18 +0100 Subject: [PATCH 57/82] Remove bf16_to_f32 function overload for scalars. --- .../include/ck/library/host_tensor/host_tensor.hpp | 2 -- library/src/host_tensor/host_tensor.cpp | 12 +----------- 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index 9aeba6df5e2..ff051ba5b3a 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -312,8 +312,6 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector s void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); -float bf16_to_f32_(ck::bhalf_t src_val); - void bf16_to_f32_(const Tensor& src, Tensor& dst); #endif diff --git a/library/src/host_tensor/host_tensor.cpp b/library/src/host_tensor/host_tensor.cpp index 89b76f9a386..b199668bc09 100644 --- a/library/src/host_tensor/host_tensor.cpp +++ b/library/src/host_tensor/host_tensor.cpp @@ -64,18 +64,8 @@ void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream os << "}" << std::endl; } -float bf16_to_f32_(ck::bhalf_t src_val) -{ - union - { - uint32_t int32; - float fp32; - } u = {uint32_t(src_val) << 16}; - return u.fp32; -} - void bf16_to_f32_(const Tensor& src, Tensor& dst) { for(int i = 0; i < src.mData.size(); ++i) - dst.mData[i] = bf16_to_f32_(src.mData[i]); + dst.mData[i] = ck::type_convert(src.mData[i]); } From 7ba2eb33787c7cd70f9b07d07b29d0aff5dd3b82 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Mar 2022 23:43:59 +0100 Subject: [PATCH 58/82] Fix namespace. --- example/01_gemm/gemm_xdl_bf16.cpp | 2 +- example/01_gemm/gemm_xdl_fp16.cpp | 2 +- example/01_gemm/gemm_xdl_int8.cpp | 2 +- example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp | 2 +- example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp | 2 +- example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp | 2 +- example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp | 2 +- example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp | 2 +- example/12_reduce/reduce_blockwise.cpp | 4 ++-- example/13_pool2d_fwd/pool2d_fwd.cpp | 4 ++-- .../conv_add_fwd_driver_offline_nchwc.cpp | 2 +- .../obselete_driver_offline/conv_bwd_driver_offline.cpp | 2 +- .../obselete_driver_offline/conv_fwd_driver_offline.cpp | 2 +- .../conv_fwd_driver_offline_nchwc.cpp | 2 +- .../conv_maxpool_fwd_driver_offline_nchwc.cpp | 4 ++-- .../obselete_driver_offline/conv_wrw_driver_offline.cpp | 2 +- .../src/obselete_driver_offline/gemm_driver_offline.cpp | 2 +- profiler/include/profile_gemm_impl.hpp | 4 ++-- profiler/include/profile_reduce_impl.hpp | 8 ++++---- 19 files changed, 26 insertions(+), 26 deletions(-) diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 7ce14143c40..6b649def54f 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -231,7 +231,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::conv::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); } return 0; diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 478f26d9f47..0c45085d468 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -238,7 +238,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } return 0; diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index 26ab7cc51dc..cea5165fd7e 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -227,7 +227,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } return 0; diff --git a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp index c8f5a72cb87..6f45605a38c 100644 --- a/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp +++ b/example/02_gemm_alpha_beta/gemm_xdl_alpha_beta.cpp @@ -246,6 +246,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } } diff --git a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp index a01f13669c4..77c960913e6 100644 --- a/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp +++ b/example/03_gemm_bias_relu/gemm_xdl_bias_relu.cpp @@ -232,6 +232,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } } diff --git a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp index ea184efa44f..9a0f41e1409 100644 --- a/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp +++ b/example/04_gemm_bias_relu_add/gemm_xdl_bias_relu_add.cpp @@ -250,6 +250,6 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } } diff --git a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp index 36325f703ba..cdb153296cf 100644 --- a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp +++ b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp @@ -244,7 +244,7 @@ int main(int argc, char* argv[]) in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); - ck::utils::conv::check_err(in_n_c_hi_wi_device_result.mData, + ck::utils::check_err(in_n_c_hi_wi_device_result.mData, in_n_c_hi_wi_host_result.mData); } } diff --git a/example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp b/example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp index 0a0d9d1f8cf..7b342787452 100644 --- a/example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp +++ b/example/11_conv2d_bwd_wgt/conv2d_bwd_wgt_xdl.cpp @@ -286,6 +286,6 @@ int main(int argc, char* argv[]) LogRangeAsType(std::cout << "wei_host : ", wei_k_c_y_x_host_result.mData, ",") << std::endl; } - ck::utils::conv::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData); + ck::utils::check_err(wei_k_c_y_x_device_result.mData, wei_k_c_y_x_host_result.mData); } } diff --git a/example/12_reduce/reduce_blockwise.cpp b/example/12_reduce/reduce_blockwise.cpp index 7625fa71f00..58fb3639ea5 100644 --- a/example/12_reduce/reduce_blockwise.cpp +++ b/example/12_reduce/reduce_blockwise.cpp @@ -358,12 +358,12 @@ int main(int argc, char* argv[]) if(args.do_verification) { out_dev.FromDevice(out.mData.data()); - ck::utils::conv::check_err(out.mData, out_ref.mData); + ck::utils::check_err(out.mData, out_ref.mData); if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - ck::utils::conv::check_err(out_indices.mData, out_indices_ref.mData); + ck::utils::check_err(out_indices.mData, out_indices_ref.mData); ; }; }; diff --git a/example/13_pool2d_fwd/pool2d_fwd.cpp b/example/13_pool2d_fwd/pool2d_fwd.cpp index 761a93a9d68..6cfbb021294 100644 --- a/example/13_pool2d_fwd/pool2d_fwd.cpp +++ b/example/13_pool2d_fwd/pool2d_fwd.cpp @@ -301,13 +301,13 @@ int main(int argc, char* argv[]) out_device_buf.FromDevice(out_n_c_ho_wo_device.mData.data()); - ck::utils::conv::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); + ck::utils::check_err(out_n_c_ho_wo_device.mData, out_n_c_ho_wo_host.mData); if constexpr(NeedIndices) { out_indices_device_buf.FromDevice(out_indices_n_c_ho_wo_device.mData.data()); - // ck::utils::conv::check_err(out_indices_n_c_ho_wo_device.mData, + // ck::utils::check_err(out_indices_n_c_ho_wo_device.mData, // out_indices_n_c_ho_wo_host.mData);; }; } diff --git a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp index 7b70a5578b0..6cfa2cb108f 100644 --- a/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_add_fwd_driver_offline_nchwc.cpp @@ -403,7 +403,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), activ_type); - ck::utils::conv::check_err(add_device.mData, add_host.mData); + ck::utils::check_err(add_device.mData, add_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp index 9d5c06a0e8f..80ef87418dc 100644 --- a/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_bwd_driver_offline.cpp @@ -475,7 +475,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), layout); - ck::utils::conv::check_err(in_device.mData, in_host.mData); + ck::utils::check_err(in_device.mData, in_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp index 1f20fed44a8..891c9b9e74d 100644 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline.cpp @@ -536,7 +536,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), layout); - ck::utils::conv::check_err(out_device.mData, out_host.mData); + ck::utils::check_err(out_device.mData, out_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp index 20fee073116..eaa1890aff0 100644 --- a/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_fwd_driver_offline_nchwc.cpp @@ -379,7 +379,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), activ_type); - ck::utils::conv::check_err(out_device.mData, out_host.mData); + ck::utils::check_err(out_device.mData, out_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp index 3f116fa168f..9e4b3639681 100644 --- a/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp +++ b/library/src/obselete_driver_offline/conv_maxpool_fwd_driver_offline_nchwc.cpp @@ -399,8 +399,8 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), activ_type); - ck::utils::conv::check_err(out_device.mData, out_host.mData); - ck::utils::conv::check_err(max_device.mData, max_host.mData); + ck::utils::check_err(out_device.mData, out_host.mData); + ck::utils::check_err(max_device.mData, max_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp index caa898dd720..3043a2b8fd6 100644 --- a/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp +++ b/library/src/obselete_driver_offline/conv_wrw_driver_offline.cpp @@ -519,7 +519,7 @@ int main(int argc, char* argv[]) make_tuple(in_right_pad_h, in_right_pad_w), layout); - ck::utils::conv::check_err(wei_device.mData, wei_host.mData); + ck::utils::check_err(wei_device.mData, wei_host.mData); if(do_log) { diff --git a/library/src/obselete_driver_offline/gemm_driver_offline.cpp b/library/src/obselete_driver_offline/gemm_driver_offline.cpp index 58d83cad651..9c97c3f7bd1 100644 --- a/library/src/obselete_driver_offline/gemm_driver_offline.cpp +++ b/library/src/obselete_driver_offline/gemm_driver_offline.cpp @@ -443,7 +443,7 @@ int main(int argc, char* argv[]) { host_gemm(a, b, c_host, layout); - ck::utils::conv::check_err(c_device.mData, c_host.mData); + ck::utils::check_err(c_device.mData, c_host.mData); if(do_log) { diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 26d0de6406e..cd4127bf9f2 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -402,7 +402,7 @@ void profile_gemm_impl(int do_verification, ref_invoker.Run(ref_argument); - ck::utils::conv::check_err(c_m_n_device_f32_result.mData, + ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); if(do_log) @@ -432,7 +432,7 @@ void profile_gemm_impl(int do_verification, a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); ref_invoker.Run(ref_argument); - ck::utils::conv::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); if(do_log) { diff --git a/profiler/include/profile_reduce_impl.hpp b/profiler/include/profile_reduce_impl.hpp index 391fe2ccf74..f9cb714ed96 100644 --- a/profiler/include/profile_reduce_impl.hpp +++ b/profiler/include/profile_reduce_impl.hpp @@ -411,12 +411,12 @@ void profile_reduce_impl_impl(bool do_verification, if(do_verification) { out_dev.FromDevice(out.mData.data()); - ck::utils::conv::check_err(out.mData, out_ref.mData); + ck::utils::check_err(out.mData, out_ref.mData); if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - ck::utils::conv::check_err(out_indices.mData, out_indices_ref.mData); + ck::utils::check_err(out_indices.mData, out_indices_ref.mData); ; }; @@ -526,12 +526,12 @@ void profile_reduce_impl_impl(bool do_verification, if(do_verification) { out_dev.FromDevice(out.mData.data()); - ck::utils::conv::check_err(out.mData, out_ref.mData); + ck::utils::check_err(out.mData, out_ref.mData); if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - ck::utils::conv::check_err(out_indices.mData, out_indices_ref.mData); + ck::utils::check_err(out_indices.mData, out_indices_ref.mData); ; }; From 19f3fd8baaf496ad8a390cd3b12ef4794759c10e Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 21 Mar 2022 23:44:15 +0100 Subject: [PATCH 59/82] Add half_float::half for check_err. --- library/include/ck/library/utility/check_err.hpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index e23aa8cd000..05e032e3ed0 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -61,11 +62,11 @@ check_err(const std::vector& out, } template -typename std::enable_if::value || std::is_same::value, - bool>::type +typename std::enable_if::value || std::is_same::value || + std::is_same::value, bool>::type check_err(const std::vector& out, const std::vector& ref, - const std::string& msg, + const std::string& msg = "Error: Incorrect results!", double rtol = 1e-5, double atol = 1e-8) { @@ -110,7 +111,7 @@ template typename std::enable_if::value && !std::is_same::value, bool>::type check_err(const std::vector& out, const std::vector& ref, - const std::string& msg, + const std::string& msg = "Error: Incorrect results!", double = 0, double = 0) { From 2b09f01db614f9e1246cb294668a8d14135772a8 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 22 Mar 2022 10:26:59 +0100 Subject: [PATCH 60/82] Fix conv params size in UT. --- test/conv_util/conv_util.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 349477b493a..07264d8b1a8 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -65,7 +65,7 @@ bool TestConvParams_GetOutputSpatialLengths() res = ck::utils::check_err( out_spatial_len, std::vector{36}, "Error: ConvParams 1D."); - conv_params.conv_filter_strides = std::vector{1, 1}; + conv_params.conv_filter_strides = std::vector{1}; out_spatial_len = conv_params.GetOutputSpatialLengths(); res = ck::utils::check_err( out_spatial_len, std::vector{71}, "Error: ConvParams 1D stride {1}."); From d332ff7f3ab2e9e415ed516611c27cf6780cf7da Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 22 Mar 2022 10:45:51 +0100 Subject: [PATCH 61/82] Fix weights initialization for int8. --- test/include/conv_test_util.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/include/conv_test_util.hpp b/test/include/conv_test_util.hpp index 9228158c9fd..2355e4be30b 100644 --- a/test/include/conv_test_util.hpp +++ b/test/include/conv_test_util.hpp @@ -125,6 +125,8 @@ auto GetHostTensors(const ck::conv_util::ConvParams& params, bool init = true) std::uniform_int_distribution<> dis(-5, 5); std::generate( input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + std::generate( + weights.begin(), weights.end(), [&dis, &gen]() { return WeiDataType(dis(gen)); }); } else { From 0bb0f4371eaecd395f9f397397666898c0df52bb Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 22 Mar 2022 10:47:49 +0100 Subject: [PATCH 62/82] Fix weights initialization for int8. --- library/include/ck/library/utility/conv_fwd_util.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/library/include/ck/library/utility/conv_fwd_util.hpp b/library/include/ck/library/utility/conv_fwd_util.hpp index f6b1b8107b4..2beaece232b 100644 --- a/library/include/ck/library/utility/conv_fwd_util.hpp +++ b/library/include/ck/library/utility/conv_fwd_util.hpp @@ -315,6 +315,8 @@ auto GetHostTensors(const ConvParams& params, bool init = true) std::uniform_int_distribution<> dis(-5, 5); std::generate( input.begin(), input.end(), [&dis, &gen]() { return InDataType(dis(gen)); }); + std::generate( + weights.begin(), weights.end(), [&dis, &gen]() { return WeiDataType(dis(gen)); }); } else { From 8a168d9e07d8a69a5a71600b134d9efa3f87b7fc Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 22 Mar 2022 18:40:21 +0100 Subject: [PATCH 63/82] Add type_convert when store output in ref conv 1D. --- .../reference_tensor_operation/cpu/reference_conv_fwd.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 5ee5c622b8c..0095d51a5b2 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -112,7 +112,7 @@ struct ReferenceConvFwd : public device::BaseOperator float v_out; arg.out_element_op_(v_out, v_acc); - arg.output_(n, k, wo) = v_out; + arg.output_(n, k, wo) = ck::type_convert(v_out); }; make_ParallelTensorFunctor(f_ncw, From 4c3ed95471291f50d086cb995c3367606d985453 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 22 Mar 2022 18:41:09 +0100 Subject: [PATCH 64/82] Get back old conv2d_fwd_xdl operation. --- .../device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 721 ++++++++++++++++++ ...d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp | 132 ++-- ...2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp | 104 +-- ...2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp | 104 +-- ...d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp | 131 ++-- 5 files changed, 953 insertions(+), 239 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp new file mode 100644 index 00000000000..42a5d8d3b95 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,721 @@ +#ifndef DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP +#define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP + +#include +#include +#include "device.hpp" +#include "device_base.hpp" +#include "device_conv_fwd.hpp" +#include "convolution_forward_specialization.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd +{ + using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + const index_t GemmK = Y * X * C; + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization_t::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvFwdSpecializationStr(ConvForwardSpecialization) + << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index c0f4a874df8..ae302e545db 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -27,75 +27,71 @@ static constexpr auto ConvFwd1x1S1P0 = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // FIXME: this instance causes numerical errors. - // DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp index 4222413f91e..beaad1d3b4e 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -29,67 +29,67 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp index 69ff3919685..402d65a6e00 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -28,67 +28,67 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp index e55f4fe2d7f..90e0320cff9 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -1,6 +1,6 @@ #include #include "config.hpp" -#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "device_operation_instance.hpp" @@ -26,74 +26,71 @@ static constexpr auto ConvFwd1x1S1P0 = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; -using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> - // clang-format on - >; +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( std::vector>& instances) From 0175c28ad995dab5364a2cef69e58e0656b0402b Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 22 Mar 2022 18:43:43 +0100 Subject: [PATCH 65/82] Silence conv debug print. --- .../gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp index 42a5d8d3b95..b219fce335e 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -452,6 +452,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K float Run(const Argument& arg, int nrepeat = 1) { +#if 0 { std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " @@ -464,7 +465,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } - +#endif if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_, From 004ecdddb728ce2ee6c458b41f67aeeaa4ba6879 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 22 Mar 2022 23:55:53 +0000 Subject: [PATCH 66/82] format --- ...ice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp | 2 +- ...ce_reduce_instance_blockwise_i8_i32_i8.hpp | 2 +- ...uffle_bf16_bf16_bf16_km_kn_mn_instance.cpp | 11 ++++--- ...uffle_bf16_bf16_bf16_km_nk_mn_instance.cpp | 11 ++++--- ...uffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp | 11 ++++--- ...l_splitk_f16_f16_f16_mk_nk_mn_instance.cpp | 32 ++++++++++++++++--- test/convnd_fwd/conv3d_fwd.cpp | 16 +++++----- test/gemm/gemm_bf16.cpp | 12 ++++--- test/gemm/gemm_int8.cpp | 14 +++++--- test/include/test_util.hpp | 16 +++++----- 10 files changed, 80 insertions(+), 47 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp index 27d7e0882a6..9058bb63a44 100644 --- a/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -468,7 +468,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K { continue; } - + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( N, K, diff --git a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp index 8d222d53dc8..f4a6677b3e0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/reduce/device_reduce_instance_blockwise_i8_i32_i8.hpp @@ -19,7 +19,7 @@ ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1); ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4); ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1); -ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); +ADD_BLOCKWISE_INST_REF_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); // clang-format on } // namespace device_reduce_instance diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp index dceb7973021..272ae982c1b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace device_gemm_instance { using BF16 = ck::bhalf_t; -using F32 = float; +using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,8 +21,9 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = + std::tuple< + // clang-format off //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| @@ -43,8 +44,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances = std::tuple< DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp index 33e33b4988b..ebcde34546b 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace device_gemm_instance { using BF16 = ck::bhalf_t; -using F32 = float; +using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,8 +21,9 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[k, m] * b[n, k] = c[m, n] -using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = + std::tuple< + // clang-format off //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| @@ -43,8 +44,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances = std::tuple< DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp index 319db8ea7f1..4e35adfeab3 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -10,7 +10,7 @@ namespace device { namespace device_gemm_instance { using BF16 = ck::bhalf_t; -using F32 = float; +using F32 = float; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -21,8 +21,9 @@ using S = ck::Sequence; using PassThrough = ck::tensor_operation::element_wise::PassThrough; // Compilation parameters for a[m, k] * b[k, n] = c[m, n] -using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< - // clang-format off +using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = + std::tuple< + // clang-format off //#####################| AData| BData| CData| AccData| CShuffle| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################| Type| Type| Type| Type| DataType| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //#####################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| @@ -43,8 +44,8 @@ using device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances = std::tuple< DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, DeviceGemmXdl_C_Shuffle< BF16, BF16, BF16, F32, BF16, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> - // clang-format on - >; + // clang-format on + >; void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp index 4b3524c30e1..346b1a4bec8 100644 --- a/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -47,11 +47,33 @@ using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances = std::tuple< // using device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_irregular_tile_instances = std::tuple< // // clang-format off -// //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -// //#########################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| -// //#########################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| -// //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -// DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, true, 1, 9, S<1, 2, 1, 72>, 2> +// //#########################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| +// B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| +// ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| +// ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| +// BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| +// CBlockTransferClusterLengths| CBlockTransfer| +// //#########################| Type| Type| Type| Type| | | | +// Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | +// XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| +// SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| +// SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| +// _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| +// //#########################| | | | | | | | +// Operation| Operation| Operation| | | | | | | | +// | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| +// PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | +// PerVector| PerVector_K1| | PerShuffle| PerShuffle| +// _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| +// //#########################| | | | | | | | | | +// | | | | | | | | | | | | +// | | | | | | | | | | | | +// | | | | | +// DeviceGemmXdlSplitKCShuffle< F16, F16, F16, F32, Row, Col, Row, +// PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 144, 4, 8, 16, +// 16, 2, 9, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, +// true, S<1, 4, 16, 4>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 2, 2, +// true, 1, 9, S<1, 2, 1, 72>, 2> // // clang-format on // >; diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index 45438616bc6..ace8c40cdb8 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -205,14 +205,14 @@ template bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs) { ck::conv_util::ConvParams params; - params.N = 64; - params.num_dim_spatial = 3; - params.filter_spatial_lengths = std::vector{3, 3, 2}; - params.input_spatial_lengths = std::vector{32, 32, 2}; - params.conv_filter_strides = std::vector{2, 2, 2}; - params.conv_filter_dilations = std::vector{1, 1, 1}; - params.input_left_pads = std::vector{1, 1, 1}; - params.input_right_pads = std::vector{1, 1, 1}; + params.N = 64; + params.num_dim_spatial = 3; + params.filter_spatial_lengths = std::vector{3, 3, 2}; + params.input_spatial_lengths = std::vector{32, 32, 2}; + params.conv_filter_strides = std::vector{2, 2, 2}; + params.conv_filter_dilations = std::vector{1, 1, 1}; + params.input_left_pads = std::vector{1, 1, 1}; + params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = test::conv::GetHostTensors&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector&); } // namespace device_gemm_instance } // namespace device } // namespace tensor_operation diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 0f4f1cbf01d..99073bbd8d5 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -32,11 +32,15 @@ namespace ck { namespace tensor_operation { namespace device { namespace device_gemm_instance { -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances(std::vector&); -void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances(std::vector&); -} +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances( + std::vector&); +} // namespace device_gemm_instance } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/test/include/test_util.hpp b/test/include/test_util.hpp index 3e88539ec35..069261f87d4 100644 --- a/test/include/test_util.hpp +++ b/test/include/test_util.hpp @@ -106,10 +106,10 @@ check_err(const std::vector& out, } bool check_err(const std::vector<_Float16>& out, - const std::vector<_Float16>& ref, - const std::string& msg, - _Float16 rtol = static_cast<_Float16>(1e-3f), - _Float16 atol = static_cast<_Float16>(1e-3f)) + const std::vector<_Float16>& ref, + const std::string& msg, + _Float16 rtol = static_cast<_Float16>(1e-3f), + _Float16 atol = static_cast<_Float16>(1e-3f)) { if(out.size() != ref.size()) { @@ -120,14 +120,14 @@ bool check_err(const std::vector<_Float16>& out, } bool res{true}; - int err_count = 0; - double err = 0; - double max_err = std::numeric_limits<_Float16>::min(); + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits<_Float16>::min(); for(std::size_t i = 0; i < ref.size(); ++i) { double out_ = double(out[i]); double ref_ = double(ref[i]); - err = std::abs(out_ - ref_); + err = std::abs(out_ - ref_); if(err > atol + rtol * std::abs(ref_) || !std::isfinite(out_) || !std::isfinite(ref_)) { max_err = err > max_err ? err : max_err; From 9a08edd9d8d268ce66828d5ca0bef03cd84cdece Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 23 Mar 2022 00:14:06 +0000 Subject: [PATCH 67/82] clean --- include/ck/config.hpp | 8 ++++++++ .../device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp | 4 +++- ...ice_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp | 4 +++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/include/ck/config.hpp b/include/ck/config.hpp index 7f51d29715d..fa9d8394365 100644 --- a/include/ck/config.hpp +++ b/include/ck/config.hpp @@ -157,6 +157,14 @@ #define CK_WORKAROUND_SWDEV_325164 1 #endif + +// workaround for verification failure ConvNd forward +// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135 +#ifndef CK_WORKAROUND_GITHUB_135 +#define CK_WORKAROUND_GITHUB_135 1 +#endif + + namespace ck { enum InMemoryDataOperationEnum_t diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp index bcb3a776a1a..4e878bf6bba 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -33,8 +33,10 @@ using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !CK_WORKAROUND_GITHUB_135 // FIXME: this instance causes numerical errors. - // DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, +#endif DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index a59b1513f86..15e7cda0671 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -33,8 +33,10 @@ using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if !CK_WORKAROUND_GITHUB_135 // FIXME: this instance causes numerical errors. - // DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, +#endif DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, From f6bd459d4a96c4640b26a7630890e96cdbd81b5e Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 23 Mar 2022 00:18:42 +0000 Subject: [PATCH 68/82] clean --- include/ck/config.hpp | 2 - ...nv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp | 132 +++++++++--------- ...d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp | 102 +++++++------- ...wd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp | 132 +++++++++--------- 4 files changed, 181 insertions(+), 187 deletions(-) diff --git a/include/ck/config.hpp b/include/ck/config.hpp index fa9d8394365..3c9ae685299 100644 --- a/include/ck/config.hpp +++ b/include/ck/config.hpp @@ -157,14 +157,12 @@ #define CK_WORKAROUND_SWDEV_325164 1 #endif - // workaround for verification failure ConvNd forward // https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135 #ifndef CK_WORKAROUND_GITHUB_135 #define CK_WORKAROUND_GITHUB_135 1 #endif - namespace ck { enum InMemoryDataOperationEnum_t diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp index 4e878bf6bba..2fcb64a5a7c 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -9,7 +9,8 @@ namespace tensor_operation { namespace device { namespace device_conv1d_fwd_instance { -using F32 = float; +using F32 = float; +using BF16 = bhalf_t; template using S = ck::Sequence; @@ -26,77 +27,74 @@ static constexpr auto ConvFwd1x1S1P0 = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple< +// clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #if !CK_WORKAROUND_GITHUB_135 // FIXME: this instance causes numerical errors. - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, #endif - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; -using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; -using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv1d_fwd_xdl_nwc_kxc_nwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 1, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances( std::vector>& instances) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp index ae302e545db..50ce68fd71a 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -29,67 +29,67 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple< // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> // clang-format on >; diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index 15e7cda0671..5f1ec520691 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -9,7 +9,8 @@ namespace tensor_operation { namespace device { namespace device_conv3d_fwd_instance { -using F32 = float; +using F32 = float; +using BF16 = bhalf_t; template using S = ck::Sequence; @@ -26,77 +27,74 @@ static constexpr auto ConvFwd1x1S1P0 = ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0; // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] -using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< +// clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | #if !CK_WORKAROUND_GITHUB_135 // FIXME: this instance causes numerical errors. - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, #endif - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; -using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; -using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = - std::tuple< - // clang-format off - //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, - DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< ushort, ushort, ushort, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> - // clang-format on - >; +using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_1x1_s1_p0_bf16_instances = std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 3, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( std::vector>& instances) From 6e5c7fe914ab6e58c537b1e33947159e6e7f46a5 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 23 Mar 2022 09:30:34 +0100 Subject: [PATCH 69/82] Fix merge. --- .../ck/tensor_operation/gpu/element/element_wise_operation.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp index 6fa8258bed9..fcc775e9000 100644 --- a/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -1,6 +1,8 @@ #ifndef CK_ELEMENT_WISE_OPERATION_HPP #define CK_ELEMENT_WISE_OPERATION_HPP +#include "data_type.hpp" + namespace ck { namespace tensor_operation { namespace element_wise { From 5121660fe7997cf888570f42cc5fa529541cd3a9 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 23 Mar 2022 09:31:21 +0100 Subject: [PATCH 70/82] Fix namespace for check_err --- test/batched_gemm/batched_gemm_fp16.cpp | 2 +- test/gemm/gemm_util.hpp | 8 ++++---- test/reduce/reduce_no_index.cpp | 8 ++++---- test/reduce/reduce_with_index.cpp | 21 ++++++++++----------- 4 files changed, 19 insertions(+), 20 deletions(-) diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp index ec2ee0d4543..5ec08e78b0b 100644 --- a/test/batched_gemm/batched_gemm_fp16.cpp +++ b/test/batched_gemm/batched_gemm_fp16.cpp @@ -109,7 +109,7 @@ bool TestBatchedGemm(const std::size_t batch_count, DeviceBatchedGemmPtr& gemmPt gemmPtr, params, a, b, c_device, a_element_op, b_element_op, c_element_op); // Assert - // bool res = test_util::check_err( + // bool res = test::check_err( // c_device.mData, c_host.mData, "Error: incorrect results!", 1e-5f, 1e-4f); bool res = check_error(c_device, c_host) < 0.007815f; diff --git a/test/gemm/gemm_util.hpp b/test/gemm/gemm_util.hpp index 14d532defc1..a2502c04eff 100644 --- a/test/gemm/gemm_util.hpp +++ b/test/gemm/gemm_util.hpp @@ -202,19 +202,19 @@ struct TestGemm bool res = false; if(std::is_same::value) { - res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); + res = test::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } else if(std::is_same::value) { - res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); + res = test::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } else if(std::is_same::value) { - res = test_util::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); + res = test::check_err(c_device.mData, c_host.mData, "Error: incorrect results!"); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; } @@ -330,7 +330,7 @@ struct TestGemmBF16 bf16_to_f32_(c_device_bf16, c_device_fp32); // Assert - bool res = test_util::check_err( + bool res = test::check_err( c_device_fp32.mData, c_host_fp32.mData, "Error: incorrect results!", 1e-2f, 1e-3f); std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; diff --git a/test/reduce/reduce_no_index.cpp b/test/reduce/reduce_no_index.cpp index 911bdf0bb17..099ee96018e 100644 --- a/test/reduce/reduce_no_index.cpp +++ b/test/reduce/reduce_no_index.cpp @@ -289,13 +289,13 @@ bool test_reduce_no_index_impl(int init_method, { reduce_util::to_f32_vector(out, out_fp32); reduce_util::to_f32_vector(out_ref, out_ref_fp32); - single_result = test_util::check_err( + single_result = test::check_err( out_fp32.mData, out_ref_fp32.mData, "Error: incorrect data result!"); } else { single_result = - test_util::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); + test::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); }; if(!single_result) @@ -376,13 +376,13 @@ bool test_reduce_no_index_impl(int init_method, { reduce_util::to_f32_vector(out, out_fp32); reduce_util::to_f32_vector(out_ref, out_ref_fp32); - single_result = test_util::check_err( + single_result = test::check_err( out_fp32.mData, out_ref_fp32.mData, "Error: incorrect data result!"); } else { single_result = - test_util::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); + test::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); }; if(!single_result) diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp index 4c51fad550d..911f17d8f0c 100644 --- a/test/reduce/reduce_with_index.cpp +++ b/test/reduce/reduce_with_index.cpp @@ -273,21 +273,21 @@ bool test_reduce_with_index_impl(int init_method, { reduce_util::to_f32_vector(out, out_fp32); reduce_util::to_f32_vector(out_ref, out_ref_fp32); - single_result = test_util::check_err( + single_result = test::check_err( out_fp32.mData, out_ref_fp32.mData, "Error: incorrect data result!"); } else { single_result = - test_util::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); + test::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); }; if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - single_result = single_result && test_util::check_err(out_indices_ref.mData, - out_indices.mData, - "Error: incorrect index result!"); + single_result = single_result && test::check_err(out_indices_ref.mData, + out_indices.mData, + "Error: incorrect index result!"); }; if(!single_result) @@ -370,22 +370,21 @@ bool test_reduce_with_index_impl(int init_method, { reduce_util::to_f32_vector(out, out_fp32); reduce_util::to_f32_vector(out_ref, out_ref_fp32); - single_result = test_util::check_err( + single_result = test::check_err( out_fp32.mData, out_ref_fp32.mData, "Error: incorrect data result!"); } else { single_result = - test_util::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); + test::check_err(out.mData, out_ref.mData, "Error: incorrect data result!"); }; if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - single_result = - single_result && test_util::check_err(out_indices_ref.mData, - out_indices.mData, - "Error: incorrect index result!"); + single_result = single_result && test::check_err(out_indices_ref.mData, + out_indices.mData, + "Error: incorrect index result!"); }; if(!single_result) From a4b8971732b28ef8a41dcf7738f9d1fb49599799 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 23 Mar 2022 11:36:30 +0100 Subject: [PATCH 71/82] Formatting. --- example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp | 3 +-- profiler/include/profile_batched_gemm_impl.hpp | 3 ++- profiler/include/profile_gemm_impl.hpp | 3 +-- test/reduce/reduce_with_index.cpp | 11 ++++++----- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp index 9ad8b48c943..7d58bdf6b01 100644 --- a/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp +++ b/example/10_conv2d_bwd_data/conv2d_bwd_data_xdl.cpp @@ -248,7 +248,6 @@ int main(int argc, char* argv[]) in_device_buf.FromDevice(in_n_c_hi_wi_device_result.mData.data()); - ck::utils::check_err(in_n_c_hi_wi_device_result.mData, - in_n_c_hi_wi_host_result.mData); + ck::utils::check_err(in_n_c_hi_wi_device_result.mData, in_n_c_hi_wi_host_result.mData); } } diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp index 4bd4672fa30..eddee92016e 100644 --- a/profiler/include/profile_batched_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -381,7 +381,8 @@ void profile_batched_gemm_impl(int do_verification, { bf16_to_f32_(c_g_m_n_device_result, *c_f32_g_m_n_device_result); - ck::utils::check_err((*c_f32_g_m_n_device_result).mData, (*c_f32_g_m_n_host_result).mData); + ck::utils::check_err((*c_f32_g_m_n_device_result).mData, + (*c_f32_g_m_n_host_result).mData); } else { diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 24ee338aa79..34fc51b9a68 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -468,8 +468,7 @@ void profile_gemm_impl(int do_verification, ref_invoker.Run(ref_argument); - ck::utils::check_err(c_m_n_device_f32_result.mData, - c_m_n_host_result.mData); + ck::utils::check_err(c_m_n_device_f32_result.mData, c_m_n_host_result.mData); if(do_log) { diff --git a/test/reduce/reduce_with_index.cpp b/test/reduce/reduce_with_index.cpp index b7aa813014f..b56e82de4ce 100644 --- a/test/reduce/reduce_with_index.cpp +++ b/test/reduce/reduce_with_index.cpp @@ -286,8 +286,8 @@ bool test_reduce_with_index_impl(int init_method, { out_indices_dev.FromDevice(out_indices.mData.data()); single_result = single_result && ck::utils::check_err(out_indices_ref.mData, - out_indices.mData, - "Error: incorrect index result!"); + out_indices.mData, + "Error: incorrect index result!"); }; if(!single_result) @@ -382,9 +382,10 @@ bool test_reduce_with_index_impl(int init_method, if(NeedIndices) { out_indices_dev.FromDevice(out_indices.mData.data()); - single_result = single_result && ck::utils::check_err(out_indices_ref.mData, - out_indices.mData, - "Error: incorrect index result!"); + single_result = + single_result && ck::utils::check_err(out_indices_ref.mData, + out_indices.mData, + "Error: incorrect index result!"); }; if(!single_result) From 6c3afc4e0b8beeed49626828ca14b5c9275c2a87 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 23 Mar 2022 13:22:23 +0100 Subject: [PATCH 72/82] Fix merge artifacts. --- .../gemm_xdl_requant_relu_requant_int8.cpp | 2 +- .../15_grouped_gemm/grouped_gemm_xdl_fp16.cpp | 5 ++-- .../include/ck/library/utility/check_err.hpp | 25 +++++++++---------- .../include/profile_grouped_gemm_impl.hpp | 4 ++- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp index 701650a9a8d..92283a17870 100644 --- a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp +++ b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp @@ -225,7 +225,7 @@ int main(int argc, char* argv[]) ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_result); + ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); } return 0; diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp index 03afb7c44c2..e739fefbdc4 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -225,8 +227,7 @@ int main(int argc, char* argv[]) c_element_op); ref_invoker.Run(ref_argument); - - check_error(c_host_tensors[i], c_device_tensors[i]); + ck::utils::check_err(c_device_tensors[i].mData, c_host_tensors[i].mData); } } diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 8eb3ef3c526..280ac83883d 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -1,4 +1,3 @@ -<<<<<<< HEAD:library/include/ck/library/utility/check_err.hpp #ifndef CHECK_ERR_HPP #define CHECK_ERR_HPP @@ -62,14 +61,13 @@ check_err(const std::vector& out, return res; } - template typename std::enable_if::value, bool>::type check_err(const std::vector& out, const std::vector& ref, const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + double rtol = 1e-3, + double atol = 1e-3) { if(out.size() != ref.size()) { @@ -80,15 +78,15 @@ check_err(const std::vector& out, } bool res{true}; - int err_count = 0; - double err = 0; + int err_count = 0; + double err = 0; // TODO: This is a hack. We should have proper specialization for bhalf_t data type. double max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { double o = type_convert(out[i]); double r = type_convert(ref[i]); - err = std::abs(o - r); + err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; @@ -110,12 +108,13 @@ check_err(const std::vector& out, } template -typename std::enable_if::value || std::is_same::value, bool>::type +typename std::enable_if::value || std::is_same::value, + bool>::type check_err(const std::vector& out, const std::vector& ref, const std::string& msg = "Error: Incorrect results!", - double rtol = 1e-3, - double atol = 1e-3) + double rtol = 1e-3, + double atol = 1e-3) { if(out.size() != ref.size()) { @@ -133,7 +132,7 @@ check_err(const std::vector& out, { double o = type_convert(out[i]); double r = type_convert(ref[i]); - err = std::abs(o - r); + err = std::abs(o - r); if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; @@ -159,8 +158,8 @@ typename std::enable_if::value && !std::is_same: check_err(const std::vector& out, const std::vector& ref, const std::string& msg = "Error: Incorrect results!", - double = 0, - double = 0) + double = 0, + double = 0) { if(out.size() != ref.size()) { diff --git a/profiler/include/profile_grouped_gemm_impl.hpp b/profiler/include/profile_grouped_gemm_impl.hpp index 2d99e93cfde..8a12268b509 100644 --- a/profiler/include/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profile_grouped_gemm_impl.hpp @@ -1,5 +1,7 @@ #pragma once #include + +#include "check_err.hpp" #include "config.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -283,7 +285,7 @@ void profile_grouped_gemm_impl(int do_verification, c_element_op); ref_invoker.Run(ref_argument); - check_error(c_m_n_host_result, c_m_n_device_results[i]); + ck::utils::check_err(c_m_n_device_results[i].mData, c_m_n_host_result.mData); if(do_log) { From 821829116e8470cd312a832cb61e1cf3dd82a625 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 23 Mar 2022 15:10:10 +0100 Subject: [PATCH 73/82] Remove deleted header. --- test/grouped_gemm/grouped_gemm_fp16.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp index 9b3d2901ee6..4bb24b6559a 100644 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -15,7 +15,6 @@ #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" -#include "test_util.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; From 476ab8d84d98d1ee7483e6dfd1222850bb37fdcc Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 24 Mar 2022 09:55:42 +0100 Subject: [PATCH 74/82] Fix some includes and use ck::utils::check_err. --- .../gemm_xdl_requant_relu_requant_int8.cpp | 2 ++ test/batched_gemm/batched_gemm_fp16.cpp | 1 - test/gemm/gemm_bf16.cpp | 1 - test/gemm/gemm_fp32.cpp | 1 - test/gemm/gemm_int8.cpp | 1 - test/grouped_gemm/grouped_gemm_fp16.cpp | 22 +++---------------- 6 files changed, 5 insertions(+), 23 deletions(-) diff --git a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp index 92283a17870..2b879600f42 100644 --- a/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp +++ b/example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" diff --git a/test/batched_gemm/batched_gemm_fp16.cpp b/test/batched_gemm/batched_gemm_fp16.cpp index 07dd65cc264..a5b13c65445 100644 --- a/test/batched_gemm/batched_gemm_fp16.cpp +++ b/test/batched_gemm/batched_gemm_fp16.cpp @@ -11,7 +11,6 @@ #include "device_tensor.hpp" #include "device_batched_gemm_xdl.hpp" #include "element_wise_operation.hpp" -#include "test_util.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; diff --git a/test/gemm/gemm_bf16.cpp b/test/gemm/gemm_bf16.cpp index 25f4dd30dba..162951dfcd9 100644 --- a/test/gemm/gemm_bf16.cpp +++ b/test/gemm/gemm_bf16.cpp @@ -6,7 +6,6 @@ #include #include -#include "check_err.hpp" #include "gemm_util.hpp" #include "config.hpp" #include "print.hpp" diff --git a/test/gemm/gemm_fp32.cpp b/test/gemm/gemm_fp32.cpp index 7010c92c7d1..cfdfac6b403 100644 --- a/test/gemm/gemm_fp32.cpp +++ b/test/gemm/gemm_fp32.cpp @@ -19,7 +19,6 @@ #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" -#include "check_err.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; diff --git a/test/gemm/gemm_int8.cpp b/test/gemm/gemm_int8.cpp index 14a60644eb2..84333626b96 100644 --- a/test/gemm/gemm_int8.cpp +++ b/test/gemm/gemm_int8.cpp @@ -19,7 +19,6 @@ #include "element_wise_operation.hpp" #include "reference_gemm.hpp" #include "gemm_specialization.hpp" -#include "check_err.hpp" using PassThrough = ck::tensor_operation::element_wise::PassThrough; diff --git a/test/grouped_gemm/grouped_gemm_fp16.cpp b/test/grouped_gemm/grouped_gemm_fp16.cpp index 4bb24b6559a..42dcdaceafa 100644 --- a/test/grouped_gemm/grouped_gemm_fp16.cpp +++ b/test/grouped_gemm/grouped_gemm_fp16.cpp @@ -4,6 +4,8 @@ #include #include #include + +#include "check_err.hpp" #include "config.hpp" #include "print.hpp" #include "device.hpp" @@ -45,24 +47,6 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; -template -static bool check_err(const Tensor& ref, const Tensor& result) -{ - float max_diff = 1e-2; - - for(int i = 0; i < ref.mData.size(); ++i) - { - float diff = std::abs(double(ref.mData[i]) - double(result.mData[i])); - if(max_diff < diff) - { - std::cout << double(ref.mData[i]) << "," << double(result.mData[i]) << std::endl; - return false; - } - } - - return true; -} - bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) { int group_count = 4; @@ -182,7 +166,7 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ref_invoker.Run(ref_argument); - bool res = check_err(c_device_tensors[i], c_host_tensors[i]); + bool res = ck::utils::check_err(c_host_tensors[i].mData, c_device_tensors[i].mData); std::cout << "group_id: " << i << (res ? " SUCCESS" : " FAILURE") << std::endl; From 494d923ca892609042a75579f0581f5ff34fe06a Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 30 Mar 2022 09:51:46 +0200 Subject: [PATCH 75/82] Remove unused check_indices restored by previous merge. --- .../ck/library/host_tensor/host_tensor.hpp | 24 ------------------- 1 file changed, 24 deletions(-) diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index b02dba8bd1e..61df67ea4f3 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -349,28 +349,4 @@ float check_error(const Tensor& ref, const Tensor& result) return linf_error; } -template -void check_indices(const Tensor& ref, const Tensor& result) -{ - bool has_error = false; - int error_count = 0; - - for(int i = 0; i < ref.mData.size(); ++i) - { - if(ref.mData[i] != result.mData[i]) - { - std::cerr << std::endl - << "Indices different at position " << i << " (ref: " << ref.mData[i] - << ", result: " << result.mData[i] << ")" << std::endl; - has_error = true; - error_count++; - if(error_count == 20) - break; - }; - } - - if(!has_error) - std::cout << std::endl << "Indices result is completely acccurate!" << std::endl; -} - #endif From 4248ed7ede401859868447462ebf71903bb33e03 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 30 Mar 2022 11:58:03 +0200 Subject: [PATCH 76/82] Fix namespaces after merge. --- .../convnd_bwd_data_xdl.cpp | 40 +++++++++---------- .../include/profile_convnd_bwd_data_impl.hpp | 24 +++++------ profiler/src/profile_convnd_bwd_data.cpp | 6 +-- test/convnd_bwd_data/convnd_bwd_data.cpp | 2 +- 4 files changed, 36 insertions(+), 36 deletions(-) diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp index 8db17f73986..66aa366d95b 100644 --- a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -99,10 +99,10 @@ void PrintUseMsg() << " , (ie RightPy, RightPx for 2D)\n" << std::endl; } -ck::conv_util::ConvParams ParseConvParams(int num_dim_spatial, char* argv[]) +ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, char* argv[]) { // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; int arg_idx = 5; params.num_dim_spatial = num_dim_spatial; @@ -152,13 +152,13 @@ HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector switch(num_dim_spatial) { case 3: { - return ck::conv_util::GetHostTensorDescriptor(dims, tl::NDHWC{}); + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWC{}); } case 2: { - return ck::conv_util::GetHostTensorDescriptor(dims, tl::NHWC{}); + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWC{}); } case 1: { - return ck::conv_util::GetHostTensorDescriptor(dims, tl::NWC{}); + return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWC{}); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); @@ -173,13 +173,13 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vectorRun(argument.get(), nrepeat); - std::size_t flop = ck::conv_util::GetFlops( + std::size_t flop = ck::utils::conv::GetFlops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = - ck::conv_util::GetBtype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::GetBtype( + params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; diff --git a/profiler/include/profile_convnd_bwd_data_impl.hpp b/profiler/include/profile_convnd_bwd_data_impl.hpp index c71d2cc9075..3763820af7a 100644 --- a/profiler/include/profile_convnd_bwd_data_impl.hpp +++ b/profiler/include/profile_convnd_bwd_data_impl.hpp @@ -1,7 +1,7 @@ #pragma once #include "config.hpp" #include "device.hpp" -#include "conv_utils.hpp" +#include "conv_fwd_util.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" #include "tensor_layout.hpp" @@ -68,13 +68,13 @@ HostTensorDescriptor get_input_host_tensor_descriptor(const std::vectorRun(argument_ptr.get(), nrepeat); std::size_t flop = - ck::conv_util::GetFlops(N, C, K, filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = ck::conv_util::GetBtype( + ck::utils::conv::GetFlops(N, C, K, filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = ck::utils::conv::GetBtype( N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; diff --git a/profiler/src/profile_convnd_bwd_data.cpp b/profiler/src/profile_convnd_bwd_data.cpp index 2f406855cce..43611963d66 100644 --- a/profiler/src/profile_convnd_bwd_data.cpp +++ b/profiler/src/profile_convnd_bwd_data.cpp @@ -32,10 +32,10 @@ enum ConvOutputLayout NKHW, // 0 NHWK, // 1 }; -ck::conv_util::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], int arg_idx) +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[], int arg_idx) { // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) - ck::conv_util::ConvParams params; + ck::utils::conv::ConvParams params; params.num_dim_spatial = num_dim_spatial; params.N = std::stoi(argv[arg_idx++]); @@ -106,7 +106,7 @@ int profile_convnd_bwd_data(int argc, char* argv[], int num_dim_spatial) const bool do_log = std::stoi(argv[8]); const int nrepeat = std::stoi(argv[9]); - ck::conv_util::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams); + ck::utils::conv::ConvParams params = parse_conv_params(num_dim_spatial, argv, preParams); auto Run = [&](auto input_type, auto wei_type, auto out_type, auto acc_type) { using InDataType = decltype(input_type); diff --git a/test/convnd_bwd_data/convnd_bwd_data.cpp b/test/convnd_bwd_data/convnd_bwd_data.cpp index 53c339fa8c7..cbc215033b4 100644 --- a/test/convnd_bwd_data/convnd_bwd_data.cpp +++ b/test/convnd_bwd_data/convnd_bwd_data.cpp @@ -12,7 +12,7 @@ int main() { bool pass = true; // check 1d - std::vector params; + std::vector params; params.push_back({1, 128, 128, 256, {1}, {14}, {2}, {1}, {0}, {0}}); params.push_back({1, 128, 128, 256, {3}, {28}, {1}, {1}, {1}, {1}}); params.push_back({1, 128, 128, 256, {1}, {3}, {1}, {1}, {0}, {0}}); From f67672a501888745d2a7e8d940d49c769174fb41 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 1 Apr 2022 11:01:09 +0200 Subject: [PATCH 77/82] Fix compilation error. --- example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp | 2 +- example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index f280cdbfa7e..92338e3f13f 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -34,7 +34,7 @@ using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; using DeviceConvFwdBasePtr = ck::tensor_operation::device::DeviceConvFwdPtr; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 0c8c9baf7cc..770f3b698a6 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -36,7 +36,7 @@ using OutElementOp = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; using DeviceConvFwdBasePtr = ck::tensor_operation::device::DeviceConvFwdPtr; From ad143e341b2ed8a7c873e54abd40f45a7d4f3570 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 1 Apr 2022 11:13:14 +0200 Subject: [PATCH 78/82] Small fixes. * Use common functions. * Fix filename * Fix namespaces. --- .../convnd_bwd_data_xdl.cpp | 97 ++----------------- 1 file changed, 7 insertions(+), 90 deletions(-) diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp index 187988c43fe..095afecd982 100644 --- a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -6,7 +6,7 @@ #include #include "config.hpp" -#include "conv_utils.hpp" +#include "conv_fwd_util.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -144,91 +144,6 @@ ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, char* argv[]) return params; } -HostTensorDescriptor GetInputHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWC{}); - } - case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWC{}); - } - case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWC{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} -HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KZYXC{}); - } - case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KYXC{}); - } - case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KXC{}); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} - -HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) -{ - namespace tl = ck::tensor_layout::convolution; - - switch(num_dim_spatial) - { - case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWK{}); - } - case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWK{}); - } - case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWK{}); - } - - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} - -DeviceConvBwdDataBasePtr GetConvInstance(int num_dim_spatial) -{ - switch(num_dim_spatial) - { - case 3: { - return std::make_unique>(); - } - case 2: { - return std::make_unique>(); - } - case 1: { - return std::make_unique>(); - } - default: { - throw std::runtime_error("Unsupported number of spatial dimensions provided!"); - } - } -} - int main(int argc, char* argv[]) { bool do_verification = 0; @@ -288,11 +203,13 @@ int main(int argc, char* argv[]) std::end(output_spatial_lengths)); Tensor in_n_c_hi_wi_host_result( - GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + ck::utils::conv::GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); Tensor in_n_c_hi_wi_device_result( - GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); - Tensor wei_k_c_y_x(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); - Tensor out_n_k_ho_wo(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + ck::utils::conv::GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + Tensor wei_k_c_y_x( + ck::utils::conv::GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); + Tensor out_n_k_ho_wo( + ck::utils::conv::GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; From 9ee768aa59cbb7f3f022f6f2075fd587e62f4b0b Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 1 Apr 2022 11:32:14 +0200 Subject: [PATCH 79/82] Fix merge artifact - retrieve removed by accident fun. --- .../convnd_bwd_data_xdl.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp index 095afecd982..e7574f47ade 100644 --- a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -144,6 +144,25 @@ ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, char* argv[]) return params; } +DeviceConvBwdDataBasePtr GetConvInstance(int num_dim_spatial) +{ + switch(num_dim_spatial) + { + case 3: { + return std::make_unique>(); + } + case 2: { + return std::make_unique>(); + } + case 1: { + return std::make_unique>(); + } + default: { + throw std::runtime_error("Unsupported number of spatial dimensions provided!"); + } + } +} + int main(int argc, char* argv[]) { bool do_verification = 0; From 64b63bd694a467266e7100ffe3cf2b00c4e9d6c6 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 1 Apr 2022 12:29:38 +0200 Subject: [PATCH 80/82] Fix ConvForwardSpecialization. --- test/convnd_fwd/conv_util.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/convnd_fwd/conv_util.hpp b/test/convnd_fwd/conv_util.hpp index 299ce4afaa1..ee1f6825240 100644 --- a/test/convnd_fwd/conv_util.hpp +++ b/test/convnd_fwd/conv_util.hpp @@ -20,7 +20,7 @@ using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; using OutElementOp = ck::tensor_operation::element_wise::PassThrough; static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization_t::Default; + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; template using DeviceConvNDFwdInstance = ck::tensor_operation::device:: From 5658c10c06d5da6050512444c2390a61c9d6d32d Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 5 Apr 2022 11:33:57 +0200 Subject: [PATCH 81/82] Adhere to coding style rules. --- .../conv2d_fwd_xdl_bias_relu.cpp | 24 ++-- .../conv2d_fwd_xdl_bias_relu_add.cpp | 26 ++-- example/09_convnd_fwd/convnd_fwd_xdl.cpp | 36 ++--- example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp | 24 ++-- example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp | 24 ++-- .../convnd_bwd_data_xdl.cpp | 12 +- .../ck/library/utility/conv_fwd_util.hpp | 95 ++++++------- .../include/profile_convnd_bwd_data_impl.hpp | 25 ++-- test/conv_util/conv_util.cpp | 25 ++-- test/convnd_fwd/conv1d_fwd.cpp | 70 +++++----- test/convnd_fwd/conv2d_fwd.cpp | 61 ++++----- test/convnd_fwd/conv3d_fwd.cpp | 127 +++++++++--------- test/convnd_fwd/conv_util.hpp | 9 +- .../reference_conv_fwd/reference_conv_fwd.cpp | 93 ++++++------- 14 files changed, 336 insertions(+), 315 deletions(-) diff --git a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp index 3ca08f63e1a..751ce16b901 100644 --- a/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp +++ b/example/06_conv2d_fwd_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -203,10 +203,12 @@ int main(int argc, char* argv[]) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); - Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); - Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); - Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); // bias: assume contiguous 1d vector Tensor bias( HostTensorDescriptor(std::vector({static_cast(params.K)}))); @@ -269,15 +271,15 @@ int main(int argc, char* argv[]) float ave_time = invoker.Run(argument, nrepeat); - std::size_t flop = GetFlops( + std::size_t flop = get_flops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); std::size_t num_btype = - GetBtype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths) + + get_btype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths) + sizeof(OutDataType) * (params.K); float tflops = static_cast(flop) / 1.E9 / ave_time; diff --git a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp index 77d64e42a29..e6339fcd23a 100644 --- a/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp +++ b/example/07_conv2d_fwd_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -200,17 +200,19 @@ int main(int argc, char* argv[]) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); - Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); - Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); - Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); // bias: assume contiguous 1d vector Tensor bias( HostTensorDescriptor(std::vector({static_cast(params.K)}))); // residual: assume same layout as output tensor - Tensor residual(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor residual(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl; @@ -280,15 +282,15 @@ int main(int argc, char* argv[]) float ave_time = invoker.Run(argument, nrepeat); - std::size_t flop = GetFlops( + std::size_t flop = get_flops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); std::size_t num_btype = - GetBtype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths) + + get_btype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths) + sizeof(OutDataType) * (params.K) + sizeof(OutDataType) * (params.N * params.K * output_spatial_lengths[0] * output_spatial_lengths[1]); diff --git a/example/09_convnd_fwd/convnd_fwd_xdl.cpp b/example/09_convnd_fwd/convnd_fwd_xdl.cpp index 48278bd9a49..e8895b86391 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl.cpp @@ -84,7 +84,7 @@ using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd< OutElementOp, NumDimSpatial>; -DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) +DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial) { switch(num_dim_spatial) { @@ -103,7 +103,7 @@ DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) } } -void PrintUseMsg() +void print_use_msg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" @@ -120,14 +120,14 @@ void PrintUseMsg() << std::endl; } -ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, char* argv[]) { // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) int conv_args = 3 + num_dim_spatial * 6; int cmdline_nargs = conv_args + 5; if(cmdline_nargs != argc) { - PrintUseMsg(); + print_use_msg(); exit(0); } @@ -196,7 +196,7 @@ int main(int argc, char* argv[]) if(argc >= 6) { - params = ParseConvParams(num_dim_spatial, argc, argv); + params = parse_conv_params(num_dim_spatial, argc, argv); } std::vector input_dims{static_cast(params.N), @@ -218,10 +218,12 @@ int main(int argc, char* argv[]) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); - Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); - Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); - Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output( + get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl; @@ -247,7 +249,7 @@ int main(int argc, char* argv[]) wei_device_buf.ToDevice(weights.mData.data()); // do GEMM - auto conv = GetConvInstance(num_dim_spatial); + auto conv = get_conv_instance(num_dim_spatial); auto invoker = conv->MakeInvokerPointer(); auto argument = conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), @@ -276,15 +278,15 @@ int main(int argc, char* argv[]) float ave_time = invoker->Run(argument.get(), nrepeat); - std::size_t flop = GetFlops( + std::size_t flop = get_flops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); std::size_t num_btype = - GetBtype(params.N, - params.C, - params.K, - params.input_spatial_lengths, - params.filter_spatial_lengths, - output_spatial_lengths); + get_btype(params.N, + params.C, + params.K, + params.input_spatial_lengths, + params.filter_spatial_lengths, + output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index 92338e3f13f..eaa5683978b 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -87,7 +87,7 @@ using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd< OutElementOp, NumDimSpatial>; -DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) +DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial) { switch(num_dim_spatial) { @@ -106,7 +106,7 @@ DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) } } -void PrintUseMsg() +void print_use_msg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" @@ -123,14 +123,14 @@ void PrintUseMsg() << std::endl; } -ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, char* argv[]) { // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) int conv_args = 3 + num_dim_spatial * 6; int cmdline_nargs = conv_args + 5; if(cmdline_nargs != argc) { - PrintUseMsg(); + print_use_msg(); exit(0); } @@ -199,7 +199,7 @@ int main(int argc, char* argv[]) if(argc >= 6) { - params = ParseConvParams(num_dim_spatial, argc, argv); + params = parse_conv_params(num_dim_spatial, argc, argv); } std::vector input_dims{static_cast(params.N), @@ -221,10 +221,10 @@ int main(int argc, char* argv[]) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); - Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); - Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); - Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl; @@ -250,7 +250,7 @@ int main(int argc, char* argv[]) wei_device_buf.ToDevice(weights.mData.data()); // do GEMM - auto conv = GetConvInstance(num_dim_spatial); + auto conv = get_conv_instance(num_dim_spatial); auto invoker = conv->MakeInvokerPointer(); auto argument = conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), @@ -279,9 +279,9 @@ int main(int argc, char* argv[]) float ave_time = invoker->Run(argument.get(), nrepeat); - std::size_t flop = GetFlops( + std::size_t flop = get_flops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = GetBtype( + std::size_t num_btype = get_btype( params.N, params.C, params.K, diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp index 770f3b698a6..34b46457706 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_int8.cpp @@ -89,7 +89,7 @@ using ReferenceConvNDFwdInstance = ck::tensor_operation::host::ReferenceConvFwd< OutElementOp, NumDimSpatial>; -DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) +DeviceConvFwdBasePtr get_conv_instance(int num_dim_spatial) { switch(num_dim_spatial) { @@ -108,7 +108,7 @@ DeviceConvFwdBasePtr GetConvInstance(int num_dim_spatial) } } -void PrintUseMsg() +void print_use_msg() { std::cout << "arg1: verification (0=no, 1=yes)\n" << "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n" @@ -125,14 +125,14 @@ void PrintUseMsg() << std::endl; } -ck::utils::conv::ConvParams ParseConvParams(int num_dim_spatial, int argc, char* argv[]) +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, int argc, char* argv[]) { // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) int conv_args = 3 + num_dim_spatial * 6; int cmdline_nargs = conv_args + 5; if(cmdline_nargs != argc) { - PrintUseMsg(); + print_use_msg(); exit(0); } @@ -201,7 +201,7 @@ int main(int argc, char* argv[]) if(argc >= 6) { - params = ParseConvParams(num_dim_spatial, argc, argv); + params = parse_conv_params(num_dim_spatial, argc, argv); } std::vector input_dims{static_cast(params.N), @@ -223,10 +223,10 @@ int main(int argc, char* argv[]) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); - Tensor weights(GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); - Tensor host_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); - Tensor device_output(GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + Tensor input(get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); + Tensor weights(get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); + Tensor host_output(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); + Tensor device_output(get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); std::cout << "input: " << input.mDesc << std::endl; std::cout << "weights: " << weights.mDesc << std::endl; @@ -252,7 +252,7 @@ int main(int argc, char* argv[]) wei_device_buf.ToDevice(weights.mData.data()); // do GEMM - auto conv = GetConvInstance(num_dim_spatial); + auto conv = get_conv_instance(num_dim_spatial); auto invoker = conv->MakeInvokerPointer(); auto argument = conv->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), @@ -281,9 +281,9 @@ int main(int argc, char* argv[]) float ave_time = invoker->Run(argument.get(), nrepeat); - std::size_t flop = GetFlops( + std::size_t flop = get_flops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = GetBtype( + std::size_t num_btype = get_btype( params.N, params.C, params.K, diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp index e7574f47ade..e41c1126654 100644 --- a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -222,13 +222,13 @@ int main(int argc, char* argv[]) std::end(output_spatial_lengths)); Tensor in_n_c_hi_wi_host_result( - ck::utils::conv::GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); Tensor in_n_c_hi_wi_device_result( - ck::utils::conv::GetInputHostTensorDescriptor(input_dims, num_dim_spatial)); + ck::utils::conv::get_input_host_tensor_descriptor(input_dims, num_dim_spatial)); Tensor wei_k_c_y_x( - ck::utils::conv::GetFiltersHostTensorDescriptor(filter_dims, num_dim_spatial)); + ck::utils::conv::get_filters_host_tensor_descriptor(filter_dims, num_dim_spatial)); Tensor out_n_k_ho_wo( - ck::utils::conv::GetOutputHostTensorDescriptor(output_dims, num_dim_spatial)); + ck::utils::conv::get_output_host_tensor_descriptor(output_dims, num_dim_spatial)); std::cout << "in_n_c_hi_wi: " << in_n_c_hi_wi_host_result.mDesc << std::endl; std::cout << "wei_k_c_y_x: " << wei_k_c_y_x.mDesc << std::endl; @@ -287,9 +287,9 @@ int main(int argc, char* argv[]) float ave_time = invoker->Run(argument.get(), nrepeat); - std::size_t flop = ck::utils::conv::GetFlops( + std::size_t flop = ck::utils::conv::get_flops( params.N, params.C, params.K, params.filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = ck::utils::conv::GetBtype( + std::size_t num_btype = ck::utils::conv::get_btype( params.N, params.C, params.K, diff --git a/library/include/ck/library/utility/conv_fwd_util.hpp b/library/include/ck/library/utility/conv_fwd_util.hpp index 2beaece232b..f758b808c36 100644 --- a/library/include/ck/library/utility/conv_fwd_util.hpp +++ b/library/include/ck/library/utility/conv_fwd_util.hpp @@ -43,11 +43,11 @@ using DeviceConvFwdNoOpPtr = * * @return The number of flops. */ -std::size_t GetFlops(ck::index_t N, - ck::index_t C, - ck::index_t K, - const std::vector& filter_spatial_lengths, - const std::vector& output_spatial_lengths) +std::size_t get_flops(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) { // 2 * N * K * * C * return static_cast(2) * N * K * @@ -81,12 +81,12 @@ std::size_t GetFlops(ck::index_t N, template -std::size_t GetBtype(ck::index_t N, - ck::index_t C, - ck::index_t K, - const std::vector& input_spatial_lengths, - const std::vector& filter_spatial_lengths, - const std::vector& output_spatial_lengths) +std::size_t get_btype(ck::index_t N, + ck::index_t C, + ck::index_t K, + const std::vector& input_spatial_lengths, + const std::vector& filter_spatial_lengths, + const std::vector& output_spatial_lengths) { // sizeof(InDataType) * (N * C * ) + // sizeof(WeiDataType) * (K * C * ) + @@ -211,8 +211,8 @@ struct ConvParams * @return The host tensor descriptor object. */ template -HostTensorDescriptor GetHostTensorDescriptor(const std::vector& dims, - const TensorLayout& layout) +HostTensorDescriptor get_host_tensor_descriptor(const std::vector& dims, + const TensorLayout& layout) { std::size_t C = dims[1]; // 1D @@ -279,7 +279,7 @@ template -auto GetHostTensors(const ConvParams& params, bool init = true) +auto get_host_tensors(const ConvParams& params, bool init = true) { std::vector input_dims{static_cast(params.N), static_cast(params.C)}; @@ -300,12 +300,13 @@ auto GetHostTensors(const ConvParams& params, bool init = true) std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(ck::utils::conv::GetHostTensorDescriptor(input_dims, InLayout{})); - Tensor weights(ck::utils::conv::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor input(ck::utils::conv::get_host_tensor_descriptor(input_dims, InLayout{})); + Tensor weights( + ck::utils::conv::get_host_tensor_descriptor(filter_dims, WeiLayout{})); Tensor host_output( - ck::utils::conv::GetHostTensorDescriptor(output_dims, OutLayout{})); + ck::utils::conv::get_host_tensor_descriptor(output_dims, OutLayout{})); Tensor device_output( - ck::utils::conv::GetHostTensorDescriptor(output_dims, OutLayout{})); + ck::utils::conv::get_host_tensor_descriptor(output_dims, OutLayout{})); if(init) { @@ -333,21 +334,21 @@ auto GetHostTensors(const ConvParams& params, bool init = true) return std::make_tuple(input, weights, host_output, device_output); } -HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) +HostTensorDescriptor get_output_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) { namespace tl = ck::tensor_layout::convolution; switch(num_dim_spatial) { case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWK{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWK{}); } case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWK{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWK{}); } case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWK{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWK{}); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); @@ -355,21 +356,21 @@ HostTensorDescriptor GetOutputHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) +HostTensorDescriptor get_filters_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) { namespace tl = ck::tensor_layout::convolution; switch(num_dim_spatial) { case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KZYXC{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::KZYXC{}); } case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KYXC{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::KYXC{}); } case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::KXC{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::KXC{}); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); @@ -377,21 +378,21 @@ HostTensorDescriptor GetFiltersHostTensorDescriptor(const std::vector& dims, - int num_dim_spatial = 2) +HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector& dims, + int num_dim_spatial = 2) { namespace tl = ck::tensor_layout::convolution; switch(num_dim_spatial) { case 3: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWC{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{}); } case 2: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWC{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{}); } case 1: { - return ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWC{}); + return ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{}); } default: { throw std::runtime_error("Unsupported number of spatial dimensions provided!"); @@ -403,10 +404,10 @@ template -void RunReferenceConvFwd(const ConvParams& params, - const Tensor& input, - const Tensor& weights, - Tensor& output) +void run_reference_convolution_forward(const ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd class DeviceConvNDFwdInstance> -void RunConvFwd(const ConvParams& params, - const Tensor& input, - const Tensor& weights, - Tensor& output) +void run_convolution_forward(const ConvParams& params, + const Tensor& input, + const Tensor& weights, + Tensor& output) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -486,12 +487,12 @@ template -bool RunConvInstances(const ConvParams& params, - const std::vector& conv_ptrs, - const Tensor& input, - const Tensor& weights, - Tensor& output, - const Tensor& host_output) +bool run_convolution_forward_instances(const ConvParams& params, + const std::vector& conv_ptrs, + const Tensor& input, + const Tensor& weights, + Tensor& output, + const Tensor& host_output) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; diff --git a/profiler/include/profile_convnd_bwd_data_impl.hpp b/profiler/include/profile_convnd_bwd_data_impl.hpp index 935172d03b4..3b3595fdfca 100644 --- a/profiler/include/profile_convnd_bwd_data_impl.hpp +++ b/profiler/include/profile_convnd_bwd_data_impl.hpp @@ -68,13 +68,13 @@ HostTensorDescriptor get_input_host_tensor_descriptor(const std::vectorRun(argument_ptr.get(), nrepeat); std::size_t flop = - ck::utils::conv::GetFlops(N, C, K, filter_spatial_lengths, output_spatial_lengths); - std::size_t num_btype = ck::utils::conv::GetBtype( - N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); + ck::utils::conv::get_flops(N, C, K, filter_spatial_lengths, output_spatial_lengths); + std::size_t num_btype = + ck::utils::conv::get_btype( + N, C, K, input_spatial_lengths, filter_spatial_lengths, output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; diff --git a/test/conv_util/conv_util.cpp b/test/conv_util/conv_util.cpp index 85d02a3b672..cc487c39e34 100644 --- a/test/conv_util/conv_util.cpp +++ b/test/conv_util/conv_util.cpp @@ -9,7 +9,7 @@ namespace { -bool TestConvParams_GetOutputSpatialLengths() +bool test_conv_params_get_output_spatial_lengths() { bool res{true}; // -------------------------- default 2D ------------------------------------ @@ -138,36 +138,36 @@ bool TestConvParams_GetOutputSpatialLengths() return res; } -bool TestGetHostTensorDescriptor() +bool test_get_host_tensor_descriptor() { bool res{true}; namespace tl = ck::tensor_layout::convolution; std::vector dims{2, 3, 4, 5}; - HostTensorDescriptor h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NHWC{}); + HostTensorDescriptor h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NHWC{}); res = ck::utils::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NHWC dimensions lengths!"); res = ck::utils::check_err( h.GetStrides(), {3 * 4 * 5, 1, 3 * 5, 3}, "Error: wrong NHWC dimensions strides!"); - h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NCHW{}); + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCHW{}); res = ck::utils::check_err(h.GetLengths(), {2, 3, 4, 5}, "Error: wrong NCHW dimensions lengths!"); res = ck::utils::check_err( h.GetStrides(), {3 * 4 * 5, 4 * 5, 5, 1}, "Error: wrong NCHW dimensions strides!"); dims = std::vector{2, 3, 4}; - h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NWC{}); + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NWC{}); res = ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NWC dimensions lengths!"); res = ck::utils::check_err(h.GetStrides(), {3 * 4, 1, 3}, "Error: wrong NWC dimensions strides!"); - h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NCW{}); + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCW{}); res = ck::utils::check_err(h.GetLengths(), {2, 3, 4}, "Error: wrong NCW dimensions lengths!"); res = ck::utils::check_err(h.GetStrides(), {3 * 4, 4, 1}, "Error: wrong NCW dimensions strides!"); dims = std::vector{2, 3, 4, 5, 6}; - h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NDHWC{}); + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NDHWC{}); res = ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NDHWC dimensions lengths!"); res = ck::utils::check_err(h.GetStrides(), {3 * 4 * 5 * 6, // N @@ -177,7 +177,7 @@ bool TestGetHostTensorDescriptor() 3}, // W "Error: wrong NDHWC dimensions strides!"); - h = ck::utils::conv::GetHostTensorDescriptor(dims, tl::NCDHW{}); + h = ck::utils::conv::get_host_tensor_descriptor(dims, tl::NCDHW{}); res = ck::utils::check_err(h.GetLengths(), dims, "Error: wrong NCDHW dimensions lengths!"); res = ck::utils::check_err(h.GetStrides(), {3 * 4 * 5 * 6, // N @@ -194,10 +194,11 @@ bool TestGetHostTensorDescriptor() int main(void) { - bool res = TestConvParams_GetOutputSpatialLengths(); - std::cout << "TestConvParams_GetOutputSpatialLengths ..... " << (res ? "SUCCESS" : "FAILURE") + bool res = test_conv_params_get_output_spatial_lengths(); + std::cout << "test_conv_params_get_output_spatial_lengths ..... " + << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = test_get_host_tensor_descriptor(); + std::cout << "test_get_host_tensor_descriptor ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestGetHostTensorDescriptor(); - std::cout << "TestGetHostTensorDescriptor ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; } diff --git a/test/convnd_fwd/conv1d_fwd.cpp b/test/convnd_fwd/conv1d_fwd.cpp index f5d02a5aa72..e6df0e6f8cf 100644 --- a/test/convnd_fwd/conv1d_fwd.cpp +++ b/test/convnd_fwd/conv1d_fwd.cpp @@ -35,7 +35,7 @@ void add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(std::vector{1}; auto host_tensors = - ck::utils::conv::GetHostTensors(params); + ck::utils::conv::get_host_tensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - ck::utils::conv::RunReferenceConvFwd<1>(params, input, weights, host_output); + ck::utils::conv::run_reference_convolution_forward<1>(params, input, weights, host_output); test::conv::RunConv<1>(params, input, weights, device_output); res = res && ck::utils::check_err( @@ -72,7 +72,7 @@ bool TestConv1DNWC() } template -bool TestConv1DNWCInstances(const std::vector& conv_ptrs) +bool test_conv1d_nwc_instances(const std::vector& conv_ptrs) { ck::utils::conv::ConvParams params; params.num_dim_spatial = 1; @@ -84,51 +84,51 @@ bool TestConv1DNWCInstances(const std::vector& conv_ptrs) params.input_right_pads = std::vector{1}; auto host_tensors = - ck::utils::conv::GetHostTensors(params); + ck::utils::conv::get_host_tensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - ck::utils::conv::RunReferenceConvFwd<1>(params, input, weights, host_output); - return ck::utils::conv::RunConvInstances<1>( + ck::utils::conv::run_reference_convolution_forward<1>(params, input, weights, host_output); + return ck::utils::conv::run_convolution_forward_instances<1>( params, conv_ptrs, input, weights, device_output, host_output); } -bool TestConv1DNWCBF16Instances() +bool test_conv1d_nwc_bf16_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv1d_fwd_instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); - return TestConv1DNWCInstances(conv_ptrs); + return test_conv1d_nwc_instances(conv_ptrs); } -bool TestConv1DNWCF16Instances() +bool test_conv1d_nwc_f16_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv1d_fwd_instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); - return TestConv1DNWCInstances(conv_ptrs); + return test_conv1d_nwc_instances(conv_ptrs); } -bool TestConv1DNWCF32Instances() +bool test_conv1d_nwc_f32_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv1d_fwd_instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); - return TestConv1DNWCInstances(conv_ptrs); + return test_conv1d_nwc_instances(conv_ptrs); } -bool TestConv1DNWCInt8Instances() +bool test_conv1d_nwc_int8_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv1d_fwd_instance:: add_device_conv1d_fwd_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); - return TestConv1DNWCInstances(conv_ptrs); + return test_conv1d_nwc_instances(conv_ptrs); } } // anonymous namespace @@ -136,18 +136,20 @@ bool TestConv1DNWCInt8Instances() int main() { bool res{true}; - res = TestConv1DNWC(); - std::cout << "TestConv1DNWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = test_conv1D_nwc(); + std::cout << "test_conv1D_nwc ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNWCBF16Instances(); + res = test_conv1d_nwc_bf16_instances(); std::cout << "\nTestConv1DNWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNWCF16Instances(); - std::cout << "\nTestConv1DNWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNWCF32Instances(); - std::cout << "\nTestConv1DNWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNWCInt8Instances(); - std::cout << "\nTestConv1DNWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") + res = test_conv1d_nwc_f16_instances(); + std::cout << "\ntest_conv1d_nwc_f16_instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + res = test_conv1d_nwc_f32_instances(); + std::cout << "\ntest_conv1d_nwc_f32_instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + res = test_conv1d_nwc_int8_instances(); + std::cout << "\ntes_tconv1_dnw_cint_8instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp index 37b4085a002..2a46d744958 100644 --- a/test/convnd_fwd/conv2d_fwd.cpp +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -37,7 +37,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector{16, 16}; params.conv_filter_strides = std::vector{1, 1}; - auto host_tensors = ck::utils::conv::GetHostTensors(params); + auto host_tensors = ck::utils::conv::get_host_tensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - ck::utils::conv::RunReferenceConvFwd<2>(params, input, weights, host_output); + ck::utils::conv::run_reference_convolution_forward<2>(params, input, weights, host_output); test::conv::RunConv<2>(params, input, weights, device_output); res = res && ck::utils::check_err( @@ -63,7 +63,7 @@ bool TestConv2DNHWC() } template -bool TestConv2DNHWCInstances(const std::vector& conv_ptrs) +bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs) { ck::utils::conv::ConvParams params; params.num_dim_spatial = 2; @@ -75,54 +75,54 @@ bool TestConv2DNHWCInstances(const std::vector& conv_ptrs) params.input_right_pads = std::vector{1, 1}; auto host_tensors = - ck::utils::conv::GetHostTensors(params); + ck::utils::conv::get_host_tensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - ck::utils::conv::RunReferenceConvFwd<2>(params, input, weights, host_output); - return ck::utils::conv::RunConvInstances<2>( + ck::utils::conv::run_reference_convolution_forward<2>(params, input, weights, host_output); + return ck::utils::conv::run_convolution_forward_instances<2>( params, conv_ptrs, input, weights, device_output, host_output); } -bool TestConv2DNHWCBF16Instances() +bool test_conv2d_nhwc_bf16_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); - return TestConv2DNHWCInstances(conv_ptrs); + return test_conv2d_nhwc_instances(conv_ptrs); } -bool TestConv2DNHWCF16Instances() +bool test_conv2d_nhwc_f16_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); - return TestConv2DNHWCInstances(conv_ptrs); + return test_conv2d_nhwc_instances(conv_ptrs); } -bool TestConv2DNHWCF32Instances() +bool test_conv2d_nhwc_f32_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); - return TestConv2DNHWCInstances(conv_ptrs); + return test_conv2d_nhwc_instances(conv_ptrs); } -bool TestConv2DNHWCInt8Instances() +bool test_conv2d_nhwc_int8_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv2d_fwd_instance:: add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); - return TestConv2DNHWCInstances(conv_ptrs); + return test_conv2d_nhwc_instances(conv_ptrs); } } // anonymous namespace @@ -130,19 +130,20 @@ bool TestConv2DNHWCInt8Instances() int main() { bool res{true}; - res = TestConv2DNHWC(); - std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = test_conv2d_nhwc(); + std::cout << "test_conv2d_nhwc ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv2DNHWCBF16Instances(); - std::cout << "\nTestConv2DNHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + res = test_conv2d_nhwc_bf16_instances(); + std::cout << "\ntest_conv2d_nhwc_bf16_instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv2DNHWCF16Instances(); - std::cout << "\nTestConv2DNHWCF16Instances ....." << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv2DNHWCF32Instances(); - std::cout << "\nTestConv2DNHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") + res = test_conv2d_nhwc_f16_instances(); + std::cout << "\ntest_conv2d_nhwc_f16_instances ....." << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv2DNHWCInt8Instances(); - std::cout << "\nTestConv2DNHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") + res = test_conv2d_nhwc_f32_instances(); + std::cout << "\ntest_conv2d_nhwc_f32_instances ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + res = test_conv2d_nhwc_int8_instances(); + std::cout << "\ntest_conv2d_nhwc_int8_instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index e91f2453db5..3dc1a6b160f 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -35,7 +35,7 @@ void add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(std::vector{1, 1, 1}; auto host_tensors = - ck::utils::conv::GetHostTensors(params); + ck::utils::conv::get_host_tensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - ck::utils::conv::RunReferenceConvFwd<3>(params, input, weights, host_output); + ck::utils::conv::run_reference_convolution_forward<3>(params, input, weights, host_output); test::conv::RunConv<3>(params, input, weights, device_output); res = res && ck::utils::check_err( @@ -71,7 +71,7 @@ bool TestConv3DNDHWC() return res; } -bool TestConv3DNDHWC2GBInput() +bool test_conv3d_ndhwc_2gb_input() { // >2GB Input ck::utils::conv::ConvParams params; @@ -87,12 +87,12 @@ bool TestConv3DNDHWC2GBInput() params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = - ck::utils::conv::GetHostTensors(params, false); + ck::utils::conv::get_host_tensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -115,7 +115,7 @@ bool TestConv3DNDHWC2GBInput() return false; } -bool TestConv3DNDHWC2GBFilters() +bool test_conv3d_ndhwc_2gb_filters() { // >2GB Filters ck::utils::conv::ConvParams params; @@ -131,12 +131,12 @@ bool TestConv3DNDHWC2GBFilters() params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = - ck::utils::conv::GetHostTensors(params, false); + ck::utils::conv::get_host_tensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -159,7 +159,7 @@ bool TestConv3DNDHWC2GBFilters() return false; } -bool TestConv3DNDHWC2GBOutput() +bool test_conv3d_ndhwc_2gb_output() { // >2GB Output ck::utils::conv::ConvParams params; @@ -175,12 +175,12 @@ bool TestConv3DNDHWC2GBOutput() params.input_right_pads = std::vector{2, 2, 2}; auto host_tensors = - ck::utils::conv::GetHostTensors(params, false); + ck::utils::conv::get_host_tensors(params, false); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); @@ -204,7 +204,7 @@ bool TestConv3DNDHWC2GBOutput() } template -bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs) +bool test_conv3d_ndhwc_instances(const std::vector& conv_ptrs) { ck::utils::conv::ConvParams params; params.N = 64; @@ -217,52 +217,52 @@ bool TestConv3DNDHWCInstances(const std::vector& conv_ptrs params.input_right_pads = std::vector{1, 1, 1}; auto host_tensors = - ck::utils::conv::GetHostTensors(params); + ck::utils::conv::get_host_tensors(params); const Tensor& input = std::get<0>(host_tensors); const Tensor& weights = std::get<1>(host_tensors); Tensor& host_output = std::get<2>(host_tensors); Tensor& device_output = std::get<3>(host_tensors); - ck::utils::conv::RunReferenceConvFwd<3>(params, input, weights, host_output); - return ck::utils::conv::RunConvInstances<3>( + ck::utils::conv::run_reference_convolution_forward<3>(params, input, weights, host_output); + return ck::utils::conv::run_convolution_forward_instances<3>( params, conv_ptrs, input, weights, device_output, host_output); } -bool TestConv3DNDHWCBF16Instances() +bool test_conv3d_ndhwc_bf16_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); - return TestConv3DNDHWCInstances(conv_ptrs); + return test_conv3d_ndhwc_instances(conv_ptrs); } -bool TestConv3DNDHWCF16Instances() +bool test_conv3d_ndhwc_f16_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); - return TestConv3DNDHWCInstances(conv_ptrs); + return test_conv3d_ndhwc_instances(conv_ptrs); } -bool TestConv3DNDHWCF32Instances() +bool test_conv3d_ndhwc_f32_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); - return TestConv3DNDHWCInstances(conv_ptrs); + return test_conv3d_ndhwc_instances(conv_ptrs); } -bool TestConv3DNDHWCInt8Instances() +bool test_conv3d_ndhwc_int8_instances() { std::vector conv_ptrs; ck::tensor_operation::device::device_conv3d_fwd_instance:: add_device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); - return TestConv3DNDHWCInstances(conv_ptrs); + return test_conv3d_ndhwc_instances(conv_ptrs); } } // anonymous namespace @@ -270,27 +270,30 @@ bool TestConv3DNDHWCInt8Instances() int main() { bool res{true}; - res = TestConv3DNDHWC(); - std::cout << "TestConv3DNDHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = test_conv3d_ndhwc(); + std::cout << "test_conv3d_ndhwc ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWC2GBInput(); - std::cout << "\nTestConv3DNDHWC2GBInput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWC2GBFilters(); - std::cout << "\nTestConv3DNDHWC2GBFilters ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWC2GBOutput(); - std::cout << "\nTestConv3DNDHWC2GBOutput ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = test_conv3d_ndhwc_2gb_input(); + std::cout << "\ntest_conv3d_ndhwc_2gb_input ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + res = test_conv3d_ndhwc_2gb_filters(); + std::cout << "\ntest_conv3d_ndhwc_2gb_filters ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; + res = test_conv3d_ndhwc_2gb_output(); + std::cout << "\ntest_conv3d_ndhwc_2gb_output ..... " << (res ? "SUCCESS" : "FAILURE") + << std::endl; - res = TestConv3DNDHWCBF16Instances(); - std::cout << "\nTestConv3DNDHWCBF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + res = test_conv3d_ndhwc_bf16_instances(); + std::cout << "\ntest_conv3d_ndhwc_bf16_instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWCF16Instances(); - std::cout << "\nTestConv3DNDHWCF16Instances ..... " << (res ? "SUCCESS" : "FAILURE") + res = test_conv3d_ndhwc_f16_instances(); + std::cout << "\ntest_conv3d_ndhwc_f16_instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWCF32Instances(); - std::cout << "\nTestConv3DNDHWCF32Instances ..... " << (res ? "SUCCESS" : "FAILURE") + res = test_conv3d_ndhwc_f32_instances(); + std::cout << "\ntest_conv3d_ndhwc_f32_instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNDHWCInt8Instances(); - std::cout << "\nTestConv3DNDHWCInt8Instances ..... " << (res ? "SUCCESS" : "FAILURE") + res = test_conv3d_ndhwc_int8_instances(); + std::cout << "\ntest_conv3d_ndhw_cint_8instances ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; diff --git a/test/convnd_fwd/conv_util.hpp b/test/convnd_fwd/conv_util.hpp index ee1f6825240..d62dab73668 100644 --- a/test/convnd_fwd/conv_util.hpp +++ b/test/convnd_fwd/conv_util.hpp @@ -76,9 +76,12 @@ void RunConv(const ck::utils::conv::ConvParams& params, const Tensor& weights, Tensor& output) { - ck::utils::conv:: - RunConvFwd( - params, input, weights, output); + ck::utils::conv::run_convolution_forward( + params, input, weights, output); } } // namespace conv diff --git a/test/reference_conv_fwd/reference_conv_fwd.cpp b/test/reference_conv_fwd/reference_conv_fwd.cpp index af7fb4f3631..d852e8f5eb2 100644 --- a/test/reference_conv_fwd/reference_conv_fwd.cpp +++ b/test/reference_conv_fwd/reference_conv_fwd.cpp @@ -57,9 +57,10 @@ template , typename FillWeightsOp = FillConstant> -Tensor RunReferenceConvFwd(const ck::utils::conv::ConvParams& params, - const FillInputOp& fill_input_op = FillInputOp{}, - const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) +Tensor +run_reference_convolution_forward(const ck::utils::conv::ConvParams& params, + const FillInputOp& fill_input_op = FillInputOp{}, + const FillWeightsOp& fill_weights_op = FillWeightsOp{0.5f}) { std::vector input_dims{static_cast(params.N), static_cast(params.C)}; @@ -80,10 +81,11 @@ Tensor RunReferenceConvFwd(const ck::utils::conv::ConvParams& param std::begin(output_spatial_lengths), std::end(output_spatial_lengths)); - Tensor input(ck::utils::conv::GetHostTensorDescriptor(input_dims, InLayout{})); - Tensor weights(ck::utils::conv::GetHostTensorDescriptor(filter_dims, WeiLayout{})); + Tensor input(ck::utils::conv::get_host_tensor_descriptor(input_dims, InLayout{})); + Tensor weights( + ck::utils::conv::get_host_tensor_descriptor(filter_dims, WeiLayout{})); Tensor host_output( - ck::utils::conv::GetHostTensorDescriptor(output_dims, OutLayout{})); + ck::utils::conv::get_host_tensor_descriptor(output_dims, OutLayout{})); fill_input_op(input.begin(), input.end()); fill_weights_op(weights.begin(), weights.end()); @@ -113,7 +115,7 @@ Tensor RunReferenceConvFwd(const ck::utils::conv::ConvParams& param return host_output; } -bool TestConv2DNHWC() +bool test_conv2d_nhwc() { bool res{true}; ck::utils::conv::ConvParams params; @@ -127,7 +129,7 @@ bool TestConv2DNHWC() params.input_left_pads = std::vector{0, 0}; params.input_right_pads = std::vector{0, 0}; - auto out_tensor = RunReferenceConvFwd<2>(params); + auto out_tensor = run_reference_convolution_forward<2>(params); std::vector ref_dims{1, 1, 4, 4}; std::vector ref_data{130.5, 148.5, @@ -160,7 +162,7 @@ bool TestConv2DNHWC() params.input_left_pads = std::vector{1, 1}; params.input_right_pads = std::vector{1, 1}; - out_tensor = RunReferenceConvFwd<2>(params); + out_tensor = run_reference_convolution_forward<2>(params); ref_dims = std::vector{1, 2, 5, 5}; ref_data = std::vector{ 210., 210., 327., 327., 351., 351., 375., 375., 399., 399., @@ -176,7 +178,7 @@ bool TestConv2DNHWC() return res; } -bool TestConv1DNWC() +bool test_conv1d_nwc() { bool res{true}; ck::utils::conv::ConvParams params; @@ -191,7 +193,8 @@ bool TestConv1DNWC() params.input_left_pads = std::vector{0}; params.input_right_pads = std::vector{0}; - auto out_tensor = RunReferenceConvFwd<1, + auto out_tensor = + run_reference_convolution_forward<1, float, float, float, @@ -216,13 +219,13 @@ bool TestConv1DNWC() params.input_left_pads = std::vector{1}; params.input_right_pads = std::vector{1}; - out_tensor = RunReferenceConvFwd<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>(params); + out_tensor = run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>(params); ref_dims = std::vector{1, 2, 5}; ref_data = std::vector{9., 9., 19.5, 19.5, 31.5, 31.5, 43.5, 43.5, 55.5, 55.5}; res = res && ck::utils::check_err(out_tensor.mDesc.GetLengths(), @@ -241,13 +244,13 @@ bool TestConv1DNWC() params.input_left_pads = std::vector{1}; params.input_right_pads = std::vector{1}; - auto out_tensor2 = RunReferenceConvFwd<1, - float, - float, - float, - ck::tensor_layout::convolution::NWC, - ck::tensor_layout::convolution::KXC, - ck::tensor_layout::convolution::NWK>( + auto out_tensor2 = run_reference_convolution_forward<1, + float, + float, + float, + ck::tensor_layout::convolution::NWC, + ck::tensor_layout::convolution::KXC, + ck::tensor_layout::convolution::NWK>( params, FillMonotonicSeq{0.f, 0.1f}); ref_dims = std::vector{2, 16, 16}; @@ -324,7 +327,7 @@ bool TestConv1DNWC() return res; } -bool TestConv3DNCDHW() +bool test_conv3d_ncdhw() { bool res{true}; ck::utils::conv::ConvParams params; @@ -339,13 +342,13 @@ bool TestConv3DNCDHW() params.input_left_pads = std::vector{0, 0, 0}; params.input_right_pads = std::vector{0, 0, 0}; - auto out_tensor = RunReferenceConvFwd<3, - float, - float, - float, - ck::tensor_layout::convolution::NCDHW, - ck::tensor_layout::convolution::KCZYX, - ck::tensor_layout::convolution::NKDHW>( + auto out_tensor = run_reference_convolution_forward<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( params, FillMonotonicSeq{0.f, 0.1f}); std::vector ref_dims{1, 1, 4, 4, 4}; std::vector ref_data{ @@ -373,13 +376,13 @@ bool TestConv3DNCDHW() params.input_left_pads = std::vector{0, 0, 0}; params.input_right_pads = std::vector{0, 0, 0}; - out_tensor = RunReferenceConvFwd<3, - float, - float, - float, - ck::tensor_layout::convolution::NCDHW, - ck::tensor_layout::convolution::KCZYX, - ck::tensor_layout::convolution::NKDHW>( + out_tensor = run_reference_convolution_forward<3, + float, + float, + float, + ck::tensor_layout::convolution::NCDHW, + ck::tensor_layout::convolution::KCZYX, + ck::tensor_layout::convolution::NKDHW>( params, FillMonotonicSeq{0.f, 0.1f}); ref_dims = std::vector{1, 2, 4, 4, 4}; ref_data = std::vector{ @@ -414,11 +417,11 @@ bool TestConv3DNCDHW() int main(void) { bool res{true}; - res = TestConv2DNHWC(); - std::cout << "TestConv2DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv1DNWC(); + res = test_conv2d_nhwc(); + std::cout << "test_conv2d_nhwc ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = test_conv1d_nwc(); std::cout << "TestConv1DNHWC ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; - res = TestConv3DNCDHW(); - std::cout << "TestConv3DNCDHW ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; + res = test_conv3d_ncdhw(); + std::cout << "test_conv3d_ncdhw ..... " << (res ? "SUCCESS" : "FAILURE") << std::endl; return res ? 0 : 1; } From 3c5447d83f63193addaf76946ea3edf125deebec Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 5 Apr 2022 16:43:54 +0200 Subject: [PATCH 82/82] Fix merge artifacts. --- .../convnd_bwd_data_xdl.cpp | 2 +- test/conv2d_bwd_weight/conv2d_bwd_weight.cpp | 24 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp index 09ad8363b05..962627ce90b 100644 --- a/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp +++ b/example/17_convnd_bwd_data_xdl/convnd_bwd_data_xdl.cpp @@ -99,7 +99,7 @@ void print_use_msg() << " , (ie RightPy, RightPx for 2D)\n" << std::endl; } -ck::conv_util::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) +ck::utils::conv::ConvParams parse_conv_params(int num_dim_spatial, char* argv[]) { // (N, K, C) + num_dim_spatial * 6 (filter, input, strides, dilations, pad left, pad right) ck::utils::conv::ConvParams params; diff --git a/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp index 561e35e3773..bb3ed985e32 100644 --- a/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp +++ b/test/conv2d_bwd_weight/conv2d_bwd_weight.cpp @@ -6,13 +6,13 @@ #include #include -#include "conv_utils.hpp" +#include "conv_fwd_util.hpp" #include "profile_conv_bwd_weight_impl.hpp" int test_self() { bool pass = true; - std::vector params; + std::vector params; params.push_back({2, 128, 256, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}); params.push_back({2, 128, 256, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); @@ -136,16 +136,16 @@ int main(int argc, char* argv[]) exit(1); } - ck::conv_util::ConvParams param{2, - N, - K, - C, - {Y, X}, - {Hi, Wi}, - {conv_stride_h, conv_stride_w}, - {conv_dilation_h, conv_dilation_w}, - {in_left_pad_h, in_left_pad_w}, - {in_right_pad_h, in_right_pad_w}}; + ck::utils::conv::ConvParams param{2, + N, + K, + C, + {Y, X}, + {Hi, Wi}, + {conv_stride_h, conv_stride_w}, + {conv_dilation_h, conv_dilation_w}, + {in_left_pad_h, in_left_pad_w}, + {in_right_pad_h, in_right_pad_w}}; if(data_type == 0) { pass = ck::profiler::profile_conv_bwd_weight_impl<2,