Skip to content

Commit 8bc09e7

Browse files
[AutoParallel] Set value phi yaml (#58893)
* set value to phi
1 parent 3ab6f55 commit 8bc09e7

File tree

13 files changed

+145
-111
lines changed

13 files changed

+145
-111
lines changed

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def BackwardValidationCheck(self):
826826
max_grad_tensor_position = -1
827827
for _, (_, _, pos) in backward_grad_inputs_map.items():
828828
assert pos > max_fwd_input_position, AssertMessage(
829-
pos, max_grad_tensor_position
829+
pos, max_fwd_input_position
830830
)
831831
max_grad_tensor_position = max(max_grad_tensor_position, pos)
832832

paddle/fluid/ir_adaptor/translator/op_translator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2216,12 +2216,12 @@ struct SetValueWithTensorOpTranscriber : public SetValueOpTranscriber {
22162216
struct SetValueGradOpTranscriber : public SetValueWithTensorOpTranscriber {
22172217
pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx,
22182218
const OpDesc& op_desc) override {
2219-
std::string target_op_name = dialect::SetValueGradOp::name();
2219+
std::string target_op_name = dialect::SetValueWithTensorGradOp::name();
22202220
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
22212221
if (!op_info) {
22222222
IR_THROW(
22232223
"Op set_value_grad should have corresponding OpInfo "
2224-
"pd_op.set_value_grad");
2224+
"pd_op.set_value_with_tensor_grad");
22252225
}
22262226

22272227
return op_info;

paddle/fluid/pir/dialect/operator/ir/ops.yaml

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -140,30 +140,6 @@
140140
func : send_v2
141141
param : [x, ring_id, dynamic_shape, peer, use_calc_stream]
142142

143-
- op : set_value
144-
args : (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values)
145-
output : Tensor(out)
146-
infer_meta:
147-
func: SetValueInferMeta
148-
param: [x]
149-
kernel:
150-
func: set_value
151-
param: [x, starts, ends, steps, axes, decrease_axes, none_axes, shape, values]
152-
inplace: (x -> out)
153-
backward: set_value_grad
154-
155-
- op : set_value_with_tensor
156-
args : (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
157-
output : Tensor(out)
158-
infer_meta:
159-
func: SetValueInferMeta
160-
param: [x]
161-
kernel:
162-
func: set_value_with_tensor
163-
param: [x, values, starts, ends, steps, axes, decrease_axes, none_axes]
164-
inplace: (x -> out)
165-
backward: set_value_grad
166-
167143
- op : shadow_feed
168144
args : (Tensor x)
169145
output : Tensor(out)

paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,3 @@
1717
kernel:
1818
func: fused_feedforward_grad
1919
optional: linear1_bias, linear2_bias, ln1_scale, ln1_bias, ln1_out, ln1_mean, ln1_variance, ln2_scale, ln2_bias, ln2_mean, ln2_variance, dropout2_out, ln1_scale_grad, ln1_bias_grad, ln2_scale_grad, ln2_bias_grad, linear2_bias_grad
20-
21-
- backward_op : set_value_grad
22-
args : (Tensor out_grad, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
23-
output : Tensor(x_grad), Tensor(values_grad)
24-
infer_meta:
25-
func: SetValueGradInferMeta
26-
param: [out_grad, values]
27-
kernel:
28-
func: set_value_grad
29-
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]

paddle/fluid/pybind/eager_method.cc

Lines changed: 59 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,7 +1306,7 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
13061306
EAGER_TRY
13071307
PyObject* _index = PyTuple_GET_ITEM(args, 0);
13081308
VLOG(4) << "Call _getitem_index_not_tensor";
1309-
std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
1309+
std::vector<int64_t> slice_axes, slice_starts, slice_ends, slice_strides,
13101310
decrease_axis, none_axes, infer_flags;
13111311
std::vector<int64_t> list_select_idxs;
13121312
// if index is a list, list_select_flag will be true
@@ -1353,26 +1353,25 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
13531353
break;
13541354
}
13551355
}
1356-
std::vector<int64_t> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
1357-
std::vector<int64_t> infer_flags_tmp(infer_flags.begin(),
1358-
infer_flags.end());
1359-
std::vector<int64_t> decrease_axis_tmp(decrease_axis.begin(),
1360-
decrease_axis.end());
13611356

13621357
if (op_type == "slice") {
13631358
eager_gil_scoped_release guard;
13641359
out = slice_ad_func(self->tensor,
1365-
slice_axes_tmp,
1360+
slice_axes,
13661361
slice_starts,
13671362
slice_ends,
1368-
infer_flags_tmp,
1369-
decrease_axis_tmp);
1363+
infer_flags,
1364+
decrease_axis);
13701365
} else if (op_type == "strided_slice") {
13711366
eager_gil_scoped_release guard;
1372-
out = strided_slice_ad_func(
1373-
self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
1374-
if (!decrease_axis_tmp.empty()) {
1375-
out = squeeze_ad_func(out, decrease_axis_tmp);
1367+
std::vector<int> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
1368+
out = strided_slice_ad_func(self->tensor,
1369+
slice_axes_tmp,
1370+
slice_starts,
1371+
slice_ends,
1372+
slice_strides);
1373+
if (!decrease_axis.empty()) {
1374+
out = squeeze_ad_func(out, decrease_axis);
13761375
}
13771376
} else {
13781377
PADDLE_THROW(platform::errors::InvalidArgument(
@@ -1607,7 +1606,7 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
16071606
// TODO(liym27): Try not to call TensorToPyArray because it always
16081607
// copys data to cpu place, which reduces performance.
16091608
if (parse_index) {
1610-
std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
1609+
std::vector<int64_t> axes, starts, ends, steps, decrease_axes, none_axes,
16111610
infer_flags;
16121611
std::vector<int64_t> list_select_idxs;
16131612
// if index is a list, list_select_flag will be true
@@ -1624,13 +1623,6 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
16241623
&list_select_idxs,
16251624
&list_select_flag);
16261625

1627-
framework::AttributeMap attrs = {{"axes", axes},
1628-
{"starts", starts},
1629-
{"ends", ends},
1630-
{"steps", steps},
1631-
{"decrease_axes", decrease_axes},
1632-
{"none_axes", none_axes}};
1633-
16341626
if (egr::Controller::Instance().HasGrad()) {
16351627
PADDLE_ENFORCE_EQ(
16361628
egr::EagerUtils::IsLeafTensor(self->tensor) &&
@@ -1643,6 +1635,8 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
16431635
}
16441636

16451637
paddle::Tensor value_tensor;
1638+
std::vector<phi::Scalar> values;
1639+
std::vector<int64_t> shape = std::vector<int64_t>{1};
16461640

16471641
if (PyCheckTensor(value_obj)) {
16481642
value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
@@ -1706,25 +1700,20 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
17061700
PyComplex_Check(value_obj)) {
17071701
if (self->tensor.dtype() == phi::DataType::FLOAT32 ||
17081702
self->tensor.dtype() == phi::DataType::FLOAT16) {
1709-
attrs["values"] = std::vector<paddle::experimental::Scalar>{
1710-
value_obj_tmp.cast<float>()};
1703+
values = std::vector<phi::Scalar>{value_obj_tmp.cast<float>()};
17111704
} else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
1712-
attrs["values"] = std::vector<paddle::experimental::Scalar>{
1713-
value_obj_tmp.cast<double>()};
1705+
values = std::vector<phi::Scalar>{value_obj_tmp.cast<double>()};
17141706
} else if (self->tensor.dtype() == phi::DataType::INT32) {
1715-
attrs["values"] = std::vector<paddle::experimental::Scalar>{
1716-
value_obj_tmp.cast<int32_t>()};
1707+
values = std::vector<phi::Scalar>{value_obj_tmp.cast<int32_t>()};
17171708
} else if (self->tensor.dtype() == phi::DataType::INT64) {
1718-
attrs["values"] = std::vector<paddle::experimental::Scalar>{
1719-
value_obj_tmp.cast<int64_t>()};
1709+
values = std::vector<phi::Scalar>{value_obj_tmp.cast<int64_t>()};
17201710
} else if (self->tensor.dtype() == phi::DataType::BOOL) {
1721-
attrs["values"] = std::vector<paddle::experimental::Scalar>{
1722-
value_obj_tmp.cast<bool>()};
1711+
values = std::vector<phi::Scalar>{value_obj_tmp.cast<bool>()};
17231712
} else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
1724-
attrs["values"] = std::vector<paddle::experimental::Scalar>{
1713+
values = std::vector<phi::Scalar>{
17251714
value_obj_tmp.cast<std::complex<float>>()};
17261715
} else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
1727-
attrs["values"] = std::vector<paddle::experimental::Scalar>{
1716+
values = std::vector<phi::Scalar>{
17281717
value_obj_tmp.cast<std::complex<double>>()};
17291718
} else {
17301719
PADDLE_THROW(platform::errors::InvalidArgument(
@@ -1734,8 +1723,6 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
17341723
"float16, "
17351724
"please check the type of tensor."));
17361725
}
1737-
attrs["shape"] = std::vector<int64_t>{1};
1738-
17391726
} else {
17401727
PADDLE_THROW(platform::errors::InvalidArgument(
17411728
"Value type error. The assign value allows "
@@ -1748,25 +1735,46 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
17481735
// Release gil and do tracing
17491736
py::gil_scoped_release release;
17501737
// use inplace set_value_ operator
1751-
if (value_tensor.initialized() &&
1752-
(self->tensor.dtype() != value_tensor.dtype())) {
1753-
if (egr::Controller::Instance().GetAMPLevel() !=
1754-
paddle::imperative::AmpLevel::O0) {
1755-
paddle::small_vector<std::vector<paddle::Tensor>,
1756-
egr::kSlotSmallVectorSize>
1757-
tmps = {{self->tensor}, {value_tensor}};
1758-
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
1759-
self->tensor = egr::EagerAmpAutoCast(
1760-
self->tensor.name(), self->tensor, amp_dtype, "set_value");
1761-
value_tensor = egr::EagerAmpAutoCast(
1762-
value_tensor.name(), value_tensor, amp_dtype, "set_value");
1763-
}
1738+
if (value_tensor.initialized()) {
17641739
if (self->tensor.dtype() != value_tensor.dtype()) {
1765-
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
1740+
if (egr::Controller::Instance().GetAMPLevel() !=
1741+
paddle::imperative::AmpLevel::O0) {
1742+
paddle::small_vector<std::vector<paddle::Tensor>,
1743+
egr::kSlotSmallVectorSize>
1744+
tmps = {{self->tensor}, {value_tensor}};
1745+
auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
1746+
self->tensor = egr::EagerAmpAutoCast(
1747+
self->tensor.name(), self->tensor, amp_dtype, "set_value");
1748+
value_tensor = egr::EagerAmpAutoCast(
1749+
value_tensor.name(), value_tensor, amp_dtype, "set_value");
1750+
}
1751+
if (self->tensor.dtype() != value_tensor.dtype()) {
1752+
value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
1753+
}
1754+
}
1755+
const phi::distributed::ProcessMesh* mesh = nullptr;
1756+
if (InputsContainDistTensor(&mesh, self->tensor, value_tensor)) {
1757+
ConvertAllInputsToDistTensor(mesh, self->tensor, value_tensor);
17661758
}
1759+
self->tensor = set_value_with_tensor__ad_func(self->tensor,
1760+
value_tensor,
1761+
starts,
1762+
ends,
1763+
steps,
1764+
axes,
1765+
decrease_axes,
1766+
none_axes);
1767+
} else {
1768+
self->tensor = set_value__ad_func(self->tensor,
1769+
starts,
1770+
ends,
1771+
steps,
1772+
axes,
1773+
decrease_axes,
1774+
none_axes,
1775+
shape,
1776+
values);
17671777
}
1768-
self->tensor = set_value__dygraph_function(
1769-
self->tensor, value_tensor, {}, {}, {}, attrs);
17701778
}
17711779
if (PyCheckTensor(value_obj)) {
17721780
// pass the stop_gradient from value to tensor.

paddle/fluid/pybind/slice_utils.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,13 +143,13 @@ static int _PySlice_GetIndices(PySliceObject* r,
143143

144144
static void ParseIndexingSlice(phi::DDim shape,
145145
PyObject* _index,
146-
std::vector<int>* slice_axes,
147-
std::vector<int>* slice_starts,
148-
std::vector<int>* slice_ends,
149-
std::vector<int>* slice_strides,
150-
std::vector<int>* decrease_axis,
151-
std::vector<int>* none_axes,
152-
std::vector<int>* infer_flags,
146+
std::vector<int64_t>* slice_axes,
147+
std::vector<int64_t>* slice_starts,
148+
std::vector<int64_t>* slice_ends,
149+
std::vector<int64_t>* slice_strides,
150+
std::vector<int64_t>* decrease_axis,
151+
std::vector<int64_t>* none_axes,
152+
std::vector<int64_t>* infer_flags,
153153
std::vector<int64_t>* list_select_idxs,
154154
bool* list_select_flag) {
155155
// We allow indexing by Integers, Slices, Ellipsis, None, tuples of those

paddle/phi/api/yaml/legacy_backward.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,28 @@
610610
func : rrelu_grad
611611
data_type : x
612612

613+
- backward_op : set_value_grad
614+
forward : set_value (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values) -> Tensor(out)
615+
args : (Tensor out_grad)
616+
output : Tensor(x_grad)
617+
infer_meta:
618+
func: UnchangedInferMeta
619+
param: [out_grad]
620+
kernel:
621+
func: assign
622+
param: [out_grad]
623+
624+
- backward_op : set_value_with_tensor_grad
625+
forward: set_value_with_tensor (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes) -> Tensor(out)
626+
args : (Tensor values,Tensor out_grad, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
627+
output : Tensor(x_grad), Tensor(values_grad)
628+
infer_meta:
629+
func: SetValueGradInferMeta
630+
param: [out_grad, values]
631+
kernel:
632+
func: set_value_grad
633+
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]
634+
613635
- backward_op : slice_double_grad
614636
forward : slice_grad (Tensor input, Tensor grad_out, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis) -> Tensor(grad_input)
615637
args : (Tensor grad_input_grad, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)

paddle/phi/api/yaml/legacy_ops.yaml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,28 @@
961961
intermediate : noise
962962
backward : rrelu_grad
963963

964+
- op : set_value
965+
args : (Tensor x, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes, int64_t[] shape, Scalar[] values)
966+
output : Tensor(out)
967+
inplace: (x -> out)
968+
infer_meta :
969+
func : SetValueInferMeta
970+
param : [x]
971+
kernel :
972+
func : set_value
973+
backward: set_value_grad
974+
975+
- op : set_value_with_tensor
976+
args : (Tensor x, Tensor values, IntArray starts, IntArray ends, IntArray steps, int64_t[] axes, int64_t[] decrease_axes, int64_t[] none_axes)
977+
output : Tensor(out)
978+
inplace: (x -> out)
979+
infer_meta:
980+
func: SetValueInferMeta
981+
param: [x]
982+
kernel:
983+
func: set_value_with_tensor
984+
backward: set_value_with_tensor_grad
985+
964986
- op : slice
965987
args : (Tensor input, int64_t[] axes, IntArray starts, IntArray ends, int64_t[] infer_flags, int64_t[] decrease_axis)
966988
output : Tensor

python/paddle/base/variable_index.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import itertools
1615
import warnings
1716
from functools import reduce
1817

@@ -875,6 +874,7 @@ def _setitem_static(x, indices, values):
875874
StartsTensorList = None
876875
EndsTensorList = None
877876
StepsTensorList = None
877+
shape = None
878878

879879
if paddle.utils._contain_var(starts):
880880
StartsTensorList = paddle.utils._convert_to_tensor_list(starts)
@@ -919,14 +919,29 @@ def _setitem_static(x, indices, values):
919919

920920
# step3.1: Only basic indexing, use OP set_value to set value.
921921
if paddle.in_dynamic_mode():
922-
return paddle._legacy_C_ops.set_value_(
923-
x,
924-
value_tensor,
925-
StartsTensorList,
926-
EndsTensorList,
927-
StepsTensorList,
928-
*itertools.chain.from_iterable(attrs.items()),
929-
)
922+
if value_tensor is None:
923+
return paddle._C_ops.set_value_(
924+
x,
925+
starts,
926+
ends,
927+
steps,
928+
axes,
929+
decrease_axes,
930+
none_axes,
931+
shape,
932+
values,
933+
)
934+
else:
935+
return paddle._C_ops.set_value_with_tensor_(
936+
x,
937+
value_tensor,
938+
starts,
939+
ends,
940+
steps,
941+
axes,
942+
decrease_axes,
943+
none_axes,
944+
)
930945
else:
931946
helper = paddle.base.layer_helper.LayerHelper(
932947
'set_value', **locals()

test/indexing/test_setitem.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,3 +503,7 @@ def test_indexing_is_multi_dim_list(self):
503503
res = self.exe.run(fetch_list=[y.name])
504504

505505
np.testing.assert_allclose(res[0], np_data)
506+
507+
508+
if __name__ == '__main__':
509+
unittest.main()

0 commit comments

Comments
 (0)