|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#pragma once |
| 16 | + |
| 17 | +#include <vector> |
| 18 | + |
| 19 | +#include "paddle/fluid/framework/op_registry.h" |
| 20 | +#include "paddle/phi/core/tensor_utils.h" |
| 21 | +#include "paddle/phi/kernels/funcs/aligned_vector.h" |
| 22 | +#include "paddle/phi/kernels/funcs/math_function.h" |
| 23 | + |
| 24 | +namespace paddle { |
| 25 | +namespace operators { |
| 26 | + |
| 27 | +enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 }; |
| 28 | + |
| 29 | +template <typename DeviceContext, typename T> |
| 30 | +inline void TransCompute(const int dim, |
| 31 | + const DeviceContext& dev_ctx, |
| 32 | + const phi::DenseTensor& in, |
| 33 | + phi::DenseTensor* out, |
| 34 | + const std::vector<int>& axis) { |
| 35 | + switch (dim) { |
| 36 | + case 0: |
| 37 | + phi::Copy<DeviceContext>(dev_ctx, in, dev_ctx.GetPlace(), false, out); |
| 38 | + break; |
| 39 | + case 1: |
| 40 | + phi::funcs::Transpose<DeviceContext, T, 1> trans1; |
| 41 | + trans1(dev_ctx, in, out, axis); |
| 42 | + break; |
| 43 | + case 2: |
| 44 | + phi::funcs::Transpose<DeviceContext, T, 2> trans2; |
| 45 | + trans2(dev_ctx, in, out, axis); |
| 46 | + break; |
| 47 | + case 3: |
| 48 | + phi::funcs::Transpose<DeviceContext, T, 3> trans3; |
| 49 | + trans3(dev_ctx, in, out, axis); |
| 50 | + break; |
| 51 | + case 4: |
| 52 | + phi::funcs::Transpose<DeviceContext, T, 4> trans4; |
| 53 | + trans4(dev_ctx, in, out, axis); |
| 54 | + break; |
| 55 | + case 5: |
| 56 | + phi::funcs::Transpose<DeviceContext, T, 5> trans5; |
| 57 | + trans5(dev_ctx, in, out, axis); |
| 58 | + break; |
| 59 | + case 6: |
| 60 | + phi::funcs::Transpose<DeviceContext, T, 6> trans6; |
| 61 | + trans6(dev_ctx, in, out, axis); |
| 62 | + break; |
| 63 | + default: |
| 64 | + // for dim >= 7 situation |
| 65 | + phi::funcs::TransposeNormal<DeviceContext, T> trans_normal; |
| 66 | + trans_normal(dev_ctx, in, out, axis); |
| 67 | + } |
| 68 | +} |
| 69 | + |
| 70 | +enum PermuteType { |
| 71 | + kCopy = 1, |
| 72 | + kTranspose = 2, |
| 73 | + kVecPermute = 3, |
| 74 | + kGeneralPermute = 4 |
| 75 | +}; |
| 76 | + |
| 77 | +constexpr int kBlockRows = 16; |
| 78 | +constexpr int kTileSize = 32; |
| 79 | + |
| 80 | +// Simplify the input dims and permute dims if possible. |
| 81 | +template <typename T> |
| 82 | +class TranposeTypeClassifier { |
| 83 | + public: |
| 84 | + TranposeTypeClassifier(const int sm_count, |
| 85 | + const size_t rank, |
| 86 | + const int64_t numel, |
| 87 | + const std::vector<int32_t>& perm, |
| 88 | + const std::vector<int64_t>& dims, |
| 89 | + const T* src, |
| 90 | + T* dst) |
| 91 | + : perm_(rank), src_dims(rank) { |
| 92 | + SimplifyPermAndDims(rank, dims, perm); |
| 93 | + if (rank_ > 1) { |
| 94 | + vec_size_ = GetPermVecSize(sm_count, src, dst); |
| 95 | + } |
| 96 | + perm_.resize(rank_); |
| 97 | + src_dims.resize(rank_); |
| 98 | + dst_dims.resize(rank_); |
| 99 | + |
| 100 | + for (auto i = 0; i < rank_; ++i) { |
| 101 | + dst_dims[i] = src_dims[perm_[i]]; |
| 102 | + } |
| 103 | + } |
| 104 | + |
| 105 | + int GetRank() const { return rank_; } |
| 106 | + int GetVecSize() const { return vec_size_; } |
| 107 | + PermuteType GetPermType() const { return type_; } |
| 108 | + |
| 109 | + std::vector<int> GetPerm() const { return perm_; } |
| 110 | + std::vector<int64_t> GetSrcDims() const { return src_dims; } |
| 111 | + std::vector<int64_t> GetDstDims() const { return dst_dims; } |
| 112 | + |
| 113 | + private: |
| 114 | + int rank_{1}; |
| 115 | + int vec_size_{1}; |
| 116 | + std::vector<int> perm_; |
| 117 | + std::vector<int64_t> src_dims; |
| 118 | + std::vector<int64_t> dst_dims; |
| 119 | + PermuteType type_{kCopy}; |
| 120 | + |
| 121 | + void SimplifyPermAndDims(const size_t rank, |
| 122 | + const std::vector<int64_t>& in_dims, |
| 123 | + const std::vector<int32_t>& perm) { |
| 124 | + int64_t combined_dims[phi::DDim::kMaxRank]; |
| 125 | + int valid_map[phi::DDim::kMaxRank]; |
| 126 | + |
| 127 | + // Merge consecutive dims to the fist one dim and |
| 128 | + // leave original dim to be 1. Example below : |
| 129 | + // perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5] |
| 130 | + // new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1] |
| 131 | + int start_perm_idx = 0; |
| 132 | + while (start_perm_idx < rank) { |
| 133 | + const int start_dim_idx = perm[start_perm_idx]; |
| 134 | + combined_dims[start_dim_idx] = in_dims[start_dim_idx]; |
| 135 | + int end_perm_idx = start_perm_idx + 1; |
| 136 | + |
| 137 | + while (end_perm_idx < rank && |
| 138 | + perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) { |
| 139 | + const int end_dim_idx = perm[end_perm_idx]; |
| 140 | + combined_dims[start_dim_idx] *= in_dims[end_dim_idx]; |
| 141 | + combined_dims[end_dim_idx] = 1; |
| 142 | + end_perm_idx += 1; |
| 143 | + } |
| 144 | + start_perm_idx = end_perm_idx; |
| 145 | + } |
| 146 | + |
| 147 | + // Reorder combined dims and marked useless dim as -1. |
| 148 | + // for example, if combined dims is [32, 1, 10, 1], |
| 149 | + // valid_map is [0, -1, 1, -1] and generate simplified |
| 150 | + // dims as [32, 10] |
| 151 | + int valid_dim_idx = 0; |
| 152 | + bool sequential_flag = false; |
| 153 | + for (auto i = 0; i < rank; ++i) { |
| 154 | + const int src_dim = combined_dims[i]; |
| 155 | + if (src_dim == 1) { |
| 156 | + valid_map[i] = -1; |
| 157 | + } else { |
| 158 | + sequential_flag = true; |
| 159 | + valid_map[i] = valid_dim_idx; |
| 160 | + src_dims[valid_dim_idx] = src_dim; |
| 161 | + valid_dim_idx += 1; |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + if (valid_dim_idx == 0) { |
| 166 | + src_dims[0] = 1; |
| 167 | + perm_[0] = 0; |
| 168 | + return; |
| 169 | + } else if (valid_dim_idx == 1) { |
| 170 | + type_ = PermuteType::kCopy; |
| 171 | + } |
| 172 | + |
| 173 | + // Acquire simplified perm with help of combined dims |
| 174 | + // and original perm, finally simplified perm is [1, 0] |
| 175 | + int perm_idx = 0; |
| 176 | + for (auto i = 0; i < rank; ++i) { |
| 177 | + const int mapped = valid_map[perm[i]]; |
| 178 | + if (mapped >= 0) { |
| 179 | + perm_[perm_idx] = mapped; |
| 180 | + perm_idx += 1; |
| 181 | + } |
| 182 | + } |
| 183 | + rank_ = valid_dim_idx; |
| 184 | + } |
| 185 | + |
| 186 | + int GetPermVecSize(const int sm_count, const T* src, T* dst) { |
| 187 | + // For gerneal_permute kernel, there is good chance for |
| 188 | + // vectorized write. |
| 189 | + type_ = PermuteType::kGeneralPermute; |
| 190 | + int vec_size = phi::GetVectorizedSize<T>(dst); |
| 191 | + |
| 192 | + // While the last dim is fixed, there is good chance for |
| 193 | + // both vectorized read and write. |
| 194 | + if (perm_[rank_ - 1] == rank_ - 1) { |
| 195 | + int tmp_size = std::min(vec_size, phi::GetVectorizedSize<T>(src)); |
| 196 | + tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]); |
| 197 | + if (tmp_size > 1) { |
| 198 | + type_ = kVecPermute; |
| 199 | + vec_size = tmp_size; |
| 200 | + } |
| 201 | + } |
| 202 | + |
| 203 | + // Once only transpose at the last 2 dims, there is good |
| 204 | + // chance for vectorized read. |
| 205 | + if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) || |
| 206 | + (rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) { |
| 207 | + type_ = PermuteType::kTranspose; |
| 208 | + int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src)); |
| 209 | + // With bytes limitation of shared_memory, the VecSize shall be |
| 210 | + // restricted for the type whose byte-size is less than 8 (double). |
| 211 | + vec_size = |
| 212 | + sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]); |
| 213 | + } |
| 214 | + return vec_size; |
| 215 | + } |
| 216 | + |
| 217 | + // To find if highest common divisor and make it as vec_size. |
| 218 | + int GetDimVesSize(const int vec_size, const size_t target_dim) { |
| 219 | + int dim_vec_size = 1; |
| 220 | + for (auto size = vec_size; size > 0; size /= 2) { |
| 221 | + if (target_dim % size == 0) { |
| 222 | + dim_vec_size = size; |
| 223 | + break; |
| 224 | + } |
| 225 | + } |
| 226 | + return dim_vec_size; |
| 227 | + } |
| 228 | +}; |
| 229 | + |
| 230 | +} // namespace operators |
| 231 | +} // namespace paddle |
0 commit comments