Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions src/driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
#include <iomanip>

namespace {

using dims_map = std::unordered_map<std::string, std::vector<std::size_t>>;

std::vector<std::string>
get_unrecognized_migraphx_envs(const char* envp[],
const std::map<std::string, std::string>& used_env)
Expand Down Expand Up @@ -213,7 +216,7 @@ struct loader

static auto parse_param_dims(const std::vector<std::string>& param_dims_info)
{
std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims;
dims_map map_input_dims;
std::string name = "";
for(auto&& x : param_dims_info)
{
Expand Down Expand Up @@ -502,16 +505,24 @@ struct program_params
return map_load_args;
}

auto generate(const program& p, const target& t, bool offload, unsigned batch)
auto generate(const program& p,
const target& t,
bool offload,
unsigned batch,
dims_map map_input_dims = {})
{
parameter_map m;
auto param_shapes = p.get_parameter_shapes();
std::unordered_map<std::string, shape> static_param_shapes;
std::transform(
param_shapes.cbegin(),
param_shapes.cend(),
std::inserter(static_param_shapes, static_param_shapes.end()),
[&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); });
for(auto&& param : param_shapes)
{
if(contains(map_input_dims, param.first))
Copy link

Copilot AI Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The contains function is not defined or included. This should likely be map_input_dims.contains(param.first) for C++20 or map_input_dims.find(param.first) != map_input_dims.end() for earlier standards.

Suggested change
if(contains(map_input_dims, param.first))
if(map_input_dims.find(param.first) != map_input_dims.end())

Copilot uses AI. Check for mistakes.
static_param_shapes[param.first] = {param.second.type(),
map_input_dims[param.first]};
else
static_param_shapes[param.first] = param.second.to_static(batch);
}

for(auto&& s : fill0)
m[s] = fill_argument(static_param_shapes.at(s), 0);
for(auto&& s : fill1)
Expand Down Expand Up @@ -591,7 +602,8 @@ struct compiler

auto params(const program& p)
{
return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch);
return parameters.generate(
p, ct.get_target(), co.offload_copy, l.batch, loader::parse_param_dims(l.param_dims));
}

auto host_params(const program& p)
Expand Down Expand Up @@ -730,7 +742,8 @@ struct verify : command<verify>
std::cout << p << std::endl;

auto t = c.ct.get_target();
auto m = c.parameters.generate(p, t, true, c.l.batch);
auto m =
c.parameters.generate(p, t, true, c.l.batch, loader::parse_param_dims(c.l.param_dims));

if(c.to_fp16)
{
Expand Down
Loading