Skip to content

Commit 53e711b

Browse files
authored
[Typing][B-78] Add type annotations for python/paddle/sparse/nn/functional/transformer.py (#65876)
1 parent 601413f commit 53e711b

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

python/paddle/sparse/nn/functional/transformer.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,29 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
1519
__all__ = []
1620

1721
from paddle import _C_ops
1822
from paddle.base.framework import dygraph_only
1923

24+
if TYPE_CHECKING:
25+
from paddle import Tensor
26+
2027

2128
@dygraph_only
2229
def attention(
23-
query,
24-
key,
25-
value,
26-
sparse_mask,
27-
key_padding_mask=None,
28-
attn_mask=None,
29-
name=None,
30-
):
30+
query: Tensor,
31+
key: Tensor,
32+
value: Tensor,
33+
sparse_mask: Tensor,
34+
key_padding_mask: Tensor | None = None,
35+
attn_mask: Tensor | None = None,
36+
name: str | None = None,
37+
) -> Tensor:
3138
r"""
3239
Note:
3340
This API is only used from ``CUDA 11.8`` .
@@ -52,11 +59,11 @@ def attention(
5259
sparse_mask (SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
5360
is `[batch_size*num_heads, seq_len, seq_len]`. `nnz` of each batch must be the same.
5461
dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
55-
key_padding_mask (DenseTensor, optional): The key padding mask tensor in the Attention module.
62+
key_padding_mask (DenseTensor|None, optional): The key padding mask tensor in the Attention module.
5663
2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. Default: None.
57-
attn_mask (DenseTensor, optional): The attention mask tensor in the Attention module.
64+
attn_mask (DenseTensor|None, optional): The attention mask tensor in the Attention module.
5865
2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. Default: None.
59-
name (str, optional): The default value is None. Normally there is no need for user
66+
name (str|None, optional): The default value is None. Normally there is no need for user
6067
to set this property. For more information, please refer to :ref:`api_guide_Name`.
6168
6269
Returns:

0 commit comments

Comments
 (0)