Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 109 additions & 20 deletions python/paddle/nn/functional/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,25 +1690,38 @@ def pad(
mode: _PaddingTensorMode = 'constant',
value: float = 0.0,
data_format: DataLayoutND = "NCHW",
pad_from_left_axis: bool = True,
name: str | None = None,
) -> Tensor:
"""
Pad tensor according to ``'pad'`` and ``'mode'``.
If mode is ``'constant'`` and length of pad is twice as length of x dimension,
then the padding will be started from the first dimension and moved back onto x
according to ``'pad'`` and ``'value'``.
If mode is ``'reflect'``, pad[0] and pad[1] must be no greater
than width-1. The height and depth dimension has the same condition.

Parameters:
x (Tensor): The input tensor with data type float32/double/int32/int64_t/complex64/complex128.
pad (Tensor|list[int]|tuple[int]): The padding size with data type int.
If mode is ``'constant'`` and length of pad is twice as length of x dimension, then x will
be padded from the first dimension to the last dimension.
Else: 1. If input dimension is 3, then the pad has the form (pad_left,
pad_right). 2. If the input dimension is 4, then the pad has the form (pad_left, pad_right,
pad_top, pad_bottom). 3. If the input dimension is 5, then the pad has the form
(pad_left, pad_right, pad_top, pad_bottom, pad_front, pad_back).
Note:
1. Denote ``'x'``'s dimension as N (same in the following). If mode is ``'constant'``, the length
of ``'pad'`` can be any even number less than or equal to 2*N.

2. When mode is ``'constant'``, and ``'pad'`` is a list or tuple, and the length of ``'pad'`` is not
equal to 2*(N - 2):
2.1. If the length of ``'pad'`` is 2*N, the order of padding can be customized by ``'pad_from_left_axis'``.
if ``'pad_from_left_axis'`` is True, then the padding order will be started from the first dimension of
``'x'`` and moving backward according to ``'pad'``; else if ``'pad_from_left_axis'`` is False, then the
padding order will be started from the last dimension of ``'x'`` and moving forward according to ``'pad'``.
2.2. Otherwise, the padding will be started from the last dimension.

3. When mode is any of ``'reflect'``, ``'replicate'``, ``'circular'``, or ``'pad'`` is a tensor, or the
length of ``'pad'`` is 2*(N - 2), and the dimension of ``'x'`` only supports 3-D, 4-D and 5-D.
In these cases, input ``'x'`` will be padded on [D, H, W] axes according to ``'data_format'``. It will pad
from the last dimension to the first dimension of [D, H, W] axes.
Specifically, if N = 3, then the pad has the form (pad_left, pad_right); if N = 4, then the pad has the form
(pad_left, pad_right, pad_top, pad_bottom); if N = 5, then the pad has the form (pad_left, pad_right,
pad_top, pad_bottom, pad_front, pad_back).

4. If mode is ``'reflect'``, pad[0] and pad[1] must be no greater than width-1. The height and depth
dimension has the same condition.

Args:
x (Tensor): The input tensor with data type float32, float64, int32, int64, complex64 or complex128.
pad (Tensor|list[int]|tuple[int]): The padding size with data type int. Refer to Note for details.
mode (str, optional): Four modes: ``'constant'`` (default), ``'reflect'``, ``'replicate'``, ``'circular'``. Default is ``'constant'``.

- 'constant' mode, uses a constant value to pad the input tensor.
Expand All @@ -1718,8 +1731,12 @@ def pad(

value (float, optional): The value to fill the padded areas in 'constant' mode . Default is :math:`0.0`.
data_format (str, optional): An string from: ``'NCL'``, ``'NLC'``, ``'NHWC'``, ``'NCHW'``, ``'NCDHW'``, ``'NDHWC'``. Specify the data format of
the input data. Default: ``'NCHW'``.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: ``'None'``.
the input data when: 1. mode is any of ``'reflect'``, ``'replicate'`` or ``'circular'``; or 2. the input ``'pad'`` is a tensor;
or 3. the length of ``'pad'`` is ``2*(x.ndim - 2)``. Default: ``'NCHW'``.
pad_from_left_axis (bool, optional): The parameter is only valid when mode is ``'constant'`` and the input ``'pad'`` is
length of ``'pad'`` is ``2*x.ndim``, the order of padding can be customized. If True, the padding will be started from
the first axis of ``'x'``; if False, it will be started from the last axis of ``'x'``. Default: True.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: ``'None'``.

Returns:
Tensor, a Tensor padded according to pad and mode and data type is same as input.
Expand All @@ -1735,43 +1752,71 @@ def pad(
pad = [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
mode = 'constant'
value = 0
pad_from_left_axis = True
Out = [[[[[0., 0., 0.],
[1., 2., 3.],
[4., 5., 6.],
[0., 0., 0.]]]]]
Out.shape = [1, 1, 1, 4, 3]

Case 1:
pad = [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
mode = 'constant'
value = 0
pad_from_left_axis = False
Out = [[[[[0., 0., 0.],
[0., 0., 0.]]],
[[[1., 2., 3.],
[4., 5., 6.]]],
[[[0., 0., 0.],
[0., 0., 0.]]]]]
Out.shape = [1, 3, 1, 2, 3]

Case 3:
pad = [1, 0, 0, 1],
mode = 'constant'
value = 0
Out = [[[[[0., 1., 2., 3.],
[0., 4., 5., 6.],
[0., 0., 0., 0.]]]]]
Out.shape = [1, 1, 1, 3, 4]

Case 4:
pad = [2, 2, 1, 1, 0, 0],
mode = 'constant'
value = 0
Out = [[[[[0. 0. 0. 0. 0. 0. 0.]
[0. 0. 1. 2. 3. 0. 0.]
[0. 0. 4. 5. 6. 0. 0.]
[0. 0. 0. 0. 0. 0. 0.]]]]]
Out.shape = [1, 1, 1, 4, 7]

Case 2:
Case 5:
pad = [2, 2, 1, 1, 0, 0],
mode = 'reflect'
Out = [[[[[6. 5. 4. 5. 6. 5. 4.]
[3. 2. 1. 2. 3. 2. 1.]
[6. 5. 4. 5. 6. 5. 4.]
[3. 2. 1. 2. 3. 2. 1.]]]]]
Out.shape = [1, 1, 1, 4, 7]

Case 3:
Case 6:
pad = [2, 2, 1, 1, 0, 0],
mode = 'replicate'
Out = [[[[[1. 1. 1. 2. 3. 3. 3.]
[1. 1. 1. 2. 3. 3. 3.]
[4. 4. 4. 5. 6. 6. 6.]
[4. 4. 4. 5. 6. 6. 6.]]]]]
Out.shape = [1, 1, 1, 4, 7]

Case 4:
Case 7:
pad = [2, 2, 1, 1, 0, 0],
mode = 'circular'
Out = [[[[[5. 6. 4. 5. 6. 4. 5.]
[2. 3. 1. 2. 3. 1. 2.]
[5. 6. 4. 5. 6. 4. 5.]
[2. 3. 1. 2. 3. 1. 2.]]]]]
Out.shape = [1, 1, 1, 4, 7]

Examples:
.. code-block:: python
Expand Down Expand Up @@ -1805,6 +1850,35 @@ def pad(
[3., 1., 2., 3., 1., 2.],
[6., 4., 5., 6., 4., 5.],
[3., 1., 2., 3., 1., 2.]]]])

>>> # example 4
>>> x_shape = (1, 1, 3)
>>> x = paddle.arange(paddle.prod(paddle.to_tensor(x_shape)), dtype="float32").reshape(x_shape) + 1
>>> y = F.pad(x, [1, 0, 0, 1, 0, 0], value=0, mode='constant', pad_from_left_axis=True)
>>> print(y)
Tensor(shape=[2, 2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0., 0., 0.],
[0., 0., 0.]],
[[1., 2., 3.],
[0., 0., 0.]]])

>>> # example 5
>>> x_shape = (1, 1, 3)
>>> x = paddle.arange(paddle.prod(paddle.to_tensor(x_shape)), dtype="float32").reshape(x_shape) + 1
>>> y = F.pad(x, [1, 0, 0, 1, 0, 0], value=0, mode='constant', pad_from_left_axis=False)
>>> print(y)
Tensor(shape=[1, 2, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0., 1., 2., 3.],
[0., 0., 0., 0.]]])

>>> # example 6
>>> x_shape = (1, 1, 3)
>>> x = paddle.arange(paddle.prod(paddle.to_tensor(x_shape)), dtype="float32").reshape(x_shape) + 1
>>> y = F.pad(x, [1, 0, 0, 1], value=0, mode='constant')
>>> print(y)
Tensor(shape=[1, 2, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
[[[0., 1., 2., 3.],
[0., 0., 0., 0.]]])
"""
assert mode in [
'reflect',
Expand All @@ -1824,11 +1898,26 @@ def pad(
if (
mode == "constant"
and isinstance(pad, (list, tuple))
and len(pad) == x_dim * 2
and len(pad) != (x_dim - 2) * 2
):
paddings = pad
pad_value = value

padding_len = len(paddings)
# pad the length of paddings to 2*x_dim
if padding_len < 2 * x_dim:
pad_len_for_paddings = 2 * x_dim - padding_len
paddings = paddings + ([0] if isinstance(pad, list) else (0,)) * (
pad_len_for_paddings
)

# since the kernel pad from left axis, if we want to pad from right axis, we need to reverse the paddings
if not (len(pad) == x_dim * 2 and pad_from_left_axis):
paddings = [
paddings[i - 1] if i % 2 == 1 else paddings[i + 1]
for i in range(2 * x_dim - 1, -1, -1)
]

if in_dynamic_mode():
out = _C_ops.pad(x, paddings, float(pad_value))
return out
Expand Down
108 changes: 108 additions & 0 deletions test/legacy_test/test_pad_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,114 @@ def test_check_grad(self):
)


class TestPadOrder2N(unittest.TestCase):
def init_case(self):
self.shape = [2, 3]
self.paddings = [(0, 1), (1, 0)]
self.pad_value = 0.5

def test_order(self):
self.init_case()
x_np = np.random.random(self.shape).astype('float32')
paddings_np = self.paddings.copy()
x = paddle.to_tensor(x_np)
paddings = list(np.array(self.paddings).flatten())

# pad_from_left_axis
pad_from_left_axis = True
out_np = np.pad(
x_np, paddings_np, mode="constant", constant_values=self.pad_value
)
out = paddle.nn.functional.pad(
x,
paddings,
mode='constant',
pad_from_left_axis=pad_from_left_axis,
)
np.testing.assert_array_equal(out, out_np)

# pad_from_right_axis:
pad_from_left_axis = False
paddings_np.reverse()
out_np = np.pad(
x_np, paddings_np, mode="constant", constant_values=self.pad_value
)
out = paddle.nn.functional.pad(
x,
paddings,
mode='constant',
pad_from_left_axis=pad_from_left_axis,
)
np.testing.assert_array_equal(out, out_np)


# test padding order for cases when length of padding is not 2(N-2) or 2N
class TestPadOrder(unittest.TestCase):
def init_case(self):
self.shape = [2, 3]
self.paddings = [(0, 1)]
self.pad_value = 0.5

def test_order(self):
self.init_case()
x_np = np.random.random(self.shape).astype('float32')
paddings_np = self.paddings.copy()
paddings_np += [(0, 0)] * (len(self.shape) - len(paddings_np))

x = paddle.to_tensor(x_np)
paddings = list(np.array(self.paddings).flatten())

# pad from last axis by default
paddings_np.reverse()
out_np = np.pad(
x_np, paddings_np, mode="constant", constant_values=self.pad_value
)
out = paddle.nn.functional.pad(x, paddings, mode='constant')
np.testing.assert_array_equal(out, out_np)


class TestPadOrder2N3D(TestPadOrder2N):
def init_case(self):
self.shape = [2, 3, 4]
self.paddings = [(0, 1), (2, 3), (2, 1)]
self.pad_value = 0.5


class TestPadOrder2N4D(TestPadOrder2N):
def init_case(self):
self.shape = [2, 3, 4, 5]
self.paddings = [(0, 1), (2, 3), (2, 1), (1, 1)]
self.pad_value = 0.5


class TestPadOrder2N5D(TestPadOrder2N):
def init_case(self):
self.shape = [1, 2, 3, 4, 5]
self.paddings = [(0, 1), (2, 3), (2, 1), (1, 1), (1, 0)]
self.pad_value = 0.5


class TestPadOrder1(TestPadOrder):
def init_case(self):
self.shape = [2, 3, 4]
self.paddings = [(0, 1), (2, 3)]
self.pad_value = 0.5


class TestPadOrder2(TestPadOrder):
def init_case(self):
self.shape = [2, 3, 4, 5]
self.paddings = [(0, 1), (2, 3), (2, 1)]
self.pad_value = 0.5


class TestPadOrder3(TestPadOrder):
def init_case(self):
self.shape = [2, 3, 4, 5]
self.paddings = [(0, 1)]
self.pad_value = 0.5


if __name__ == "__main__":
# paddle.enable_static()
unittest.main()