Skip to content

Commit 07af41d

Browse files
authored
[Typing][B-75] Add type annotations for python/paddle/sparse/nn/functional/activation.py (#65873)
1 parent 7238195 commit 07af41d

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

python/paddle/sparse/nn/functional/activation.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,21 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
1519
__all__ = []
1620

1721
from paddle import _C_ops
1822
from paddle.base.framework import dygraph_only, in_dynamic_or_pir_mode
1923
from paddle.base.layer_helper import LayerHelper
2024

25+
if TYPE_CHECKING:
26+
from paddle import Tensor
27+
2128

22-
def relu(x, name=None):
29+
def relu(x: Tensor, name: str | None = None) -> Tensor:
2330
"""
2431
sparse relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
2532
@@ -29,7 +36,7 @@ def relu(x, name=None):
2936
3037
Parameters:
3138
x (Tensor): The input Sparse Tensor with data type float32, float64.
32-
name (str, optional): Name for the operation (optional, default is None).
39+
name (str|None, optional): Name for the operation (optional, default is None).
3340
For more information, please refer to :ref:`api_guide_Name`.
3441
3542
Returns:
@@ -60,7 +67,7 @@ def relu(x, name=None):
6067
return out
6168

6269

63-
def softmax(x, axis=-1, name=None):
70+
def softmax(x: Tensor, axis: int = -1, name: str | None = None) -> Tensor:
6471
r"""
6572
sparse softmax activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
6673
@@ -78,7 +85,7 @@ def softmax(x, axis=-1, name=None):
7885
Parameters:
7986
x (Tensor): The input tensor. It can be SparseCooTensor/SparseCsrTensor. The data type can be float32 or float64.
8087
axis (int, optional): The axis along which to perform softmax calculations. Only support -1 for SparseCsrTensor.
81-
name (str, optional): Name for the operation (optional, default is None).
88+
name (str|None, optional): Name for the operation (optional, default is None).
8289
For more information, please refer to :ref:`api_guide_Name`.
8390
8491
Returns:
@@ -91,7 +98,7 @@ def softmax(x, axis=-1, name=None):
9198
>>> paddle.seed(100)
9299
93100
>>> mask = paddle.rand((3, 4)) < 0.5
94-
>>> x = paddle.rand((3, 4)) * mask
101+
>>> x = paddle.rand((3, 4)) * mask.astype('float32')
95102
>>> print(x)
96103
Tensor(shape=[3, 4], dtype=float32, place=Place(cpu), stop_gradient=True,
97104
[[0. , 0.95717543, 0.43864486, 0. ],
@@ -146,7 +153,7 @@ def softmax(x, axis=-1, name=None):
146153

147154

148155
@dygraph_only
149-
def relu6(x, name=None):
156+
def relu6(x: Tensor, name: str | None = None) -> Tensor:
150157
"""
151158
sparse relu6 activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
152159
@@ -156,7 +163,7 @@ def relu6(x, name=None):
156163
157164
Parameters:
158165
x (Tensor): The input Sparse Tensor with data type float32, float64.
159-
name (str, optional): Name for the operation (optional, default is None).
166+
name (str|None, optional): Name for the operation (optional, default is None).
160167
For more information, please refer to :ref:`api_guide_Name`.
161168
162169
Returns:
@@ -175,7 +182,9 @@ def relu6(x, name=None):
175182

176183

177184
@dygraph_only
178-
def leaky_relu(x, negative_slope=0.01, name=None):
185+
def leaky_relu(
186+
x: Tensor, negative_slope: float = 0.01, name: str | None = None
187+
) -> Tensor:
179188
r"""
180189
sparse leaky_relu activation, requiring x to be a SparseCooTensor or SparseCsrTensor.
181190
@@ -192,7 +201,7 @@ def leaky_relu(x, negative_slope=0.01, name=None):
192201
x (Tensor): The input Sparse Tensor with data type float32, float64.
193202
negative_slope (float, optional): Slope of the activation function at
194203
:math:`x < 0` . Default is 0.01.
195-
name (str, optional): Name for the operation (optional, default is None).
204+
name (str|None, optional): Name for the operation (optional, default is None).
196205
For more information, please refer to :ref:`api_guide_Name`.
197206
198207
Returns:

0 commit comments

Comments
 (0)