@@ -1306,7 +1306,7 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
13061306 EAGER_TRY
13071307 PyObject* _index = PyTuple_GET_ITEM (args, 0 );
13081308 VLOG (4 ) << " Call _getitem_index_not_tensor" ;
1309- std::vector<int > slice_axes, slice_starts, slice_ends, slice_strides,
1309+ std::vector<int64_t > slice_axes, slice_starts, slice_ends, slice_strides,
13101310 decrease_axis, none_axes, infer_flags;
13111311 std::vector<int64_t > list_select_idxs;
13121312 // if index is a list, list_select_flag will be true
@@ -1353,26 +1353,25 @@ static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
13531353 break ;
13541354 }
13551355 }
1356- std::vector<int64_t > slice_axes_tmp (slice_axes.begin (), slice_axes.end ());
1357- std::vector<int64_t > infer_flags_tmp (infer_flags.begin (),
1358- infer_flags.end ());
1359- std::vector<int64_t > decrease_axis_tmp (decrease_axis.begin (),
1360- decrease_axis.end ());
13611356
13621357 if (op_type == " slice" ) {
13631358 eager_gil_scoped_release guard;
13641359 out = slice_ad_func (self->tensor ,
1365- slice_axes_tmp ,
1360+ slice_axes ,
13661361 slice_starts,
13671362 slice_ends,
1368- infer_flags_tmp ,
1369- decrease_axis_tmp );
1363+ infer_flags ,
1364+ decrease_axis );
13701365 } else if (op_type == " strided_slice" ) {
13711366 eager_gil_scoped_release guard;
1372- out = strided_slice_ad_func (
1373- self->tensor , slice_axes, slice_starts, slice_ends, slice_strides);
1374- if (!decrease_axis_tmp.empty ()) {
1375- out = squeeze_ad_func (out, decrease_axis_tmp);
1367+ std::vector<int > slice_axes_tmp (slice_axes.begin (), slice_axes.end ());
1368+ out = strided_slice_ad_func (self->tensor ,
1369+ slice_axes_tmp,
1370+ slice_starts,
1371+ slice_ends,
1372+ slice_strides);
1373+ if (!decrease_axis.empty ()) {
1374+ out = squeeze_ad_func (out, decrease_axis);
13761375 }
13771376 } else {
13781377 PADDLE_THROW (platform::errors::InvalidArgument (
@@ -1607,7 +1606,7 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
16071606 // TODO(liym27): Try not to call TensorToPyArray because it always
16081607 // copys data to cpu place, which reduces performance.
16091608 if (parse_index) {
1610- std::vector<int > axes, starts, ends, steps, decrease_axes, none_axes,
1609+ std::vector<int64_t > axes, starts, ends, steps, decrease_axes, none_axes,
16111610 infer_flags;
16121611 std::vector<int64_t > list_select_idxs;
16131612 // if index is a list, list_select_flag will be true
@@ -1624,13 +1623,6 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
16241623 &list_select_idxs,
16251624 &list_select_flag);
16261625
1627- framework::AttributeMap attrs = {{" axes" , axes},
1628- {" starts" , starts},
1629- {" ends" , ends},
1630- {" steps" , steps},
1631- {" decrease_axes" , decrease_axes},
1632- {" none_axes" , none_axes}};
1633-
16341626 if (egr::Controller::Instance ().HasGrad ()) {
16351627 PADDLE_ENFORCE_EQ (
16361628 egr::EagerUtils::IsLeafTensor (self->tensor ) &&
@@ -1643,6 +1635,8 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
16431635 }
16441636
16451637 paddle::Tensor value_tensor;
1638+ std::vector<phi::Scalar> values;
1639+ std::vector<int64_t > shape = std::vector<int64_t >{1 };
16461640
16471641 if (PyCheckTensor (value_obj)) {
16481642 value_tensor = reinterpret_cast <TensorObject*>(value_obj)->tensor ;
@@ -1706,25 +1700,20 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
17061700 PyComplex_Check (value_obj)) {
17071701 if (self->tensor .dtype () == phi::DataType::FLOAT32 ||
17081702 self->tensor .dtype () == phi::DataType::FLOAT16) {
1709- attrs[" values" ] = std::vector<paddle::experimental::Scalar>{
1710- value_obj_tmp.cast <float >()};
1703+ values = std::vector<phi::Scalar>{value_obj_tmp.cast <float >()};
17111704 } else if (self->tensor .dtype () == phi::DataType::FLOAT64) {
1712- attrs[" values" ] = std::vector<paddle::experimental::Scalar>{
1713- value_obj_tmp.cast <double >()};
1705+ values = std::vector<phi::Scalar>{value_obj_tmp.cast <double >()};
17141706 } else if (self->tensor .dtype () == phi::DataType::INT32) {
1715- attrs[" values" ] = std::vector<paddle::experimental::Scalar>{
1716- value_obj_tmp.cast <int32_t >()};
1707+ values = std::vector<phi::Scalar>{value_obj_tmp.cast <int32_t >()};
17171708 } else if (self->tensor .dtype () == phi::DataType::INT64) {
1718- attrs[" values" ] = std::vector<paddle::experimental::Scalar>{
1719- value_obj_tmp.cast <int64_t >()};
1709+ values = std::vector<phi::Scalar>{value_obj_tmp.cast <int64_t >()};
17201710 } else if (self->tensor .dtype () == phi::DataType::BOOL) {
1721- attrs[" values" ] = std::vector<paddle::experimental::Scalar>{
1722- value_obj_tmp.cast <bool >()};
1711+ values = std::vector<phi::Scalar>{value_obj_tmp.cast <bool >()};
17231712 } else if (self->tensor .dtype () == phi::DataType::COMPLEX64) {
1724- attrs[ " values" ] = std::vector<paddle::experimental ::Scalar>{
1713+ values = std::vector<phi ::Scalar>{
17251714 value_obj_tmp.cast <std::complex <float >>()};
17261715 } else if (self->tensor .dtype () == phi::DataType::COMPLEX128) {
1727- attrs[ " values" ] = std::vector<paddle::experimental ::Scalar>{
1716+ values = std::vector<phi ::Scalar>{
17281717 value_obj_tmp.cast <std::complex <double >>()};
17291718 } else {
17301719 PADDLE_THROW (platform::errors::InvalidArgument (
@@ -1734,8 +1723,6 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
17341723 " float16, "
17351724 " please check the type of tensor." ));
17361725 }
1737- attrs[" shape" ] = std::vector<int64_t >{1 };
1738-
17391726 } else {
17401727 PADDLE_THROW (platform::errors::InvalidArgument (
17411728 " Value type error. The assign value allows "
@@ -1748,25 +1735,46 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
17481735 // Release gil and do tracing
17491736 py::gil_scoped_release release;
17501737 // use inplace set_value_ operator
1751- if (value_tensor.initialized () &&
1752- (self->tensor .dtype () != value_tensor.dtype ())) {
1753- if (egr::Controller::Instance ().GetAMPLevel () !=
1754- paddle::imperative::AmpLevel::O0) {
1755- paddle::small_vector<std::vector<paddle::Tensor>,
1756- egr::kSlotSmallVectorSize >
1757- tmps = {{self->tensor }, {value_tensor}};
1758- auto amp_dtype = egr::GetAmpDestDtype (" set_value" , tmps);
1759- self->tensor = egr::EagerAmpAutoCast (
1760- self->tensor .name (), self->tensor , amp_dtype, " set_value" );
1761- value_tensor = egr::EagerAmpAutoCast (
1762- value_tensor.name (), value_tensor, amp_dtype, " set_value" );
1763- }
1738+ if (value_tensor.initialized ()) {
17641739 if (self->tensor .dtype () != value_tensor.dtype ()) {
1765- value_tensor = cast_ad_func (value_tensor, self->tensor .dtype ());
1740+ if (egr::Controller::Instance ().GetAMPLevel () !=
1741+ paddle::imperative::AmpLevel::O0) {
1742+ paddle::small_vector<std::vector<paddle::Tensor>,
1743+ egr::kSlotSmallVectorSize >
1744+ tmps = {{self->tensor }, {value_tensor}};
1745+ auto amp_dtype = egr::GetAmpDestDtype (" set_value" , tmps);
1746+ self->tensor = egr::EagerAmpAutoCast (
1747+ self->tensor .name (), self->tensor , amp_dtype, " set_value" );
1748+ value_tensor = egr::EagerAmpAutoCast (
1749+ value_tensor.name (), value_tensor, amp_dtype, " set_value" );
1750+ }
1751+ if (self->tensor .dtype () != value_tensor.dtype ()) {
1752+ value_tensor = cast_ad_func (value_tensor, self->tensor .dtype ());
1753+ }
1754+ }
1755+ const phi::distributed::ProcessMesh* mesh = nullptr ;
1756+ if (InputsContainDistTensor (&mesh, self->tensor , value_tensor)) {
1757+ ConvertAllInputsToDistTensor (mesh, self->tensor , value_tensor);
17661758 }
1759+ self->tensor = set_value_with_tensor__ad_func (self->tensor ,
1760+ value_tensor,
1761+ starts,
1762+ ends,
1763+ steps,
1764+ axes,
1765+ decrease_axes,
1766+ none_axes);
1767+ } else {
1768+ self->tensor = set_value__ad_func (self->tensor ,
1769+ starts,
1770+ ends,
1771+ steps,
1772+ axes,
1773+ decrease_axes,
1774+ none_axes,
1775+ shape,
1776+ values);
17671777 }
1768- self->tensor = set_value__dygraph_function (
1769- self->tensor , value_tensor, {}, {}, {}, attrs);
17701778 }
17711779 if (PyCheckTensor (value_obj)) {
17721780 // pass the stop_gradient from value to tensor.
0 commit comments