Skip to content

Commit 5477811

Browse files
authored
Grouped Conv Bwd Data out index calculation optimizations (#2917)
* Grouped Conv Bwd Data index calculation optimizations * fixes * refactor instances * gfx12 fixes * temporary disable splitK for gfx12
1 parent 0f10e6d commit 5477811

17 files changed

+895
-75
lines changed

include/ck/tensor_description/multi_index_transform.hpp

Lines changed: 193 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#pragma once
55

@@ -1553,6 +1553,198 @@ struct UnMerge
15531553
}
15541554
};
15551555

1556+
/**
1557+
* @brief Transformation struct for convolution backward data output indices to GEMM indices.
1558+
*
1559+
* This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the
1560+
* convolution backward data operation to the corresponding indices (K0, M, K1) used in the
1561+
* implicit GEMM computation. It encapsulates the necessary parameters and transformation logic
1562+
* required to efficiently perform the index conversion.
1563+
*/
1564+
struct ConvBwdDataImplicitGemmOutTransform
1565+
{
1566+
static constexpr auto I0 = Number<0>{};
1567+
static constexpr auto I1 = Number<1>{};
1568+
static constexpr auto I2 = Number<2>{};
1569+
static constexpr auto I3 = Number<3>{};
1570+
1571+
using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K
1572+
using UpperIndex = MultiIndex<3>; // K0, M, K1
1573+
1574+
index_t N_, Ho_, Wo_, K_;
1575+
index_t XDot_;
1576+
index_t HTilde_, WTilde_;
1577+
index_t WTildeSlice_, TildeSlice_;
1578+
index_t IHTildeSliceBegin_, IWTildeSliceBegin_;
1579+
index_t HRatio_, WRatio_;
1580+
index_t XDotSlice_K_;
1581+
index_t MPad_, KPad_;
1582+
Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_;
1583+
1584+
Tuple<index_t, index_t, index_t, index_t>
1585+
low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
1586+
Tuple<index_t, index_t, index_t, index_t>
1587+
low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_
1588+
1589+
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default;
1590+
1591+
__host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N,
1592+
index_t Ho,
1593+
index_t Wo,
1594+
index_t K,
1595+
index_t XDot,
1596+
index_t HTilde,
1597+
index_t WTilde,
1598+
index_t WTildeSlice,
1599+
index_t HWTildeSlice,
1600+
index_t IHTildeSliceBegin,
1601+
index_t IWTildeSliceBegin,
1602+
index_t HRatio,
1603+
index_t WRatio,
1604+
index_t XDotSlice_K,
1605+
index_t K0,
1606+
index_t MPadded,
1607+
index_t K1,
1608+
index_t MPad,
1609+
index_t KPad)
1610+
: N_{N},
1611+
Ho_{Ho},
1612+
Wo_{Wo},
1613+
K_{K},
1614+
XDot_{XDot},
1615+
HTilde_{HTilde},
1616+
WTilde_{WTilde},
1617+
WTildeSlice_{WTildeSlice},
1618+
TildeSlice_{HWTildeSlice},
1619+
IHTildeSliceBegin_{IHTildeSliceBegin},
1620+
IWTildeSliceBegin_{IWTildeSliceBegin},
1621+
HRatio_{HRatio},
1622+
WRatio_{WRatio},
1623+
XDotSlice_K_{XDotSlice_K},
1624+
MPad_{MPad},
1625+
KPad_{KPad},
1626+
up_lengths_{make_tuple(K0, MPadded, K1)},
1627+
low_lengths_magic_divisor_multiplier_{
1628+
MagicDivision::CalculateMagicMultiplier(XDotSlice_K_),
1629+
MagicDivision::CalculateMagicMultiplier(K_),
1630+
MagicDivision::CalculateMagicMultiplier(TildeSlice_),
1631+
MagicDivision::CalculateMagicMultiplier(WTildeSlice_)},
1632+
low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_),
1633+
MagicDivision::CalculateMagicShift(K_),
1634+
MagicDivision::CalculateMagicShift(TildeSlice_),
1635+
MagicDivision::CalculateMagicShift(WTildeSlice_)}
1636+
{
1637+
}
1638+
1639+
__host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; }
1640+
1641+
__host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; }
1642+
1643+
__host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; }
1644+
1645+
template <typename UpIdx>
1646+
__host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const
1647+
{
1648+
index_t NStep, HStep, WStep;
1649+
// Merge
1650+
// NStep = M_id / TildeSlice_
1651+
NStep = MagicDivision::DoMagicDivision(idx_up[I1],
1652+
this->low_lengths_magic_divisor_multiplier_[I2],
1653+
this->low_lengths_magic_divisor_shift_[I2]);
1654+
HStep = idx_up[I1] - NStep * TildeSlice_;
1655+
// HStep = HStep / WTildeSlice_
1656+
HStep = MagicDivision::DoMagicDivision(HStep,
1657+
this->low_lengths_magic_divisor_multiplier_[I3],
1658+
this->low_lengths_magic_divisor_shift_[I3]);
1659+
WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_;
1660+
// Slice
1661+
HStep += IHTildeSliceBegin_;
1662+
WStep += IWTildeSliceBegin_;
1663+
1664+
return make_tuple(NStep, HStep, WStep, 0);
1665+
}
1666+
1667+
template <typename UpIdx>
1668+
__host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const
1669+
{
1670+
// UnMerge
1671+
// K_idx <- K0_idx * K1 + K1_idx
1672+
index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2];
1673+
// Merge
1674+
// YStep = K_idx / XDotSlice_K_
1675+
index_t YStep =
1676+
MagicDivision::DoMagicDivision(K_idx,
1677+
this->low_lengths_magic_divisor_multiplier_[I0],
1678+
this->low_lengths_magic_divisor_shift_[I0]);
1679+
index_t KStep = K_idx - YStep * XDotSlice_K_;
1680+
// Xstep = KStep / K_
1681+
index_t XStep =
1682+
MagicDivision::DoMagicDivision(KStep,
1683+
this->low_lengths_magic_divisor_multiplier_[I1],
1684+
this->low_lengths_magic_divisor_shift_[I1]);
1685+
KStep -= XStep * K_;
1686+
// Embed
1687+
YStep *= HRatio_;
1688+
XStep *= WRatio_;
1689+
1690+
return make_tuple(0, YStep, XStep, KStep);
1691+
}
1692+
1693+
template <typename LowIdx, typename UpIdx>
1694+
__host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low,
1695+
const UpIdx& idx_up) const
1696+
{
1697+
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
1698+
}
1699+
1700+
template <typename LowIdxDiff,
1701+
typename UpIdxDiff,
1702+
typename LowIdx,
1703+
typename UpIdx,
1704+
index_t Hack>
1705+
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
1706+
const UpIdxDiff& /* idx_diff_up */,
1707+
LowIdx& idx_low,
1708+
const UpIdx& idx_up,
1709+
Number<Hack>) const
1710+
{
1711+
LowIdx low_old = idx_low;
1712+
idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up);
1713+
idx_diff_low = idx_low - low_old;
1714+
}
1715+
1716+
__host__ __device__ static constexpr bool IsLinearTransform() { return false; }
1717+
1718+
__host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex()
1719+
{
1720+
return true;
1721+
}
1722+
1723+
template <typename UpIdx>
1724+
__host__ __device__ constexpr bool
1725+
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
1726+
{
1727+
// Padding
1728+
index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}];
1729+
index_t& M_idx = idx_up[Number<1>{}];
1730+
1731+
bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ &&
1732+
K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_;
1733+
return pad_valid;
1734+
}
1735+
1736+
__host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; }
1737+
1738+
__host__ __device__ void Print() const
1739+
{
1740+
printf("{");
1741+
printf("ConvBwdDataImplicitGemmOutTransform, ");
1742+
printf("up_lengths_");
1743+
print_multi_index(up_lengths_);
1744+
printf("}");
1745+
}
1746+
};
1747+
15561748
template <typename LowerIndex>
15571749
struct Freeze
15581750
{

include/ck/tensor_description/multi_index_transform_helper.hpp

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// SPDX-License-Identifier: MIT
2-
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
2+
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
33

44
#pragma once
55

@@ -94,6 +94,59 @@ __host__ __device__ constexpr auto make_unmerge_transform(
9494
return UnMerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
9595
}
9696

97+
__host__ __device__ constexpr auto make_conv_bwd_data_out_transform(index_t N,
98+
index_t Ho,
99+
index_t Wo,
100+
index_t K,
101+
[[maybe_unused]] index_t YDot,
102+
index_t XDot,
103+
index_t HTilde,
104+
index_t WTilde,
105+
index_t ConvDilationH,
106+
index_t ConvDilationW,
107+
index_t HTildeSlice,
108+
index_t WTildeSlice,
109+
index_t YDotSlice,
110+
index_t XDotSlice,
111+
index_t IHTildeSliceBegin,
112+
index_t IWTildeSliceBegin,
113+
index_t GcdStrideDilationH,
114+
index_t GcdStrideDilationW,
115+
index_t K0,
116+
index_t K1,
117+
index_t MPerBlock,
118+
index_t GemmKPerBlock)
119+
{
120+
// Calculate padding
121+
const auto MRaw = N * HTildeSlice * WTildeSlice;
122+
const auto MPadded = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock;
123+
const auto MPad = MPadded - MRaw;
124+
125+
const auto KRaw = YDotSlice * XDotSlice * K;
126+
const auto KPadded = math::integer_divide_ceil(KRaw, GemmKPerBlock) * GemmKPerBlock;
127+
const auto KPad = KPadded - KRaw;
128+
129+
return ConvBwdDataImplicitGemmOutTransform{N,
130+
Ho,
131+
Wo,
132+
K,
133+
XDot,
134+
HTilde,
135+
WTilde,
136+
WTildeSlice,
137+
HTildeSlice * WTildeSlice,
138+
IHTildeSliceBegin,
139+
IWTildeSliceBegin,
140+
-ConvDilationH / GcdStrideDilationH,
141+
-ConvDilationW / GcdStrideDilationW,
142+
XDotSlice * K,
143+
K0,
144+
MPadded,
145+
K1,
146+
MPad,
147+
KPad};
148+
}
149+
97150
template <typename LowerIndex>
98151
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx)
99152
{

include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1485,7 +1485,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
14851485
static bool IsSupportedArgument(const Argument& arg)
14861486
{
14871487
// gfx11 doesn't support float atomic
1488-
if(ck::is_gfx11_supported() && arg.k_batch_ > 1)
1488+
// Todo: Enable splitK for gfx12
1489+
if((ck::is_gfx12_supported() || ck::is_gfx11_supported()) && arg.k_batch_ > 1)
14891490
{
14901491
return false;
14911492
}

include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
namespace ck {
1414
namespace tensor_operation {
1515

16+
/**
17+
* @brief Enable custom tensor transform for convolution backward data output.
18+
*
19+
* When set to 1, this macro enables a custom transformation of the output tensor
20+
* in convolution backward data operations.
21+
*/
22+
#define CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT 1
23+
1624
template <
1725
index_t NDimSpatial,
1826
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
@@ -705,6 +713,12 @@ struct TransformConvBwdDataToGemm_v1
705713

706714
if constexpr(NDimSpatial == 2)
707715
{
716+
const index_t K0PerBlock = GemmKPerBlock / AK1;
717+
const index_t AK0 = math::integer_divide_ceil(YDotSlice * XDotSlice * K_,
718+
AK1 * K0PerBlock * batch_k_) *
719+
K0PerBlock;
720+
721+
#if CK_USE_CUSTOM_TENSOR_TRANSFORM_FOR_BWD_DATA_OUT == 0
708722
// A: output tensor
709723
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
710724
out_grid_desc,
@@ -762,21 +776,53 @@ struct TransformConvBwdDataToGemm_v1
762776
make_tuple(GemmKPerBlock, GemmMPerBlock),
763777
Sequence<true, DoPadGemmM>{});
764778

765-
const index_t K0PerBlock = GemmKPerBlock / AK1;
766-
const index_t AK0 =
767-
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
768-
AK1 * K0PerBlock * batch_k_) *
769-
K0PerBlock;
770-
771779
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
772780
out_gemmk_gemmm_padded_grid_desc,
773781
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
774782
make_pass_through_transform(
775783
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
776784
make_tuple(Sequence<0>{}, Sequence<1>{}),
777785
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
778-
779786
return out_gemmak0_gemmm_gemmak1_grid_desc;
787+
#else
788+
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
789+
out_grid_desc,
790+
make_tuple(make_pass_through_transform(N_),
791+
make_pad_transform(Ho_, I0, I0),
792+
make_pad_transform(Wo_, I0, I0),
793+
make_pass_through_transform(K_)),
794+
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
795+
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
796+
797+
const auto out_n_hop_wop_k_grid_desc_final = transform_tensor_descriptor(
798+
out_n_hop_wop_k_grid_desc,
799+
make_tuple(make_conv_bwd_data_out_transform(N_,
800+
Ho_,
801+
Wo_,
802+
K_,
803+
YDot_,
804+
XDot_,
805+
HTilde_,
806+
WTilde_,
807+
ConvDilationH_,
808+
ConvDilationW_,
809+
HTildeSlice,
810+
WTildeSlice,
811+
YDotSlice,
812+
XDotSlice,
813+
IHTildeSliceBegin,
814+
IWTildeSliceBegin,
815+
GcdStrideDilationH_,
816+
GcdStrideDilationW_,
817+
AK0,
818+
AK1,
819+
GemmMPerBlock,
820+
GemmKPerBlock)),
821+
make_tuple(Sequence<0, 1, 2, 3>{}),
822+
make_tuple(Sequence<0, 1, 2>{}));
823+
824+
return out_n_hop_wop_k_grid_desc_final;
825+
#endif
780826
}
781827
else if constexpr(NDimSpatial == 3)
782828
{

0 commit comments

Comments
 (0)