|
1 | 1 | // SPDX-License-Identifier: MIT
|
2 |
| -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. |
| 2 | +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. |
3 | 3 |
|
4 | 4 | #pragma once
|
5 | 5 |
|
@@ -1553,6 +1553,198 @@ struct UnMerge
|
1553 | 1553 | }
|
1554 | 1554 | };
|
1555 | 1555 |
|
| 1556 | +/** |
| 1557 | + * @brief Transformation struct for convolution backward data output indices to GEMM indices. |
| 1558 | + * |
| 1559 | + * This struct is responsible for mapping the output tensor indices (N, Ho, Wo, K) from the |
| 1560 | + * convolution backward data operation to the corresponding indices (K0, M, K1) used in the |
| 1561 | + * implicit GEMM computation. It encapsulates the necessary parameters and transformation logic |
| 1562 | + * required to efficiently perform the index conversion. |
| 1563 | + */ |
| 1564 | +struct ConvBwdDataImplicitGemmOutTransform |
| 1565 | +{ |
| 1566 | + static constexpr auto I0 = Number<0>{}; |
| 1567 | + static constexpr auto I1 = Number<1>{}; |
| 1568 | + static constexpr auto I2 = Number<2>{}; |
| 1569 | + static constexpr auto I3 = Number<3>{}; |
| 1570 | + |
| 1571 | + using LowerIndex = MultiIndex<4>; // N, Ho, Wo, K |
| 1572 | + using UpperIndex = MultiIndex<3>; // K0, M, K1 |
| 1573 | + |
| 1574 | + index_t N_, Ho_, Wo_, K_; |
| 1575 | + index_t XDot_; |
| 1576 | + index_t HTilde_, WTilde_; |
| 1577 | + index_t WTildeSlice_, TildeSlice_; |
| 1578 | + index_t IHTildeSliceBegin_, IWTildeSliceBegin_; |
| 1579 | + index_t HRatio_, WRatio_; |
| 1580 | + index_t XDotSlice_K_; |
| 1581 | + index_t MPad_, KPad_; |
| 1582 | + Tuple<index_t, index_t, index_t> up_lengths_; // K0_, MPadded, K1_; |
| 1583 | + |
| 1584 | + Tuple<index_t, index_t, index_t, index_t> |
| 1585 | + low_lengths_magic_divisor_multiplier_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_ |
| 1586 | + Tuple<index_t, index_t, index_t, index_t> |
| 1587 | + low_lengths_magic_divisor_shift_; // XDotSlice_K_, K_, TildeSlice_, WTildeSlice_ |
| 1588 | + |
| 1589 | + __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform() = default; |
| 1590 | + |
| 1591 | + __host__ __device__ constexpr ConvBwdDataImplicitGemmOutTransform(index_t N, |
| 1592 | + index_t Ho, |
| 1593 | + index_t Wo, |
| 1594 | + index_t K, |
| 1595 | + index_t XDot, |
| 1596 | + index_t HTilde, |
| 1597 | + index_t WTilde, |
| 1598 | + index_t WTildeSlice, |
| 1599 | + index_t HWTildeSlice, |
| 1600 | + index_t IHTildeSliceBegin, |
| 1601 | + index_t IWTildeSliceBegin, |
| 1602 | + index_t HRatio, |
| 1603 | + index_t WRatio, |
| 1604 | + index_t XDotSlice_K, |
| 1605 | + index_t K0, |
| 1606 | + index_t MPadded, |
| 1607 | + index_t K1, |
| 1608 | + index_t MPad, |
| 1609 | + index_t KPad) |
| 1610 | + : N_{N}, |
| 1611 | + Ho_{Ho}, |
| 1612 | + Wo_{Wo}, |
| 1613 | + K_{K}, |
| 1614 | + XDot_{XDot}, |
| 1615 | + HTilde_{HTilde}, |
| 1616 | + WTilde_{WTilde}, |
| 1617 | + WTildeSlice_{WTildeSlice}, |
| 1618 | + TildeSlice_{HWTildeSlice}, |
| 1619 | + IHTildeSliceBegin_{IHTildeSliceBegin}, |
| 1620 | + IWTildeSliceBegin_{IWTildeSliceBegin}, |
| 1621 | + HRatio_{HRatio}, |
| 1622 | + WRatio_{WRatio}, |
| 1623 | + XDotSlice_K_{XDotSlice_K}, |
| 1624 | + MPad_{MPad}, |
| 1625 | + KPad_{KPad}, |
| 1626 | + up_lengths_{make_tuple(K0, MPadded, K1)}, |
| 1627 | + low_lengths_magic_divisor_multiplier_{ |
| 1628 | + MagicDivision::CalculateMagicMultiplier(XDotSlice_K_), |
| 1629 | + MagicDivision::CalculateMagicMultiplier(K_), |
| 1630 | + MagicDivision::CalculateMagicMultiplier(TildeSlice_), |
| 1631 | + MagicDivision::CalculateMagicMultiplier(WTildeSlice_)}, |
| 1632 | + low_lengths_magic_divisor_shift_{MagicDivision::CalculateMagicShift(XDotSlice_K_), |
| 1633 | + MagicDivision::CalculateMagicShift(K_), |
| 1634 | + MagicDivision::CalculateMagicShift(TildeSlice_), |
| 1635 | + MagicDivision::CalculateMagicShift(WTildeSlice_)} |
| 1636 | + { |
| 1637 | + } |
| 1638 | + |
| 1639 | + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 4; } |
| 1640 | + |
| 1641 | + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 3; } |
| 1642 | + |
| 1643 | + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } |
| 1644 | + |
| 1645 | + template <typename UpIdx> |
| 1646 | + __host__ __device__ constexpr auto CalculateLowerIndexN(const UpIdx& idx_up) const |
| 1647 | + { |
| 1648 | + index_t NStep, HStep, WStep; |
| 1649 | + // Merge |
| 1650 | + // NStep = M_id / TildeSlice_ |
| 1651 | + NStep = MagicDivision::DoMagicDivision(idx_up[I1], |
| 1652 | + this->low_lengths_magic_divisor_multiplier_[I2], |
| 1653 | + this->low_lengths_magic_divisor_shift_[I2]); |
| 1654 | + HStep = idx_up[I1] - NStep * TildeSlice_; |
| 1655 | + // HStep = HStep / WTildeSlice_ |
| 1656 | + HStep = MagicDivision::DoMagicDivision(HStep, |
| 1657 | + this->low_lengths_magic_divisor_multiplier_[I3], |
| 1658 | + this->low_lengths_magic_divisor_shift_[I3]); |
| 1659 | + WStep = idx_up[I1] - NStep * TildeSlice_ - HStep * WTildeSlice_; |
| 1660 | + // Slice |
| 1661 | + HStep += IHTildeSliceBegin_; |
| 1662 | + WStep += IWTildeSliceBegin_; |
| 1663 | + |
| 1664 | + return make_tuple(NStep, HStep, WStep, 0); |
| 1665 | + } |
| 1666 | + |
| 1667 | + template <typename UpIdx> |
| 1668 | + __host__ __device__ constexpr auto CalculateLowerIndexK(const UpIdx& idx_up) const |
| 1669 | + { |
| 1670 | + // UnMerge |
| 1671 | + // K_idx <- K0_idx * K1 + K1_idx |
| 1672 | + index_t K_idx = idx_up[I0] * up_lengths_[I2] + idx_up[I2]; |
| 1673 | + // Merge |
| 1674 | + // YStep = K_idx / XDotSlice_K_ |
| 1675 | + index_t YStep = |
| 1676 | + MagicDivision::DoMagicDivision(K_idx, |
| 1677 | + this->low_lengths_magic_divisor_multiplier_[I0], |
| 1678 | + this->low_lengths_magic_divisor_shift_[I0]); |
| 1679 | + index_t KStep = K_idx - YStep * XDotSlice_K_; |
| 1680 | + // Xstep = KStep / K_ |
| 1681 | + index_t XStep = |
| 1682 | + MagicDivision::DoMagicDivision(KStep, |
| 1683 | + this->low_lengths_magic_divisor_multiplier_[I1], |
| 1684 | + this->low_lengths_magic_divisor_shift_[I1]); |
| 1685 | + KStep -= XStep * K_; |
| 1686 | + // Embed |
| 1687 | + YStep *= HRatio_; |
| 1688 | + XStep *= WRatio_; |
| 1689 | + |
| 1690 | + return make_tuple(0, YStep, XStep, KStep); |
| 1691 | + } |
| 1692 | + |
| 1693 | + template <typename LowIdx, typename UpIdx> |
| 1694 | + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, |
| 1695 | + const UpIdx& idx_up) const |
| 1696 | + { |
| 1697 | + idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up); |
| 1698 | + } |
| 1699 | + |
| 1700 | + template <typename LowIdxDiff, |
| 1701 | + typename UpIdxDiff, |
| 1702 | + typename LowIdx, |
| 1703 | + typename UpIdx, |
| 1704 | + index_t Hack> |
| 1705 | + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, |
| 1706 | + const UpIdxDiff& /* idx_diff_up */, |
| 1707 | + LowIdx& idx_low, |
| 1708 | + const UpIdx& idx_up, |
| 1709 | + Number<Hack>) const |
| 1710 | + { |
| 1711 | + LowIdx low_old = idx_low; |
| 1712 | + idx_low = CalculateLowerIndexN(idx_up) + CalculateLowerIndexK(idx_up); |
| 1713 | + idx_diff_low = idx_low - low_old; |
| 1714 | + } |
| 1715 | + |
| 1716 | + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } |
| 1717 | + |
| 1718 | + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() |
| 1719 | + { |
| 1720 | + return true; |
| 1721 | + } |
| 1722 | + |
| 1723 | + template <typename UpIdx> |
| 1724 | + __host__ __device__ constexpr bool |
| 1725 | + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const |
| 1726 | + { |
| 1727 | + // Padding |
| 1728 | + index_t K_idx = idx_up[Number<0>{}] * up_lengths_[Number<2>{}] + idx_up[Number<2>{}]; |
| 1729 | + index_t& M_idx = idx_up[Number<1>{}]; |
| 1730 | + |
| 1731 | + bool pad_valid = M_idx < up_lengths_[Number<1>{}] - MPad_ && |
| 1732 | + K_idx < up_lengths_[Number<0>{}] * up_lengths_[Number<2>{}] - KPad_; |
| 1733 | + return pad_valid; |
| 1734 | + } |
| 1735 | + |
| 1736 | + __host__ __device__ static constexpr bool IsKnownAtCompileTime() { return false; } |
| 1737 | + |
| 1738 | + __host__ __device__ void Print() const |
| 1739 | + { |
| 1740 | + printf("{"); |
| 1741 | + printf("ConvBwdDataImplicitGemmOutTransform, "); |
| 1742 | + printf("up_lengths_"); |
| 1743 | + print_multi_index(up_lengths_); |
| 1744 | + printf("}"); |
| 1745 | + } |
| 1746 | +}; |
| 1747 | + |
1556 | 1748 | template <typename LowerIndex>
|
1557 | 1749 | struct Freeze
|
1558 | 1750 | {
|
|
0 commit comments