Skip to content

Commit 7c40def

Browse files
reset tranpose_op.cc
1 parent 29216da commit 7c40def

File tree

2 files changed

+232
-2
lines changed

2 files changed

+232
-2
lines changed

paddle/fluid/operators/transpose_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/fluid/framework/op_registry.h"
16-
#include "paddle/phi/kernels/funcs/transpose_functor.h"
15+
#include "paddle/fluid/operators/transpose_op.h"
1716

1817
#include <memory>
1918
#include <string>
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
19+
#include "paddle/fluid/framework/op_registry.h"
20+
#include "paddle/phi/core/tensor_utils.h"
21+
#include "paddle/phi/kernels/funcs/aligned_vector.h"
22+
#include "paddle/phi/kernels/funcs/math_function.h"
23+
24+
namespace paddle {
25+
namespace operators {
26+
27+
enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 };
28+
29+
template <typename DeviceContext, typename T>
30+
inline void TransCompute(const int dim,
31+
const DeviceContext& dev_ctx,
32+
const phi::DenseTensor& in,
33+
phi::DenseTensor* out,
34+
const std::vector<int>& axis) {
35+
switch (dim) {
36+
case 0:
37+
phi::Copy<DeviceContext>(dev_ctx, in, dev_ctx.GetPlace(), false, out);
38+
break;
39+
case 1:
40+
phi::funcs::Transpose<DeviceContext, T, 1> trans1;
41+
trans1(dev_ctx, in, out, axis);
42+
break;
43+
case 2:
44+
phi::funcs::Transpose<DeviceContext, T, 2> trans2;
45+
trans2(dev_ctx, in, out, axis);
46+
break;
47+
case 3:
48+
phi::funcs::Transpose<DeviceContext, T, 3> trans3;
49+
trans3(dev_ctx, in, out, axis);
50+
break;
51+
case 4:
52+
phi::funcs::Transpose<DeviceContext, T, 4> trans4;
53+
trans4(dev_ctx, in, out, axis);
54+
break;
55+
case 5:
56+
phi::funcs::Transpose<DeviceContext, T, 5> trans5;
57+
trans5(dev_ctx, in, out, axis);
58+
break;
59+
case 6:
60+
phi::funcs::Transpose<DeviceContext, T, 6> trans6;
61+
trans6(dev_ctx, in, out, axis);
62+
break;
63+
default:
64+
// for dim >= 7 situation
65+
phi::funcs::TransposeNormal<DeviceContext, T> trans_normal;
66+
trans_normal(dev_ctx, in, out, axis);
67+
}
68+
}
69+
70+
enum PermuteType {
71+
kCopy = 1,
72+
kTranspose = 2,
73+
kVecPermute = 3,
74+
kGeneralPermute = 4
75+
};
76+
77+
constexpr int kBlockRows = 16;
78+
constexpr int kTileSize = 32;
79+
80+
// Simplify the input dims and permute dims if possible.
81+
template <typename T>
82+
class TranposeTypeClassifier {
83+
public:
84+
TranposeTypeClassifier(const int sm_count,
85+
const size_t rank,
86+
const int64_t numel,
87+
const std::vector<int32_t>& perm,
88+
const std::vector<int64_t>& dims,
89+
const T* src,
90+
T* dst)
91+
: perm_(rank), src_dims(rank) {
92+
SimplifyPermAndDims(rank, dims, perm);
93+
if (rank_ > 1) {
94+
vec_size_ = GetPermVecSize(sm_count, src, dst);
95+
}
96+
perm_.resize(rank_);
97+
src_dims.resize(rank_);
98+
dst_dims.resize(rank_);
99+
100+
for (auto i = 0; i < rank_; ++i) {
101+
dst_dims[i] = src_dims[perm_[i]];
102+
}
103+
}
104+
105+
int GetRank() const { return rank_; }
106+
int GetVecSize() const { return vec_size_; }
107+
PermuteType GetPermType() const { return type_; }
108+
109+
std::vector<int> GetPerm() const { return perm_; }
110+
std::vector<int64_t> GetSrcDims() const { return src_dims; }
111+
std::vector<int64_t> GetDstDims() const { return dst_dims; }
112+
113+
private:
114+
int rank_{1};
115+
int vec_size_{1};
116+
std::vector<int> perm_;
117+
std::vector<int64_t> src_dims;
118+
std::vector<int64_t> dst_dims;
119+
PermuteType type_{kCopy};
120+
121+
void SimplifyPermAndDims(const size_t rank,
122+
const std::vector<int64_t>& in_dims,
123+
const std::vector<int32_t>& perm) {
124+
int64_t combined_dims[phi::DDim::kMaxRank];
125+
int valid_map[phi::DDim::kMaxRank];
126+
127+
// Merge consecutive dims to the fist one dim and
128+
// leave original dim to be 1. Example below :
129+
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
130+
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
131+
int start_perm_idx = 0;
132+
while (start_perm_idx < rank) {
133+
const int start_dim_idx = perm[start_perm_idx];
134+
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
135+
int end_perm_idx = start_perm_idx + 1;
136+
137+
while (end_perm_idx < rank &&
138+
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
139+
const int end_dim_idx = perm[end_perm_idx];
140+
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
141+
combined_dims[end_dim_idx] = 1;
142+
end_perm_idx += 1;
143+
}
144+
start_perm_idx = end_perm_idx;
145+
}
146+
147+
// Reorder combined dims and marked useless dim as -1.
148+
// for example, if combined dims is [32, 1, 10, 1],
149+
// valid_map is [0, -1, 1, -1] and generate simplified
150+
// dims as [32, 10]
151+
int valid_dim_idx = 0;
152+
bool sequential_flag = false;
153+
for (auto i = 0; i < rank; ++i) {
154+
const int src_dim = combined_dims[i];
155+
if (src_dim == 1) {
156+
valid_map[i] = -1;
157+
} else {
158+
sequential_flag = true;
159+
valid_map[i] = valid_dim_idx;
160+
src_dims[valid_dim_idx] = src_dim;
161+
valid_dim_idx += 1;
162+
}
163+
}
164+
165+
if (valid_dim_idx == 0) {
166+
src_dims[0] = 1;
167+
perm_[0] = 0;
168+
return;
169+
} else if (valid_dim_idx == 1) {
170+
type_ = PermuteType::kCopy;
171+
}
172+
173+
// Acquire simplified perm with help of combined dims
174+
// and original perm, finally simplified perm is [1, 0]
175+
int perm_idx = 0;
176+
for (auto i = 0; i < rank; ++i) {
177+
const int mapped = valid_map[perm[i]];
178+
if (mapped >= 0) {
179+
perm_[perm_idx] = mapped;
180+
perm_idx += 1;
181+
}
182+
}
183+
rank_ = valid_dim_idx;
184+
}
185+
186+
int GetPermVecSize(const int sm_count, const T* src, T* dst) {
187+
// For gerneal_permute kernel, there is good chance for
188+
// vectorized write.
189+
type_ = PermuteType::kGeneralPermute;
190+
int vec_size = phi::GetVectorizedSize<T>(dst);
191+
192+
// While the last dim is fixed, there is good chance for
193+
// both vectorized read and write.
194+
if (perm_[rank_ - 1] == rank_ - 1) {
195+
int tmp_size = std::min(vec_size, phi::GetVectorizedSize<T>(src));
196+
tmp_size = GetDimVesSize(tmp_size, src_dims[rank_ - 1]);
197+
if (tmp_size > 1) {
198+
type_ = kVecPermute;
199+
vec_size = tmp_size;
200+
}
201+
}
202+
203+
// Once only transpose at the last 2 dims, there is good
204+
// chance for vectorized read.
205+
if ((rank_ == 2 && perm_[1] == 0 && perm_[0] == 1) ||
206+
(rank_ == 3 && perm_[2] == 1 && perm_[1] == 2)) {
207+
type_ = PermuteType::kTranspose;
208+
int tmp_vec = std::min(vec_size, phi::GetVectorizedSize<T>(src));
209+
// With bytes limitation of shared_memory, the VecSize shall be
210+
// restricted for the type whose byte-size is less than 8 (double).
211+
vec_size =
212+
sizeof(T) > 8 ? 1 : GetDimVesSize(tmp_vec, src_dims[rank_ - 1]);
213+
}
214+
return vec_size;
215+
}
216+
217+
// To find if highest common divisor and make it as vec_size.
218+
int GetDimVesSize(const int vec_size, const size_t target_dim) {
219+
int dim_vec_size = 1;
220+
for (auto size = vec_size; size > 0; size /= 2) {
221+
if (target_dim % size == 0) {
222+
dim_vec_size = size;
223+
break;
224+
}
225+
}
226+
return dim_vec_size;
227+
}
228+
};
229+
230+
} // namespace operators
231+
} // namespace paddle

0 commit comments

Comments
 (0)