Skip to content

Commit

Permalink
[NFC] Moved DetectRocm calls to constructor of ExecutionContext (#2295)
Browse files Browse the repository at this point in the history
  • Loading branch information
DrizztDoUrden authored Aug 10, 2023
1 parent 9ad4dc4 commit a0c73f5
Show file tree
Hide file tree
Showing 23 changed files with 21 additions and 69 deletions.
2 changes: 1 addition & 1 deletion fin
Submodule fin updated from 74382a to ebf9b3
27 changes: 6 additions & 21 deletions src/convolution_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,8 @@ static inline auto MakeFwdCtxAndProblem(miopenHandle_t handle,
conv,
direction};

auto ctx = [&] {
auto tmp = ExecutionContext{&miopen::deref(handle)};
tmp.DetectRocm();
problem.SetupFloats(tmp);
return tmp;
}();

auto ctx = ExecutionContext{&miopen::deref(handle)};
problem.SetupFloats(ctx);
return std::make_tuple(std::move(ctx), std::move(problem));
}

Expand All @@ -97,13 +92,8 @@ static inline auto MakeBwdCtxAndProblem(miopenHandle_t handle,
conv,
direction};

auto ctx = [&] {
auto tmp = ExecutionContext{&miopen::deref(handle)};
tmp.DetectRocm();
problem.SetupFloats(tmp);
return tmp;
}();

auto ctx = ExecutionContext{&miopen::deref(handle)};
problem.SetupFloats(ctx);
return std::make_tuple(std::move(ctx), std::move(problem));
}

Expand All @@ -127,13 +117,8 @@ static inline auto MakeWrWCtxAndProblem(miopenHandle_t handle,
conv,
direction};

auto ctx = [&] {
auto tmp = ExecutionContext{&miopen::deref(handle)};
tmp.DetectRocm();
problem.SetupFloats(tmp);
return tmp;
}();

auto ctx = ExecutionContext{&miopen::deref(handle)};
problem.SetupFloats(ctx);
return std::make_tuple(std::move(ctx), std::move(problem));
}

Expand Down
3 changes: 1 addition & 2 deletions src/execution_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ bool IsHipKernelsEnabled()
#endif
}

ExecutionContext& ExecutionContext::DetectRocm()
void ExecutionContext::DetectRocm()
{
use_binaries = false;
use_asm_kernels = false;
Expand All @@ -220,7 +220,6 @@ ExecutionContext& ExecutionContext::DetectRocm()
use_binaries = !IsDisabled(MIOPEN_DEBUG_AMD_ROCM_PRECOMPILED_BINARIES{});
#endif
}
return *this;
}

} // namespace miopen
1 change: 0 additions & 1 deletion src/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ miopenStatus_t FusionPlanDescriptor::Compile(Handle& handle)
const auto solvers = GetFusedSolvers();
auto fusion_ctx = FusionContext{handle};
auto fusion_problem = FusionDescription{this};
fusion_ctx.DetectRocm();
AnyInvokeParams invoke_params;
miopen::OperatorArgs params;
std::vector<Allocator::ManageDataPtr> invoke_bufs;
Expand Down
8 changes: 4 additions & 4 deletions src/include/miopen/execution_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,15 @@ struct ExecutionContext
inline Handle& GetStream() const { return *stream; }
inline void SetStream(Handle* stream_) { stream = stream_; }

ExecutionContext(Handle* stream_) : stream(stream_) {}
ExecutionContext() { DetectRocm(); }
ExecutionContext(Handle* stream_) : stream(stream_) { DetectRocm(); }

ExecutionContext() = default;
virtual ~ExecutionContext() = default;
ExecutionContext(const ExecutionContext&) = default;
ExecutionContext(ExecutionContext&&) = default;
ExecutionContext& operator=(const ExecutionContext&) = default;
ExecutionContext& operator=(ExecutionContext&&) = default;

ExecutionContext& DetectRocm();

#if MIOPEN_EMBED_DB
std::string GetPerfDbPathEmbed() const
{
Expand Down Expand Up @@ -281,6 +279,8 @@ struct ExecutionContext

private:
Handle* stream = nullptr;

void DetectRocm();
};

bool IsHipKernelsEnabled();
Expand Down
3 changes: 1 addition & 2 deletions src/include/miopen/find_solution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ struct SolverContainer
return;
}

auto ctx = ExecutionContext{&handle};
ctx.DetectRocm();
auto ctx = ExecutionContext{&handle};
const auto slns = SearchForSolutions(ctx, problem, 1);

if(slns.empty())
Expand Down
1 change: 0 additions & 1 deletion src/include/miopen/fusion/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ struct FusionContext : miopen::ExecutionContext
ConvolutionContext GetConvContext(const miopen::ProblemDescription& conv_problem) const
{
auto ctx = ConvolutionContext{*this};
ctx.DetectRocm();
conv_problem.conv_problem.SetupFloats(ctx);
return ctx;
}
Expand Down
8 changes: 0 additions & 8 deletions src/ocl/convolutionocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ static Invoker PrepareInvoker(ExecutionContext ctx,
const NetworkConfig& config,
solver::Id solver_id)
{
ctx.DetectRocm();
problem.SetupFloats(ctx);
ctx.do_search = false;

Expand Down Expand Up @@ -255,7 +254,6 @@ void ConvolutionDescriptor::FindConvFwdAlgorithm(Handle& handle,
conv::ProblemDescription(xDesc, wDesc, yDesc, *this, conv::Direction::Forward);
const auto ctx = [&] {
auto tmp = ExecutionContext{&handle};
tmp.DetectRocm();
problem.SetupFloats(tmp);
tmp.do_search = exhaustiveSearch;
return tmp;
Expand Down Expand Up @@ -649,7 +647,6 @@ std::vector<miopenConvSolution_t> GetSolutions(const ExecutionContext& exec_ctx,
// All the above can be found by calling IsApplicable().
// We need fully initialized context for this, see below.
auto ctx = ConvolutionContext{exec_ctx};
ctx.DetectRocm();

for(const auto& pair : fdb_record)
{
Expand Down Expand Up @@ -724,7 +721,6 @@ std::size_t ConvolutionDescriptor::GetForwardSolutionWorkspaceSize(Handle& handl
conv::ProblemDescription{xDesc, wDesc, yDesc, *this, conv::Direction::Forward};
auto ctx = ConvolutionContext{};
ctx.SetStream(&handle);
ctx.DetectRocm();
if(sol.IsApplicable(ctx, problem))
return sol.GetWorkspaceSize(ctx, problem);
MIOPEN_THROW(miopenStatusBadParm,
Expand Down Expand Up @@ -804,7 +800,6 @@ void ConvolutionDescriptor::FindConvBwdDataAlgorithm(Handle& handle,

const auto ctx = [&] {
auto tmp = ExecutionContext{&handle};
tmp.DetectRocm();
problem.SetupFloats(tmp);
tmp.do_search = exhaustiveSearch;
return tmp;
Expand Down Expand Up @@ -935,7 +930,6 @@ std::size_t ConvolutionDescriptor::GetBackwardSolutionWorkspaceSize(Handle& hand
conv::ProblemDescription{dyDesc, wDesc, dxDesc, *this, conv::Direction::BackwardData};
auto ctx = ConvolutionContext{};
ctx.SetStream(&handle);
ctx.DetectRocm();
if(sol.IsApplicable(ctx, problem))
return sol.GetWorkspaceSize(ctx, problem);
else
Expand Down Expand Up @@ -1013,7 +1007,6 @@ void ConvolutionDescriptor::FindConvBwdWeightsAlgorithm(Handle& handle,
conv::ProblemDescription{dyDesc, dwDesc, xDesc, *this, conv::Direction::BackwardWeights};
const auto ctx = [&] {
auto tmp = ExecutionContext{&handle};
tmp.DetectRocm();
problem.SetupFloats(tmp);
tmp.do_search = exhaustiveSearch;
return tmp;
Expand Down Expand Up @@ -1135,7 +1128,6 @@ std::size_t ConvolutionDescriptor::GetWrwSolutionWorkspaceSize(Handle& handle,
conv::ProblemDescription{dyDesc, dwDesc, xDesc, *this, conv::Direction::BackwardWeights};
auto ctx = ConvolutionContext{};
ctx.SetStream(&handle);
ctx.DetectRocm();
if(sol.IsApplicable(ctx, problem))
return sol.GetWorkspaceSize(ctx, problem);
else
Expand Down
4 changes: 1 addition & 3 deletions src/problem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,7 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,
}
else
{
auto tmp_ctx = ExecutionContext{&handle};
tmp_ctx.DetectRocm();
auto tmp_ctx = ExecutionContext{&handle};
const auto workspace_max = conv_desc.GetWorkSpaceSize(tmp_ctx, conv_problem);
workspace_size = std::min(options.workspace_limit, workspace_max);
owned_workspace = workspace_size != 0 ? handle.Create(workspace_size) : nullptr;
Expand Down Expand Up @@ -349,7 +348,6 @@ std::vector<Solution> Problem::FindSolutionsImpl(Handle& handle,
const auto legacy_problem = ProblemDescription{conv_problem};
const auto netcfg = conv_problem.BuildConfKey();
auto conv_ctx = ConvolutionContext{{&handle}};
conv_ctx.DetectRocm();
conv_problem.SetupFloats(conv_ctx);

decltype(auto) db = GetDb(conv_ctx);
Expand Down
1 change: 0 additions & 1 deletion src/solution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ void Solution::RunImpl(Handle& handle,

const auto legacy_problem = ProblemDescription{conv_problem};
auto conv_ctx = ConvolutionContext{{&handle}};
conv_ctx.DetectRocm();
conv_problem.SetupFloats(conv_ctx);

decltype(auto) db = GetDb(conv_ctx);
Expand Down
12 changes: 4 additions & 8 deletions test/conv_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ static inline bool is_direct_fwd_bwd_data_supported(miopen::Handle& handle,
ctx.general_compile_options = "";
ctx.SetStream(&handle);
problem.SetupFloats(ctx);
ctx.DetectRocm();
if(FindAllDirectSolutions(ctx, problem, {}).empty())
return false;
}
Expand All @@ -119,7 +118,6 @@ static inline bool is_direct_bwd_wrw_supported(miopen::Handle& handle,
ctx.disable_perfdb_access = true;
ctx.SetStream(&handle);
problem.SetupFloats(ctx);
ctx.DetectRocm();

return !FindAllBwdWrW2DSolutions(ctx, problem, {}).empty();
}
Expand All @@ -146,7 +144,6 @@ static inline bool skip_config(miopen::Handle& handle,
ctx.disable_perfdb_access = true;
ctx.SetStream(&handle);
problem.conv_problem.SetupFloats(ctx);
ctx.DetectRocm();

return ctx.GetStream().GetDeviceName() == "gfx908" && problem.Is2d() && problem.IsFp16() &&
problem.IsLayoutDefault() && ctx.use_hip_kernels && problem.GetGroupCount() == 1 &&
Expand Down Expand Up @@ -547,7 +544,7 @@ struct verify_forward_conv : conv_base<T, Tout>
std::vector<char> ws;
miopen::Allocator::ManageDataPtr ws_dev = nullptr;

const auto ctx = ExecutionContext{&handle}.DetectRocm();
const auto ctx = ExecutionContext{&handle};
const auto problem = ConvProblemDescription{
input.desc,
weights.desc,
Expand Down Expand Up @@ -1035,7 +1032,7 @@ struct verify_backward_conv : conv_base<T>
bool fallback_path_taken = false;
std::size_t count = 0;

const auto ctx = ExecutionContext{&handle}.DetectRocm();
const auto ctx = ExecutionContext{&handle};
const auto problem = ConvProblemDescription{
out.desc,
weights.desc,
Expand Down Expand Up @@ -1405,7 +1402,7 @@ struct verify_backward_weights_conv : conv_base<T>
bool fallback_path_taken = false;
std::size_t count = 0;

const auto ctx = ExecutionContext{&handle}.DetectRocm();
const auto ctx = ExecutionContext{&handle};
const auto problem =
ConvProblemDescription{filter.mode != miopenTranspose ? out.desc : input.desc,
rweights.desc,
Expand Down Expand Up @@ -1666,7 +1663,7 @@ struct verify_forward_conv_int8 : conv_base<T>
auto in_vpad_dev = handle.Write(input_vpad.data);
auto wei_vpad_dev = handle.Write(weights_vpad.data);

const auto ctx = ExecutionContext{&handle}.DetectRocm();
const auto ctx = ExecutionContext{&handle};
const auto problem = ConvProblemDescription{
is_transform ? weight_vpad_desc : weights.desc,
is_transform ? input_vpad_desc : input.desc,
Expand Down Expand Up @@ -2275,7 +2272,6 @@ struct conv_driver : test_driver
};

auto ctx = miopen::ExecutionContext{&get_handle()};
ctx.DetectRocm();

bool skip_forward = false;

Expand Down
1 change: 0 additions & 1 deletion test/embed_sqlite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ struct EmbedSQLite : test_driver
const auto problem = miopen::ProblemDescription{conv_problem};
miopen::ConvolutionContext ctx{};
ctx.SetStream(&handle);
ctx.DetectRocm();
// Check PerfDb
{
// Get filename for the sys db
Expand Down
6 changes: 3 additions & 3 deletions test/find_db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ struct FindDbTest : test_driver
{
MIOPEN_LOG_I("Starting backward find-db test.");

const auto ctx = ExecutionContext{&handle}.DetectRocm();
const auto ctx = ExecutionContext{&handle};
const auto problem =
conv::ProblemDescription{y.desc, w.desc, x.desc, filter, conv::Direction::BackwardData};
const auto workspace_size = filter.GetWorkSpaceSize(ctx, problem);
Expand Down Expand Up @@ -137,7 +137,7 @@ struct FindDbTest : test_driver
{
std::cout << "Starting forward find-db test." << std::endl;

const auto ctx = ExecutionContext{&handle}.DetectRocm();
const auto ctx = ExecutionContext{&handle};
const auto problem =
conv::ProblemDescription{x.desc, w.desc, y.desc, filter, conv::Direction::Forward};
const auto workspace_size = filter.GetWorkSpaceSize(ctx, problem);
Expand Down Expand Up @@ -171,7 +171,7 @@ struct FindDbTest : test_driver
{
MIOPEN_LOG_I("Starting wrw find-db test.");

const auto ctx = ExecutionContext{&handle}.DetectRocm();
const auto ctx = ExecutionContext{&handle};
const auto problem = conv::ProblemDescription{
y.desc, w.desc, x.desc, filter, conv::Direction::BackwardWeights};
const auto workspace_size = filter.GetWorkSpaceSize(ctx, problem);
Expand Down
3 changes: 0 additions & 3 deletions test/gpu_conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ bool gpu_ref_convolution_fwd(const tensor<Tin>& input,
input.desc, weights.desc, rout.desc, filter, miopen::conv::Direction::Forward};
auto ctx = miopen::ConvolutionContext{};
ctx.SetStream(&handle);
ctx.DetectRocm();
if(naive_solver.IsApplicable(ctx, problem))
{
gpu_ref_used = true;
Expand Down Expand Up @@ -128,7 +127,6 @@ bool gpu_ref_convolution_bwd(tensor<Tin>& input,
output.desc, weights.desc, input.desc, filter, miopen::conv::Direction::BackwardData};
auto ctx = miopen::ConvolutionContext{};
ctx.SetStream(&handle);
ctx.DetectRocm();
if(naive_solver.IsApplicable(ctx, problem))
{
gpu_ref_used = true;
Expand Down Expand Up @@ -169,7 +167,6 @@ bool gpu_ref_convolution_wrw(const tensor<Tin>& input,
miopen::conv::Direction::BackwardWeights};
auto ctx = miopen::ConvolutionContext{};
ctx.SetStream(&handle);
ctx.DetectRocm();
if(naive_solver.IsApplicable(ctx, problem))
{
gpu_ref_used = true;
Expand Down
1 change: 0 additions & 1 deletion test/gpu_nchw_nhwc_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ struct transpose_test : transpose_base

miopen::ExecutionContext ctx;
ctx.SetStream(&miopen::deref(this->handle));
ctx.DetectRocm();
// ctx.SetupFloats();

TRANSPOSE_SOL transpose_sol(ctx, to_miopen_data_type<T>::get(), n, c, h, w);
Expand Down
1 change: 0 additions & 1 deletion test/gtest/bad_fusion_plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ class TestFusionPlan
Solver solv{};
const auto fusion_problem = miopen::FusionDescription{&fusePlanDesc};
auto fusion_ctx = miopen::FusionContext{handle};
fusion_ctx.DetectRocm();

return solv.IsApplicable(fusion_ctx, fusion_problem);
}
Expand Down
2 changes: 0 additions & 2 deletions test/gtest/cba_infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ void RunSolver(miopen::FusionPlanDescriptor& fusePlanDesc,
Solver solv{};
const auto fusion_problem = miopen::FusionDescription{&fusePlanDesc};
auto fusion_ctx = miopen::FusionContext{handle};
fusion_ctx.DetectRocm();
if(!solv.IsApplicable(fusion_ctx, fusion_problem))
{
test_skipped = true;
Expand All @@ -95,7 +94,6 @@ void RunTunableSolver(miopen::FusionPlanDescriptor& fusePlanDesc,
Solver solv{};
const auto fusion_problem = miopen::FusionDescription{&fusePlanDesc};
auto fusion_ctx = miopen::FusionContext{handle};
fusion_ctx.DetectRocm();
if(!solv.IsApplicable(fusion_ctx, fusion_problem))
{
test_skipped = true;
Expand Down
1 change: 0 additions & 1 deletion test/gtest/group_conv_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ void SolverFwd(const miopen::TensorDescriptor& inputDesc,
auto ctx = miopen::ConvolutionContext{};

ctx.SetStream(&handle);
ctx.DetectRocm();

if(!solv.IsApplicable(ctx, problem))
{
Expand Down
1 change: 0 additions & 1 deletion test/gtest/kernel_tuning_net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ void TestParameterPredictionModel(miopen::ProblemDescription problem,
GTEST_SKIP();
miopen::ConvolutionContext ctx;
ctx.SetStream(&handle);
ctx.DetectRocm();
T perf_config;
bool valid = false;
perf_config.RunParmeterPredictionModel(ctx, problem, valid);
Expand Down
1 change: 0 additions & 1 deletion test/gtest/na_infer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ void RunSolver(miopen::FusionPlanDescriptor& fusePlanDesc,
Solver solv{};
const auto fusion_problem = miopen::FusionDescription{&fusePlanDesc};
auto fusion_ctx = miopen::FusionContext{handle};
fusion_ctx.DetectRocm();
if(!solv.IsApplicable(fusion_ctx, fusion_problem))
{
test_skipped = true;
Expand Down
Loading

0 comments on commit a0c73f5

Please sign in to comment.