Skip to content

Commit 4b1c567

Browse files
committed
update doc, test=document_fix
1 parent 2dfbba8 commit 4b1c567

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

python/paddle/nn/functional/flash_attention.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def flash_attn_qkvpacked(
323323
``d`` represents the size of the last dimension of the three parameters.
324324
325325
Warning:
326-
This API is only support inputs with dtype float16 and bfloat16.
326+
This API only supports inputs with dtype float16 and bfloat16.
327327
Don't call this API if flash_attn is not supported.
328328
329329
Args:
@@ -342,9 +342,7 @@ def flash_attn_qkvpacked(
342342
:ref:`api_guide_Name`.
343343
344344
Returns:
345-
out(Tensor): The attention tensor.
346-
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
347-
The dtype can be float16 or bfloat16.
345+
out(Tensor): The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.
348346
softmax(Tensor): The softmax tensor. None if return_softmax is False.
349347
350348
Examples:
@@ -355,7 +353,7 @@ def flash_attn_qkvpacked(
355353
356354
>>> paddle.seed(2023)
357355
>>> q = paddle.rand((1, 128, 2, 16))
358-
>>> qkv = paddle.stack([q,q,q], axis=2)
356+
>>> qkv = paddle.stack([q, q, q], axis=2)
359357
>>> output = paddle.nn.functional.flash_attn_qkvpacked(qkv, 0.9, False, False)
360358
>>> print(output)
361359
(Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True,
@@ -516,6 +514,9 @@ def flash_attn_unpadded(
516514
:ref:`api_guide_Name`.
517515
518516
Returns:
517+
out(Tensor): The attention tensor.
518+
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
519+
The dtype can be float16 or bfloat16.
519520
out(Tensor): The attention tensor.
520521
4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
521522
The dtype can be float16 or bfloat16.
@@ -621,7 +622,7 @@ def flash_attn_varlen_qkvpacked(
621622
``d`` represents the size of the last dimension of the three parameters.
622623
623624
Warning:
624-
This API is only support inputs with dtype float16 and bfloat16.
625+
This API only supports inputs with dtype float16 and bfloat16.
625626
626627
Args:
627628
qkv(Tensor): The padded query/key/value packed tensor in the Attention module. The padding part won't be computed
@@ -646,9 +647,7 @@ def flash_attn_varlen_qkvpacked(
646647
:ref:`api_guide_Name`.
647648
648649
Returns:
649-
out(Tensor): The attention tensor. The tensor is padded by zeros.
650-
3-D tensor with shape: [total_seq_len, num_heads, head_dim].
651-
The dtype can be float16 or bfloat16.
650+
out(Tensor): The attention tensor. The tensor is padded by zeros. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.
652651
softmax(Tensor): The softmax tensor. None if return_softmax is False.
653652
654653
Examples:
@@ -660,7 +659,7 @@ def flash_attn_varlen_qkvpacked(
660659
>>> q = paddle.rand((2, 128, 8, 16), dtype='float16')
661660
>>> cu = paddle.arange(0, 384, 128, dtype='int32')
662661
>>> qq = paddle.reshape(q, [256, 8, 16])
663-
>>> qkv = paddle.stack([qq,qq,qq], axis=2)
662+
>>> qkv = paddle.stack([qq, qq, qq], axis=2)
664663
>>> output = paddle.nn.functional.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False)
665664
>>> # doctest: -SKIP
666665

0 commit comments

Comments
 (0)