Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/include/miopen/conv/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4556,11 +4556,7 @@ struct ConvHipImplicitGemm3DGroupFwdXdlops final
GetSolution(const ExecutionContext&,
const miopen::conv::ProblemDescription&,
const PerformanceConfigHipImplicitGemm3DGroupFwdXdlops&) const override;
/// \ref igemm_get_wti_magic_number
float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override
{
return 0.02f;
};
float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;

MIOPEN_INTERNALS_EXPORT size_t GetWorkspaceSize(
const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#endif
#include <miopen/solver/implicitgemm_ck_util.hpp>
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS)
MIOPEN_DECLARE_ENV_VAR_UINT64(CK_CONV3D_IDX);

namespace miopen {
namespace solver {
Expand Down Expand Up @@ -360,8 +361,56 @@ void PerformanceConfigHipImplicitGemm3DGroupFwdXdlops::Init(const ProblemDescrip
FillValidKernelsIDs<DeviceOpGFwdDefaultPtrs<DataType>, CKArgs<DataType>>(problem);
break;
}
index = 0;
index = 0;

// for BF16 and FP16
index = env::value(CK_CONV3D_IDX);
if(index == 0 && problem.GetAlphaBetaCase() == DEFAULT)
{
int G = ProblemInterpreter::GetGroupCountG(problem);
int K1 = ProblemInterpreter::GetOutputChannelK(problem);
int K = K1 / G; // Number of output Channel per group
if(problem.GetInDataType() == miopenBFloat16)
{
if(valid_kernels.size() > 30)
{
index = 30;
assert(valid_kernels[30] ==
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, "
"BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>");
}
if(K < 64 && valid_kernels.size() > 38)
{
index = 38;
assert(valid_kernels[38] ==
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"<256, 64, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, "
"BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>");
}
}
else if(problem.GetInDataType() == miopenHalf)
{
if(valid_kernels.size() > 31)
{
index = 31;
assert(valid_kernels[31] ==
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, "
"BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>");
}
if(K < 64 && valid_kernels.size() > 57)
{
index = 57;
assert(valid_kernels[57] ==
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
"<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, "
"BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>");
}
}
}
kernel_id = valid_kernels[index];
MIOPEN_LOG_I2("index:" << index << " kernel_id:" << kernel_id);
}

template <typename DataType>
Expand Down Expand Up @@ -425,6 +474,11 @@ bool PerformanceConfigHipImplicitGemm3DGroupFwdXdlops::SetNextValue(
{
HeuristicInit(problem);
assert(!valid_kernels.empty());
if(index != 0)
{
index = 0;
kernel_id = valid_kernels[index];
}
return true;
}
if((index + 1) < valid_kernels.size())
Expand Down Expand Up @@ -540,6 +594,35 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable(
return false;
}

float ConvHipImplicitGemm3DGroupFwdXdlops::GetWti(
const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const
{
decltype(auto) xDesc = problem.GetIn();
decltype(auto) wDesc = problem.GetWeights();

if(xDesc.GetType() == miopenHalf || xDesc.GetType() == miopenBFloat16)
{
auto& in_c = xDesc.GetLengths()[1];
auto& w_x = wDesc.GetLengths()[2];
auto& w_y = wDesc.GetLengths()[3];
auto& w_d = wDesc.GetLengths()[4];
// For cases where the filter shape is not 1x1x1 and the input channel (in_c) is greater
// than 8, CK's implementation offers better performance.
if((w_x == 1 && w_y == 1 && w_d == 1) == false)
{
if(in_c < 8)
{
return 0.00002; // force disable
}
else
{
return 1.0; // force enable
}
}
}
return 0.02f;
}

ConvSolution ConvHipImplicitGemm3DGroupFwdXdlops::GetSolution(
[[maybe_unused]] const ExecutionContext& ctx,
[[maybe_unused]] const ProblemDescription& problem,
Expand Down
Loading