diff --git a/projects/miopen/CHANGELOG.md b/projects/miopen/CHANGELOG.md index ee975761317..50f9b58f73c 100644 --- a/projects/miopen/CHANGELOG.md +++ b/projects/miopen/CHANGELOG.md @@ -16,6 +16,7 @@ Full documentation for MIOpen is available [here](https://rocm.docs.amd.com/proj ### Optimized * [Conv] Enabled Composable Kernel (CK) implicit gemms on gfx950. +* [Conv] Improve Composable Kernel (CK) kernel selection during tuning ### Resolved issues diff --git a/projects/miopen/src/include/miopen/conv/solvers.hpp b/projects/miopen/src/include/miopen/conv/solvers.hpp index 5fa56cb8da1..f48ec8c7b46 100644 --- a/projects/miopen/src/include/miopen/conv/solvers.hpp +++ b/projects/miopen/src/include/miopen/conv/solvers.hpp @@ -4644,6 +4644,10 @@ struct ConvHipImplicitGemm3DGroupWrwXdlops final private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; + + template + std::size_t GetCKMaxWorkspaceSize(const miopen::conv::ProblemDescription&) const; + size_t GetCKMaxWorkspaceSize(const miopen::conv::ProblemDescription& problem) const; }; struct PerformanceConfigHipImplicitGemm3DGroupBwdXdlops @@ -4816,6 +4820,8 @@ struct ConvHipImplicitGemmGroupBwdXdlops final private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; + + size_t GetCKMaxWorkspaceSize(const miopen::conv::ProblemDescription& problem) const; }; struct PerformanceConfigHipImplicitGemmGroupWrwXdlops @@ -4910,6 +4916,8 @@ struct ConvHipImplicitGemmGroupWrwXdlops final private: template bool CheckCKApplicability(const miopen::conv::ProblemDescription&) const; + + size_t GetCKMaxWorkspaceSize(const miopen::conv::ProblemDescription& problem) const; }; } // namespace conv diff --git a/projects/miopen/src/include/miopen/generic_search.hpp b/projects/miopen/src/include/miopen/generic_search.hpp index 35426fd35fc..e922a23caf0 100644 --- a/projects/miopen/src/include/miopen/generic_search.hpp +++ b/projects/miopen/src/include/miopen/generic_search.hpp @@ -553,15 +553,6 @@ auto GenericSearch(const Solver s, try { - if(default_solution.workspace_sz != current_solution.workspace_sz) - { - ret = -2; - MIOPEN_LOG_E('#' << n_current << " (" << n_runs_total << ") " - << "Workspace size should not depend on PerformanceConfig: " - << default_solution.workspace_sz - << " != " << current_solution.workspace_sz); - } - invoker = profile_h.PrepareInvoker(*current_solution.invoker_factory, current_solution.construction_params); diff --git a/projects/miopen/src/include/miopen/solver/implicitgemm_ck_util.hpp b/projects/miopen/src/include/miopen/solver/implicitgemm_ck_util.hpp index cbb1fd60eb6..a58e04ce4dd 100644 --- a/projects/miopen/src/include/miopen/solver/implicitgemm_ck_util.hpp +++ b/projects/miopen/src/include/miopen/solver/implicitgemm_ck_util.hpp @@ -33,6 +33,7 @@ #include #include #include +#include #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL #include @@ -50,7 +51,30 @@ struct ProblemDescription; } // namespace conv namespace solver { + +static constexpr int CkSplitkAutoDeduce = -1; + +template +inline static bool NextCKSplitkValue(int& v) +{ + assert((IsTwoPower(v) || v == CkSplitkAutoDeduce)); + if(v == H) + { + v = CkSplitkAutoDeduce; + return false; + } + if(v == CkSplitkAutoDeduce) + { + v = L; + return true; + } + + v *= 2; + return false; +} + #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + namespace conv { template using DeviceOpGWrw = ck::tensor_operation::device::DeviceGroupedConvBwdWeight< @@ -300,6 +324,33 @@ bool IsCKApplicable(const ProblemDescriptionType& problem) ptrs.begin(), ptrs.end(), [&args](auto& ptr) { return args.IsSupportedBy(ptr); }); } +template +size_t GetCKSplitkMaxWorkspaceSize(const ProblemDescriptionType& problem) +{ + const auto args = CKArgsType{problem}; + auto max_workspace_size = 0; + + const auto ptrs = DeviceOpType::GetInstances(); + for(auto& ptr : ptrs) + { + auto split_k = CkSplitkAutoDeduce; + do + { + if(args.IsSupportedBySplitK(ptr, split_k)) + { + auto workspace_size = args.GetCKSplitkWorkspaceSize(ptr, split_k); + if(workspace_size > max_workspace_size) + max_workspace_size = workspace_size; + } + } while(!NextCKSplitkValue<1, 128>(split_k)); + } + + MIOPEN_LOG_I("Max workspace size reported by CK: " << max_workspace_size); + return max_workspace_size; +} + #define WORKAROUND_CK_ISSUE_1184 1 #if WORKAROUND_CK_ISSUE_1184 using WorkAroundHipEventProfiler = HipEventProfiler; @@ -744,13 +795,14 @@ inline bool CKWrwRequireWorkspace( } /// \todo move to a cpp file -inline size_t GetWorkspaceSizeLayoutTransformConv(const miopen::conv::ProblemDescription& problem) +inline size_t GetWorkspaceSizeLayoutTransformConv(const miopen::conv::ProblemDescription& problem, + size_t ck_ws_size = 0) { if(problem.IsLayoutNHWC()) { if(problem.GetDirection() == ::miopen::conv::Direction::BackwardWeights) { - return GetCKAlphaBetaWorkspace(problem); + return (ck_ws_size > 0) ? ck_ws_size : GetCKAlphaBetaWorkspace(problem); } return 0; } @@ -759,10 +811,11 @@ inline size_t GetWorkspaceSizeLayoutTransformConv(const miopen::conv::ProblemDes if(problem.GetDirection() == ::miopen::conv::Direction::BackwardWeights) { - MultiBufferWorkspaceTraits wt({GetPackedSize(problem.GetIn()), - GetPackedSize(problem.GetWeights()), - GetPackedSize(problem.GetOut()), - GetCKAlphaBetaWorkspace(problem)}); + MultiBufferWorkspaceTraits wt( + {GetPackedSize(problem.GetIn()), + GetPackedSize(problem.GetWeights()), + GetPackedSize(problem.GetOut()), + (ck_ws_size > 0) ? ck_ws_size : GetCKAlphaBetaWorkspace(problem)}); return wt.GetSize(); } @@ -1079,11 +1132,6 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, std::optional _ck_buff_des; - if(problem.IsDirectionBackwardWrW()) - { - _ck_buff_des.emplace(GetCKAlphaBetaWorkspace(problem), 0); - } - auto ptr_iter = FindConvPtrByID(conv_ptrs, id_string); if(ptr_iter == conv_ptrs.end()) { @@ -1091,6 +1139,14 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, return {miopenStatusInvalidValue}; } + if constexpr(std::is_same_v) { + auto ck_ws_size = ck_args.GetCKSplitkWorkspaceSize(*ptr_iter, split_k.value_or(1)); + _ck_buff_des.emplace(ck_ws_size, 0); + result.workspace_sz = GetWorkspaceSizeLayoutTransformConv(problem, ck_ws_size); + } else { + result.workspace_sz = GetWorkspaceSizeLayoutTransformConv(problem); + } + auto [_input1_tr_inst, _input2_tr_inst, _output_tr_inst, _output_init_tr_inst] = internal::MakeTaggedTransposeInstances( result, ctx, problem, ck_args, input1_op, input2_op, output_op, _ck_buff_des); @@ -1197,8 +1253,6 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, output_tr_inst.ConvertTo(handle, kernels, conv_tensors); }; }; - - result.workspace_sz = GetWorkspaceSizeLayoutTransformConv(problem); #endif return result; } @@ -1235,8 +1289,9 @@ ConvSolution InitInvokerFactoryNHWC(const ExecutionContext&, ConvSolution result; #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL miopenAlphaBetaCase_t alpha_beta_case = problem.GetAlphaBetaCase(); - [[maybe_unused]] bool should_allocated_wrw_buffer = - ShouldAllocateWorkSpaceBufferForWRW(problem); + auto ck_args = CKArgsType{problem}; + auto ck_ws_size = ck_args.GetCKSplitkWorkspaceSize(*ptr_iter, split_k.value_or(1)); + [[maybe_unused]] bool should_allocated_wrw_buffer = ck_ws_size > 0; result.invoker_factory = [kernel_id = kernel_id, split_k = split_k, @@ -1297,7 +1352,7 @@ ConvSolution InitInvokerFactoryNHWC(const ExecutionContext&, } }; }; - result.workspace_sz = GetWorkspaceSizeLayoutTransformConv(problem); + result.workspace_sz = GetWorkspaceSizeLayoutTransformConv(problem, ck_ws_size); #endif return result; } diff --git a/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp b/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp index 0c15d775091..c17dec86841 100644 --- a/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp +++ b/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_3d_grouped_wrw_xdlops.cpp @@ -250,29 +250,32 @@ struct CKArgs template bool IsSupportedBy(const ConvPtr& conv_ptr) const { - auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, 1); - // Creat dummy workspace to pass the ck IsSupportedArgument check. - - int dummy_var = 1; - conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &dummy_var); - + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, 1); + auto workspace_size = conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + if(workspace_size != 0) + conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &workspace_size); return conv_ptr->IsSupportedArgument(arg_ptr.get()); } template bool IsSupportedBySplitK(const ConvPtr& conv_ptr, int split_k) const { - auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); - - if(CKWrwRequireWorkspace(G, C1, K1, data_type, alpha_beta_case)) + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); + auto workspace_size = conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + if(workspace_size != 0) { - // Creat dummy workspace to pass the ck IsSupportedArgument check. - int dummy_var = 1; - conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &dummy_var); + conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &workspace_size); } return conv_ptr->IsSupportedArgument(arg_ptr.get()); } + template + std::size_t GetCKSplitkWorkspaceSize(const ConvPtr& conv_ptr, int split_k) const + { + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); + return conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + } + int G; int N; int K; @@ -398,7 +401,7 @@ bool PerformanceConfigHipImplicitGemm3DGroupWrwXdlops::SetNextValue( } do { - bool flag = NextTwoPower<1, 128>(split_k); + bool flag = NextCKSplitkValue<1, 128>(split_k); if(!flag) { kernel_id = valid_kernels[index] + "+" + std::to_string(split_k); @@ -465,11 +468,54 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +template +size_t +ConvHipImplicitGemm3DGroupWrwXdlops::GetCKMaxWorkspaceSize(const ProblemDescription& problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + switch(problem.GetAlphaBetaCase()) + { + case BILINEAR: + return GetCKSplitkMaxWorkspaceSize, + CKArgs>(problem); + case SCALE: + return GetCKSplitkMaxWorkspaceSize, CKArgs>( + problem); + default: + return GetCKSplitkMaxWorkspaceSize, + CKArgs>(problem); + } +#else + return 0; +#endif +} + +size_t +ConvHipImplicitGemm3DGroupWrwXdlops::GetCKMaxWorkspaceSize(const ProblemDescription& problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + switch(problem.GetInDataType()) + { + case miopenHalf: return GetCKMaxWorkspaceSize(problem); + case miopenFloat: return GetCKMaxWorkspaceSize(problem); + case miopenInt8: return GetCKMaxWorkspaceSize(problem); + case miopenBFloat16: return GetCKMaxWorkspaceSize(problem); + case miopenInt64: + case miopenInt32: + case miopenFloat8_fnuz: + case miopenBFloat8_fnuz: + case miopenDouble: break; + } +#endif + return 0; // other types not applicable for this solver +} + size_t ConvHipImplicitGemm3DGroupWrwXdlops::GetWorkspaceSize(const ExecutionContext&, const ProblemDescription& problem) const { - return GetWorkspaceSizeLayoutTransformConv(problem); + auto ck_ws_size = GetCKMaxWorkspaceSize(problem); + return GetWorkspaceSizeLayoutTransformConv(problem, ck_ws_size); } PerformanceConfigHipImplicitGemm3DGroupWrwXdlops diff --git a/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp b/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp index 4533508e15b..5dc6e154021 100644 --- a/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp +++ b/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_grouped_bwd_xdlops.cpp @@ -166,24 +166,30 @@ struct CKArgs template bool IsSupportedBy(const ConvPtr& conv_ptr) const { - auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, 1); + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, 1); + auto workspace_size = conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + if(workspace_size != 0) + conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &workspace_size); return conv_ptr->IsSupportedArgument(arg_ptr.get()); } template bool IsSupportedBySplitK(const ConvPtr& conv_ptr, int split_k) const { - auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); - - if(CKWrwRequireWorkspace(G, C1, K1, data_type, alpha_beta_case)) - { - // Creat dummy workspace to pass the ck IsSupportedArgument check. - int dummy_var = 1; - conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &dummy_var); - } + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); + auto workspace_size = conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + if(workspace_size != 0) + conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &workspace_size); return conv_ptr->IsSupportedArgument(arg_ptr.get()); } + template + std::size_t GetCKSplitkWorkspaceSize(const ConvPtr& conv_ptr, int split_k) const + { + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); + return conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + } + int G; int N; int K; @@ -527,10 +533,33 @@ bool ConvHipImplicitGemmGroupBwdXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +size_t +ConvHipImplicitGemmGroupBwdXdlops::GetCKMaxWorkspaceSize(const ProblemDescription& problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + switch(problem.GetInDataType()) + { + case miopenHalf: + return GetCKSplitkMaxWorkspaceSize, CKArgs>(problem); + case miopenFloat: return GetCKSplitkMaxWorkspaceSize, CKArgs>(problem); + case miopenInt8: return GetCKSplitkMaxWorkspaceSize, CKArgs>(problem); + case miopenBFloat16: + return GetCKSplitkMaxWorkspaceSize, CKArgs>(problem); + case miopenInt64: + case miopenInt32: + case miopenFloat8_fnuz: + case miopenBFloat8_fnuz: + case miopenDouble: break; + } +#endif + return 0; // other types not applicable for this solver +} + size_t ConvHipImplicitGemmGroupBwdXdlops::GetWorkspaceSize(const ExecutionContext&, const ProblemDescription& problem) const { - return GetWorkspaceSizeLayoutTransformConv(problem); + auto ck_ws_size = GetCKMaxWorkspaceSize(problem); + return GetWorkspaceSizeLayoutTransformConv(problem, ck_ws_size); } PerformanceConfigHipImplicitGemmGroupBwdXdlops diff --git a/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp b/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp index f6a38370b43..a0aacc56506 100644 --- a/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp +++ b/projects/miopen/src/solver/conv/conv_hip_implicit_gemm_grouped_wrw_xdlops.cpp @@ -165,29 +165,30 @@ struct CKArgs template bool IsSupportedBy(const ConvPtr& conv_ptr) const { - auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, 1); - // Creat dummy workspace to pass the ck IsSupportedArgument check. - - int dummy_var = 1; - conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &dummy_var); - + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, 1); + auto workspace_size = conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + if(workspace_size != 0) + conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &workspace_size); return conv_ptr->IsSupportedArgument(arg_ptr.get()); } template bool IsSupportedBySplitK(const ConvPtr& conv_ptr, int split_k) const { - auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); - - if(CKWrwRequireWorkspace(G, C1, K1, data_type, alpha_beta_case)) - { - // Creat dummy workspace to pass the ck IsSupportedArgument check. - int dummy_var = 1; - conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &dummy_var); - } + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); + auto workspace_size = conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + if(workspace_size != 0) + conv_ptr->SetWorkSpacePointer(arg_ptr.get(), &workspace_size); return conv_ptr->IsSupportedArgument(arg_ptr.get()); } + template + std::size_t GetCKSplitkWorkspaceSize(const ConvPtr& conv_ptr, int split_k) const + { + auto arg_ptr = MakeArgPtr(conv_ptr, nullptr, nullptr, nullptr, 1.0f, 0.0f, split_k); + return conv_ptr->GetWorkSpaceSize(arg_ptr.get()); + } + int G; int N; int K; @@ -499,7 +500,7 @@ bool PerformanceConfigHipImplicitGemmGroupWrwXdlops::SetNextValue(const ProblemD } do { - bool flag = NextTwoPower<1, 128>(split_k); + bool flag = NextCKSplitkValue<1, 128>(split_k); if(!flag) { kernel_id = valid_kernels[index] + "+" + std::to_string(split_k); @@ -567,10 +568,33 @@ bool ConvHipImplicitGemmGroupWrwXdlops::IsValidPerformanceConfig( return config.IsValid(problem); } +size_t +ConvHipImplicitGemmGroupWrwXdlops::GetCKMaxWorkspaceSize(const ProblemDescription& problem) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + switch(problem.GetInDataType()) + { + case miopenHalf: + return GetCKSplitkMaxWorkspaceSize, CKArgs>(problem); + case miopenFloat: return GetCKSplitkMaxWorkspaceSize, CKArgs>(problem); + case miopenInt8: return GetCKSplitkMaxWorkspaceSize, CKArgs>(problem); + case miopenBFloat16: + return GetCKSplitkMaxWorkspaceSize, CKArgs>(problem); + case miopenInt64: + case miopenInt32: + case miopenFloat8_fnuz: + case miopenBFloat8_fnuz: + case miopenDouble: break; + } +#endif + return 0; // other types not applicable for this solver +} + size_t ConvHipImplicitGemmGroupWrwXdlops::GetWorkspaceSize(const ExecutionContext&, const ProblemDescription& problem) const { - return GetWorkspaceSizeLayoutTransformConv(problem); + auto ck_ws_size = GetCKMaxWorkspaceSize(problem); + return GetWorkspaceSizeLayoutTransformConv(problem, ck_ws_size); } PerformanceConfigHipImplicitGemmGroupWrwXdlops