Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 84 additions & 5 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2804,6 +2804,14 @@ def tensor_split(
"""
Split the input tensor into multiple sub-Tensors along ``axis``, allowing not being of equal size.

In the following figure, the shape of Tenser x is [6], and after paddle.tensor_split(x, num_or_indices=4) transformation, we get four sub-Tensors out0, out1, out2, and out3 :

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-1_en.png

since the length of x in axis = 0 direction 6 is not divisible by num_or_indices = 4,
the size of the first int(6 % 4) part after splitting will be int(6 / 4) + 1
and the size of the remaining parts will be int(6 / 4).

Args:
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections along ``axis``.
Expand All @@ -2821,19 +2829,30 @@ def tensor_split(

Examples:
.. code-block:: python
:name: tensor-split-example-1

>>> import paddle

>>> # x is a Tensor of shape [8]
>>> # evenly split
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1 = paddle.tensor_split(x, num_or_indices=2)
>>> print(out0.shape)
[4]
>>> print(out1.shape)
[4]


.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-2.png

.. code-block:: python
:name: tensor-split-example-2

>>> import paddle

>>> # not evenly split
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=3)
>>> print(out0.shape)
[3]
Expand All @@ -2842,7 +2861,16 @@ def tensor_split(
>>> print(out2.shape)
[2]

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-3_en.png

.. code-block:: python
:name: tensor-split-example-3

>>> import paddle

>>> # split with indices
>>> # x is a Tensor of shape [8]
>>> x = paddle.rand([8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3])
>>> print(out0.shape)
[2]
Expand All @@ -2851,6 +2879,13 @@ def tensor_split(
>>> print(out2.shape)
[5]

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-4.png

.. code-block:: python
:name: tensor-split-example-4

>>> import paddle

>>> # split along axis
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
Expand All @@ -2860,6 +2895,16 @@ def tensor_split(
>>> print(out1.shape)
[7, 4]

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-5.png

.. code-block:: python
:name: tensor-spilt-example-5

>>> import paddle

>>> # split along axis with indices
>>> # x is a Tensor of shape [7, 8]
>>> x = paddle.rand([7, 8])
>>> out0, out1, out2 = paddle.tensor_split(x, num_or_indices=[2, 3], axis=1)
>>> print(out0.shape)
[7, 2]
Expand All @@ -2868,6 +2913,8 @@ def tensor_split(
>>> print(out2.shape)
[7, 5]

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/tensor_split/tensor_split-6.png

"""
if x.ndim <= 0 or x.ndim <= axis:
raise ValueError(
Expand Down Expand Up @@ -2921,8 +2968,17 @@ def hsplit(
x: Tensor, num_or_indices: int | Sequence[int], name: str | None = None
) -> list[Tensor]:
"""
Split the input tensor into multiple sub-Tensors along the horizontal axis, which is equivalent to ``paddle.tensor_split`` with ``axis=1``
when ``x`` 's dimension is larger than 1, or equivalent to ``paddle.tensor_split`` with ``axis=0`` when ``x`` 's dimension is 1.

``hsplit`` Full name Horizontal Split, splits the input Tensor into multiple sub-Tensors along the horizontal axis, in the following two cases:

1. When the dimension of x is equal to 1, it is equivalent to ``paddle.tensor_split`` with ``axis=0``;

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/hsplit/hsplit-1.png

2. when the dimension of x is greater than 1, it is equivalent to ``paddle.tensor_split`` with ``axis=1``.

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/hsplit/hsplit-2.png


Args:
x (Tensor): A Tensor whose dimension must be greater than 0. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
Expand Down Expand Up @@ -2977,7 +3033,16 @@ def dsplit(
x: Tensor, num_or_indices: int | Sequence[int], name: str | None = None
) -> list[Tensor]:
"""
Split the input tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.tensor_split`` with ``axis=2``.

``dsplit`` Full name Depth Split, splits the input Tensor into multiple sub-Tensors along the depth axis, which is equivalent to ``paddle.tensor_split`` with ``axis=2``.

Note:
Make sure that the number of Tensor dimensions transformed using ``paddle.dsplit`` must be no less than 3.

In the following figure, Tenser ``x`` has shape [4, 4, 4], and after ``paddle.dsplit(x, num_or_indices=2)`` transformation, we get ``out0`` and ``out1`` sub-Tensors whose shapes are both [4, 4, 2] :

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/dsplit/dsplit.png


Args:
x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
Expand Down Expand Up @@ -3021,14 +3086,28 @@ def vsplit(
x: Tensor, num_or_indices: int | Sequence[int], name: str | None = None
) -> list[Tensor]:
"""
Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.tensor_split`` with ``axis=0``.

``vsplit`` Full name Vertical Split, splits the input Tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.tensor_split`` with ``axis=0``.

1. When the number of Tensor dimensions is equal to 2:

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/vsplit/vsplit-1.png

2. When the number of Tensor dimensions is greater than 2:

.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/images/api_legend/vsplit/vsplit-2.png


Note:
Make sure that the number of Tensor dimensions transformed using ``paddle.vsplit`` must be not less than 2.

Args:
x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, bfloat16, float16, float32, float64, uint8, int32 or int64.
num_or_indices (int|list|tuple): If ``num_or_indices`` is an int ``n``, ``x`` is split into ``n`` sections.
If ``num_or_indices`` is a list or tuple of integer indices, ``x`` is split at each of the indices.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .

Returns:
list[Tensor], The list of segmented Tensors.

Expand Down