Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions client_example/16_convnd_fwd/conv3d_fwd_fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;

using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::KZYXGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;

static constexpr ck::index_t NumDimSpatial = 3;
Expand All @@ -38,7 +38,7 @@ int main()
InLayout,
WeiLayout,
OutLayout>(
{N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K})
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
4 changes: 2 additions & 2 deletions client_example/16_convnd_fwd/conv3d_fwd_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using WeiDataType = float;
using OutDataType = float;

using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::KZYXGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;

static constexpr ck::index_t NumDimSpatial = 3;
Expand All @@ -38,7 +38,7 @@ int main()
InLayout,
WeiLayout,
OutLayout>(
{N, Di, Hi, Wi, G, C}, {K, Z, Y, X, G, C}, {N, Do, Ho, Wo, G, K})
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,11 @@ void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instances(
PassThrough,
PassThrough>>>& instances);

// grouped conv3d forward, NDHWGC/KZYXGC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
KZYXGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
BF16,
Expand All @@ -260,10 +260,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(
PassThrough,
PassThrough>>>& instances);

void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
KZYXGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F16,
Expand All @@ -274,10 +274,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(
PassThrough,
PassThrough>>>& instances);

void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
KZYXGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
F32,
Expand All @@ -288,10 +288,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(
PassThrough,
PassThrough>>>& instances);

void add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances(
void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<3,
NDHWGC,
KZYXGC,
GKZYXC,
Empty_Tuple,
NDHWGK,
int8_t,
Expand Down Expand Up @@ -433,28 +433,28 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
}
}
else if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, KZYXGC> && is_same_v<OutLayout, NDHWGK>)
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instances(op_ptrs);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ add_instance_library(device_grouped_conv3d_fwd_instance
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_int8_instance.cpp

device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_bf16_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f16_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_f32_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_kzyxgc_ndhwgk_int8_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp
)
Loading