Skip to content

Commit d06a456

Browse files
MikhayEeermegemini
authored andcommitted
[Typing][C-67,C-69,C-71,C-72] Add type annotations for 4 files in python/paddle/incubate/nn/functional/ (PaddlePaddle#66616)
--------- Co-authored-by: megemini <[email protected]>
1 parent ec2c76a commit d06a456

File tree

4 files changed

+67
-36
lines changed

4 files changed

+67
-36
lines changed

python/paddle/incubate/nn/functional/blha_get_max_len.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
1519
from paddle import _C_ops
1620
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode
1721

22+
if TYPE_CHECKING:
23+
from paddle import Tensor
24+
1825

19-
def blha_get_max_len(seq_lens_encoder, seq_lens_decoder, batch_size):
26+
def blha_get_max_len(
27+
seq_lens_encoder: Tensor, seq_lens_decoder: Tensor, batch_size: Tensor
28+
) -> tuple[Tensor, Tensor]:
2029
"""
2130
Apply Fused BlhaGetMaxLen kernel. Typically used before the block_multihead_attention operator.
2231

python/paddle/incubate/nn/functional/fused_dot_product_attention.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,26 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
1517
from paddle import Tensor, _C_ops
1618
from paddle.framework import LayerHelper, in_dynamic_or_pir_mode
1719

1820

1921
def cudnn_flash_attention(
20-
q,
21-
k,
22-
v,
23-
bias=None,
24-
cu_seqlen_q=None,
25-
cu_seqlen_k=None,
26-
scaling_factor=1.0,
27-
dropout_prob=0.0,
28-
training=True,
29-
mask_type=None,
30-
bias_type=None,
31-
name=None,
32-
):
22+
q: Tensor,
23+
k: Tensor,
24+
v: Tensor,
25+
bias: Tensor | None = None,
26+
cu_seqlen_q: Tensor | None = None,
27+
cu_seqlen_k: Tensor | None = None,
28+
scaling_factor: float = 1.0,
29+
dropout_prob: float = 0.0,
30+
training: bool = True,
31+
mask_type: str | None = None,
32+
bias_type: str | None = None,
33+
name: str | None = None,
34+
) -> Tensor:
3335
r"""
3436
Fused Dot Product Attention. This is a fusion operator to compute scaled dot product attention in transformer
3537
model architecture. This operator only supports running on Ampere and Hopper GPU and need cudnn version >= 8906.
@@ -128,13 +130,13 @@ def fused_dot_product_attention(
128130
query: Tensor,
129131
key: Tensor,
130132
value: Tensor,
131-
attn_mask: Tensor = None,
133+
attn_mask: Tensor | None = None,
132134
dropout_p: float = 0.0,
133135
is_causal: bool = False,
134-
scaling_factor: float = None,
136+
scaling_factor: float | None = None,
135137
training: bool = True,
136-
name: str = None,
137-
):
138+
name: str | None = None,
139+
) -> Tensor:
138140
r"""
139141
Fused Dot Product Attention. This is a fusion operator to compute scaled dot product attention in transformer
140142
model architecture. This operator only supports running on Ampere and Hopper GPU and need cudnn version >= 8906.

python/paddle/incubate/nn/functional/fused_ec_moe.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
1519
from paddle.base.layer_helper import LayerHelper
1620

21+
if TYPE_CHECKING:
22+
from paddle import Tensor
23+
1724

1825
def fused_ec_moe(
19-
x, gate, bmm0_weight, bmm0_bias, bmm1_weight, bmm1_bias, act_type
20-
):
26+
x: Tensor,
27+
gate: Tensor,
28+
bmm0_weight: Tensor,
29+
bmm0_bias: Tensor,
30+
bmm1_weight: Tensor,
31+
bmm1_bias: Tensor,
32+
act_type: str,
33+
) -> Tensor:
2134
"""
2235
Applies fused ec_moe kernel.
2336
This method requires SM_ARCH in sm75, sm80, sm86.

python/paddle/incubate/nn/functional/fused_gate_attention.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,34 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
from typing import TYPE_CHECKING
18+
1519
from paddle import _legacy_C_ops
1620
from paddle.framework import in_dynamic_mode
1721

22+
if TYPE_CHECKING:
23+
from paddle import Tensor
24+
1825

1926
def fused_gate_attention(
20-
query,
21-
key=None,
22-
query_weight=None,
23-
key_weight=None,
24-
value_weight=None,
25-
qkv_weight=None,
26-
gate_linear_weight=None,
27-
gate_linear_bias=None,
28-
out_linear_weight=None,
29-
out_linear_bias=None,
30-
nonbatched_bias=None,
31-
attn_mask=None,
32-
has_gating=True,
33-
merge_qkv=True,
34-
use_flash_attn=False,
35-
):
27+
query: Tensor,
28+
key: Tensor | None = None,
29+
query_weight: Tensor | None = None,
30+
key_weight: Tensor | None = None,
31+
value_weight: Tensor | None = None,
32+
qkv_weight: Tensor | None = None,
33+
gate_linear_weight: Tensor | None = None,
34+
gate_linear_bias: Tensor | None = None,
35+
out_linear_weight: Tensor | None = None,
36+
out_linear_bias: Tensor | None = None,
37+
nonbatched_bias: Tensor | None = None,
38+
attn_mask: Tensor | None = None,
39+
has_gating: bool = True,
40+
merge_qkv: bool = True,
41+
use_flash_attn: bool = False,
42+
) -> Tensor:
3643
r"""
3744
Attention maps queries and a set of key-value pairs to outputs, and
3845
Gate Attention performs multiple parallel attention to jointly attending

0 commit comments

Comments
 (0)