Skip to content

Commit 31b1f70

Browse files
authored
refactor the forward implementation of reshape npu op (#38748)
* refactor the forward implementation of reshape npu op * update reshape npu op * update reshape npu op
1 parent 7d4ce5b commit 31b1f70

File tree

1 file changed

+85
-14
lines changed

1 file changed

+85
-14
lines changed

paddle/fluid/operators/reshape_op_npu.cc

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616
#include <string>
1717

1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/utils.h"
1920
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
2021

2122
namespace paddle {
@@ -25,23 +26,93 @@ template <typename DeviceContext, typename T>
2526
class Reshape2NPUKernel : public framework::OpKernel<T> {
2627
public:
2728
void Compute(const framework::ExecutionContext& ctx) const override {
29+
auto stream =
30+
ctx.template device_context<paddle::platform::NPUDeviceContext>()
31+
.stream();
32+
auto place = ctx.GetPlace();
2833
auto* x = ctx.Input<framework::Tensor>("X");
2934
auto* out = ctx.Output<framework::Tensor>("Out");
30-
auto list_new_shape_tensor =
31-
ctx.MultiInput<framework::Tensor>("ShapeTensor");
32-
if (list_new_shape_tensor.size() > 0) {
33-
PADDLE_THROW(platform::errors::Unimplemented(
34-
"Input(ShapeTensor) is not supported on NPU."));
35+
36+
std::vector<int32_t> target_shape_vector;
37+
auto shape_tensor_vector = ctx.MultiInput<framework::Tensor>("ShapeTensor");
38+
if (shape_tensor_vector.size() > 0) {
39+
for (auto* shape_tensor : shape_tensor_vector) {
40+
PADDLE_ENFORCE_EQ(
41+
shape_tensor->dims().size(), 1,
42+
platform::errors::InvalidArgument(
43+
"If the element type of 'shape' in Reshape Op is Tensor, "
44+
"the element's shape must be [1]. But received the element's "
45+
"shape is [%d]",
46+
shape_tensor->dims().size()));
47+
48+
target_shape_vector.push_back(GetDataFromTensor<int>(shape_tensor)[0]);
49+
}
50+
} else {
51+
auto* shape_tensor = ctx.HasInput("Shape")
52+
? ctx.Input<framework::LoDTensor>("Shape")
53+
: nullptr;
54+
if (shape_tensor) {
55+
target_shape_vector = GetDataFromTensor<int>(shape_tensor);
56+
} else {
57+
target_shape_vector = ctx.Attr<std::vector<int>>("shape");
58+
PADDLE_ENFORCE_GT(
59+
target_shape_vector.size(), 0,
60+
platform::errors::InvalidArgument(
61+
"The length of shape attribute should be larger than 0 when "
62+
"input ShapeTensor and Shape are empty!"));
63+
}
3564
}
36-
PADDLE_ENFORCE_EQ(ctx.Input<framework::LoDTensor>("Shape"), nullptr,
37-
platform::errors::Unimplemented(
38-
"Input(Shape) is not supported on NPU."));
39-
auto shape = out->dims();
40-
out->mutable_data(ctx.GetPlace(), x->type());
41-
framework::TensorCopy(
42-
*x, ctx.GetPlace(),
43-
ctx.template device_context<platform::DeviceContext>(), out);
44-
out->Resize(shape);
65+
66+
int num_negative =
67+
std::count(target_shape_vector.begin(), target_shape_vector.end(), -1);
68+
PADDLE_ENFORCE_LE(
69+
num_negative, 1,
70+
platform::errors::InvalidArgument(
71+
"The max number of -1 in shape attribute or shape tensor is 1 "
72+
"but received %d.",
73+
num_negative));
74+
auto it_zero =
75+
std::find(target_shape_vector.begin(), target_shape_vector.end(), 0);
76+
if (it_zero != target_shape_vector.end()) {
77+
int x_rank = x->dims().size();
78+
for (size_t i = 0; i < target_shape_vector.size(); i++) {
79+
if (target_shape_vector[i] == 0) {
80+
PADDLE_ENFORCE_LT(
81+
i, x_rank,
82+
platform::errors::InvalidArgument(
83+
"The index of 0 in shape attribute or shape tensor",
84+
"should be less than input dim size, ",
85+
"but the index is %d and input dim size is %d", i, x_rank));
86+
target_shape_vector[i] = x->dims().at(i);
87+
}
88+
}
89+
}
90+
91+
auto it =
92+
std::find(target_shape_vector.begin(), target_shape_vector.end(), -1);
93+
if (it != target_shape_vector.end()) {
94+
auto ddim_out_vec = framework::vectorize(x->dims());
95+
int ddim_out_product = std::accumulate(
96+
ddim_out_vec.begin(), ddim_out_vec.end(), 1, std::multiplies<int>());
97+
int reshape_out_product = std::accumulate(target_shape_vector.begin(),
98+
target_shape_vector.end(), -1,
99+
std::multiplies<int>());
100+
int index = std::distance(target_shape_vector.begin(), it);
101+
target_shape_vector[index] = ddim_out_product / reshape_out_product;
102+
}
103+
104+
auto out_dims = framework::make_ddim(target_shape_vector);
105+
out->mutable_data<T>(out_dims, place);
106+
107+
NpuOpRunner runner;
108+
// the shape input must be on the host side
109+
runner.SetType("Reshape")
110+
.AddInput(*x)
111+
.AddInput(std::vector<int32_t>(target_shape_vector))
112+
.AddOutput(*out)
113+
.AddAttr("axis", 0)
114+
.AddAttr("num_axes", -1);
115+
runner.Run(stream);
45116
}
46117
};
47118

0 commit comments

Comments
 (0)