Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@
unfold,
masked_fill,
masked_fill_,
masked_scatter_,
masked_scatter,
index_fill,
index_fill_,
diagonal_scatter,
Expand Down Expand Up @@ -929,6 +931,8 @@
'polygamma_',
'masked_fill',
'masked_fill_',
'masked_scatter',
'masked_scatter_',
'hypot',
'hypot_',
'index_fill',
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@
index_put_,
masked_fill,
masked_fill_,
masked_scatter,
masked_scatter_,
moveaxis,
put_along_axis,
put_along_axis_,
Expand Down Expand Up @@ -763,6 +765,8 @@
'atleast_2d',
'atleast_3d',
'diagonal_scatter',
'masked_scatter',
'masked_scatter_',
"combinations",
]

Expand Down
86 changes: 86 additions & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4034,6 +4034,92 @@ def get_attr_shape(list_shape):
return out


def masked_scatter(x, mask, value, name=None):
"""
Copies elements from `value` into `x` tensor at positions where the `mask` is True.

Elements from source are copied into `x` starting at position 0 of `value` and continuing in order one-by-one for
each occurrence of `mask` being True. The shape of `mask` must be broadcastable with the shape of the underlying tensor.
The `value` should have at least as many elements as the number of ones in `mask`.

Args:
x (Tensor): An N-D Tensor. The data type is ``float16``, ``float32``, ``float64``, ``int32``,
``int64`` or ``bfloat16``.
mask (Tensor): The boolean tensor indicate the position to be filled.
The data type of mask must be bool.
value (Tensor): The value used to fill the target tensor.
Supported data types are same as x.
name (str, optional): Name for the operation (optional, default is None). For more information,
please refer to :ref:`api_guide_Name`.

Returns:
Tensor, A reshaped Tensor with the same data type as ``x``.

Examples:
.. code-block:: python

>>> import paddle
>>> paddle.seed(2048)
>>> x = paddle.randn([2, 2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

估计得加 seed 固定输出,否则示例检查过不了,参考 固定的输出优于随机

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例代码检查没过,log 如下,可能是因为 x = paddle.randn([2, 2]) 没有 print(x),导致没有获得输出。代码示例部分也都注意一下这吧~ 可以本地跑一下试试
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seed 固定的情况下 CI 和 本地的结果不同,我这里又重新改了一次。

Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[ 0.74132639, -1.79502666],
[-0.01776697, -0.93422651]])

>>> mask = paddle.randn([2, 2])
>>> mask = mask>0.6
Tensor(shape=[2, 2], dtype=bool, place=Place(gpu:0), stop_gradient=True,
[[True , True ],
[False, False]])
>>> value = paddle.to_tensor([1, 2, 3, 4, 5,], dtype="float32")

>>> out = paddle.masked_scatter(x, mask, value)
>>> print(out)
Tensor(shape=[2, 2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[ 1, 2],
[-0.01776697, -0.93422651]])

"""
# make sure the dtype of x and value is the same
assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool

zeros_like_x = paddle.zeros_like(x, dtype=int)
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)
assert (
mask_prefix[-1] <= value.numel()
), f'mask true nums must be <= value size, but got mask true nums is {mask.sum().item()}, value size is {value.numel().item()}'

value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask)
return paddle.where(mask, x, value)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

实现逻辑上,masked_scatter masked_scatter_两个API有不同的实现逻辑。是某些操作在静态图下不支持,又希望masked_scatter 使用同一套代码的原因吗。

理论上这组API的实现内部只应该有少量inplace / outplace API的差异,可以简单分析下动态图下当前两种方式的实现有性能差异吗。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,主要是 broadcast_to 的这个操作,在静态图下的 shape 如果是 None 或者 -1 之类的时候没办法正确的 broadcast,后面是模仿 paddle.where 中的 broadcast 操作。我感觉在动态图下的 inplace 操作直接使用 broadcast_to 可以避免多余的 Op 操作,性能上应该会快一些。后面我也可以测试一下看看有没有明显的差距

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为两个 API 的逻辑一致,性能只有较为轻微的差异


@inplace_apis_in_dygraph_only
def masked_scatter_(x, mask, value, name=None):
"""
Inplace version of ``masked_scatter`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_paddle_masked_scatter`.
"""
assert (
x.dtype == value.dtype
), f'x and value must have the same dtype, but got x dtype is {x.dtype}, value dtype is {value.dtype}'
assert mask.dtype == paddle.bool
zeros_like_x = paddle.zeros_like(x, dtype=int)
mask = paddle.add(paddle.cast(mask, dtype="int"), zeros_like_x)
mask_prefix = paddle.clip(mask.cumsum() - 1, min=0)
assert (
mask_prefix[-1] <= value.numel()
), f'mask true nums must be <= value size, but got mask true nums is {mask_prefix[-1].item()}, value size is {value.numel().item()}'

value = value.flatten()[mask_prefix].reshape(mask.shape)
mask = paddle.logical_not(mask)
out = paddle.where_(mask, x, value)
return out


@inplace_apis_in_dygraph_only
def reshape_(x, shape, name=None):
"""
Expand Down
16 changes: 16 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,22 @@ def init_data(self):
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceMaskedScatter(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.masked_scatter(var, self.mask, self.value)

def inplace_api_processing(self, var):
return paddle.masked_scatter_(var, self.mask, self.value)

def init_data(self):
self.dtype = "float32"
self.input_var_numpy = np.random.uniform(-5, 5, [30, 3])
self.value = np.random.uniform(size=(30, 30))
self.value = paddle.to_tensor(self.value, dtype=self.dtype)
self.mask = np.random.randint(0, 2, [30, 1]).astype('bool')
self.mask = paddle.to_tensor(self.mask, dtype='bool')


class TestDygraphInplaceWithContinuous(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
Expand Down
Loading