|
64 | 64 | #include <iomanip>
|
65 | 65 |
|
66 | 66 | namespace {
|
| 67 | + |
| 68 | +using dims_map = std::unordered_map<std::string, std::vector<std::size_t>>; |
| 69 | + |
67 | 70 | std::vector<std::string>
|
68 | 71 | get_unrecognized_migraphx_envs(const char* envp[],
|
69 | 72 | const std::map<std::string, std::string>& used_env)
|
@@ -213,7 +216,7 @@ struct loader
|
213 | 216 |
|
214 | 217 | static auto parse_param_dims(const std::vector<std::string>& param_dims_info)
|
215 | 218 | {
|
216 |
| - std::unordered_map<std::string, std::vector<std::size_t>> map_input_dims; |
| 219 | + dims_map map_input_dims; |
217 | 220 | std::string name = "";
|
218 | 221 | for(auto&& x : param_dims_info)
|
219 | 222 | {
|
@@ -502,16 +505,24 @@ struct program_params
|
502 | 505 | return map_load_args;
|
503 | 506 | }
|
504 | 507 |
|
505 |
| - auto generate(const program& p, const target& t, bool offload, unsigned batch) |
| 508 | + auto generate(const program& p, |
| 509 | + const target& t, |
| 510 | + bool offload, |
| 511 | + unsigned batch, |
| 512 | + dims_map map_input_dims = {}) |
506 | 513 | {
|
507 | 514 | parameter_map m;
|
508 | 515 | auto param_shapes = p.get_parameter_shapes();
|
509 | 516 | std::unordered_map<std::string, shape> static_param_shapes;
|
510 |
| - std::transform( |
511 |
| - param_shapes.cbegin(), |
512 |
| - param_shapes.cend(), |
513 |
| - std::inserter(static_param_shapes, static_param_shapes.end()), |
514 |
| - [&](const auto& x) { return std::make_pair(x.first, x.second.to_static(batch)); }); |
| 517 | + for(auto&& param : param_shapes) |
| 518 | + { |
| 519 | + if(contains(map_input_dims, param.first)) |
| 520 | + static_param_shapes[param.first] = {param.second.type(), |
| 521 | + map_input_dims[param.first]}; |
| 522 | + else |
| 523 | + static_param_shapes[param.first] = param.second.to_static(batch); |
| 524 | + } |
| 525 | + |
515 | 526 | for(auto&& s : fill0)
|
516 | 527 | m[s] = fill_argument(static_param_shapes.at(s), 0);
|
517 | 528 | for(auto&& s : fill1)
|
@@ -591,7 +602,8 @@ struct compiler
|
591 | 602 |
|
592 | 603 | auto params(const program& p)
|
593 | 604 | {
|
594 |
| - return parameters.generate(p, ct.get_target(), co.offload_copy, l.batch); |
| 605 | + return parameters.generate( |
| 606 | + p, ct.get_target(), co.offload_copy, l.batch, loader::parse_param_dims(l.param_dims)); |
595 | 607 | }
|
596 | 608 |
|
597 | 609 | auto host_params(const program& p)
|
@@ -730,7 +742,8 @@ struct verify : command<verify>
|
730 | 742 | std::cout << p << std::endl;
|
731 | 743 |
|
732 | 744 | auto t = c.ct.get_target();
|
733 |
| - auto m = c.parameters.generate(p, t, true, c.l.batch); |
| 745 | + auto m = |
| 746 | + c.parameters.generate(p, t, true, c.l.batch, loader::parse_param_dims(c.l.param_dims)); |
734 | 747 |
|
735 | 748 | if(c.to_fp16)
|
736 | 749 | {
|
|
0 commit comments