@@ -323,7 +323,7 @@ def flash_attn_qkvpacked(
323323 ``d`` represents the size of the last dimension of the three parameters.
324324
325325 Warning:
326- This API is only support inputs with dtype float16 and bfloat16.
326+ This API only supports inputs with dtype float16 and bfloat16.
327327 Don't call this API if flash_attn is not supported.
328328
329329 Args:
@@ -342,9 +342,7 @@ def flash_attn_qkvpacked(
342342 :ref:`api_guide_Name`.
343343
344344 Returns:
345- out(Tensor): The attention tensor.
346- 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
347- The dtype can be float16 or bfloat16.
345+ out(Tensor): The attention tensor. 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.
348346 softmax(Tensor): The softmax tensor. None if return_softmax is False.
349347
350348 Examples:
@@ -355,7 +353,7 @@ def flash_attn_qkvpacked(
355353
356354 >>> paddle.seed(2023)
357355 >>> q = paddle.rand((1, 128, 2, 16))
358- >>> qkv = paddle.stack([q,q, q], axis=2)
356+ >>> qkv = paddle.stack([q, q, q], axis=2)
359357 >>> output = paddle.nn.functional.flash_attn_qkvpacked(qkv, 0.9, False, False)
360358 >>> print(output)
361359 (Tensor(shape=[1, 128, 2, 16], dtype=float32, place=Place(cpu), stop_gradient=True,
@@ -516,6 +514,9 @@ def flash_attn_unpadded(
516514 :ref:`api_guide_Name`.
517515
518516 Returns:
517+ out(Tensor): The attention tensor.
518+ 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
519+ The dtype can be float16 or bfloat16.
519520 out(Tensor): The attention tensor.
520521 4-D tensor with shape: [batch_size, seq_len, num_heads, head_dim].
521522 The dtype can be float16 or bfloat16.
@@ -621,7 +622,7 @@ def flash_attn_varlen_qkvpacked(
621622 ``d`` represents the size of the last dimension of the three parameters.
622623
623624 Warning:
624- This API is only support inputs with dtype float16 and bfloat16.
625+ This API only supports inputs with dtype float16 and bfloat16.
625626
626627 Args:
627628 qkv(Tensor): The padded query/key/value packed tensor in the Attention module. The padding part won't be computed
@@ -646,9 +647,7 @@ def flash_attn_varlen_qkvpacked(
646647 :ref:`api_guide_Name`.
647648
648649 Returns:
649- out(Tensor): The attention tensor. The tensor is padded by zeros.
650- 3-D tensor with shape: [total_seq_len, num_heads, head_dim].
651- The dtype can be float16 or bfloat16.
650+ out(Tensor): The attention tensor. The tensor is padded by zeros. 3-D tensor with shape: [total_seq_len, num_heads, head_dim]. The dtype can be float16 or bfloat16.
652651 softmax(Tensor): The softmax tensor. None if return_softmax is False.
653652
654653 Examples:
@@ -660,7 +659,7 @@ def flash_attn_varlen_qkvpacked(
660659 >>> q = paddle.rand((2, 128, 8, 16), dtype='float16')
661660 >>> cu = paddle.arange(0, 384, 128, dtype='int32')
662661 >>> qq = paddle.reshape(q, [256, 8, 16])
663- >>> qkv = paddle.stack([qq,qq,qq], axis=2)
662+ >>> qkv = paddle.stack([qq, qq, qq], axis=2)
664663 >>> output = paddle.nn.functional.flash_attn_varlen_qkvpacked(qkv, cu, cu, 128, 128, 0.25, 0.0, False, False)
665664 >>> # doctest: -SKIP
666665
0 commit comments