diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 31cb7472c27..4f402e06b5d 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -64,6 +64,9 @@ #include namespace { + +using dims_map = std::unordered_map>; + std::vector get_unrecognized_migraphx_envs(const char* envp[], const std::map& used_env) @@ -213,7 +216,7 @@ struct loader static auto parse_param_dims(const std::vector& param_dims_info) { - std::unordered_map> map_input_dims; + dims_map map_input_dims; std::string name = ""; for(auto&& x : param_dims_info) { @@ -502,16 +505,24 @@ struct program_params return map_load_args; } - auto generate(const program& p, const target& t, bool offload, unsigned batch) + auto generate(const program& p, + const target& t, + bool offload, + unsigned batch, + dims_map map_input_dims = {}) { parameter_map m; auto param_shapes = p.get_parameter_shapes(); std::unordered_map static_param_shapes; - std::transform( - param_shapes.cbegin(), - param_shapes.cend(), - std::inserter(static_param_shapes, static_param_shapes.end()), - [&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); }); + for(auto&& param : param_shapes) + { + if(contains(map_input_dims, param.first)) + static_param_shapes[param.first] = {param.second.type(), + map_input_dims[param.first]}; + else + static_param_shapes[param.first] = param.second.to_static(batch); + } + for(auto&& s : fill0) m[s] = fill_argument(static_param_shapes.at(s), 0); for(auto&& s : fill1) @@ -591,7 +602,8 @@ struct compiler auto params(const program& p) { - return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch); + return parameters.generate( + p, ct.get_target(), co.offload_copy, l.batch, loader::parse_param_dims(l.param_dims)); } auto host_params(const program& p) @@ -730,7 +742,8 @@ struct verify : command std::cout << p << std::endl; auto t = c.ct.get_target(); - auto m = c.parameters.generate(p, t, true, c.l.batch); + auto m = + c.parameters.generate(p, t, true, c.l.batch, loader::parse_param_dims(c.l.param_dims)); if(c.to_fp16) {