Skip to content
Merged
97 changes: 50 additions & 47 deletions src/conv/invokers/impl_gemm_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
#include <miopen/solver/implicitgemm_util.hpp>
#include <miopen/batched_transpose_sol.hpp>
#include <boost/any.hpp>
#include <miopen/solver/problem_description_interpreter.hpp>

namespace miopen {
namespace conv {

using ProblemInterpreter = solver::ProblemInterpreter;

static inline uint32_t igemm_find_tile_size_with_upper_bound(
uint32_t out_size, size_t upper_bound, uint32_t stride, uint32_t dilation, uint32_t filter)
{
Expand Down Expand Up @@ -115,21 +118,21 @@ InvokerFactory MakeImplGemmDynamicForward1x1InvokerFactory(const ProblemDescript
InvokerFactory MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem,
const int cfg)
{
int hi = problem.GetOutHeight();
int wi = problem.GetOutWidth();
int n = problem.GetInBatchSize();
int k = problem.GetInChannels();
int c = problem.GetOutChannels();
int ho = problem.GetInHeight();
int wo = problem.GetInWidth();
int stride_h = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1;
int stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1;
int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1;
int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1;
int pad_h = problem.GetPadH();
int pad_w = problem.GetPadW();
int y = problem.GetWeightsHeight();
int x = problem.GetWeightsWidth();
const int hi = ProblemInterpreter::GetInputHeightHi(problem);
const int wi = ProblemInterpreter::GetInputWidthWi(problem);
const int n = ProblemInterpreter::GetBatchN(problem);
const int k = ProblemInterpreter::GetOutputChannelK(problem);
const int c = ProblemInterpreter::GetInputChannelC(problem);
const int ho = ProblemInterpreter::GetOutputHeightHo(problem);
const int wo = ProblemInterpreter::GetOutputWidthWo(problem);
const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
const int y = ProblemInterpreter::GetFilterHeightY(problem);
const int x = ProblemInterpreter::GetFilterWidthX(problem);

int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);
Expand Down Expand Up @@ -249,22 +252,22 @@ InvokerFactory
MakeImplGemmDynamicBackwardDataInvokerFactory(const ProblemDescription& problem,
const solver::TunableImplicitGemmGTCDynamic_t& cfg)
{
int hi = problem.GetOutHeight();
int wi = problem.GetOutWidth();
int n = problem.GetInBatchSize();
int k = problem.GetInChannels();
int c = problem.GetOutChannels();
int ho = problem.GetInHeight();
int wo = problem.GetInWidth();
int stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1;
int stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1;
int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1;
int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1;
int pad_h = problem.GetPadH();
int pad_w = problem.GetPadW();
int y = problem.GetWeightsHeight();
int x = problem.GetWeightsWidth();
int group = problem.GetGroupCount();
const int hi = ProblemInterpreter::GetInputHeightHi(problem);
const int wi = ProblemInterpreter::GetInputWidthWi(problem);
const int n = ProblemInterpreter::GetBatchN(problem);
const int k = ProblemInterpreter::GetOutputChannelK(problem);
const int c = ProblemInterpreter::GetInputChannelC(problem);
const int ho = ProblemInterpreter::GetOutputHeightHo(problem);
const int wo = ProblemInterpreter::GetOutputWidthWo(problem);
const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
const int y = ProblemInterpreter::GetFilterHeightY(problem);
const int x = ProblemInterpreter::GetFilterWidthX(problem);
const auto group = ProblemInterpreter::GetGroupCountG(problem);

int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);
Expand Down Expand Up @@ -727,22 +730,22 @@ InvokerFactory MakeImplGemmDynamicBackwardDataXdlopsNHWCInvokerFactory(
const ProblemDescription& problem,
const solver::conv::PerformanceConfigAsmImplicitGemmGTCBwdXdlopsNHWC& config)
{
int hi = problem.GetOutHeight();
int wi = problem.GetOutWidth();
int n = problem.GetInBatchSize();
int k = problem.GetInChannels();
int c = problem.GetOutChannels();
int ho = problem.GetInHeight();
int wo = problem.GetInWidth();
int stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1;
int stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1;
int dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1;
int dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1;
int pad_h = problem.GetPadH();
int pad_w = problem.GetPadW();
int y = problem.GetWeightsHeight();
int x = problem.GetWeightsWidth();
int group = problem.GetGroupCount();
const int hi = ProblemInterpreter::GetInputHeightHi(problem);
const int wi = ProblemInterpreter::GetInputWidthWi(problem);
const int n = ProblemInterpreter::GetBatchN(problem);
const int k = ProblemInterpreter::GetOutputChannelK(problem);
const int c = ProblemInterpreter::GetInputChannelC(problem);
const int ho = ProblemInterpreter::GetOutputHeightHo(problem);
const int wo = ProblemInterpreter::GetOutputWidthWo(problem);
const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
const int y = ProblemInterpreter::GetFilterHeightY(problem);
const int x = ProblemInterpreter::GetFilterWidthX(problem);
const auto group = ProblemInterpreter::GetGroupCountG(problem);

int gcd_stride_dilation_h = solver::gcd(stride_h, dilation_h);
int gcd_stride_dilation_w = solver::gcd(stride_w, dilation_w);
Expand Down
4 changes: 0 additions & 4 deletions src/include/miopen/conv/asm_implicit_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,6 @@
/// https://github.com/ROCm/MIOpen/issues/2624
#define WORKAROUND_ISSUE_2624 1

/// W/A for issue 2624: asm igemm wrw computation error with stride=2, padding=2, filter=3, h=w=1
/// https://github.com/ROCm/MIOpen/issues/2867
#define WORKAROUND_ISSUE_2867 1

namespace miopen {

namespace solver {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ struct ProblemInterpreter
return problem.IsDirectionForward() ? problem.GetInDataType() : problem.GetOutDataType();
}

static auto GetWeightsDataType(const miopen::conv::ProblemDescription& problem)
{
return problem.GetWeightsDataType();
}

static int GetFilterDepthZ(const miopen::conv::ProblemDescription& problem)
{
return problem.GetWeightsDepth();
Expand Down
31 changes: 16 additions & 15 deletions src/solver/conv/conv_asm_implicit_gemm_bwd_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <miopen/gcn_asm_utils.hpp>
#include <algorithm>
#include <miopen/solver/implicitgemm_util.hpp>
#include <miopen/solver/problem_description_interpreter.hpp>

MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_V4R1)

Expand All @@ -48,21 +49,21 @@ static inline bool FindImplicitGemmDynamicKernelBwd(const ProblemDescription& pr
// TODO: add more dynamic kernel to expand support range, and update this function
// clang-format off
// refer to ProblemInterpreter, in bwd most dimension is reversed
int hi = problem.GetOutHeight();
int wi = problem.GetOutWidth();
int n = problem.GetBatchSize();
int k = problem.GetInChannels();
int c = problem.GetOutChannels();
int ho = problem.GetInHeight();
int wo = problem.GetInWidth();
int stride_h = problem.GetInHeight() > 1 ? problem.GetKernelStrideH() : 1;
int stride_w = problem.GetInWidth() > 1 ? problem.GetKernelStrideW() : 1;
int dilation_h = problem.GetWeightsHeight() > 1? problem.GetDilationH() : 1;
int dilation_w = problem.GetWeightsWidth() > 1? problem.GetDilationW() : 1;
int pad_h = problem.GetPadH();
int pad_w = problem.GetPadW();
int y = problem.GetWeightsHeight();
int x = problem.GetWeightsWidth();
int hi = ProblemInterpreter::GetInputHeightHi(problem);
int wi = ProblemInterpreter::GetInputWidthWi(problem);
int n = ProblemInterpreter::GetBatchN(problem);
int k = ProblemInterpreter::GetOutputChannelK(problem);
int c = ProblemInterpreter::GetInputChannelC(problem);
int ho = ProblemInterpreter::GetOutputHeightHo(problem);
int wo = ProblemInterpreter::GetOutputWidthWo(problem);
int stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
int stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
int dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
int dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
int pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
int pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
int y = ProblemInterpreter::GetFilterHeightY(problem);
int x = ProblemInterpreter::GetFilterWidthX(problem);

int gcd_stride_dilation_h = gcd(stride_h, dilation_h);
int gcd_stride_dilation_w = gcd(stride_w, dilation_w);
Expand Down
35 changes: 18 additions & 17 deletions src/solver/conv/conv_asm_implicit_gemm_gtc_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <miopen/gcn_asm_utils.hpp>
#include <miopen/solver/implicitgemm_util.hpp>
#include <miopen/conv/asm_implicit_gemm.hpp>
#include <miopen/solver/problem_description_interpreter.hpp>

MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_ASM_BWD_GTC_XDLOPS)

Expand Down Expand Up @@ -793,22 +794,22 @@ FindImplicitGemmGtcDynamicBwdKernel(const ProblemDescription& problem)
auto tunables = GetImplicitGemmGtcDynamicBwdTunablesList(problem);

// so far, "group" is only supported by bwd fp16 kernels
const auto group = problem.IsFp16() ? problem.GetGroupCount() : 1;
const int hi = problem.GetOutHeight();
const int wi = problem.GetOutWidth();
const int n = problem.GetBatchSize();
const int k = problem.GetInChannels() / group;
const int c = problem.GetOutChannels() / group;
const int ho = problem.GetInHeight();
const int wo = problem.GetInWidth();
const auto stride_h = problem.GetOutHeight() > 1 ? problem.GetKernelStrideH() : 1;
const auto stride_w = problem.GetOutWidth() > 1 ? problem.GetKernelStrideW() : 1;
const auto dilation_h = problem.GetWeightsHeight() > 1 ? problem.GetDilationH() : 1;
const auto dilation_w = problem.GetWeightsWidth() > 1 ? problem.GetDilationW() : 1;
const auto pad_h = problem.GetPadH();
const auto pad_w = problem.GetPadW();
const int y = problem.GetWeightsHeight();
const int x = problem.GetWeightsWidth();
const auto group = problem.IsFp16() ? ProblemInterpreter::GetGroupCountG(problem) : 1;
const int hi = ProblemInterpreter::GetInputHeightHi(problem);
const int wi = ProblemInterpreter::GetInputWidthWi(problem);
const int n = ProblemInterpreter::GetBatchN(problem);
const int k = ProblemInterpreter::GetOutputChannelK(problem) / group;
const int c = ProblemInterpreter::GetInputChannelC(problem) / group;
const int ho = ProblemInterpreter::GetOutputHeightHo(problem);
const int wo = ProblemInterpreter::GetOutputWidthWo(problem);
const auto stride_h = ProblemInterpreter::GetAdjustedConvolutionStrideH(problem);
const auto stride_w = ProblemInterpreter::GetAdjustedConvolutionStrideW(problem);
const auto dilation_h = ProblemInterpreter::GetAdjustedConvolutionDilationH(problem);
const auto dilation_w = ProblemInterpreter::GetAdjustedConvolutionDilationW(problem);
const auto pad_h = ProblemInterpreter::GetInputLeftPadH(problem);
const auto pad_w = ProblemInterpreter::GetInputLeftPadW(problem);
const int y = ProblemInterpreter::GetFilterHeightY(problem);
const int x = ProblemInterpreter::GetFilterWidthX(problem);

const auto gcd_stride_dilation_h = gcd(stride_h, dilation_h);
const auto gcd_stride_dilation_w = gcd(stride_w, dilation_w);
Expand Down Expand Up @@ -1012,7 +1013,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext
return false;

// So far, "group" is not supported by the bwd fp32 kernels
if(problem.IsFp32() && problem.GetGroupCount() != 1)
if(problem.IsFp32() && ProblemInterpreter::GetGroupCountG(problem) != 1)
return false;

if(!problem.IsLayoutDefault())
Expand Down
Loading
Loading