Skip to content

Commit f56e672

Browse files
authored
[CINN] 【Infer Symbolic Shape BUAA 】Add flashmask_attention op (#68385)
* first * fix build * fix
1 parent 4fc054c commit f56e672

File tree

3 files changed

+71
-1
lines changed

3 files changed

+71
-1
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1748,6 +1748,75 @@ bool FlashAttnVarlenQkvpackedOpInferSymbolicShape(
17481748
// return true;
17491749
// }
17501750

1751+
bool FlashmaskAttentionOpInferSymbolicShape(
1752+
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1753+
const symbol::ShapeOrDataDimExprs &q =
1754+
infer_context->GetShapeOrDataForValue(op->operand_source(0));
1755+
const symbol::ShapeOrDataDimExprs &k =
1756+
infer_context->GetShapeOrDataForValue(op->operand_source(1));
1757+
const symbol::ShapeOrDataDimExprs &v =
1758+
infer_context->GetShapeOrDataForValue(op->operand_source(2));
1759+
1760+
PADDLE_ENFORCE_EQ(q.shape().size(),
1761+
4,
1762+
common::errors::InvalidArgument(
1763+
"flash_attn receive input with dim "
1764+
"[batch_size, seq_len, num_heads, head_dim]"));
1765+
1766+
infer_context->AddEqualCstr(q.shape()[0], k.shape()[0]);
1767+
infer_context->AddEqualCstr(q.shape()[0], v.shape()[0]);
1768+
infer_context->AddEqualCstr(k.shape()[1], v.shape()[1]);
1769+
1770+
if (op->operand_source(3)) {
1771+
const std::vector<symbol::DimExpr> &startend_row_indices =
1772+
infer_context->GetShapeOrDataForValue(op->operand_source(4)).shape();
1773+
PADDLE_ENFORCE_EQ(
1774+
startend_row_indices.size(),
1775+
4,
1776+
common::errors::InvalidArgument(
1777+
"flashmask_attention receive startend_row_indices with dim "
1778+
"[batch_size, num_heads,seq_len, mask_bounds]"));
1779+
}
1780+
std::vector<symbol::DimExpr> out_shape = q.shape();
1781+
1782+
out_shape.back() = v.shape().back();
1783+
1784+
infer_context->SetShapeOrDataForValue(
1785+
op->result(0), symbol::TensorShapeOrDataDimExprs(out_shape));
1786+
1787+
// GPU has round for seqlen, but XPU has not. Here we align with the GPU
1788+
// version.
1789+
auto round_multiple = [](symbol::DimExpr x) {
1790+
auto m = symbol::DimExpr{128};
1791+
auto m_minus_one = symbol::DimExpr{127};
1792+
return (x + m_minus_one) / m * m;
1793+
};
1794+
auto batch_size_expr = q.shape()[0];
1795+
auto num_heads_expr = q.shape()[2];
1796+
auto seqlen_q_rounded_expr = round_multiple(q.shape()[1]);
1797+
auto seqlen_k_rounded_expr = round_multiple(k.shape()[1]);
1798+
1799+
if (op->result(1)) {
1800+
std::vector<symbol::DimExpr> softmax_shape{batch_size_expr,
1801+
num_heads_expr,
1802+
seqlen_q_rounded_expr,
1803+
seqlen_k_rounded_expr};
1804+
infer_context->SetShapeOrDataForValue(
1805+
op->result(1), symbol::TensorShapeOrDataDimExprs(softmax_shape));
1806+
}
1807+
if (op->result(2)) {
1808+
std::vector<symbol::DimExpr> softmax_lse_shape{
1809+
batch_size_expr, num_heads_expr, seqlen_q_rounded_expr};
1810+
infer_context->SetShapeOrDataForValue(
1811+
op->result(2), symbol::TensorShapeOrDataDimExprs(softmax_lse_shape));
1812+
}
1813+
if (op->result(3)) {
1814+
std::vector<symbol::DimExpr> seed_offset_shape{symbol::DimExpr{2}};
1815+
infer_context->SetShapeOrDataForValue(
1816+
op->result(3), symbol::TensorShapeOrDataDimExprs(out_shape));
1817+
}
1818+
return true;
1819+
}
17511820
bool FusedBatchNormActOpInferSymbolicShape(
17521821
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
17531822
return BatchNormOpInferSymbolicShape(op, infer_context);

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedFeedforward)
5757
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedAttention)
5858
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttnVarlenQkvpacked)
5959
// OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashAttnUnpadded)
60+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FlashmaskAttention)
6061
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct)
6162
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBatchNormAct_)
6263
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FusedBnAddActivation)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1996,7 +1996,7 @@
19961996
func : flashmask_attention
19971997
data_type : q
19981998
backward : flashmask_attention_grad
1999-
# interfaces : paddle::dialect::InferSymbolicShapeInterface
1999+
interfaces : paddle::dialect::InferSymbolicShapeInterface
20002000

20012001
- op : flatten
20022002
args : (Tensor x, int start_axis = 1, int stop_axis = 1)

0 commit comments

Comments
 (0)