diff --git a/src/include/miopen/conv/solvers.hpp b/src/include/miopen/conv/solvers.hpp index 2626b00b70..513b5552f2 100644 --- a/src/include/miopen/conv/solvers.hpp +++ b/src/include/miopen/conv/solvers.hpp @@ -4556,11 +4556,7 @@ struct ConvHipImplicitGemm3DGroupFwdXdlops final GetSolution(const ExecutionContext&, const miopen::conv::ProblemDescription&, const PerformanceConfigHipImplicitGemm3DGroupFwdXdlops&) const override; - /// \ref igemm_get_wti_magic_number - float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override - { - return 0.02f; - }; + float GetWti(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override; MIOPEN_INTERNALS_EXPORT size_t GetWorkspaceSize( const ExecutionContext&, const miopen::conv::ProblemDescription&) const override; diff --git a/src/solver/conv/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp b/src/solver/conv/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp index 8968230e6b..2d21a5e081 100644 --- a/src/solver/conv/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp +++ b/src/solver/conv/conv_hip_implicit_gemm_3d_grouped_fwd_xdlops.cpp @@ -40,6 +40,7 @@ #endif #include MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_3D_CONV_IMPLICIT_GEMM_HIP_FWD_XDLOPS) +MIOPEN_DECLARE_ENV_VAR_UINT64(CK_CONV3D_IDX); namespace miopen { namespace solver { @@ -360,8 +361,56 @@ void PerformanceConfigHipImplicitGemm3DGroupFwdXdlops::Init(const ProblemDescrip FillValidKernelsIDs, CKArgs>(problem); break; } - index = 0; + index = 0; + + // for BF16 and FP16 + index = env::value(CK_CONV3D_IDX); + if(index == 0 && problem.GetAlphaBetaCase() == DEFAULT) + { + int G = ProblemInterpreter::GetGroupCountG(problem); + int K1 = ProblemInterpreter::GetOutputChannelK(problem); + int K = K1 / G; // Number of output Channel per group + if(problem.GetInDataType() == miopenBFloat16) + { + if(valid_kernels.size() > 30) + { + index = 30; + assert(valid_kernels[30] == + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + "<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, " + "BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>"); + } + if(K < 64 && valid_kernels.size() > 38) + { + index = 38; + assert(valid_kernels[38] == + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + "<256, 64, 64, 64, Default, 32, 32, 1, 1, 8, 8, 8, 1, 1, " + "BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>"); + } + } + else if(problem.GetInDataType() == miopenHalf) + { + if(valid_kernels.size() > 31) + { + index = 31; + assert(valid_kernels[31] == + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + "<256, 128, 128, 64, Default, 32, 32, 2, 2, 8, 8, 8, 1, 1, " + "BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v3>"); + } + if(K < 64 && valid_kernels.size() > 57) + { + index = 57; + assert(valid_kernels[57] == + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + "<64, 16, 16, 128, Default, 16, 16, 1, 1, 8, 8, 4, 1, 1, " + "BlkGemmPipelineScheduler: Interwave, BlkGemmPipelineVersion: v1>"); + } + } + } kernel_id = valid_kernels[index]; + MIOPEN_LOG_I2("index:" << index << " kernel_id:" << kernel_id); } template @@ -425,6 +474,11 @@ bool PerformanceConfigHipImplicitGemm3DGroupFwdXdlops::SetNextValue( { HeuristicInit(problem); assert(!valid_kernels.empty()); + if(index != 0) + { + index = 0; + kernel_id = valid_kernels[index]; + } return true; } if((index + 1) < valid_kernels.size()) @@ -540,6 +594,35 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable( return false; } +float ConvHipImplicitGemm3DGroupFwdXdlops::GetWti( + const ExecutionContext&, const miopen::conv::ProblemDescription& problem) const +{ + decltype(auto) xDesc = problem.GetIn(); + decltype(auto) wDesc = problem.GetWeights(); + + if(xDesc.GetType() == miopenHalf || xDesc.GetType() == miopenBFloat16) + { + auto& in_c = xDesc.GetLengths()[1]; + auto& w_x = wDesc.GetLengths()[2]; + auto& w_y = wDesc.GetLengths()[3]; + auto& w_d = wDesc.GetLengths()[4]; + // For cases where the filter shape is not 1x1x1 and the input channel (in_c) is greater + // than 8, CK's implementation offers better performance. + if((w_x == 1 && w_y == 1 && w_d == 1) == false) + { + if(in_c < 8) + { + return 0.00002; // force disable + } + else + { + return 1.0; // force enable + } + } + } + return 0.02f; +} + ConvSolution ConvHipImplicitGemm3DGroupFwdXdlops::GetSolution( [[maybe_unused]] const ExecutionContext& ctx, [[maybe_unused]] const ProblemDescription& problem,