Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/01_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()

list(APPEND gpu_list_tf32 gfx942)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
Expand Down
2 changes: 1 addition & 1 deletion example/09_convnd_fwd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()

list(APPEND gpu_list_tf32 gfx942)
list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
Expand Down
2 changes: 1 addition & 1 deletion example/09_convnd_fwd/convnd_fwd_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
{
return 1e-2;
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, float>)
{
Expand Down
10 changes: 10 additions & 0 deletions example/15_grouped_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,13 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4)
endif()

list(APPEND gpu_list_tf32 gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0)
add_example_executable(example_grouped_gemm_xdl_fp32_tf32 grouped_gemm_xdl_fp32_tf32.cpp)
add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_fp32_tf32)
set(target 1)
endif()
endforeach()
66 changes: 66 additions & 0 deletions example/15_grouped_gemm/grouped_gemm_xdl_fp32_tf32.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.

#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>

#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"

#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"

#define EXAMPLE_WITH_COMPUTE_DATATYPE

template <ck::index_t... Is>
using S = ck::Sequence<Is...>;

using F32 = float;

using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ADataType = F32;
using BDataType = F32;
using AccDataType = F32;
using CShuffleDataType = F32;
using DsDataType = ck::Tuple<>;
using EDataType = F32;
using ComputeDataType = ck::tf32_t;

using ALayout = Row;
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;

using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = PassThrough;

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;

using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 16, 16, 8, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4, ck::LoopScheduler::Default, ComputeDataType>;
// clang-format on

#include "run_grouped_gemm_example.inc"

int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); }

#undef EXAMPLE_WITH_COMPUTE_DATATYPE
12 changes: 10 additions & 2 deletions example/15_grouped_gemm/run_grouped_gemm_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

#pragma once

// use macro to minimize code change
#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE
using ComputeDataType = AccDataType;
#endif

struct ProblemSize final
{
std::vector<ck::index_t> Ms;
Expand Down Expand Up @@ -231,7 +236,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
AccDataType,
AElementOp,
BElementOp,
CDEElementOp>;
CDEElementOp,
ComputeDataType>;

for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
Expand All @@ -253,7 +259,9 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]);

#else
pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]);
pass &= ck::utils::check_err<decltype(c_device_tensors[i]),
decltype(c_host_tensors[i]),
ComputeDataType>(c_device_tensors[i], c_host_tensors[i]);
#endif
}
}
Expand Down
5 changes: 4 additions & 1 deletion include/ck/host_utility/device_prop.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ inline bool is_wmma_supported()
return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported();
}

inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; }
inline bool is_tf32_supported()
{
return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950";
}

} // namespace ck
#endif
4 changes: 2 additions & 2 deletions include/ck/library/utility/check_err.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ typename std::enable_if<
check_err(const Range& out,
const RefRange& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-5,
double atol = 3e-5)
double rtol = 5e-4,
double atol = 5e-4)
{
if(out.size() != ref.size())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ template <typename ALayout,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
typename CElementwiseOperation,
typename ComputeDataType = ADataType>
struct DeviceGroupedGemm : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler()>
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeDataType = ADataType>
struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
Expand All @@ -145,7 +146,8 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
CDEElementwiseOperation,
ComputeDataType>
{
using DeviceOp = DeviceGroupedGemm_Xdl;
GET_NXDL_PER_WAVE_IMPL
Expand Down Expand Up @@ -233,8 +235,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));

using ComputeDataType = ADataType;

// GridwiseGemm
template <index_t NXdlPerWave_>
using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle<
Expand Down
63 changes: 59 additions & 4 deletions include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ enum struct MfmaInstr
mfma_f32_16x16x128f8f6f4,
mfma_scale_f32_32x32x64f8f6f4,
mfma_scale_f32_16x16x128f8f6f4,
mfma_f32_16x16x8xf32, // tf32
mfma_f32_32x32x4xf32,
mfma_f32_16x16x8xf32, // tf32 on gfx942
mfma_f32_32x32x4xf32, // tf32 on gfx942
mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950
mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950
// gfx11
wmma_f32_16x16x16_f16,
wmma_f32_16x16x16_bf16,
Expand Down Expand Up @@ -1015,6 +1017,51 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4xf32>
}
};

template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16xf32>
{
// gfx950 specific: use bf16x3 simulate tf32
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;

template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32xf32>
{
// gfx950 specific: use bf16x3 simulate tf32
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;

template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32xf32<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};

// gfx11
struct mfma_type_gfx11_base
{
Expand Down Expand Up @@ -1275,12 +1322,14 @@ struct MfmaSelector
}

template <>
constexpr auto GetMfma<tf32_t, 32, 32>()
constexpr auto GetMfma<tf32_t, 32, 32, tf32_t>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_32x32x16xf32;
#elif defined(__gfx942__)
return MfmaInstr::mfma_f32_32x32x4xf32;
#else
Expand All @@ -1289,12 +1338,14 @@ struct MfmaSelector
}

template <>
constexpr auto GetMfma<tf32_t, 16, 16>()
constexpr auto GetMfma<tf32_t, 16, 16, tf32_t>()
{
#if defined(__gfx12__)
return MfmaInstr::wmma_unsupport_16x16_gfx12;
#elif defined(__gfx11__)
return MfmaInstr::wmma_unsupport_16x16_gfx11;
#elif defined(__gfx950__)
return MfmaInstr::mfma_f32_16x16x32xf32;
#elif defined(__gfx942__)
return MfmaInstr::mfma_f32_16x16x8xf32;
#else
Expand Down Expand Up @@ -2185,6 +2236,10 @@ struct XdlopsGemm
(is_same<base_type, int8_t>::value && KPack <= 8) ||
((is_same<base_type, f8_t>::value || is_same<base_type, bf8_t>::value) && KPack < 32) ||
is_same<additional_type, pk_i4_t>::value)
#if defined(__gfx950__)
// tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16.
|| (is_same<base_type, tf32_t>::value && KPack <= 4)
#endif
? true
: false;
static constexpr auto mfma = MfmaSelector<base_type,
Expand Down
Loading