From 369638fff831d5880d8c1d13ed1afd58b8e38114 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Tue, 17 Jan 2023 17:09:57 +0100 Subject: [PATCH 01/11] Conv3d bwd weight client example. --- .../11_grouped_conv_bwd_weight/CMakeLists.txt | 9 +- ...ouped_conv2d_bwd_weight.cpp => common.hpp} | 128 +++++++++++++----- .../grouped_conv2d_bwd_weight_fp16.cpp | 41 ++++++ .../grouped_conv3d_bwd_weight_fp16.cpp | 53 ++++++++ .../grouped_conv3d_bwd_weight_fp32.cpp | 53 ++++++++ 5 files changed, 246 insertions(+), 38 deletions(-) rename client_example/11_grouped_conv_bwd_weight/{grouped_conv2d_bwd_weight.cpp => common.hpp} (59%) create mode 100644 client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp create mode 100644 client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp create mode 100644 client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp diff --git a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt index 3e3f6677666..761e0de95a4 100644 --- a/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt +++ b/client_example/11_grouped_conv_bwd_weight/CMakeLists.txt @@ -1,2 +1,7 @@ -add_executable(client_grouped_conv2d_bwd_weight grouped_conv2d_bwd_weight.cpp) -target_link_libraries(client_grouped_conv2d_bwd_weight PRIVATE composable_kernel::device_operations) +add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp) +add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp) +add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp) + +target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_operations) diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp b/client_example/11_grouped_conv_bwd_weight/common.hpp similarity index 59% rename from client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp rename to client_example/11_grouped_conv_bwd_weight/common.hpp index 1ecc8568959..a906263333c 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight.cpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -13,27 +13,8 @@ #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -using InDataType = ck::half_t; -using WeiDataType = ck::half_t; -using OutDataType = ck::half_t; - -using InLayout = ck::tensor_layout::convolution::GNHWC; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 32; -static constexpr ck::index_t N = 256; -static constexpr ck::index_t K = 192; -static constexpr ck::index_t C = 192; -static constexpr ck::index_t Y = 3; -static constexpr ck::index_t X = 3; -static constexpr ck::index_t Hi = 28; -static constexpr ck::index_t Wi = 28; -static constexpr ck::index_t Ho = 28; -static constexpr ck::index_t Wo = 28; - struct SimpleDeviceMem { SimpleDeviceMem() = delete; @@ -50,22 +31,93 @@ struct SimpleDeviceMem void* p_mem_; }; -int main() +template +std::size_t GetFlops(ck::index_t G, + ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::array& output_spatial_lengths, + const std::array& filter_spatial_lengths) { - std::array input_spatial_lengths{Hi, Wi}; - std::array filter_spatial_lengths{Y, X}; - std::array output_spatial_lengths{Ho, Wo}; + // 2 * G * N * K * C * * + return static_cast(2) * G * N * K * C * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies<>()); +} - std::array conv_filter_strides{1, 1}; - std::array conv_filter_dilations{1, 1}; - std::array input_left_pads{1, 1}; - std::array input_right_pads{1, 1}; +template +std::size_t GetInputByte(ck::index_t G, + ck::index_t N, + ck::index_t C, + const std::array& input_spatial_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * (G * N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + static_cast(1), + std::multiplies<>())); +} - ck::index_t split_k = 2; +template +std::size_t GetWeightByte(ck::index_t G, + ck::index_t K, + ck::index_t C, + const std::array& filter_spatial_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * (G * K * C * + std::accumulate(std::begin(filter_spatial_lengths), + std::end(filter_spatial_lengths), + static_cast(1), + std::multiplies<>())); +} - SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); +template +std::size_t GetOutputByte(ck::index_t G, + ck::index_t N, + ck::index_t K, + const std::array& output_spatial_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * (G * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies())); +} + +template +bool run_grouped_conv_bwd_weight( + ck::index_t G, + ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) +{ + + ck::index_t split_k = 2; + SimpleDeviceMem in(GetInputByte(G, N, C, input_spatial_lengths)); + SimpleDeviceMem wei(GetWeightByte(G, K, C, filter_spatial_lengths)); + SimpleDeviceMem out(GetOutputByte(G, N, K, output_spatial_lengths)); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeightRun(argument_ptr.get(), StreamConfig{nullptr, true}); - std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + - sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * G * N * Ho * Wo * K; + std::size_t flop = + GetFlops(G, N, K, C, output_spatial_lengths, filter_spatial_lengths); + std::size_t num_bytes = + GetInputByte(G, N, C, input_spatial_lengths) + + GetWeightByte(G, K, C, filter_spatial_lengths) + + GetOutputByte(G, N, K, output_spatial_lengths); float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; @@ -149,7 +203,7 @@ int main() if(best_op_id < 0) { std::cerr << "no suitable instance" << std::endl; - return EXIT_FAILURE; + return false; } std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops @@ -187,4 +241,6 @@ int main() std::cout << "Done" << std::endl; } + + return true; } diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp new file mode 100644 index 00000000000..9b510ec91d9 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::GNHWC; +using WeiLayout = ck::tensor_layout::convolution::GKYXC; +using OutLayout = ck::tensor_layout::convolution::GNHWK; + +static constexpr ck::index_t NumDimSpatial = 2; +static constexpr ck::index_t G = 32; +static constexpr ck::index_t N = 256; +static constexpr ck::index_t K = 192; +static constexpr ck::index_t C = 192; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +int main() +{ + return run_grouped_conv_bwd_weight( + G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp new file mode 100644 index 00000000000..146696d4eb7 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 8; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 128; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_bwd_weight(G, + N, + K, + C, + {Di, Hi, Wi}, + {Z, Y, X}, + {Do, Ho, Wo}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} \ No newline at end of file diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp new file mode 100644 index 00000000000..c41ba4eaae2 --- /dev/null +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 8; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 128; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_bwd_weight(G, + N, + K, + C, + {Di, Hi, Wi}, + {Z, Y, X}, + {Do, Ho, Wo}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} \ No newline at end of file From 195db449fcdc1a3d1e194e24d9a934b620ec6ea8 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 18 Jan 2023 16:17:18 +0100 Subject: [PATCH 02/11] Update year in license --- .../grouped_conv2d_bwd_weight_fp16.cpp | 2 +- .../grouped_conv3d_bwd_weight_fp16.cpp | 2 +- .../grouped_conv3d_bwd_weight_fp32.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp index 9b510ec91d9..1903bd95b67 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp index 146696d4eb7..e95915bf7a1 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp index c41ba4eaae2..77e289350c4 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #include "common.hpp" From f0b96d08ced12acfe6a8316eae9da7ff3fad1c81 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Wed, 18 Jan 2023 16:18:14 +0100 Subject: [PATCH 03/11] Convolution bwd data 3D fp16/fp32 client example. --- .../15_convnd_bwd_data/CMakeLists.txt | 5 + client_example/15_convnd_bwd_data/common.hpp | 234 ++++++++++++++++++ .../conv3d_bwd_data_fp16.cpp | 44 ++++ .../conv3d_bwd_data_fp32.cpp | 43 ++++ 4 files changed, 326 insertions(+) create mode 100644 client_example/15_convnd_bwd_data/CMakeLists.txt create mode 100644 client_example/15_convnd_bwd_data/common.hpp create mode 100644 client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp create mode 100644 client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp diff --git a/client_example/15_convnd_bwd_data/CMakeLists.txt b/client_example/15_convnd_bwd_data/CMakeLists.txt new file mode 100644 index 00000000000..8a60a71674f --- /dev/null +++ b/client_example/15_convnd_bwd_data/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp) +add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp) + +target_link_libraries(client_conv3d_bwd_data_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv3d_bwd_data_fp32 PRIVATE composable_kernel::device_operations) diff --git a/client_example/15_convnd_bwd_data/common.hpp b/client_example/15_convnd_bwd_data/common.hpp new file mode 100644 index 00000000000..669728a2f7f --- /dev/null +++ b/client_example/15_convnd_bwd_data/common.hpp @@ -0,0 +1,234 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +std::size_t GetFlops(ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::vector& output_spatial_lengths, + const std::vector& weights_spatial_lengths) +{ + // 2 * N * K * C * * + + return static_cast(2) * N * K * C * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::begin(weights_spatial_lengths), + std::end(weights_spatial_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t GetInputByte(ck::index_t N, + ck::index_t C, + const std::vector& input_spatial_lengths) +{ + // sizeof(InDataType) * (N * C * ) + + return sizeof(InDataType) * N * C * std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t GetWeightByte(ck::index_t K, + ck::index_t C, + const std::vector& weights_spatial_lengths) +{ + // sizeof(WeiDataType) * (K * C * ) + + return sizeof(WeiDataType) * K * C * std::accumulate(std::begin(weights_spatial_lengths), + std::end(weights_spatial_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t GetOutputByte(ck::index_t N, + ck::index_t K, + const std::vector& output_spatial_lengths) +{ + // sizeof(OutDataType) * (N * K * ); + return sizeof(OutDataType) * N * K * std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies()); +} + +template +bool run_conv_bwd_data(ck::index_t N, + ck::index_t K, + ck::index_t C, + const std::vector& in_spatial_lengths, + const std::vector& wei_spatial_lengths, + const std::vector& out_spatial_lengths) +{ + std::size_t in_mem_size = GetInputByte(N, C, in_spatial_lengths); + std::size_t wei_mem_size = GetWeightByte(K, C, wei_spatial_lengths); + std::size_t out_mem_size = GetOutputByte(N, K, out_spatial_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::vector filter_strides(NumDimSpatial, 1); + std::vector filter_dilations(NumDimSpatial, 1); + std::vector input_left_pads(NumDimSpatial, 1); + std::vector input_right_pads(NumDimSpatial, 1); + + using DeviceOp = ck::tensor_operation::device::DeviceConvBwdData; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + std::size_t flop = GetFlops(N, K, C, out_spatial_lengths, wei_spatial_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + out.GetDeviceBuffer(), + N, + K, + C, + in_spatial_lengths, + wei_spatial_lengths, + out_spatial_lengths, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + out.GetDeviceBuffer(), + N, + K, + C, + in_spatial_lengths, + wei_spatial_lengths, + out_spatial_lengths, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp new file mode 100644 index 00000000000..6c11705bf58 --- /dev/null +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::NDHWC; +using WeiLayout = ck::tensor_layout::convolution::KZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +int main() +{ + return run_conv_bwd_data( + N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} + diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp new file mode 100644 index 00000000000..5efcbdc697c --- /dev/null +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; + +using InLayout = ck::tensor_layout::convolution::NDHWC; +using WeiLayout = ck::tensor_layout::convolution::KZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +int main() +{ + return run_conv_bwd_data( + N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} From 3c0fcc528fd42f9cc8fe37c2dc0628e7c1bd323b Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 19 Jan 2023 17:48:42 +0100 Subject: [PATCH 04/11] Client example for convnd fwd fp16 fp32 --- client_example/16_convnd_fwd/CMakeLists.txt | 5 + client_example/16_convnd_fwd/common.hpp | 302 ++++++++++++++++++ .../16_convnd_fwd/conv3d_fwd_fp16.cpp | 44 +++ .../16_convnd_fwd/conv3d_fwd_fp32.cpp | 44 +++ 4 files changed, 395 insertions(+) create mode 100644 client_example/16_convnd_fwd/CMakeLists.txt create mode 100644 client_example/16_convnd_fwd/common.hpp create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp create mode 100644 client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp diff --git a/client_example/16_convnd_fwd/CMakeLists.txt b/client_example/16_convnd_fwd/CMakeLists.txt new file mode 100644 index 00000000000..e2580a370ca --- /dev/null +++ b/client_example/16_convnd_fwd/CMakeLists.txt @@ -0,0 +1,5 @@ +add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp) +add_executable(client_conv3d_fwd_fp32 conv3d_fwd_fp32.cpp) + +target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_operations) +target_link_libraries(client_conv3d_fwd_fp32 PRIVATE composable_kernel::device_operations) diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp new file mode 100644 index 00000000000..b4792220259 --- /dev/null +++ b/client_example/16_convnd_fwd/common.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +struct SimpleDeviceMem +{ + SimpleDeviceMem() = delete; + + SimpleDeviceMem(std::size_t mem_size) : p_mem_{} + { + (void)hipMalloc(static_cast(&p_mem_), mem_size); + } + + void* GetDeviceBuffer() { return p_mem_; } + + ~SimpleDeviceMem() { (void)hipFree(p_mem_); } + + void* p_mem_; +}; + +template +std::size_t GetFlops(const std::array& output_lengths, + const std::array& weights_lengths) +{ + // 2 * G * N * K * C * * + + ck::index_t G = weights_lengths[0]; + ck::index_t N = output_lengths[1]; + ck::index_t K = weights_lengths[1]; + ck::index_t C = weights_lengths[2]; + + return static_cast(2) * G * N * K * C * + std::accumulate(std::next(std::begin(output_lengths), 3), + std::end(output_lengths), + static_cast(1), + std::multiplies<>()) * + std::accumulate(std::next(std::begin(weights_lengths), 3), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t GetInputByte(const std::array& input_lengths) +{ + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), + std::end(input_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t GetWeightByte(const std::array& weights_lengths) +{ + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), + std::end(weights_lengths), + static_cast(1), + std::multiplies<>()); +} + +template +std::size_t GetOutputByte(const std::array& output_lengths) +{ + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), + std::end(output_lengths), + static_cast(1), + std::multiplies()); +} + +template +void print_array(const std::array& a) +{ + for(int i = 0; i < NumDimSpatial + 3; ++i) + { + std::cout << a[i] << ", "; + } + std::cout << std::endl; +} + +template +bool run_grouped_conv_fwd(std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) +{ + std::size_t in_mem_size = GetInputByte(in_lengths); + std::size_t wei_mem_size = GetWeightByte(wei_lengths); + std::size_t out_mem_size = GetOutputByte(out_lengths); + + SimpleDeviceMem in(in_mem_size); + SimpleDeviceMem wei(wei_mem_size); + SimpleDeviceMem out(out_mem_size); + + std::array in_strides; + std::array wei_strides; + std::array out_strides; + in_strides.fill(0); + wei_strides.fill(0); + out_strides.fill(0); + in_strides.back() = 1; + wei_strides.back() = 1; + out_strides.back() = 1; + + std::partial_sum(rbegin(in_lengths), + std::prev(rend(in_lengths)), + std::next(rbegin(in_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(wei_lengths), + std::prev(rend(wei_lengths)), + std::next(rbegin(wei_strides)), + std::multiplies<>{}); + std::partial_sum(rbegin(out_lengths), + std::prev(rend(out_lengths)), + std::next(rbegin(out_strides)), + std::multiplies<>{}); + + // transpose GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(rbegin(in_lengths), + std::next(rbegin(in_lengths)), + std::next(rbegin(in_lengths), NumDimSpatial + 1)); + std::rotate(rbegin(in_strides), + std::next(rbegin(in_strides)), + std::next(rbegin(in_strides), NumDimSpatial + 1)); + std::rotate(rbegin(wei_lengths), + std::next(rbegin(wei_lengths)), + std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + std::rotate(rbegin(wei_strides), + std::next(rbegin(wei_strides)), + std::next(rbegin(wei_strides), NumDimSpatial + 1)); + std::rotate(rbegin(out_lengths), + std::next(rbegin(out_lengths)), + std::next(rbegin(out_lengths), NumDimSpatial + 1)); + std::rotate(rbegin(out_strides), + std::next(rbegin(out_strides)), + std::next(rbegin(out_strides), NumDimSpatial + 1)); + + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + conv_filter_strides.fill(1); + conv_filter_dilations.fill(1); + input_left_pads.fill(1); + input_right_pads.fill(1); + + print_array(in_lengths); + print_array(in_strides); + print_array(wei_lengths); + print_array(wei_strides); + print_array(out_lengths); + print_array(out_strides); + + std::size_t flop = GetFlops(out_lengths, wei_lengths); + std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; + + using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD, + OutLayout, + InDataType, + WeiDataType, + ck::Tuple<>, + OutDataType, + PassThrough, + PassThrough, + PassThrough>; + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + std::string best_op_name; + int best_op_id = -1; + float best_avg_time = std::numeric_limits::max(); + float best_gb_per_sec = 0; + float best_tflops = 0; + + // profile device operation instances + std::cout << "Run all instances and do timing" << std::endl; + + for(int i = 0; i < op_ptrs.size(); ++i) + { + auto& op_ptr = op_ptrs[i]; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); + + float tflops = static_cast(flop) / 1.E9 / avg_time; + float gb_per_sec = num_bytes / 1.E6 / avg_time; + + std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_id = i; + best_op_name = op_name; + best_avg_time = avg_time; + best_gb_per_sec = gb_per_sec; + best_tflops = tflops; + } + } + else + { + std::cerr << op_name << " does not support this problem" << std::endl; + } + } + + if(best_op_id < 0) + { + std::cerr << "no suitable instance" << std::endl; + return false; + } + + std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops + << " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + // run the best intance + { + auto& op_ptr = op_ptrs[best_op_id]; + std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() + << std::endl; + auto argument_ptr = op_ptr->MakeArgumentPointer( + in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + std::array{}, + out.GetDeviceBuffer(), + in_lengths, + in_strides, + wei_lengths, + wei_strides, + std::array, 0>{{}}, + std::array, 0>{{}}, + out_lengths, + out_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false}); + } + + std::cout << "Done" << std::endl; + } + return true; +} diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp new file mode 100644 index 00000000000..7098afc8104 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = ck::half_t; +using WeiDataType = ck::half_t; +using OutDataType = ck::half_t; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 3; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 3; + +int main() +{ + return run_grouped_conv_fwd( + {G, N, Di, Hi, Wi, C}, {G, K, Z, Y, X, C}, {G, N, Do, Ho, Wo, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp new file mode 100644 index 00000000000..551eecae894 --- /dev/null +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +using InDataType = float; +using WeiDataType = float; +using OutDataType = float; + +using InLayout = ck::tensor_layout::convolution::GNDHWC; +using WeiLayout = ck::tensor_layout::convolution::GKZYXC; +using OutLayout = ck::tensor_layout::convolution::GNDHWK; + +static constexpr ck::index_t NumDimSpatial = 3; +static constexpr ck::index_t G = 1; +static constexpr ck::index_t N = 64; +static constexpr ck::index_t K = 128; +static constexpr ck::index_t C = 64; +static constexpr ck::index_t Z = 3; +static constexpr ck::index_t Y = 3; +static constexpr ck::index_t X = 3; +static constexpr ck::index_t Di = 28; +static constexpr ck::index_t Hi = 28; +static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Do = 28; +static constexpr ck::index_t Ho = 28; +static constexpr ck::index_t Wo = 28; + +int main() +{ + return run_grouped_conv_fwd( + {G, N, Di, Hi, Wi, C}, {G, K, Z, Y, X, C}, {G, N, Do, Ho, Wo, K}) + ? EXIT_SUCCESS + : EXIT_FAILURE; +} From daa1c99a7f28200c1f4bb66f1fdf4f61eebce28a Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 20 Jan 2023 10:18:13 +0100 Subject: [PATCH 05/11] clang-format --- client_example/15_convnd_bwd_data/common.hpp | 115 +++++++++--------- .../conv3d_bwd_data_fp16.cpp | 14 +-- .../conv3d_bwd_data_fp32.cpp | 13 +- 3 files changed, 69 insertions(+), 73 deletions(-) diff --git a/client_example/15_convnd_bwd_data/common.hpp b/client_example/15_convnd_bwd_data/common.hpp index 669728a2f7f..9799fb73a5a 100644 --- a/client_example/15_convnd_bwd_data/common.hpp +++ b/client_example/15_convnd_bwd_data/common.hpp @@ -32,9 +32,9 @@ struct SimpleDeviceMem void* p_mem_; }; -std::size_t GetFlops(ck::index_t N, +std::size_t GetFlops(ck::index_t N, ck::index_t K, - ck::index_t C, + ck::index_t C, const std::vector& output_spatial_lengths, const std::vector& weights_spatial_lengths) { @@ -52,39 +52,39 @@ std::size_t GetFlops(ck::index_t N, } template -std::size_t GetInputByte(ck::index_t N, - ck::index_t C, - const std::vector& input_spatial_lengths) +std::size_t +GetInputByte(ck::index_t N, ck::index_t C, const std::vector& input_spatial_lengths) { // sizeof(InDataType) * (N * C * ) + - return sizeof(InDataType) * N * C * std::accumulate(std::begin(input_spatial_lengths), - std::end(input_spatial_lengths), - static_cast(1), - std::multiplies<>()); + return sizeof(InDataType) * N * C * + std::accumulate(std::begin(input_spatial_lengths), + std::end(input_spatial_lengths), + static_cast(1), + std::multiplies<>()); } template -std::size_t GetWeightByte(ck::index_t K, - ck::index_t C, - const std::vector& weights_spatial_lengths) +std::size_t +GetWeightByte(ck::index_t K, ck::index_t C, const std::vector& weights_spatial_lengths) { // sizeof(WeiDataType) * (K * C * ) + - return sizeof(WeiDataType) * K * C * std::accumulate(std::begin(weights_spatial_lengths), - std::end(weights_spatial_lengths), - static_cast(1), - std::multiplies<>()); + return sizeof(WeiDataType) * K * C * + std::accumulate(std::begin(weights_spatial_lengths), + std::end(weights_spatial_lengths), + static_cast(1), + std::multiplies<>()); } template -std::size_t GetOutputByte(ck::index_t N, - ck::index_t K, - const std::vector& output_spatial_lengths) +std::size_t +GetOutputByte(ck::index_t N, ck::index_t K, const std::vector& output_spatial_lengths) { // sizeof(OutDataType) * (N * K * ); - return sizeof(OutDataType) * N * K * std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), - static_cast(1), - std::multiplies()); + return sizeof(OutDataType) * N * K * + std::accumulate(std::begin(output_spatial_lengths), + std::end(output_spatial_lengths), + static_cast(1), + std::multiplies()); } template & wei_spatial_lengths, const std::vector& out_spatial_lengths) { - std::size_t in_mem_size = GetInputByte(N, C, in_spatial_lengths); + std::size_t in_mem_size = GetInputByte(N, C, in_spatial_lengths); std::size_t wei_mem_size = GetWeightByte(K, C, wei_spatial_lengths); std::size_t out_mem_size = GetOutputByte(N, K, out_spatial_lengths); SimpleDeviceMem in(in_mem_size); SimpleDeviceMem wei(wei_mem_size); SimpleDeviceMem out(out_mem_size); - + std::vector filter_strides(NumDimSpatial, 1); std::vector filter_dilations(NumDimSpatial, 1); std::vector input_left_pads(NumDimSpatial, 1); @@ -136,10 +136,9 @@ bool run_conv_bwd_data(ck::index_t N, float best_gb_per_sec = 0; float best_tflops = 0; - std::size_t flop = GetFlops(N, K, C, out_spatial_lengths, wei_spatial_lengths); + std::size_t flop = GetFlops(N, K, C, out_spatial_lengths, wei_spatial_lengths); std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; - // profile device operation instances std::cout << "Run all instances and do timing" << std::endl; @@ -147,21 +146,21 @@ bool run_conv_bwd_data(ck::index_t N, { auto& op_ptr = op_ptrs[i]; auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - out.GetDeviceBuffer(), - N, - K, - C, - in_spatial_lengths, - wei_spatial_lengths, - out_spatial_lengths, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); + wei.GetDeviceBuffer(), + out.GetDeviceBuffer(), + N, + K, + C, + in_spatial_lengths, + wei_spatial_lengths, + out_spatial_lengths, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); std::string op_name = op_ptr->GetTypeString(); @@ -204,22 +203,22 @@ bool run_conv_bwd_data(ck::index_t N, auto& op_ptr = op_ptrs[best_op_id]; std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() << std::endl; - auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), - wei.GetDeviceBuffer(), - out.GetDeviceBuffer(), - N, - K, - C, - in_spatial_lengths, - wei_spatial_lengths, - out_spatial_lengths, - filter_strides, - filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); + auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), + wei.GetDeviceBuffer(), + out.GetDeviceBuffer(), + N, + K, + C, + in_spatial_lengths, + wei_spatial_lengths, + out_spatial_lengths, + filter_strides, + filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp index 6c11705bf58..5210567241e 100644 --- a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp16.cpp @@ -10,9 +10,9 @@ using InDataType = ck::half_t; using WeiDataType = ck::half_t; using OutDataType = ck::half_t; -using InLayout = ck::tensor_layout::convolution::NDHWC; -using WeiLayout = ck::tensor_layout::convolution::KZYXC; -using OutLayout = ck::tensor_layout::convolution::NDHWK; +using InLayout = ck::tensor_layout::convolution::NDHWC; +using WeiLayout = ck::tensor_layout::convolution::KZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWK; static constexpr ck::index_t NumDimSpatial = 3; static constexpr ck::index_t N = 64; @@ -36,9 +36,7 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>( - N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo}) - ? EXIT_SUCCESS - : EXIT_FAILURE; + OutLayout>(N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } - diff --git a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp index 5efcbdc697c..441bdfe7bec 100644 --- a/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp +++ b/client_example/15_convnd_bwd_data/conv3d_bwd_data_fp32.cpp @@ -10,9 +10,9 @@ using InDataType = float; using WeiDataType = float; using OutDataType = float; -using InLayout = ck::tensor_layout::convolution::NDHWC; -using WeiLayout = ck::tensor_layout::convolution::KZYXC; -using OutLayout = ck::tensor_layout::convolution::NDHWK; +using InLayout = ck::tensor_layout::convolution::NDHWC; +using WeiLayout = ck::tensor_layout::convolution::KZYXC; +using OutLayout = ck::tensor_layout::convolution::NDHWK; static constexpr ck::index_t NumDimSpatial = 3; static constexpr ck::index_t N = 64; @@ -36,8 +36,7 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>( - N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo}) - ? EXIT_SUCCESS - : EXIT_FAILURE; + OutLayout>(N, K, C, {Di, Hi, Wi}, {Z, Y, X}, {Do, Ho, Wo}) + ? EXIT_SUCCESS + : EXIT_FAILURE; } From bb692de6e9f50f606cc21d110317e8fa9e32bbd5 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 23 Jan 2023 10:17:43 +0100 Subject: [PATCH 06/11] Review remarks. --- .../grouped_conv3d_bwd_weight_fp16.cpp | 2 +- .../grouped_conv3d_bwd_weight_fp32.cpp | 2 +- client_example/16_convnd_fwd/common.hpp | 53 +++++++++---------- .../16_convnd_fwd/conv3d_fwd_fp32.cpp | 4 +- 4 files changed, 28 insertions(+), 33 deletions(-) diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp index e95915bf7a1..2f2b5d4e211 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -50,4 +50,4 @@ int main() {1, 1, 1}) ? EXIT_SUCCESS : EXIT_FAILURE; -} \ No newline at end of file +} diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp index 77e289350c4..796311d2318 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -50,4 +50,4 @@ int main() {1, 1, 1}) ? EXIT_SUCCESS : EXIT_FAILURE; -} \ No newline at end of file +} diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index b4792220259..2f5d92d0358 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -32,9 +32,10 @@ struct SimpleDeviceMem void* p_mem_; }; -template -std::size_t GetFlops(const std::array& output_lengths, - const std::array& weights_lengths) +template +std::size_t +GetFlops(const std::array& output_lengths, + const std::array& weights_lengths) { // 2 * G * N * K * C * * @@ -44,18 +45,19 @@ std::size_t GetFlops(const std::array& output_le ck::index_t C = weights_lengths[2]; return static_cast(2) * G * N * K * C * - std::accumulate(std::next(std::begin(output_lengths), 3), + std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim), std::end(output_lengths), static_cast(1), std::multiplies<>()) * - std::accumulate(std::next(std::begin(weights_lengths), 3), + std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim), std::end(weights_lengths), static_cast(1), std::multiplies<>()); } template -std::size_t GetInputByte(const std::array& input_lengths) +std::size_t +GetInputByte(const std::array& input_lengths) { // sizeof(InDataType) * (G * N * C * ) + return sizeof(InDataType) * std::accumulate(std::begin(input_lengths), @@ -65,7 +67,8 @@ std::size_t GetInputByte(const std::array& input } template -std::size_t GetWeightByte(const std::array& weights_lengths) +std::size_t +GetWeightByte(const std::array& weights_lengths) { // sizeof(WeiDataType) * (G * K * C * ) + return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths), @@ -75,7 +78,8 @@ std::size_t GetWeightByte(const std::array& weig } template -std::size_t GetOutputByte(const std::array& output_lengths) +std::size_t +GetOutputByte(const std::array& output_lengths) { // sizeof(OutDataType) * (G * N * K * ); return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths), @@ -84,26 +88,17 @@ std::size_t GetOutputByte(const std::array& outp std::multiplies()); } -template -void print_array(const std::array& a) -{ - for(int i = 0; i < NumDimSpatial + 3; ++i) - { - std::cout << a[i] << ", "; - } - std::cout << std::endl; -} - template -bool run_grouped_conv_fwd(std::array in_lengths, - std::array wei_lengths, - std::array out_lengths) + typename OutLayout, + ck::index_t NumNonSpatialDim = 3> +bool run_grouped_conv_fwd(std::array in_lengths, + std::array wei_lengths, + std::array out_lengths) { std::size_t in_mem_size = GetInputByte(in_lengths); std::size_t wei_mem_size = GetWeightByte(wei_lengths); @@ -113,9 +108,9 @@ bool run_grouped_conv_fwd(std::array in_lengths, SimpleDeviceMem wei(wei_mem_size); SimpleDeviceMem out(out_mem_size); - std::array in_strides; - std::array wei_strides; - std::array out_strides; + std::array in_strides; + std::array wei_strides; + std::array out_strides; in_strides.fill(0); wei_strides.fill(0); out_strides.fill(0); @@ -214,8 +209,8 @@ bool run_grouped_conv_fwd(std::array in_lengths, in_strides, wei_lengths, wei_strides, - std::array, 0>{{}}, - std::array, 0>{{}}, + std::array, 0>{{}}, + std::array, 0>{{}}, out_lengths, out_strides, conv_filter_strides, @@ -277,8 +272,8 @@ bool run_grouped_conv_fwd(std::array in_lengths, in_strides, wei_lengths, wei_strides, - std::array, 0>{{}}, - std::array, 0>{{}}, + std::array, 0>{{}}, + std::array, 0>{{}}, out_lengths, out_strides, conv_filter_strides, diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp index 551eecae894..379ee8c94fc 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -24,10 +24,10 @@ static constexpr ck::index_t Y = 3; static constexpr ck::index_t X = 3; static constexpr ck::index_t Di = 28; static constexpr ck::index_t Hi = 28; -static constexpr ck::index_t Wi = 28; +static constexpr ck::index_t Wi = 3; static constexpr ck::index_t Do = 28; static constexpr ck::index_t Ho = 28; -static constexpr ck::index_t Wo = 28; +static constexpr ck::index_t Wo = 3; int main() { From 98c0a911a10d877b20df2787067a718d5fd9effc Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Mon, 23 Jan 2023 11:47:49 +0100 Subject: [PATCH 07/11] Fix compiler err. --- client_example/16_convnd_fwd/common.hpp | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index 2f5d92d0358..ca23467753b 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -55,7 +55,7 @@ GetFlops(const std::array& output std::multiplies<>()); } -template +template std::size_t GetInputByte(const std::array& input_lengths) { @@ -66,7 +66,7 @@ GetInputByte(const std::array& in std::multiplies<>()); } -template +template std::size_t GetWeightByte(const std::array& weights_lengths) { @@ -77,7 +77,7 @@ GetWeightByte(const std::array& w std::multiplies<>()); } -template +template std::size_t GetOutputByte(const std::array& output_lengths) { @@ -160,13 +160,6 @@ bool run_grouped_conv_fwd(std::array(in_lengths); - print_array(in_strides); - print_array(wei_lengths); - print_array(wei_strides); - print_array(out_lengths); - print_array(out_strides); - std::size_t flop = GetFlops(out_lengths, wei_lengths); std::size_t num_bytes = in_mem_size + wei_mem_size + out_mem_size; From 034c1622651b8312d94ed3b911d203bbb0768197 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 3 Feb 2023 15:26:08 +0100 Subject: [PATCH 08/11] Update data layout to standard one. --- client_example/16_convnd_fwd/common.hpp | 18 ++++++++++++++++-- .../16_convnd_fwd/conv3d_fwd_fp16.cpp | 8 ++++---- .../16_convnd_fwd/conv3d_fwd_fp32.cpp | 8 ++++---- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/client_example/16_convnd_fwd/common.hpp b/client_example/16_convnd_fwd/common.hpp index ca23467753b..a6bb5aa65be 100644 --- a/client_example/16_convnd_fwd/common.hpp +++ b/client_example/16_convnd_fwd/common.hpp @@ -38,7 +38,6 @@ GetFlops(const std::array& output const std::array& weights_lengths) { // 2 * G * N * K * C * * - ck::index_t G = weights_lengths[0]; ck::index_t N = output_lengths[1]; ck::index_t K = weights_lengths[1]; @@ -131,22 +130,37 @@ bool run_grouped_conv_fwd(std::array{}); - // transpose GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + // transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW + std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths)); std::rotate(rbegin(in_lengths), std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), NumDimSpatial + 1)); + + std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides)); std::rotate(rbegin(in_strides), std::next(rbegin(in_strides)), std::next(rbegin(in_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 2), rend(wei_lengths)); std::rotate(rbegin(wei_lengths), std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 2), rend(wei_strides)); std::rotate(rbegin(wei_strides), std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths)); std::rotate(rbegin(out_lengths), std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), NumDimSpatial + 1)); + + std::rotate( + std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides)); std::rotate(rbegin(out_strides), std::next(rbegin(out_strides)), std::next(rbegin(out_strides), NumDimSpatial + 1)); diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp index 7098afc8104..10f914bbee3 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp @@ -10,9 +10,9 @@ using InDataType = ck::half_t; using WeiDataType = ck::half_t; using OutDataType = ck::half_t; -using InLayout = ck::tensor_layout::convolution::GNDHWC; -using WeiLayout = ck::tensor_layout::convolution::GKZYXC; -using OutLayout = ck::tensor_layout::convolution::GNDHWK; +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::KZYXGC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; static constexpr ck::index_t NumDimSpatial = 3; static constexpr ck::index_t G = 1; @@ -38,7 +38,7 @@ int main() InLayout, WeiLayout, OutLayout>( - {G, N, Di, Hi, Wi, C}, {G, K, Z, Y, X, C}, {G, N, Do, Ho, Wo, K}) + {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp index 379ee8c94fc..43c98f1e9b8 100644 --- a/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp +++ b/client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp @@ -10,9 +10,9 @@ using InDataType = float; using WeiDataType = float; using OutDataType = float; -using InLayout = ck::tensor_layout::convolution::GNDHWC; -using WeiLayout = ck::tensor_layout::convolution::GKZYXC; -using OutLayout = ck::tensor_layout::convolution::GNDHWK; +using InLayout = ck::tensor_layout::convolution::NDHWGC; +using WeiLayout = ck::tensor_layout::convolution::KZYXGC; +using OutLayout = ck::tensor_layout::convolution::NDHWGK; static constexpr ck::index_t NumDimSpatial = 3; static constexpr ck::index_t G = 1; @@ -38,7 +38,7 @@ int main() InLayout, WeiLayout, OutLayout>( - {G, N, Di, Hi, Wi, C}, {G, K, Z, Y, X, C}, {G, N, Do, Ho, Wo, K}) + {N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K}) ? EXIT_SUCCESS : EXIT_FAILURE; } From 862982a9ec527abc5ec17873fddb8da558ce5dea Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 3 Feb 2023 15:26:32 +0100 Subject: [PATCH 09/11] Add conv 3d fwd NDHWGC instances --- .../gpu/grouped_convolution_forward.hpp | 82 +++++++++++++++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 5 ++ 2 files changed, 87 insertions(+) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index ee38b738274..a8df7f0d5bf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -244,6 +244,63 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances( PassThrough, PassThrough>>>& instances); +// grouped conv3d forward, NDHWGC/KZYXGC/NDHWGK +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( + std::vector>>& instances); + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( + std::vector>>& instances); + template && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(op_ptrs); + } + else if constexpr(is_same_v && + is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances(op_ptrs); + } + } return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index 78eedca5f76..90efc09ee75 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -3,4 +3,9 @@ add_instance_library(device_grouped_conv3d_fwd_instance device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp + + device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp + device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp ) From b40861c38e956dea967c91d8ecb8d4e7428e147d Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 3 Feb 2023 15:27:14 +0100 Subject: [PATCH 10/11] clang-format --- .../46_gemm_add_multiply/run_gemm_add_multiply_example.inc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc index 4f7a8a4ca73..e1b2bccfe11 100644 --- a/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc +++ b/example/46_gemm_add_multiply/run_gemm_add_multiply_example.inc @@ -53,7 +53,6 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); - a_device_buf.ToDevice(a_m_k.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data()); @@ -84,8 +83,8 @@ bool run_gemm_add_multiply(const ProblemSize& problem_size, const ExecutionConfi if(!device_op.IsSupportedArgument(argument)) { - std::cout << "wrong! this device_op instance does not support this problem" << std::endl; - return true; + std::cout << "wrong! this device_op instance does not support this problem" << std::endl; + return true; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); From ecacb72f080cfc65377861c65165f8ddfdcbbeb4 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Fri, 3 Feb 2023 15:27:31 +0100 Subject: [PATCH 11/11] Conv3d fwd NDHWGC instances. --- ...xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp | 129 ++++++++++++++++++ ..._xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp | 129 ++++++++++++++++++ ..._xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp | 128 +++++++++++++++++ ...xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp | 125 +++++++++++++++++ 4 files changed, 511 insertions(+) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..8c384937352 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp new file mode 100644 index 00000000000..487cd22721a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp @@ -0,0 +1,129 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp new file mode 100644 index 00000000000..d497cd57edf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp @@ -0,0 +1,128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances = + std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp new file mode 100644 index 00000000000..2e53fbbda5c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NDHWGC = ck::tensor_layout::convolution::NDHWGC; +using KZYXGC = ck::tensor_layout::convolution::KZYXGC; +using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +// in[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = out[g, n, do, ho, wo, k] +using device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances = std::tuple< + // clang-format off + // Default + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + + // Filter1x1Stride1Pad0 + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 3, NDHWGC, KZYXGC, Empty_Tuple, NDHWGK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck