Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,18 @@
func : flash_attn_grad
data_type: q

- backward_op : flash_attn_qkvpacked_grad
forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
optional : attn_mask
output : Tensor(qkv_grad)
infer_meta :
func : FlashAttnQKVPackedGradInferMeta
param : [qkv]
kernel :
func : flash_attn_qkvpacked_grad
data_type: qkv

- backward_op : flash_attn_unpadded_grad
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
Expand All @@ -859,6 +871,18 @@
func : flash_attn_unpadded_grad
data_type: q

- backward_op : flash_attn_varlen_qkvpacked_grad
forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true)
optional : attn_mask
output : Tensor(qkv_grad)
infer_meta :
func : FlashAttnQKVPackedGradInferMeta
param : [qkv]
kernel :
func : flash_attn_varlen_qkvpacked_grad
data_type: qkv

- backward_op : flash_attn_with_sparse_mask_grad
forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0)
Expand Down
25 changes: 25 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,18 @@
backward : flash_attn_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : flash_attn_qkvpacked
args : (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset, attn_mask
infer_meta :
func : FlashAttnQKVPackedInferMeta
param : [qkv]
kernel :
func : flash_attn_qkvpacked
data_type : qkv
backward : flash_attn_qkvpacked_grad

- op : flash_attn_unpadded
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
Expand All @@ -1057,6 +1069,19 @@
intermediate : softmax_lse, seed_offset
backward : flash_attn_unpadded_grad

- op : flash_attn_varlen_qkvpacked
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true)
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
optional : fixed_seed_offset , attn_mask
infer_meta :
func : FlashAttnQKVPackedInferMeta
param : [qkv]
kernel :
func : flash_attn_varlen_qkvpacked
data_type : qkv
intermediate : softmax_lse, seed_offset
backward : flash_attn_varlen_qkvpacked_grad

- op : flash_attn_with_sparse_mask
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "")
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
Expand Down
7 changes: 4 additions & 3 deletions paddle/phi/core/kernel_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
auto kernel_iter = iter->second.find(
{Backend::GPUDNN, phi::DataLayout::ALL_LAYOUT, kernel_key.dtype()});
if (kernel_iter != iter->second.end()) {
return {kernel_iter->second, false, false};
return {
kernel_iter->second, false, kernel_iter->second.IsSupportStride()};
}
kernel_key =
KernelKey(Backend::GPU, kernel_key.layout(), kernel_key.dtype());
Expand Down Expand Up @@ -351,7 +352,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
<< ", expected_kernel_key:" << kernel_key
<< ", fallbacking to CPU one!";

return {kernel_iter->second, true, false};
return {kernel_iter->second, true, kernel_iter->second.IsSupportStride()};
}

PADDLE_ENFORCE_NE(
Expand All @@ -366,7 +367,7 @@ KernelResult KernelFactory::SelectKernelOrThrowError(
kernel_name,
KernelSelectionErrorMessage(kernel_name, kernel_key)));

return {kernel_iter->second, false, false};
return {kernel_iter->second, false, kernel_iter->second.IsSupportStride()};
}

const KernelArgsDef& KernelFactory::GetFirstKernelArgsDef(
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/core/kernel_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ class Kernel {
return kernel_registered_type_;
}

bool IsSupportStride() const { return support_stride_; }
void SetSupportStride(bool support) { support_stride_ = support; }
GetKernelTypeForVarFn get_kerneltype_forvar_fn_{nullptr};
std::function<bool(const KernelContext* ctx)> check_if_onednn_kernel_support_{
nullptr};
Expand All @@ -290,6 +292,7 @@ class Kernel {
void* variadic_fn_ = nullptr;
KernelArgsDef args_def_;
KernelRegisteredType kernel_registered_type_ = KernelRegisteredType::FUNCTION;
bool support_stride_ = false;
};

using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,12 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
}
}

void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) {
if (dqkv) {
dqkv->share_meta(qkv);
}
}

void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
MetaTensor* dk,
MetaTensor* dv);

void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq);

void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/infermeta/ternary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ limitations under the License. */
#include "glog/logging.h"

#include "paddle/common/ddim.h"
#include "paddle/common/errors.h"
#include "paddle/common/layout.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/funcs/common_shape.h"
#include "paddle/phi/kernels/impl/box_coder.h"

Expand Down Expand Up @@ -371,6 +374,31 @@ void FlashAttnInferMeta(const MetaTensor& q,
seed_offset->set_dtype(phi::DataType::INT64);
}
}
void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset) {
const auto& qkvdims = qkv.dims();
PADDLE_ENFORCE(qkvdims.size() == 4 || qkvdims.size() == 5,
phi::errors::InvalidArgument(
"qkv dims must be 4(unpadded) or 5(padded batch)"));
// qkv [total_*,nheads/nheads_k+2,nheads_k,headdim]
auto out_dims = DDim({qkvdims[0], (qkvdims[1] - 2) * qkvdims[2], qkvdims[3]});
if (qkvdims.size() == 5) {
// qkv [batchsize,seqlen,nheads/nheads_k+2,nheads_k,headdim]
out_dims =
DDim{qkvdims[0], qkvdims[1], (qkvdims[2] - 2) * qkvdims[3], qkvdims[4]};
}
out->set_dims(out_dims);
out->set_dtype(qkv.dtype());
out->set_layout(qkv.layout());
softmax->set_dtype(qkv.dtype());
softmax_lse->set_dtype(qkv.dtype());
if (seed_offset) {
seed_offset->set_dtype(phi::DataType::INT64);
}
Comment on lines +415 to +419
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这几个output可以设置dim吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分是和原有的FlashAttnInferMeta一样的,这几个output的shape是在flash_attn_utils.h的FlashAttnFwdParamsV2中Resize的,没有在这里设置dim

}

void ArangeTensorInferMeta(const MetaTensor& start,
const MetaTensor& end,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/ternary.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ void FlashAttnInferMeta(const MetaTensor& q,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);

void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
MetaTensor* out,
MetaTensor* softmax,
MetaTensor* softmax_lse,
MetaTensor* seed_offset);

void InstanceNormInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
Expand Down
Loading