@@ -1748,6 +1748,75 @@ bool FlashAttnVarlenQkvpackedOpInferSymbolicShape(
17481748// return true;
17491749// }
17501750
1751+ bool FlashmaskAttentionOpInferSymbolicShape (
1752+ pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
1753+ const symbol::ShapeOrDataDimExprs &q =
1754+ infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
1755+ const symbol::ShapeOrDataDimExprs &k =
1756+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ));
1757+ const symbol::ShapeOrDataDimExprs &v =
1758+ infer_context->GetShapeOrDataForValue (op->operand_source (2 ));
1759+
1760+ PADDLE_ENFORCE_EQ (q.shape ().size (),
1761+ 4 ,
1762+ common::errors::InvalidArgument (
1763+ " flash_attn receive input with dim "
1764+ " [batch_size, seq_len, num_heads, head_dim]" ));
1765+
1766+ infer_context->AddEqualCstr (q.shape ()[0 ], k.shape ()[0 ]);
1767+ infer_context->AddEqualCstr (q.shape ()[0 ], v.shape ()[0 ]);
1768+ infer_context->AddEqualCstr (k.shape ()[1 ], v.shape ()[1 ]);
1769+
1770+ if (op->operand_source (3 )) {
1771+ const std::vector<symbol::DimExpr> &startend_row_indices =
1772+ infer_context->GetShapeOrDataForValue (op->operand_source (4 )).shape ();
1773+ PADDLE_ENFORCE_EQ (
1774+ startend_row_indices.size (),
1775+ 4 ,
1776+ common::errors::InvalidArgument (
1777+ " flashmask_attention receive startend_row_indices with dim "
1778+ " [batch_size, num_heads,seq_len, mask_bounds]" ));
1779+ }
1780+ std::vector<symbol::DimExpr> out_shape = q.shape ();
1781+
1782+ out_shape.back () = v.shape ().back ();
1783+
1784+ infer_context->SetShapeOrDataForValue (
1785+ op->result (0 ), symbol::TensorShapeOrDataDimExprs (out_shape));
1786+
1787+ // GPU has round for seqlen, but XPU has not. Here we align with the GPU
1788+ // version.
1789+ auto round_multiple = [](symbol::DimExpr x) {
1790+ auto m = symbol::DimExpr{128 };
1791+ auto m_minus_one = symbol::DimExpr{127 };
1792+ return (x + m_minus_one) / m * m;
1793+ };
1794+ auto batch_size_expr = q.shape ()[0 ];
1795+ auto num_heads_expr = q.shape ()[2 ];
1796+ auto seqlen_q_rounded_expr = round_multiple (q.shape ()[1 ]);
1797+ auto seqlen_k_rounded_expr = round_multiple (k.shape ()[1 ]);
1798+
1799+ if (op->result (1 )) {
1800+ std::vector<symbol::DimExpr> softmax_shape{batch_size_expr,
1801+ num_heads_expr,
1802+ seqlen_q_rounded_expr,
1803+ seqlen_k_rounded_expr};
1804+ infer_context->SetShapeOrDataForValue (
1805+ op->result (1 ), symbol::TensorShapeOrDataDimExprs (softmax_shape));
1806+ }
1807+ if (op->result (2 )) {
1808+ std::vector<symbol::DimExpr> softmax_lse_shape{
1809+ batch_size_expr, num_heads_expr, seqlen_q_rounded_expr};
1810+ infer_context->SetShapeOrDataForValue (
1811+ op->result (2 ), symbol::TensorShapeOrDataDimExprs (softmax_lse_shape));
1812+ }
1813+ if (op->result (3 )) {
1814+ std::vector<symbol::DimExpr> seed_offset_shape{symbol::DimExpr{2 }};
1815+ infer_context->SetShapeOrDataForValue (
1816+ op->result (3 ), symbol::TensorShapeOrDataDimExprs (out_shape));
1817+ }
1818+ return true ;
1819+ }
17511820bool FusedBatchNormActOpInferSymbolicShape (
17521821 pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
17531822 return BatchNormOpInferSymbolicShape (op, infer_context);
0 commit comments