Skip to content

Commit b2e095c

Browse files
authored
[Cherry-pick] Add truncated_normal/unique/swish/unbind yaml and polish Getting tensor place impl (#41539)
* [Phi] Polish truncated normal kernel and add yaml (#41280) * polish truncated normal kernel * add yaml * add truncated normal kernel and add yaml * polish unittests and yaml * import dygraph mehtod * add unique yaml and final state api (#41460) * fix get tensor backend set bug (#41478) * [Phi] Add unbind yaml and final state api (#41277) * add unbind yaml * fix unittest * [Phi] Add swish yaml and final state api (#41479) * add swish yaml and final state api * skip mkldnn test * fix grad mkldnn test * add cherry-pick lost code
1 parent ae34db3 commit b2e095c

20 files changed

+385
-167
lines changed

paddle/phi/api/lib/api_custom_impl.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,54 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
475475
return api_output;
476476
}
477477

478+
std::vector<Tensor> unbind_impl(const Tensor& input, int axis) {
479+
auto kernel_key_set = ParseKernelKeyByInputArgs(input);
480+
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
481+
482+
Backend kernel_backend = kernel_key.backend();
483+
DataLayout kernel_layout = kernel_key.layout();
484+
DataType kernel_data_type = kernel_key.dtype();
485+
486+
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
487+
"unbind", {kernel_backend, kernel_layout, kernel_data_type});
488+
VLOG(6) << "unbind API kernel key: [" << kernel_backend << ", "
489+
<< kernel_layout << ", " << kernel_data_type << "]";
490+
VLOG(6) << "unbind API kernel: " << kernel;
491+
492+
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
493+
494+
auto dense_input = PrepareData(input, kernel.InputAt(0), {});
495+
496+
// Calculate the number of out tensors
497+
auto input_shape = input.dims();
498+
if (axis < 0) {
499+
axis = input_shape.size() + axis;
500+
}
501+
auto out_num = input_shape[axis];
502+
503+
std::vector<Tensor> out;
504+
auto dense_outs = SetKernelOutput(out_num, kernel_backend, &out);
505+
std::vector<phi::MetaTensor> meta_outs;
506+
meta_outs.reserve(out_num);
507+
std::vector<phi::MetaTensor*> meta_out_ptrs;
508+
meta_out_ptrs.reserve(out_num);
509+
for (int64_t i = 0; i < out_num; ++i) {
510+
meta_outs.push_back(dense_outs[i]);
511+
meta_out_ptrs.push_back(&meta_outs.back());
512+
}
513+
514+
phi::UnbindInferMeta(MakeMetaTensor(*dense_input), axis, meta_out_ptrs);
515+
516+
using kernel_signature = void (*)(const phi::DeviceContext&,
517+
const phi::DenseTensor&,
518+
int,
519+
std::vector<phi::DenseTensor*>&);
520+
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
521+
(*kernel_fn)(*dev_ctx, *dense_input, axis, dense_outs);
522+
523+
return out;
524+
}
525+
478526
////////////////// Backward(grad) api impls //////////////////////
479527

480528
// TODO(chenweihang): the original sum grad op can support higher-level

paddle/phi/api/lib/api_custom_impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <vector>
18+
1719
#include "paddle/phi/api/include/tensor.h"
1820
#include "paddle/phi/common/int_array.h"
1921
#include "paddle/phi/common/place.h"
@@ -73,6 +75,8 @@ std::tuple<Tensor, Tensor, Tensor> momentum_impl(
7375
bool multi_precision,
7476
float rescale_grad);
7577

78+
std::vector<Tensor> unbind_impl(const Tensor& input, int axis);
79+
7680
////////////////// Backward(grad) api impls //////////////////////
7781

7882
std::vector<Tensor> add_n_grad_impl(const std::vector<Tensor>& x,

paddle/phi/api/lib/kernel_dispatch.cc

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,46 @@ limitations under the License. */
1414

1515
#include "paddle/phi/api/lib/kernel_dispatch.h"
1616

17-
#include "paddle/phi/api/include/context_pool.h"
18-
#include "paddle/phi/core/compat/convert_utils.h"
1917
#ifdef _MSC_VER
2018
#include <intrin.h>
2119
#endif
2220

21+
#include "paddle/phi/api/include/context_pool.h"
22+
#include "paddle/phi/core/compat/convert_utils.h"
23+
#include "paddle/phi/core/string_tensor_utils.h"
24+
#include "paddle/phi/core/tensor_utils.h"
25+
2326
namespace paddle {
2427
namespace experimental {
2528
namespace detail {
2629

30+
// We need judge whether the allocation is nullptr,
31+
// whether the allocation is initialized, wo we need GetHolder method
32+
bool HasAllocation(const phi::TensorBase& t) {
33+
if (phi::DenseTensor::classof(&t)) {
34+
return phi::DenseTensorUtils::GetHolder(
35+
static_cast<const phi::DenseTensor&>(t)) != nullptr;
36+
} else if (phi::SelectedRows::classof(&t)) {
37+
return phi::DenseTensorUtils::GetHolder(
38+
static_cast<const phi::SelectedRows&>(t).value()) != nullptr;
39+
} else if (phi::SparseCsrTensor::classof(&t)) {
40+
return phi::DenseTensorUtils::GetHolder(
41+
static_cast<const phi::SparseCsrTensor&>(t)
42+
.non_zero_elements()) != nullptr;
43+
} else if (phi::SparseCooTensor::classof(&t)) {
44+
return phi::DenseTensorUtils::GetHolder(
45+
static_cast<const phi::SparseCooTensor&>(t)
46+
.non_zero_elements()) != nullptr;
47+
} else if (phi::StringTensor::classof(&t)) {
48+
return phi::StringTensorUtils::GetHolder(
49+
static_cast<const phi::StringTensor&>(t)) != nullptr;
50+
} else {
51+
return false;
52+
}
53+
}
54+
2755
BackendSet GetTensorBackendSet(const phi::TensorBase& t) {
28-
if (t.initialized()) {
56+
if (HasAllocation(t)) {
2957
BackendSet backend_set(phi::TransToPhiBackend(t.place()));
3058
switch (t.layout()) {
3159
case DataLayout::MKLDNN:

paddle/phi/core/string_tensor_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ class StringTensorUtils {
2323
static StringTensorMeta* GetMutableMeta(StringTensor* tensor) {
2424
return &(tensor->meta_);
2525
}
26+
27+
static const std::shared_ptr<phi::Allocation>& GetHolder(
28+
const StringTensor& tensor) {
29+
return tensor.holder_;
30+
}
2631
};
2732

2833
} // namespace phi

paddle/phi/core/tensor_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ class DenseTensorUtils {
2525
return &(tensor->meta_);
2626
}
2727

28+
static const std::shared_ptr<phi::Allocation>& GetHolder(
29+
const DenseTensor& tensor) {
30+
return tensor.holder_;
31+
}
32+
2833
static DenseTensor Slice(const DenseTensor& tensor,
2934
int64_t begin_idx,
3035
int64_t end_idx) {

paddle/phi/infermeta/unary.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2429,7 +2429,7 @@ void TransposeGradInferMeta(const MetaTensor& x,
24292429

24302430
void UnbindInferMeta(const MetaTensor& x,
24312431
int axis,
2432-
std::vector<MetaTensor>* outs) {
2432+
std::vector<MetaTensor*> outs) {
24332433
auto in_dims = x.dims();
24342434
std::vector<int> out_dim;
24352435
axis = axis < 0 ? in_dims.size() + axis : axis;
@@ -2438,11 +2438,11 @@ void UnbindInferMeta(const MetaTensor& x,
24382438
}
24392439
auto out_dims = phi::make_ddim(out_dim);
24402440

2441-
for (size_t i = 0; i < outs->size(); ++i) {
2442-
(*outs)[i].set_dtype(x.dtype());
2443-
(*outs)[i].set_dims(out_dims);
2444-
(*outs)[i].set_layout(x.layout());
2445-
(*outs)[i].share_lod(x);
2441+
for (size_t i = 0; i < outs.size(); ++i) {
2442+
outs[i]->set_dtype(x.dtype());
2443+
outs[i]->set_dims(out_dims);
2444+
outs[i]->set_layout(x.layout());
2445+
outs[i]->share_lod(x);
24462446
}
24472447
}
24482448

paddle/phi/infermeta/unary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ void TrilTriuInferMeta(const MetaTensor& x,
365365

366366
void UnbindInferMeta(const MetaTensor& x,
367367
int axis,
368-
std::vector<MetaTensor>* outs);
368+
std::vector<MetaTensor*> outs);
369369

370370
void UnchangedInferMeta(const MetaTensor& x, MetaTensor* out);
371371

paddle/phi/kernels/cpu/truncated_gaussian_random_kernel.cc

Lines changed: 140 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,141 @@
2121
#include "paddle/phi/backends/cpu/cpu_context.h"
2222
#include "paddle/phi/core/kernel_registry.h"
2323

24-
#include "paddle/fluid/framework/generator.h"
25-
2624
namespace phi {
2725

26+
// reference: https://gist.github.com/lakshayg/d80172fe5ae3c5d2c2aedb53c250320e
27+
template <typename T>
28+
T Erfinv(T x) {
29+
if (x < -1 || x > 1) {
30+
return std::numeric_limits<T>::quiet_NaN();
31+
} else if (x == 1.0) {
32+
return std::numeric_limits<T>::infinity();
33+
} else if (x == -1.0) {
34+
return -std::numeric_limits<T>::infinity();
35+
}
36+
37+
const T LN2 = 6.931471805599453094172321214581e-1;
38+
39+
const T A0 = 1.1975323115670912564578e0;
40+
const T A1 = 4.7072688112383978012285e1;
41+
const T A2 = 6.9706266534389598238465e2;
42+
const T A3 = 4.8548868893843886794648e3;
43+
const T A4 = 1.6235862515167575384252e4;
44+
const T A5 = 2.3782041382114385731252e4;
45+
const T A6 = 1.1819493347062294404278e4;
46+
const T A7 = 8.8709406962545514830200e2;
47+
48+
const T B0 = 1.0000000000000000000e0;
49+
const T B1 = 4.2313330701600911252e1;
50+
const T B2 = 6.8718700749205790830e2;
51+
const T B3 = 5.3941960214247511077e3;
52+
const T B4 = 2.1213794301586595867e4;
53+
const T B5 = 3.9307895800092710610e4;
54+
const T B6 = 2.8729085735721942674e4;
55+
const T B7 = 5.2264952788528545610e3;
56+
57+
const T C0 = 1.42343711074968357734e0;
58+
const T C1 = 4.63033784615654529590e0;
59+
const T C2 = 5.76949722146069140550e0;
60+
const T C3 = 3.64784832476320460504e0;
61+
const T C4 = 1.27045825245236838258e0;
62+
const T C5 = 2.41780725177450611770e-1;
63+
const T C6 = 2.27238449892691845833e-2;
64+
const T C7 = 7.74545014278341407640e-4;
65+
66+
const T D0 = 1.4142135623730950488016887e0;
67+
const T D1 = 2.9036514445419946173133295e0;
68+
const T D2 = 2.3707661626024532365971225e0;
69+
const T D3 = 9.7547832001787427186894837e-1;
70+
const T D4 = 2.0945065210512749128288442e-1;
71+
const T D5 = 2.1494160384252876777097297e-2;
72+
const T D6 = 7.7441459065157709165577218e-4;
73+
const T D7 = 1.4859850019840355905497876e-9;
74+
75+
const T E0 = 6.65790464350110377720e0;
76+
const T E1 = 5.46378491116411436990e0;
77+
const T E2 = 1.78482653991729133580e0;
78+
const T E3 = 2.96560571828504891230e-1;
79+
const T E4 = 2.65321895265761230930e-2;
80+
const T E5 = 1.24266094738807843860e-3;
81+
const T E6 = 2.71155556874348757815e-5;
82+
const T E7 = 2.01033439929228813265e-7;
83+
84+
const T F0 = 1.414213562373095048801689e0;
85+
const T F1 = 8.482908416595164588112026e-1;
86+
const T F2 = 1.936480946950659106176712e-1;
87+
const T F3 = 2.103693768272068968719679e-2;
88+
const T F4 = 1.112800997078859844711555e-3;
89+
const T F5 = 2.611088405080593625138020e-5;
90+
const T F6 = 2.010321207683943062279931e-7;
91+
const T F7 = 2.891024605872965461538222e-15;
92+
93+
T abs_x = abs(x);
94+
95+
if (abs_x <= 0.85) {
96+
T r = 0.180625 - 0.25 * x * x;
97+
T num =
98+
(((((((A7 * r + A6) * r + A5) * r + A4) * r + A3) * r + A2) * r + A1) *
99+
r +
100+
A0);
101+
T den =
102+
(((((((B7 * r + B6) * r + B5) * r + B4) * r + B3) * r + B2) * r + B1) *
103+
r +
104+
B0);
105+
return x * num / den;
106+
}
107+
108+
T r = sqrt(LN2 - log(1.0 - abs_x));
109+
110+
T num, den;
111+
if (r <= 5.0) {
112+
r = r - 1.6;
113+
num =
114+
(((((((C7 * r + C6) * r + C5) * r + C4) * r + C3) * r + C2) * r + C1) *
115+
r +
116+
C0);
117+
den =
118+
(((((((D7 * r + D6) * r + D5) * r + D4) * r + D3) * r + D2) * r + D1) *
119+
r +
120+
D0);
121+
} else {
122+
r = r - 5.0;
123+
num =
124+
(((((((E7 * r + E6) * r + E5) * r + E4) * r + E3) * r + E2) * r + E1) *
125+
r +
126+
E0);
127+
den =
128+
(((((((F7 * r + F6) * r + F5) * r + F4) * r + F3) * r + F2) * r + F1) *
129+
r +
130+
F0);
131+
}
132+
133+
if (x < 0) {
134+
return -num / den;
135+
} else {
136+
return num / den;
137+
}
138+
}
139+
140+
template <typename T>
141+
struct TruncatedNormal {
142+
T mean, std;
143+
T a_normal_cdf;
144+
T b_normal_cdf;
145+
TruncatedNormal(T mean, T std) : mean(mean), std(std) {
146+
auto normal_cdf = [](T x) {
147+
return (1.0 + std::erf(x / std::sqrt(2.0))) / 2.0;
148+
};
149+
a_normal_cdf = normal_cdf(-2.0);
150+
b_normal_cdf = normal_cdf(2.0);
151+
}
152+
153+
T operator()(T value) const {
154+
auto p = a_normal_cdf + (b_normal_cdf - a_normal_cdf) * value;
155+
return std::sqrt(2.0) * Erfinv(2 * p - 1) * std + mean;
156+
}
157+
};
158+
28159
template <typename T, typename Context>
29160
void TruncatedGaussianRandomKernel(const Context& dev_ctx,
30161
const std::vector<int>& shape,
@@ -42,7 +173,13 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx,
42173
TruncatedNormal<T> truncated_normal(mean, std);
43174
int64_t size = tensor->numel();
44175

45-
auto engine = paddle::framework::GetCPURandomEngine(seed);
176+
std::shared_ptr<std::mt19937_64> engine;
177+
if (seed) {
178+
engine = std::make_shared<std::mt19937_64>();
179+
engine->seed(seed);
180+
} else {
181+
engine = dev_ctx.GetGenerator()->GetCPUEngine();
182+
}
46183
for (int64_t i = 0; i < size; ++i) {
47184
data[i] = truncated_normal(dist(*engine));
48185
}

paddle/phi/kernels/gpu/truncated_gaussian_random_kernel.cu

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
#include "paddle/phi/core/dense_tensor.h"
2525
#include "paddle/phi/core/kernel_registry.h"
2626

27-
#include "paddle/fluid/framework/generator.h"
28-
2927
namespace phi {
3028

3129
template <typename T>
@@ -106,8 +104,7 @@ void TruncatedGaussianRandomKernel(const Context& dev_ctx,
106104
thrust::counting_iterator<int64_t> index_sequence_begin(0);
107105
int64_t size = tensor->numel();
108106

109-
int device_id = dev_ctx.GetPlace().GetDeviceId();
110-
auto gen_cuda = paddle::framework::GetDefaultCUDAGenerator(device_id);
107+
auto gen_cuda = dev_ctx.GetGenerator();
111108

112109
if (gen_cuda->GetIsInitPy() && seed_flag) {
113110
auto seed_offset = gen_cuda->IncrementOffset(1);

0 commit comments

Comments
 (0)