@@ -63,46 +63,52 @@ namespace operators {
6363/*
6464* Pack input and output tensors into respective vectors with
6565* consideration of varible X`s class type.
66+ * Input variable X is supported to be whether LoDTensor or
67+ * SelectedRows class type in this package function, once X
68+ * was SelectedRows type, a valid pointer x_for_selectedrows
69+ * is excepted to be passed in from op kernel for acquisition
70+ * of the valid address of LoDTensor created ahead in the function.
6671*/
6772template <typename OutT>
6873int PackTensorsIntoVector (const framework::ExecutionContext &ctx,
6974 std::vector<const framework::Tensor *> *ins,
7075 std::vector<framework::Tensor *> *outs,
71- framework::Tensor *x_ptr = nullptr ) {
76+ framework::Tensor *x_for_selectedrows = nullptr ) {
7277 int axis = -1 ;
73- int x_dims_size = 0 ;
74- framework::Tensor *z;
75-
7678 auto x_var = ctx.InputVar (" X" );
7779 PADDLE_ENFORCE_NOT_NULL (
7880 x_var, platform::errors::InvalidArgument (
79- " Unable get input Variable X, Variable name is %s.\n " ,
81+ " Unable to get input Variable X, Variable name is %s.\n " ,
8082 ctx.InputName (" X" )));
8183 auto *y = ctx.Input <framework::LoDTensor>(" Y" );
84+ framework::Tensor *z;
8285
83- if (x_ptr == nullptr || x_var->IsType <framework::LoDTensor>()) {
86+ if (x_var->IsType <framework::LoDTensor>()) {
8487 auto *x = ctx.Input <framework::LoDTensor>(" X" );
8588 z = ctx.Output <framework::LoDTensor>(" Out" );
8689 ins->emplace_back (x);
87- x_dims_size = x->dims ().size ();
88-
8990 } else if (x_var->IsType <framework::SelectedRows>()) {
9091 PADDLE_ENFORCE_EQ (y->dims ().size () == 1 && y->dims ()[0 ] == 1 , true ,
9192 platform::errors::InvalidArgument (
9293 " For elementwise_op, if X is Sparse, Y must be "
9394 " scalar. But reveived the size of Y = %d." ,
9495 y->dims ().size ()));
96+ PADDLE_ENFORCE_NOT_NULL (
97+ x_for_selectedrows,
98+ platform::errors::InvalidArgument (
99+ " The parameter x_for_selectedrows is excepted to "
100+ " be valid, once input varible X`s class type is "
101+ " SelectedRows.\n " ));
95102 auto &x_sele = x_var->Get <framework::SelectedRows>();
96103 auto out_sele = ctx.Output <framework::SelectedRows>(" Out" );
97- *x_ptr = x_sele.value ();
104+ *x_for_selectedrows = x_sele.value ();
98105 out_sele->set_rows (x_sele.rows ());
99106 out_sele->set_height (x_sele.height ());
100107 out_sele->mutable_value ()->Resize (x_sele.value ().dims ());
101- out_sele->mutable_value ()->mutable_data (ctx.GetPlace (), x_ptr->type ());
108+ out_sele->mutable_value ()->mutable_data (ctx.GetPlace (),
109+ x_for_selectedrows->type ());
102110 z = ctx.Output <framework::SelectedRows>(" Out" )->mutable_value ();
103- ins->emplace_back (x_ptr);
104- x_dims_size = x_ptr->dims ().size ();
105-
111+ ins->emplace_back (x_for_selectedrows);
106112 } else {
107113 PADDLE_THROW (platform::errors::InvalidArgument (
108114 " X's type[%s] is not supported by elementwise_op. X's type should be "
@@ -115,7 +121,6 @@ int PackTensorsIntoVector(const framework::ExecutionContext &ctx,
115121 if (y != nullptr ) {
116122 ins->emplace_back (y);
117123 axis = ctx.HasAttr (" axis" ) ? ctx.Attr <int >(" axis" ) : -1 ;
118- axis = axis == -1 ? std::abs (y->dims ().size () - x_dims_size) : axis;
119124 }
120125 return axis;
121126}
0 commit comments