|
7 | 7 | #include <miopen/solver/implicitgemm_util.hpp>
|
8 | 8 | #include <miopen/batched_transpose_sol.hpp>
|
9 | 9 | #include <boost/any.hpp>
|
10 |
| -#include <miopen/solver/problem_description_interpreter.hpp> |
11 | 10 |
|
12 | 11 | namespace miopen {
|
13 | 12 | namespace conv {
|
14 | 13 |
|
15 |
| -using ProblemInterpreter = solver::ProblemInterpreter; |
16 |
| - |
17 | 14 | static inline uint32_t igemm_find_tile_size_with_upper_bound(
|
18 | 15 | uint32_t out_size, size_t upper_bound, uint32_t stride, uint32_t dilation, uint32_t filter)
|
19 | 16 | {
|
@@ -118,21 +115,21 @@ InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ProblemDescript
|
118 | 115 | InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem,
|
119 | 116 | const int cfg)
|
120 | 117 | {
|
121 |
| - const int hi = ProblemInterpreter::GetInputHeightHi(problem); |
122 |
| - const int wi = ProblemInterpreter::GetInputWidthWi(problem); |
123 |
| - const int n = ProblemInterpreter::GetBatchN(problem); |
124 |
| - const int k = ProblemInterpreter::GetOutputChannelK(problem); |
125 |
| - const int c = ProblemInterpreter::GetInputChannelC(problem); |
126 |
| - const int ho = ProblemInterpreter::GetOutputHeightHo(problem); |
127 |
| - const int wo = ProblemInterpreter::GetOutputWidthWo(problem); |
128 |
| - const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem); |
129 |
| - const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem); |
130 |
| - const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem); |
131 |
| - const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem); |
132 |
| - const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem); |
133 |
| - const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem); |
134 |
| - const int y = ProblemInterpreter::GetFilterHeightY(problem); |
135 |
| - const int x = ProblemInterpreter::GetFilterWidthX(problem); |
| 118 | + int hi = problem.GetOutHeight(); |
| 119 | + int wi = problem.GetOutWidth(); |
| 120 | + int n = problem.GetInBatchSize(); |
| 121 | + int k = problem.GetInChannels(); |
| 122 | + int c = problem.GetOutChannels(); |
| 123 | + int ho = problem.GetInHeight(); |
| 124 | + int wo = problem.GetInWidth(); |
| 125 | + int stride_h = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1; |
| 126 | + int stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1; |
| 127 | + int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; |
| 128 | + int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; |
| 129 | + int pad_h = problem.GetPadH(); |
| 130 | + int pad_w = problem.GetPadW(); |
| 131 | + int y = problem.GetWeightsHeight(); |
| 132 | + int x = problem.GetWeightsWidth(); |
136 | 133 |
|
137 | 134 | int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
|
138 | 135 | int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);
|
@@ -252,22 +249,22 @@ InvokerFactory
|
252 | 249 | MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem,
|
253 | 250 | const solver::TunableImplicitGemmGTCDynamic_t& cfg)
|
254 | 251 | {
|
255 |
| - const int hi = ProblemInterpreter::GetInputHeightHi(problem); |
256 |
| - const int wi = ProblemInterpreter::GetInputWidthWi(problem); |
257 |
| - const int n = ProblemInterpreter::GetBatchN(problem); |
258 |
| - const int k = ProblemInterpreter::GetOutputChannelK(problem); |
259 |
| - const int c = ProblemInterpreter::GetInputChannelC(problem); |
260 |
| - const int ho = ProblemInterpreter::GetOutputHeightHo(problem); |
261 |
| - const int wo = ProblemInterpreter::GetOutputWidthWo(problem); |
262 |
| - const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem); |
263 |
| - const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem); |
264 |
| - const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem); |
265 |
| - const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem); |
266 |
| - const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem); |
267 |
| - const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem); |
268 |
| - const int y = ProblemInterpreter::GetFilterHeightY(problem); |
269 |
| - const int x = ProblemInterpreter::GetFilterWidthX(problem); |
270 |
| - const auto group = ProblemInterpreter::GetGroupCountG(problem); |
| 252 | + int hi = problem.GetOutHeight(); |
| 253 | + int wi = problem.GetOutWidth(); |
| 254 | + int n = problem.GetInBatchSize(); |
| 255 | + int k = problem.GetInChannels(); |
| 256 | + int c = problem.GetOutChannels(); |
| 257 | + int ho = problem.GetInHeight(); |
| 258 | + int wo = problem.GetInWidth(); |
| 259 | + int stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1; |
| 260 | + int stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1; |
| 261 | + int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; |
| 262 | + int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; |
| 263 | + int pad_h = problem.GetPadH(); |
| 264 | + int pad_w = problem.GetPadW(); |
| 265 | + int y = problem.GetWeightsHeight(); |
| 266 | + int x = problem.GetWeightsWidth(); |
| 267 | + int group = problem.GetGroupCount(); |
271 | 268 |
|
272 | 269 | int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
|
273 | 270 | int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);
|
@@ -730,22 +727,22 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(
|
730 | 727 | const ProblemDescription& problem,
|
731 | 728 | const solver::conv::PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config)
|
732 | 729 | {
|
733 |
| - const int hi = ProblemInterpreter::GetInputHeightHi(problem); |
734 |
| - const int wi = ProblemInterpreter::GetInputWidthWi(problem); |
735 |
| - const int n = ProblemInterpreter::GetBatchN(problem); |
736 |
| - const int k = ProblemInterpreter::GetOutputChannelK(problem); |
737 |
| - const int c = ProblemInterpreter::GetInputChannelC(problem); |
738 |
| - const int ho = ProblemInterpreter::GetOutputHeightHo(problem); |
739 |
| - const int wo = ProblemInterpreter::GetOutputWidthWo(problem); |
740 |
| - const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem); |
741 |
| - const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem); |
742 |
| - const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem); |
743 |
| - const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem); |
744 |
| - const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem); |
745 |
| - const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem); |
746 |
| - const int y = ProblemInterpreter::GetFilterHeightY(problem); |
747 |
| - const int x = ProblemInterpreter::GetFilterWidthX(problem); |
748 |
| - const auto group = ProblemInterpreter::GetGroupCountG(problem); |
| 730 | + int hi = problem.GetOutHeight(); |
| 731 | + int wi = problem.GetOutWidth(); |
| 732 | + int n = problem.GetInBatchSize(); |
| 733 | + int k = problem.GetInChannels(); |
| 734 | + int c = problem.GetOutChannels(); |
| 735 | + int ho = problem.GetInHeight(); |
| 736 | + int wo = problem.GetInWidth(); |
| 737 | + int stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1; |
| 738 | + int stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1; |
| 739 | + int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1; |
| 740 | + int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1; |
| 741 | + int pad_h = problem.GetPadH(); |
| 742 | + int pad_w = problem.GetPadW(); |
| 743 | + int y = problem.GetWeightsHeight(); |
| 744 | + int x = problem.GetWeightsWidth(); |
| 745 | + int group = problem.GetGroupCount(); |
749 | 746 |
|
750 | 747 | int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
|
751 | 748 | int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);
|
|
0 commit comments