Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@
view,
view_as,
unfold,
index_fill,
index_fill_,
)

from .tensor.math import ( # noqa: F401
Expand Down Expand Up @@ -899,4 +901,6 @@
'i1e',
'polygamma',
'polygamma_',
'index_fill',
"index_fill_",
]
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@
from .manipulation import view # noqa: F401
from .manipulation import view_as # noqa: F401
from .manipulation import unfold # noqa: F401
from .manipulation import index_fill # noqa: F401
from .manipulation import index_fill_ # noqa: F401
from .math import abs # noqa: F401
from .math import abs_ # noqa: F401
from .math import acos # noqa: F401
Expand Down Expand Up @@ -746,6 +748,8 @@
'create_array',
'einsum',
'normal_',
'index_fill',
'index_fill_',
]

# this list used in math_op_patch.py for magic_method bind
Expand Down
104 changes: 104 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5166,3 +5166,107 @@ def unfold(x, axis, size, step, name=None):
}
for name, func in __METHODS.items():
setattr(core.eager.Tensor, name, func)


def _index_fill_impl(x, index, axis, value, inplace):
if not isinstance(index, Variable):
raise ValueError("index must be Tensor")

if not isinstance(value, Variable):
value = paddle.to_tensor(value, dtype=x.dtype)
else:
if len(value.shape) > 0:
raise ValueError("value must be scalar or 0-D tensor")

x_dim = len(x.shape)
if axis < 0:
axis = axis + x_dim

if not (isinstance(axis, int)) or (axis > x_dim - 1) or axis < -x_dim:
Copy link
Contributor

@jeff41404 jeff41404 Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The negative axis has been processed in L5183 above, so the judgment condition here should be axis < 0 not axis < -x_dim.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

raise ValueError(
"The axis should be a int, and in range [-rank(x), rank(x))"
)

perm = list(range(len(x.shape)))
perm[0] = axis
perm[axis] = 0

out = paddle.clone(x)
out = paddle.transpose(out, perm)
out = paddle.index_put(out, (index,), value)
out = paddle.transpose(out, perm)

if inplace:
x[:] = out
return x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处,如果是inplace,是否可以不要调用clone+setitem赋值,而是直接使用index_put_赋值; 如果非inplace,是否可以不需要额外的clone操作

Copy link
Contributor

@jeff41404 jeff41404 Oct 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the implementation solution in rfc shoule be also changed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current implementation is same with rfc API design already

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in rfc this line: out = paddle.clone(x) at else branch need to be delete ?

else:
return out


def index_fill(x, index, axis, value, name=None):
"""
Outplace version of ``index_fill_`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_index_fill`.

Examples:
.. code-block:: python

import paddle

arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype('int64')
input_tensor = paddle.to_tensor(arr)
index = paddle.to_tensor([0, 2], dtype="int32")
value = -1
res = paddle.index_fill(input_tensor, index, 0, value)
print(input_tensor)
# Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
# [[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print(res)
# Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
# [[-1, -1, -1],
# [ 4, 5, 6],
# [-1, -1, -1]])
"""
return _index_fill_impl(x, index, axis, value, False)


@inplace_apis_in_dygraph_only
def index_fill_(x, index, axis, value, name=None):
"""
Fill the elements of the input tensor with value by the spcific axis and index.

Args:
x (Tensor) : The Destination Tensor. Supported data types are int32, int64, float16, float32, float64.
axis (int): The dimension along which to index.
index (Tensor): The 1-D Tensor containing the indices to index.
The data type of ``index`` must be int32 or int64.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of parameters in the document should be consistent with the order of parameters in the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

value (float): The tensor used to fill with.
name(str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Tensor, same dimention and dtype with x.

Examples:
.. code-block:: python

import paddle

arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype('int64')
input_tensor = paddle.to_tensor(arr)
index = paddle.to_tensor([0, 2], dtype="int32")
value = -1
res = paddle.index_fill(input_tensor, index, 0, value)
print(input_tensor)
# Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
# [[-1, -1, -1],
# [ 4, 5, 6],
# [-1, -1, -1]])
print(res)
# Tensor(shape=[3, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
# [[-1, -1, -1],
# [ 4, 5, 6],
# [-1, -1, -1]])
"""
return _index_fill_impl(x, index, axis, value, True)
164 changes: 164 additions & 0 deletions test/legacy_test/test_index_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from itertools import combinations

import numpy as np

import paddle
from paddle.base import Program

paddle.enable_static()


def compute_index_put_ref(x, axis, index, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_index_fill_ref is more suitable?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the logic of compute_index_put_ref too complex? and it is better to use Numpy's fancy indexing, eg, x=np.transpose(x, perm); x[index] = value; x=np.transpose(x, perm) be simpler and easier to understand?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

def compute_stride(axis, dims):
size = 1
for i in range(axis + 1, len(dims)):
size *= dims[i]
return size

ndims = len(x.shape)
finished = 0
counter = [0] * ndims
x_data = 0
x_stride = compute_stride(axis, x.shape)
x_dim_vec = x.shape
out = np.ndarray.flatten(x)

while finished == 0:
for i in index:
out[x_data + i * x_stride] = value
if ndims == 1:
break
for dim_i in range(ndims):
if dim_i == axis:
if dim_i == ndims - 1:
finished = 1
break
continue
x_stride_ = compute_stride(dim_i, x_dim_vec)
counter[dim_i] += 1
x_data += x_stride_
if counter[dim_i] == x_dim_vec[dim_i]:
if dim_i == ndims - 1:
finished = 1
break
else:
x_data -= counter[dim_i] * x_stride_
counter[dim_i] = 0
else:
break

return np.reshape(out, x_dim_vec)


class TestIndexFillAPIBase(unittest.TestCase):
def setUp(self):
self.init_setting()
self.modify_setting()
self.x_np = np.random.random(self.x_shape).astype(self.dtype_np)
self.index_np = np.array(self.combs[np.random.randint(0, 252)]).astype(
self.index_type
)

self.place = ['cpu']
if paddle.is_compiled_with_cuda():
self.place.append('gpu')

def init_setting(self):
self.dtype_np = 'float64'
self.index_type = 'int64'
self.x_shape = (20, 40)
self.index_size = (5,)
self.axis = 0
self.value = -1
self.combs = list(combinations(list(range(10)), self.index_size[0]))

def modify_setting(self):
pass

def test_static_graph(self):
paddle.enable_static()
for place in self.place:
with paddle.static.program_guard(Program()):
x = paddle.static.data(
name="x", shape=self.x_shape, dtype=self.dtype_np
)
index = paddle.static.data(
name="index", shape=self.index_size, dtype=self.index_type
)
out = paddle.index_fill(x, index, self.axis, self.value)
exe = paddle.static.Executor(place=place)
feed_list = {"x": self.x_np, "index": self.index_np}
pd_res = exe.run(
paddle.static.default_main_program(),
feed=feed_list,
fetch_list=[out],
)[0]
ref_res = compute_index_put_ref(
self.x_np, self.axis, self.index_np, self.value
)
np.testing.assert_allclose(ref_res, pd_res, atol=1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个assert能否使用默认的atol阈值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


def test_dygraph(self):
paddle.disable_static()
for place in self.place:
paddle.device.set_device(place)
x_pd = paddle.to_tensor(self.x_np)
index_pd = paddle.to_tensor(self.index_np)
pd_res = paddle.index_fill(x_pd, index_pd, self.axis, self.value)
ref_res = compute_index_put_ref(
self.x_np, self.axis, self.index_np, self.value
)
np.testing.assert_allclose(ref_res, pd_res, atol=1e-5)

def test_errors(self):
data_np = np.random.random((10, 10)).astype(np.float32)
index = paddle.to_tensor([0, 2])

def test_index_not_tensor():
res = paddle.index_fill(data_np, [0, 2], axis=-1, value=-1)

self.assertRaises(ValueError, test_index_not_tensor)

def test_value_shape():
res = paddle.index_fill(
data_np, index, axis=-1, value=paddle.to_tensor([-1, -4])
)

self.assertRaises(ValueError, test_value_shape)

def test_axis_range():
res = paddle.index_fill(data_np, index, axis=4, value=-1)

self.assertRaises(ValueError, test_axis_range)


class TestIndexFillAPI1(TestIndexFillAPIBase):
def modify_setting(self):
self.dtype_np = 'int64'
self.index_type = 'int32'
self.x_shape = (10, 15, 10)
self.axis = 1


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

补充下complex类型的测试吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_put不支持complex类型的输入

增加了float16类型的测试

class TestIndexFillAPI2(TestIndexFillAPIBase):
def modify_setting(self):
self.dtype_np = 'bool'
self.index_type = 'int32'
self.x_shape = (10, 15, 10)
self.axis = 1
self.value = True
15 changes: 15 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,5 +1482,20 @@ def test_forward_version(self):
self.assertEqual(var.inplace_version, 2)


class TestDygraphInplaceIndexFill(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.random((20, 40))
self.dtype = "float32"
self.axis = 0
self.index = paddle.to_tensor([0, 2])
self.value = -1

def inplace_api_processing(self, var):
return paddle.index_fill_(var, self.index, self.axis, self.value)

def non_inplace_api_processing(self, var):
return paddle.index_fill(var, self.index, self.axis, self.value)


if __name__ == '__main__':
unittest.main()