Skip to content

Commit c3b9c88

Browse files
Revert "[BUG] [CONV] Fix incorrect stride calculation when w=1/h=1 in MISA solvers (#3786)" (#3836)
This reverts commit 8ad3741.
1 parent 2f42c3c commit c3b9c88

23 files changed

+510
-695
lines changed

src/conv/invokers/impl_gemm_dynamic.cpp

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,10 @@
77
#include <miopen/solver/implicitgemm_util.hpp>
88
#include <miopen/batched_transpose_sol.hpp>
99
#include <boost/any.hpp>
10-
#include <miopen/solver/problem_description_interpreter.hpp>
1110

1211
namespace miopen {
1312
namespace conv {
1413

15-
using ProblemInterpreter = solver::ProblemInterpreter;
16-
1714
static inline uint32_t igemm_find_tile_size_with_upper_bound(
1815
uint32_t out_size, size_t upper_bound, uint32_t stride, uint32_t dilation, uint32_t filter)
1916
{
@@ -118,21 +115,21 @@ InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ProblemDescript
118115
InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem,
119116
const int cfg)
120117
{
121-
const int hi = ProblemInterpreter::GetInputHeightHi(problem);
122-
const int wi = ProblemInterpreter::GetInputWidthWi(problem);
123-
const int n = ProblemInterpreter::GetBatchN(problem);
124-
const int k = ProblemInterpreter::GetOutputChannelK(problem);
125-
const int c = ProblemInterpreter::GetInputChannelC(problem);
126-
const int ho = ProblemInterpreter::GetOutputHeightHo(problem);
127-
const int wo = ProblemInterpreter::GetOutputWidthWo(problem);
128-
const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
129-
const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
130-
const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
131-
const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
132-
const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
133-
const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
134-
const int y = ProblemInterpreter::GetFilterHeightY(problem);
135-
const int x = ProblemInterpreter::GetFilterWidthX(problem);
118+
int hi = problem.GetOutHeight();
119+
int wi = problem.GetOutWidth();
120+
int n = problem.GetInBatchSize();
121+
int k = problem.GetInChannels();
122+
int c = problem.GetOutChannels();
123+
int ho = problem.GetInHeight();
124+
int wo = problem.GetInWidth();
125+
int stride_h = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1;
126+
int stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1;
127+
int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1;
128+
int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1;
129+
int pad_h = problem.GetPadH();
130+
int pad_w = problem.GetPadW();
131+
int y = problem.GetWeightsHeight();
132+
int x = problem.GetWeightsWidth();
136133

137134
int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
138135
int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);
@@ -252,22 +249,22 @@ InvokerFactory
252249
MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem,
253250
const solver::TunableImplicitGemmGTCDynamic_t& cfg)
254251
{
255-
const int hi = ProblemInterpreter::GetInputHeightHi(problem);
256-
const int wi = ProblemInterpreter::GetInputWidthWi(problem);
257-
const int n = ProblemInterpreter::GetBatchN(problem);
258-
const int k = ProblemInterpreter::GetOutputChannelK(problem);
259-
const int c = ProblemInterpreter::GetInputChannelC(problem);
260-
const int ho = ProblemInterpreter::GetOutputHeightHo(problem);
261-
const int wo = ProblemInterpreter::GetOutputWidthWo(problem);
262-
const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
263-
const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
264-
const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
265-
const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
266-
const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
267-
const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
268-
const int y = ProblemInterpreter::GetFilterHeightY(problem);
269-
const int x = ProblemInterpreter::GetFilterWidthX(problem);
270-
const auto group = ProblemInterpreter::GetGroupCountG(problem);
252+
int hi = problem.GetOutHeight();
253+
int wi = problem.GetOutWidth();
254+
int n = problem.GetInBatchSize();
255+
int k = problem.GetInChannels();
256+
int c = problem.GetOutChannels();
257+
int ho = problem.GetInHeight();
258+
int wo = problem.GetInWidth();
259+
int stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1;
260+
int stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1;
261+
int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1;
262+
int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1;
263+
int pad_h = problem.GetPadH();
264+
int pad_w = problem.GetPadW();
265+
int y = problem.GetWeightsHeight();
266+
int x = problem.GetWeightsWidth();
267+
int group = problem.GetGroupCount();
271268

272269
int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
273270
int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);
@@ -730,22 +727,22 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(
730727
const ProblemDescription& problem,
731728
const solver::conv::PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config)
732729
{
733-
const int hi = ProblemInterpreter::GetInputHeightHi(problem);
734-
const int wi = ProblemInterpreter::GetInputWidthWi(problem);
735-
const int n = ProblemInterpreter::GetBatchN(problem);
736-
const int k = ProblemInterpreter::GetOutputChannelK(problem);
737-
const int c = ProblemInterpreter::GetInputChannelC(problem);
738-
const int ho = ProblemInterpreter::GetOutputHeightHo(problem);
739-
const int wo = ProblemInterpreter::GetOutputWidthWo(problem);
740-
const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
741-
const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
742-
const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
743-
const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
744-
const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
745-
const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
746-
const int y = ProblemInterpreter::GetFilterHeightY(problem);
747-
const int x = ProblemInterpreter::GetFilterWidthX(problem);
748-
const auto group = ProblemInterpreter::GetGroupCountG(problem);
730+
int hi = problem.GetOutHeight();
731+
int wi = problem.GetOutWidth();
732+
int n = problem.GetInBatchSize();
733+
int k = problem.GetInChannels();
734+
int c = problem.GetOutChannels();
735+
int ho = problem.GetInHeight();
736+
int wo = problem.GetInWidth();
737+
int stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1;
738+
int stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1;
739+
int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1;
740+
int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1;
741+
int pad_h = problem.GetPadH();
742+
int pad_w = problem.GetPadW();
743+
int y = problem.GetWeightsHeight();
744+
int x = problem.GetWeightsWidth();
745+
int group = problem.GetGroupCount();
749746

750747
int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
751748
int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);

src/include/miopen/conv/asm_implicit_gemm.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@
4343
/// https://github.com/ROCm/MIOpen/issues/2624
4444
#define WORKAROUND_ISSUE_2624 1
4545

46+
/// W/A for issue 2624: asm igemm wrw computation error with stride=2, padding=2, filter=3, h=w=1
47+
/// https://github.com/ROCm/MIOpen/issues/2867
48+
#define WORKAROUND_ISSUE_2867 1
49+
4650
namespace miopen {
4751

4852
namespace solver {

src/include/miopen/solver/problem_description_interpreter.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,6 @@ struct ProblemInterpreter
155155
return problem.IsDirectionForward() ? problem.GetInDataType() : problem.GetOutDataType();
156156
}
157157

158-
static auto GetWeightsDataType(const miopen::conv::ProblemDescription& problem)
159-
{
160-
return problem.GetWeightsDataType();
161-
}
162-
163158
static int GetFilterDepthZ(const miopen::conv::ProblemDescription& problem)
164159
{
165160
return problem.GetWeightsDepth();

src/solver/conv/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#include <miopen/gcn_asm_utils.hpp>
3232
#include <algorithm>
3333
#include <miopen/solver/implicitgemm_util.hpp>
34-
#include <miopen/solver/problem_description_interpreter.hpp>
3534

3635
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_V4R1)
3736

@@ -49,21 +48,21 @@ static inline bool FindImplicitGemmDynamicKernelBwd(const ProblemDescription& pr
4948
// TODO: add more dynamic kernel to expand support range, and update this function
5049
// clang-format off
5150
// refer to ProblemInterpreter, in bwd most dimension is reversed
52-
int hi = ProblemInterpreter::GetInputHeightHi(problem);
53-
int wi = ProblemInterpreter::GetInputWidthWi(problem);
54-
int n = ProblemInterpreter::GetBatchN(problem);
55-
int k = ProblemInterpreter::GetOutputChannelK(problem);
56-
int c = ProblemInterpreter::GetInputChannelC(problem);
57-
int ho = ProblemInterpreter::GetOutputHeightHo(problem);
58-
int wo = ProblemInterpreter::GetOutputWidthWo(problem);
59-
int stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
60-
int stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
61-
int dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
62-
int dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
63-
int pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
64-
int pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
65-
int y = ProblemInterpreter::GetFilterHeightY(problem);
66-
int x = ProblemInterpreter::GetFilterWidthX(problem);
51+
int hi = problem.GetOutHeight();
52+
int wi = problem.GetOutWidth();
53+
int n = problem.GetBatchSize();
54+
int k = problem.GetInChannels();
55+
int c = problem.GetOutChannels();
56+
int ho = problem.GetInHeight();
57+
int wo = problem.GetInWidth();
58+
int stride_h = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1;
59+
int stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1;
60+
int dilation_h = problem.GetWeightsHeight() > 1? problem.GetDilationH() : 1;
61+
int dilation_w = problem.GetWeightsWidth() > 1? problem.GetDilationW() : 1;
62+
int pad_h = problem.GetPadH();
63+
int pad_w = problem.GetPadW();
64+
int y = problem.GetWeightsHeight();
65+
int x = problem.GetWeightsWidth();
6766

6867
int gcd_stride_dilation_h = gcd(stride_h, dilation_h);
6968
int gcd_stride_dilation_w = gcd(stride_w, dilation_w);

src/solver/conv/conv_asm_implicit_gemm_gtc_bwd.cpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#include <miopen/gcn_asm_utils.hpp>
3232
#include <miopen/solver/implicitgemm_util.hpp>
3333
#include <miopen/conv/asm_implicit_gemm.hpp>
34-
#include <miopen/solver/problem_description_interpreter.hpp>
3534

3635
MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS)
3736

@@ -794,22 +793,22 @@ FindImplicitGemmGtcDynamicBwdKernel(const ProblemDescription& problem)
794793
auto tunables = GetImplicitGemmGtcDynamicBwdTunablesList(problem);
795794

796795
// so far, "group" is only supported by bwd fp16 kernels
797-
const auto group = problem.IsFp16() ? ProblemInterpreter::GetGroupCountG(problem) : 1;
798-
const int hi = ProblemInterpreter::GetInputHeightHi(problem);
799-
const int wi = ProblemInterpreter::GetInputWidthWi(problem);
800-
const int n = ProblemInterpreter::GetBatchN(problem);
801-
const int k = ProblemInterpreter::GetOutputChannelK(problem) / group;
802-
const int c = ProblemInterpreter::GetInputChannelC(problem) / group;
803-
const int ho = ProblemInterpreter::GetOutputHeightHo(problem);
804-
const int wo = ProblemInterpreter::GetOutputWidthWo(problem);
805-
const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
806-
const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
807-
const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
808-
const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
809-
const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
810-
const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
811-
const int y = ProblemInterpreter::GetFilterHeightY(problem);
812-
const int x = ProblemInterpreter::GetFilterWidthX(problem);
796+
const auto group = problem.IsFp16() ? problem.GetGroupCount() : 1;
797+
const int hi = problem.GetOutHeight();
798+
const int wi = problem.GetOutWidth();
799+
const int n = problem.GetBatchSize();
800+
const int k = problem.GetInChannels() / group;
801+
const int c = problem.GetOutChannels() / group;
802+
const int ho = problem.GetInHeight();
803+
const int wo = problem.GetInWidth();
804+
const auto stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1;
805+
const auto stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1;
806+
const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1;
807+
const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1;
808+
const auto pad_h = problem.GetPadH();
809+
const auto pad_w = problem.GetPadW();
810+
const int y = problem.GetWeightsHeight();
811+
const int x = problem.GetWeightsWidth();
813812

814813
const auto gcd_stride_dilation_h = gcd(stride_h, dilation_h);
815814
const auto gcd_stride_dilation_w = gcd(stride_w, dilation_w);
@@ -1013,7 +1012,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext
10131012
return false;
10141013

10151014
// So far, "group" is not supported by the bwd fp32 kernels
1016-
if(problem.IsFp32() && ProblemInterpreter::GetGroupCountG(problem) != 1)
1015+
if(problem.IsFp32() && problem.GetGroupCount() != 1)
10171016
return false;
10181017

10191018
if(!problem.IsLayoutDefault())

0 commit comments

Comments
 (0)