1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from __future__ import annotations
16+
17+ from typing import TYPE_CHECKING
18+
1519__all__ = []
1620
1721from paddle import _C_ops
1822from paddle .base .framework import dygraph_only
1923
24+ if TYPE_CHECKING :
25+ from paddle import Tensor
26+
2027
2128@dygraph_only
2229def attention (
23- query ,
24- key ,
25- value ,
26- sparse_mask ,
27- key_padding_mask = None ,
28- attn_mask = None ,
29- name = None ,
30- ):
30+ query : Tensor ,
31+ key : Tensor ,
32+ value : Tensor ,
33+ sparse_mask : Tensor ,
34+ key_padding_mask : Tensor | None = None ,
35+ attn_mask : Tensor | None = None ,
36+ name : str | None = None ,
37+ ) -> Tensor :
3138 r"""
3239 Note:
3340 This API is only used from ``CUDA 11.8`` .
@@ -52,11 +59,11 @@ def attention(
5259 sparse_mask (SparseCsrTensor): The sparse layout in the Attention module. Its dense shape
5360 is `[batch_size*num_heads, seq_len, seq_len]`. `nnz` of each batch must be the same.
5461 dtype of `crows` and `cols` must be int64, dtype of `values` can be float32 or float64.
55- key_padding_mask (DenseTensor, optional): The key padding mask tensor in the Attention module.
62+ key_padding_mask (DenseTensor|None , optional): The key padding mask tensor in the Attention module.
5663 2D tensor with shape: [batch_size, seq_len]. dtype can be float32 or float64. Default: None.
57- attn_mask (DenseTensor, optional): The attention mask tensor in the Attention module.
64+ attn_mask (DenseTensor|None , optional): The attention mask tensor in the Attention module.
5865 2D tensor with shape: [seq_len, seq_len]. dtype can be float32 or float64. Default: None.
59- name (str, optional): The default value is None. Normally there is no need for user
66+ name (str|None , optional): The default value is None. Normally there is no need for user
6067 to set this property. For more information, please refer to :ref:`api_guide_Name`.
6168
6269 Returns:
0 commit comments