Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions example/01_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)

add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32)

add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)

Expand Down
16 changes: 12 additions & 4 deletions example/01_gemm/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,14 @@ bool parse_cmd_args<ProblemSizeSplitK>(int argc,
return true;
}

template <typename DataType>
template <typename DataType, typename ComputeDataType = DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
Expand Down Expand Up @@ -351,10 +355,14 @@ inline __host__ __device__ constexpr double get_rtol()
}
}

template <typename DataType>
template <typename DataType, typename ComputeDataType = DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<ComputeDataType, ck::tf32_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
Expand Down
85 changes: 85 additions & 0 deletions example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>

#include "common.hpp"

#define USING_DIRECT_LOADS 1
#if USING_DIRECT_LOADS
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp"
#else
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
#endif

#define EXAMPLE_WITH_COMPUTE_DATATYPE

using F32 = float;

using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using CDataType = F32;
using ComputeDataType = ck::tf32_t;

using ALayout = Row;
using BLayout = Col;
using CLayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

#if USING_DIRECT_LOADS
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad
Comment thread
linqun marked this conversation as resolved.
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer|
// ######| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds|
// ######| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler | pipeline ver | gemm type |
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block|
// ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32,
8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1,
1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>;
// clang-format on
#else
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
#endif
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp,
ComputeDataType,
ComputeDataType>;

using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;

#include "run_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }

#undef EXAMPLE_WITH_COMPUTE_DATATYPE
13 changes: 9 additions & 4 deletions example/01_gemm/run_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
#pragma once
#include "ck/library/utility/validation_common.hpp"

// use macro to minimize code change
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
using ComputeDataType = AccDataType;
#endif

template <typename ProblemType>
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
{
Expand Down Expand Up @@ -218,8 +223,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
get_rtol<CDataType, ComputeDataType>(),
get_atol<CDataType, ComputeDataType>());
#endif
}

Expand Down Expand Up @@ -249,8 +254,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
pass &= ck::utils::check_err(c_m_n_device_result,
c_m_n_device_ref_result,
"Error: Incorrect results!",
get_rtol<CDataType>(),
get_atol<CDataType>());
get_rtol<CDataType, ComputeDataType>(),
get_atol<CDataType, ComputeDataType>());
}

return pass == true;
Expand Down
3 changes: 2 additions & 1 deletion example/09_convnd_fwd/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp)
add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp)
add_example_executable(example_convnd_fwd_xdl_bf16 convnd_fwd_xdl_bf16.cpp)
add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp)
Expand All @@ -19,4 +20,4 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp)
set(target 1)
endif()
endforeach()
endforeach()
29 changes: 21 additions & 8 deletions example/09_convnd_fwd/convnd_fwd_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,14 @@ void print_helper_msg()
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}

template <typename DataType>
template <typename DataType, typename GemmType = DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
{
return 5e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
Expand Down Expand Up @@ -68,10 +72,14 @@ inline __host__ __device__ constexpr double get_rtol()
}
}

template <typename DataType>
template <typename DataType, typename GemmType = DataType>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Compute Type

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
{
return 1e-2;
}
else if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
Expand Down Expand Up @@ -116,7 +124,8 @@ template <ck::index_t NDimSpatial,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
typename DeviceConvNDFwdInstance,
typename ComputeDataType = OutDataType>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
Expand Down Expand Up @@ -228,7 +237,11 @@ bool run_grouped_conv_fwd(bool do_verification,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp>();
OutElementOp,
0,
0,
0,
ComputeDataType>();

auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
Expand All @@ -249,8 +262,8 @@ bool run_grouped_conv_fwd(bool do_verification,
return ck::utils::check_err(out_device,
out_host,
"Error: incorrect results!",
get_rtol<OutDataType>(),
get_atol<OutDataType>());
get_rtol<OutDataType, ComputeDataType>(),
get_atol<OutDataType, ComputeDataType>());
}

return true;
Expand Down
89 changes: 89 additions & 0 deletions example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include "convnd_fwd_common.hpp"

#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"

#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"

#define EXAMPLE_WITH_COMPUTE_DATATYPE

using InDataType = float;
using WeiDataType = float;
using AccDataType = float;
using CShuffleDataType = float;
using OutDataType = float;
using ComputeDataType = ck::tf32_t;

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;

static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;

static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;

template <ck::index_t NDimSpatial, typename InLayout, typename WeiLayout, typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout, // ALayout
WeiLayout, // BLayout
ck::Tuple<>, // DsLayout
OutLayout, // ELayout
InDataType, // ADataType
WeiDataType, // BDataType
AccDataType, // AccDataType
CShuffleDataType, // CShuffleDataType
ck::Tuple<>, // DsDataType
OutDataType, // EDataType
InElementOp, // AElementwiseOperation
WeiElementOp, // BElementwiseOperation
OutElementOp, // CDEElementwiseOperation
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
192, // NPerBlock
16, // KPerBlock
4, // AK1
4, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
3, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
4, // CDEBlockTransferScalarPerVector_NPerBlock
ComputeDataType, // AComputeDataType
ComputeDataType, // BComputeDataType
ck::LoopScheduler::Default, // LoopScheduler
1 // NumGroupsToMerge
>;

#include "run_convnd_fwd_example.inc"

int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }

#undef EXAMPLE_WITH_COMPUTE_DATATYPE
4 changes: 4 additions & 0 deletions example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"

#define EXAMPLE_WITH_COMPUTE_DATATYPE

using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
Expand Down Expand Up @@ -87,3 +89,5 @@ int main(int argc, char* argv[])
}
return run_convnd_fwd_example(argc, argv) ? 0 : 1;
}

#undef EXAMPLE_WITH_COMPUTE_DATATYPE
Loading
Loading