@@ -16,6 +16,7 @@ limitations under the License. */
1616#include < string>
1717
1818#include " paddle/fluid/framework/op_registry.h"
19+ #include " paddle/fluid/operators/utils.h"
1920#include " paddle/fluid/platform/device/npu/npu_op_runner.h"
2021
2122namespace paddle {
@@ -25,23 +26,93 @@ template <typename DeviceContext, typename T>
2526class Reshape2NPUKernel : public framework ::OpKernel<T> {
2627 public:
2728 void Compute (const framework::ExecutionContext& ctx) const override {
29+ auto stream =
30+ ctx.template device_context <paddle::platform::NPUDeviceContext>()
31+ .stream ();
32+ auto place = ctx.GetPlace ();
2833 auto * x = ctx.Input <framework::Tensor>(" X" );
2934 auto * out = ctx.Output <framework::Tensor>(" Out" );
30- auto list_new_shape_tensor =
31- ctx.MultiInput <framework::Tensor>(" ShapeTensor" );
32- if (list_new_shape_tensor.size () > 0 ) {
33- PADDLE_THROW (platform::errors::Unimplemented (
34- " Input(ShapeTensor) is not supported on NPU." ));
35+
36+ std::vector<int32_t > target_shape_vector;
37+ auto shape_tensor_vector = ctx.MultiInput <framework::Tensor>(" ShapeTensor" );
38+ if (shape_tensor_vector.size () > 0 ) {
39+ for (auto * shape_tensor : shape_tensor_vector) {
40+ PADDLE_ENFORCE_EQ (
41+ shape_tensor->dims ().size (), 1 ,
42+ platform::errors::InvalidArgument (
43+ " If the element type of 'shape' in Reshape Op is Tensor, "
44+ " the element's shape must be [1]. But received the element's "
45+ " shape is [%d]" ,
46+ shape_tensor->dims ().size ()));
47+
48+ target_shape_vector.push_back (GetDataFromTensor<int >(shape_tensor)[0 ]);
49+ }
50+ } else {
51+ auto * shape_tensor = ctx.HasInput (" Shape" )
52+ ? ctx.Input <framework::LoDTensor>(" Shape" )
53+ : nullptr ;
54+ if (shape_tensor) {
55+ target_shape_vector = GetDataFromTensor<int >(shape_tensor);
56+ } else {
57+ target_shape_vector = ctx.Attr <std::vector<int >>(" shape" );
58+ PADDLE_ENFORCE_GT (
59+ target_shape_vector.size (), 0 ,
60+ platform::errors::InvalidArgument (
61+ " The length of shape attribute should be larger than 0 when "
62+ " input ShapeTensor and Shape are empty!" ));
63+ }
3564 }
36- PADDLE_ENFORCE_EQ (ctx.Input <framework::LoDTensor>(" Shape" ), nullptr ,
37- platform::errors::Unimplemented (
38- " Input(Shape) is not supported on NPU." ));
39- auto shape = out->dims ();
40- out->mutable_data (ctx.GetPlace (), x->type ());
41- framework::TensorCopy (
42- *x, ctx.GetPlace (),
43- ctx.template device_context <platform::DeviceContext>(), out);
44- out->Resize (shape);
65+
66+ int num_negative =
67+ std::count (target_shape_vector.begin (), target_shape_vector.end (), -1 );
68+ PADDLE_ENFORCE_LE (
69+ num_negative, 1 ,
70+ platform::errors::InvalidArgument (
71+ " The max number of -1 in shape attribute or shape tensor is 1 "
72+ " but received %d." ,
73+ num_negative));
74+ auto it_zero =
75+ std::find (target_shape_vector.begin (), target_shape_vector.end (), 0 );
76+ if (it_zero != target_shape_vector.end ()) {
77+ int x_rank = x->dims ().size ();
78+ for (size_t i = 0 ; i < target_shape_vector.size (); i++) {
79+ if (target_shape_vector[i] == 0 ) {
80+ PADDLE_ENFORCE_LT (
81+ i, x_rank,
82+ platform::errors::InvalidArgument (
83+ " The index of 0 in shape attribute or shape tensor" ,
84+ " should be less than input dim size, " ,
85+ " but the index is %d and input dim size is %d" , i, x_rank));
86+ target_shape_vector[i] = x->dims ().at (i);
87+ }
88+ }
89+ }
90+
91+ auto it =
92+ std::find (target_shape_vector.begin (), target_shape_vector.end (), -1 );
93+ if (it != target_shape_vector.end ()) {
94+ auto ddim_out_vec = framework::vectorize (x->dims ());
95+ int ddim_out_product = std::accumulate (
96+ ddim_out_vec.begin (), ddim_out_vec.end (), 1 , std::multiplies<int >());
97+ int reshape_out_product = std::accumulate (target_shape_vector.begin (),
98+ target_shape_vector.end (), -1 ,
99+ std::multiplies<int >());
100+ int index = std::distance (target_shape_vector.begin (), it);
101+ target_shape_vector[index] = ddim_out_product / reshape_out_product;
102+ }
103+
104+ auto out_dims = framework::make_ddim (target_shape_vector);
105+ out->mutable_data <T>(out_dims, place);
106+
107+ NpuOpRunner runner;
108+ // the shape input must be on the host side
109+ runner.SetType (" Reshape" )
110+ .AddInput (*x)
111+ .AddInput (std::vector<int32_t >(target_shape_vector))
112+ .AddOutput (*out)
113+ .AddAttr (" axis" , 0 )
114+ .AddAttr (" num_axes" , -1 );
115+ runner.Run (stream);
45116 }
46117};
47118
0 commit comments