Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented find modes and fallback for FusionPlanDescriptor::Compile #3158

Merged
merged 12 commits into from
Aug 1, 2024
Merged
10 changes: 10 additions & 0 deletions cmake/ClangTidy.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ else()
message( STATUS "Clang tidy found: ${CLANG_TIDY_VERSION}")
endif()

set(EXTRA_CHECKS)
# There is a bug in tidy that hangs it in some cases when it encounters optional access
# It can spend 3.5h+ and timeout the CI
# It is fixed in 18.0.0 or worked around by disabling this check: bugprone-unchecked-optional-access
# https://github.com/llvm/llvm-project/issues/59492
if (CLANG_TIDY_VERSION VERSION_LESS "18.0.0")
list(APPEND EXTRA_CHECKS -bugprone-unchecked-optional-access)
endif()

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

set(CLANG_TIDY_FIXIT_DIR ${CMAKE_BINARY_DIR}/fixits)
Expand All @@ -81,6 +90,7 @@ macro(enable_clang_tidy)
set(multiValueArgs CHECKS ERRORS EXTRA_ARGS)

cmake_parse_arguments(PARSE "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
list(APPEND PARSE_CHECKS ${EXTRA_CHECKS})
string(REPLACE ";" "," CLANG_TIDY_CHECKS "${PARSE_CHECKS}")
string(REPLACE ";" "," CLANG_TIDY_ERRORS "${PARSE_ERRORS}")
set(CLANG_TIDY_EXTRA_ARGS)
Expand Down
42 changes: 27 additions & 15 deletions src/find_controls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@
#include <ostream>
#include <cstdlib>
#include <cstring>
#include <string_view>
#include <optional>

MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_FIND_ENFORCE)
MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_DEBUG_FIND_ONLY_SOLVER)
MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_FIND_MODE)
MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_FIND_MODE_FUSION)

namespace miopen {

Expand Down Expand Up @@ -196,11 +199,12 @@ std::ostream& operator<<(std::ostream& os, const FindMode::Values& v)
return os << ToCString(v) << "(" << static_cast<int>(v) << ')';
}

FindMode::Values GetFindModeValueImpl2()
template <class Variable>
std::optional<FindMode::Values> GetFindModeValueImpl2(Variable)
{
auto str = miopen::GetStringEnv(ENV(MIOPEN_FIND_MODE));
auto str = miopen::GetStringEnv(Variable{});
if(str.empty())
return FindMode::Values::Default_;
return std::nullopt;
for(auto& c : str)
c = toupper(static_cast<unsigned char>(c));
if(str == "NORMAL")
Expand All @@ -225,26 +229,34 @@ FindMode::Values GetFindModeValueImpl2()
const auto val = static_cast<FindMode::Values>(stoul(str));
if(FindMode::Values::Begin_ <= val && val < FindMode::Values::End_)
return val;
MIOPEN_LOG_NQE("Wrong MIOPEN_FIND_MODE, using default.");
return FindMode::Values::Default_;
MIOPEN_LOG_NQE("Wrong " << Variable::GetName() << ", using default.");
return std::nullopt;
}

FindMode::Values GetFindModeValueImpl()
template <class Variable>
FindMode::Values GetFindModeValue(Variable, FindMode::Values defaultValue)
{
auto rv = GetFindModeValueImpl2();
MIOPEN_LOG_NQI("MIOPEN_FIND_MODE = " << rv);
return rv;
}

FindMode::Values GetFindModeValue()
{
static const FindMode::Values val = GetFindModeValueImpl();
static const FindMode::Values val = [&]() {
auto rv = GetFindModeValueImpl2(Variable{}).value_or(defaultValue);
MIOPEN_LOG_NQI(Variable::GetName() << " = " << rv);
return rv;
}();
return val;
}

} // namespace

FindMode::FindMode() { value = GetFindModeValue(); }
FindMode::FindMode(solver::Primitive primitive)
{
switch(primitive)
{
case solver::Primitive::Fusion:
value = GetFindModeValue(ENV(MIOPEN_FIND_MODE_FUSION), FindMode::Values::Fast);
break;
default: value = GetFindModeValue(ENV(MIOPEN_FIND_MODE), FindMode::Values::Default_); break;
}
}

std::ostream& operator<<(std::ostream& os, const FindMode& obj) { return os << obj.value; }

static_assert(miopenConvolutionFindModeNormal ==
Expand Down
202 changes: 194 additions & 8 deletions src/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@
#include <miopen/fusion_plan.hpp>
#include <miopen/logger.hpp>
#include <miopen/handle.hpp>
#include <miopen/names.hpp>
#include <miopen/visit_float.hpp>
#include <miopen/stringutils.hpp>
#include <miopen/solver_id.hpp>
#include <miopen/solution.hpp>
#include <miopen/fusion/solvers.hpp>
#include <miopen/fusion/fusion_invoke_params.hpp>
#include <miopen/fusion/utils.hpp>
Expand Down Expand Up @@ -878,26 +880,205 @@ FindFusion(const ExecutionContext& ctx,
"fusion");
}

namespace {

// Copy from convolutionocl.cpp
struct SolutionTimeComparator
{
inline bool operator()(const miopenConvSolution_t& lhs, const miopenConvSolution_t& rhs) const
{
// Negative values are very coarse estimations.
// The more modulus, the "worse" (slower) is solution.
if(lhs.time < 0 && rhs.time < 0)
return !(lhs.time < rhs.time);
// Positive values are always "better" than negative (coarse) estimations.
if(lhs.time > 0 && rhs.time < 0)
return true;
if(lhs.time < 0 && rhs.time > 0)
return false;
// Both values are positive. The less is the better.
return (lhs.time < rhs.time);
}
};

std::ostream& operator<<(std::ostream& os, const miopenConvSolution_t& s)
{
return os << "id: " << s.solution_id //
<< ", algo: " << s.algorithm //
<< ", time: " << s.time << ", ws: " << s.workspace_size //
<< ", name: " << miopen::solver::Id(s.solution_id).ToString();
}

// Modified copy from convolutionocl.cpp
std::vector<miopenConvSolution_t> GetSolutions(const FusionContext& ctx,
const FusionDescription& problem,
const size_t maxSolutionCount)
{
const FindDbRecord fdb_record{ctx.GetStream(), problem, "fusion"};

if(fdb_record.empty())
return {};

auto interim = std::vector<miopenConvSolution_t>{};
interim.reserve(20); // Heuristic for speed.

for(const auto& pair : fdb_record)
{
const auto solver_id = solver::Id{pair.first};

// Wrong IDs can't be used to call IsApplicable(), so let's
// ignore obsolete or invalid IDs read from find-db first.
if(!solver_id.IsValid())
{
// Do not disturb users with warnings unless detailed log is enabled.
MIOPEN_LOG_I("[Warning] incorrect solver_id: " << pair.first);
continue;
}

// algorithm doesn't matter for our purpose here, so we stub it out
interim.emplace_back(miopenConvSolution_t{pair.second.time,
pair.second.workspace,
solver_id.Value(),
miopenConvolutionAlgoDirect});
}

std::sort(begin(interim), end(interim), SolutionTimeComparator{});
auto out = std::vector<miopenConvSolution_t>{};
out.reserve(maxSolutionCount);
auto n_copied = 0;
for(const auto& s : interim)
{
const auto solver_id = solver::Id{s.solution_id};
bool is_applicable = false;

GetAllFusionSolvers().FindById(
solver_id, [&](auto solver) { is_applicable = solver.IsApplicable(ctx, problem); });

if(!is_applicable)
continue;
out.push_back(s);
if(++n_copied >= maxSolutionCount)
break;
}

for(const auto& s : out)
MIOPEN_LOG_I2(s);

return out;
}

} // namespace

miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle)
{
std::vector<Allocator::ManageDataPtr> invoke_bufs;
miopen::OperatorArgs params;

const auto find_results = Find(handle, [&]() {
return AllocateBuffersAndMakeFusionInvokeParams(
handle, FusionDescription{this}, invoke_bufs, params, *this);
});
const auto& fusion_problem = FusionDescription{this};

const auto network_config = fusion_problem.MakeNetworkConfig();
auto invoker = handle.GetInvoker(network_config, boost::none, AlgorithmName{"fusion"});

const auto network_config = FusionDescription{this}.MakeNetworkConfig();
if(invoker)
{
invokers.push_back(*invoker);
return miopenStatusSuccess;
}

std::vector<PerfField> find_results;
const auto hasConv = [](solver::Id id) {
bool ret = true;
GetFusedNonConvSolvers().FindById(id, [&](auto) { ret = false; });
return ret;
};

{
FindMode findMode(solver::Primitive::Fusion);
auto sol = boost::optional<miopenConvSolution_t>{};

if(findMode.IsFast(fusion_problem) || findMode.IsHybrid(fusion_problem))
{
const auto ctx = FusionContext{handle};
auto sols = GetSolutions(ctx, fusion_problem, 1);
const auto fallback = sols.empty();

if(fallback)
{
auto fallback_failed = true;
bool found = false;

GetAllFusionSolvers().Foreach([&](auto solver) {
if(found || !solver.IsApplicable(ctx, fusion_problem))
return;
const auto id = solver::Id(solver.SolverDbId());
const auto wti = solver.GetWti(ctx, fusion_problem);
// Assume WTI == 1.0 (100%) is 10 ms.
// Return negative values as is, avoid DIV/0.
const auto time = wti <= 0.0f ? wti : (10.f / wti);

const auto algo = hasConv(id) ? id.GetAlgo() : miopenConvolutionAlgoDirect;
sols.push_back(miopenConvSolution_t{time, 0, id.Value(), algo});
fallback_failed = false;
});

if(fallback_failed)
{
MIOPEN_LOG_I("No supported fusion solvers found");
return miopenStatusUnsupportedOp;
}
}

// override the normal find with immed mode with env var
if(!sols.empty() && (!(findMode.IsHybrid(fusion_problem) && fallback)))
// || env::enabled(MIOPEN_DEBUG_FORCE_IMMED_MODE_FALLBACK)
{
std::sort(sols.begin(), sols.end(), SolutionTimeComparator());
sol = sols.front();
}
// In Hybrid Find mode, we use Normal Find instead of Immediate fallback kernels.
}

if(sol.has_value())
{
// We need to create an invoker

const auto id = solver::Id{sol->solution_id};

GetAllFusionSolvers().FindById(id, [&](auto solver) {
const auto ctx = FusionContext{handle};
auto db = GetDb(ctx);
const auto solution = solver::FindSolution(
solver, ctx, fusion_problem, db, {}); // auto tune is not expected here
auto invoker =
handle.PrepareInvoker(*solution.invoker_factory, solution.construction_params);

auto algo =
hasConv(id)
? ConvolutionAlgoToDirectionalString(id.GetAlgo(), conv::Direction::Forward)
: "fusion";

handle.RegisterInvoker(invoker, network_config, id.ToString());
find_results.push_back(PerfField{std::move(algo), id.ToString(), .0f, 0});
});
}
else
{
find_results = Find(handle, [&]() {
return AllocateBuffersAndMakeFusionInvokeParams(
handle, fusion_problem, invoke_bufs, params, *this);
});
}
}

for(const auto& result : find_results)
{
if(conv_fwd_algo && result.algorithm != "fusion" &&
const auto id = solver::Id(result.solver_id);

if(conv_fwd_algo && hasConv(id) &&
miopen::StringToConvolutionFwdAlgo(result.algorithm) != *conv_fwd_algo)
continue;

const auto id = solver::Id{result.solver_id};
const auto invoker = handle.GetInvoker(network_config, id);
invoker = handle.GetInvoker(network_config, id);

if(!invoker)
{
Expand All @@ -907,6 +1088,8 @@ miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle)

invokers.push_back(*invoker);
MIOPEN_LOG_I2(result.algorithm);
break;
// We never use any invoker after the first anyway.
}

if(invokers.empty())
Expand All @@ -915,6 +1098,9 @@ miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle)
return miopenStatusUnsupportedOp;
}

handle.SetAsFound1_0(network_config,
AlgorithmName{"fusion"},
solver::Id(find_results.front().solver_id).ToString());
return miopenStatusSuccess;
}

Expand Down
1 change: 0 additions & 1 deletion src/include/miopen/conv_solution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ struct ConvSolution
int n_out_pix_tiles; // # output pixel tiles per wk-item (ALU)
int n_in_data_tiles; // # of blocks of different inputs in LDS
int n_stacks; // # of diff stacks (part of batch).
float weight = 0.0f;

ConvSolution(miopenStatus_t status_ = miopenStatusSuccess)
: status(status_),
Expand Down
1 change: 1 addition & 0 deletions src/include/miopen/env.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ struct EnvVar
static miopen::internal::EnvVar<type> var{#name, default_val}; \
return var; \
} \
static constexpr std::string_view GetName() { return #name; } \
}; \
}

Expand Down
3 changes: 2 additions & 1 deletion src/include/miopen/find_controls.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ class FindMode
}

public:
FindMode();
// Todo: remove default value of primitive
FindMode(solver::Primitive primitive = solver::Primitive::Convolution);
Values Get() const { return value; }
void Set(Values const v) { value = v; }

Expand Down
7 changes: 7 additions & 0 deletions src/include/miopen/find_solution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ struct SolverContainer
Solvers{}...);
}

///\todo: remove when AnySolver would be able to work with non-conv solvers
template <class Functor>
void Foreach(Functor&& receiver)
{
miopen::each_args([&](auto solver) { receiver(solver); }, Solvers{}...);
}

// Search for all applicable solutions among many solvers
template <class Context, class Problem, class Db, class Solution = miopen::solver::ConvSolution>
std::vector<Solution>
Expand Down
2 changes: 2 additions & 0 deletions src/include/miopen/fusion/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct FusionDescription : ProblemDescriptionBase
#endif
{
const miopen::FusionPlanDescriptor* fusion_plan_desc;
bool disable_search_enforce = false;

FusionDescription(const miopen::FusionPlanDescriptor* ptr_desc) : fusion_plan_desc(ptr_desc) {}

[[nodiscard]] NetworkConfig MakeNetworkConfig() const override
Expand Down
Loading