Skip to content

Commit 535a222

Browse files
gongshaotianAndPuQingooooo-create
authored
[CINN] Add Div to replace Recipical in DimExpr (PaddlePaddle#70376)
* add div shape op * fix ut * Refactor SimplifyDiv to streamline numerator and denominator handling * fix * Refactor Div structure and update related operations for improved clarity and functionality * fix * fix * fix * fix * add BinaryExprMatchTrait * fix broadcast_tree typo and fix bug of BinaryDimExprMatchTrait * refine code * optimize error message and delete FlattenOperands<Div> * fix bug * refine * accuracy * prevent infinite loops * delete VisitEachOperandStruct<Div> * refine code * clean code * Add Mul(S0,0) => 0 and check for divided by 0 * fix more cinn tests * Code Reuse * refine code * revert parallel_run * refine error message --------- Co-authored-by: PuQing <[email protected]> Co-authored-by: ooooo <[email protected]>
1 parent b203799 commit 535a222

File tree

17 files changed

+607
-263
lines changed

17 files changed

+607
-263
lines changed

paddle/cinn/adt/dim_expr_match_trait.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@ struct UnaryDimExprMatchTrait {
3333
}
3434
};
3535

36+
template <template <typename> class Op, typename T0>
37+
struct BinaryDimExprMatchTrait {
38+
using base_type = Op<DimExpr>;
39+
40+
static constexpr int is_template = true;
41+
42+
template <template <typename, typename> class Matcher>
43+
static bool MatchChildren(const base_type& value) {
44+
const auto& lhs = value->lhs;
45+
const auto& rhs = value->rhs;
46+
return Matcher<T0, DimExpr>::Call(lhs) && Matcher<T0, DimExpr>::Call(rhs);
47+
}
48+
};
49+
3650
template <template <typename> class Op, typename T0>
3751
struct ListDimExprMatchTrait {
3852
using base_type = Op<DimExpr>;
@@ -65,10 +79,6 @@ template <typename T0>
6579
struct MatchTrait<DimExpr, ::symbol::Negative<T0>> final
6680
: public UnaryDimExprMatchTrait<::symbol::Negative, T0> {};
6781

68-
template <typename T0>
69-
struct MatchTrait<DimExpr, ::symbol::Reciprocal<T0>> final
70-
: public UnaryDimExprMatchTrait<::symbol::Reciprocal, T0> {};
71-
7282
template <typename T0>
7383
struct MatchTrait<DimExpr, ::symbol::Add<T0>> final
7484
: public ListDimExprMatchTrait<::symbol::Add, T0> {};
@@ -77,6 +87,10 @@ template <typename T0>
7787
struct MatchTrait<DimExpr, ::symbol::Mul<T0>> final
7888
: public ListDimExprMatchTrait<::symbol::Mul, T0> {};
7989

90+
template <typename T0>
91+
struct MatchTrait<DimExpr, ::symbol::Div<T0>> final
92+
: public BinaryDimExprMatchTrait<::symbol::Div, T0> {};
93+
8094
template <typename T0>
8195
struct MatchTrait<DimExpr, ::symbol::Broadcast<T0>> final
8296
: public ListDimExprMatchTrait<::symbol::Broadcast, T0> {};

paddle/cinn/common/broadcast_tree.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,17 @@ bool SearchBroadcastImplForUnary(const T& unary, const DoEachT& DoEach) {
4343
return SearchBroadcast(operand, DoEach);
4444
}
4545

46-
template <typename DoEachT>
47-
bool SearchBroadcastImpl(const symbol::Negative<symbol::DimExpr>& unary,
48-
const DoEachT& DoEach) {
49-
return SearchBroadcastImplForUnary(unary, DoEach);
46+
template <typename T, typename DoEachT>
47+
bool SearchBroadcastImplForBinary(const T& binary, const DoEachT& DoEach) {
48+
const auto& lhs = binary->lhs;
49+
const auto& rhs = binary->rhs;
50+
if (SearchBroadcast(lhs, DoEach)) return true;
51+
if (SearchBroadcast(rhs, DoEach)) return true;
52+
return false;
5053
}
5154

5255
template <typename DoEachT>
53-
bool SearchBroadcastImpl(const symbol::Reciprocal<symbol::DimExpr>& unary,
56+
bool SearchBroadcastImpl(const symbol::Negative<symbol::DimExpr>& unary,
5457
const DoEachT& DoEach) {
5558
return SearchBroadcastImplForUnary(unary, DoEach);
5659
}
@@ -76,6 +79,12 @@ bool SearchBroadcastImpl(const symbol::Mul<symbol::DimExpr>& variadic,
7679
return SearchBroadcastImplForVariadic(variadic, DoEach);
7780
}
7881

82+
template <typename DoEachT>
83+
bool SearchBroadcastImpl(const symbol::Div<symbol::DimExpr>& binary,
84+
const DoEachT& DoEach) {
85+
return SearchBroadcastImplForBinary(binary, DoEach);
86+
}
87+
7988
template <typename DoEachT>
8089
bool SearchBroadcastImpl(const symbol::Max<symbol::DimExpr>& variadic,
8190
const DoEachT& DoEach) {

paddle/cinn/common/dim_expr_converter.cc

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,6 @@ struct DimExprToIrExprVisitor {
4545
return ir::Sub::Make(ir::Expr(std::int64_t(0)), ConvertToIrExpr(operand));
4646
}
4747

48-
ir::Expr operator()(const Reciprocal<DimExpr>& dim_expr) {
49-
const auto& [operand] = *dim_expr;
50-
return ir::Div::Make(ir::Expr(std::int64_t(1)), ConvertToIrExpr(operand));
51-
}
52-
5348
ir::Expr operator()(const Add<DimExpr>& dim_expr) {
5449
const auto& [operands] = dim_expr;
5550
if (operands->empty()) {
@@ -69,21 +64,17 @@ struct DimExprToIrExprVisitor {
6964
}
7065
ir::Expr product = ConvertToIrExpr(operands->at(0));
7166
for (std::size_t i = 1; i < operands->size(); ++i) {
72-
// Convert Reciprocal<DimExpr>(S0) to (1 / S0) will result in precision
73-
// error. For example, (S0 * S1 / S2) != (S0 * S1 * (1 / S2)). So we
74-
// should use Div instead of Reciprocal here.
75-
if (operands->at(i).isa<Reciprocal<DimExpr>>()) {
76-
product = ir::Div::Make(
77-
product,
78-
ConvertToIrExpr(
79-
operands->at(i).dyn_cast<Reciprocal<DimExpr>>()->data));
80-
} else {
81-
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
82-
}
67+
product = ir::Mul::Make(product, ConvertToIrExpr(operands->at(i)));
8368
}
8469
return product;
8570
}
8671

72+
ir::Expr operator()(const Div<DimExpr>& dim_expr) {
73+
const auto& lhs = ConvertToIrExpr(dim_expr->lhs);
74+
const auto& rhs = ConvertToIrExpr(dim_expr->rhs);
75+
return ir::Div::Make(lhs, rhs);
76+
}
77+
8778
ir::Expr operator()(const Max<DimExpr>& dim_expr) {
8879
const auto& [operands] = dim_expr;
8980
PADDLE_ENFORCE_EQ(

paddle/cinn/hlir/dialect/operator/ir/generate_shape_util.cc

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@ std::string GetSerializedTag<Negative<DimExpr>>() {
3030
return "Negative";
3131
}
3232

33-
template <>
34-
std::string GetSerializedTag<Reciprocal<DimExpr>>() {
35-
return "Reciprocal";
36-
}
37-
3833
template <>
3934
std::string GetSerializedTag<Add<DimExpr>>() {
4035
return "Add";
@@ -45,6 +40,11 @@ std::string GetSerializedTag<Mul<DimExpr>>() {
4540
return "Mul";
4641
}
4742

43+
template <>
44+
std::string GetSerializedTag<Div<DimExpr>>() {
45+
return "Div";
46+
}
47+
4848
template <>
4949
std::string GetSerializedTag<Max<DimExpr>>() {
5050
return "Max";
@@ -80,13 +80,20 @@ ::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::IrContext* ctx,
8080
return pir::ArrayAttribute::get(ctx, attr_vecs);
8181
}
8282

83-
::pir::Attribute ConvertDimExprToAttributeImpl(
84-
::pir::IrContext* ctx, const Negative<DimExpr>& dim_expr) {
85-
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
83+
template <typename T>
84+
::pir::Attribute ConvertBinaryDimExprToAttributeImpl(::pir::IrContext* ctx,
85+
const T& dim_expr) {
86+
std::vector<::pir::Attribute> attr_vecs{};
87+
attr_vecs.push_back(pir::StrAttribute::get(ctx, GetSerializedTag<T>()));
88+
const auto& lhs = dim_expr->lhs;
89+
const auto& rhs = dim_expr->rhs;
90+
attr_vecs.push_back(ConvertDimExprToAttribute(ctx, lhs));
91+
attr_vecs.push_back(ConvertDimExprToAttribute(ctx, rhs));
92+
return pir::ArrayAttribute::get(ctx, attr_vecs);
8693
}
8794

8895
::pir::Attribute ConvertDimExprToAttributeImpl(
89-
::pir::IrContext* ctx, const Reciprocal<DimExpr>& dim_expr) {
96+
::pir::IrContext* ctx, const Negative<DimExpr>& dim_expr) {
9097
return ConvertUnaryDimExprToAttributeImpl(ctx, dim_expr);
9198
}
9299

@@ -112,6 +119,11 @@ ::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
112119
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
113120
}
114121

122+
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
123+
const Div<DimExpr>& dim_expr) {
124+
return ConvertBinaryDimExprToAttributeImpl(ctx, dim_expr);
125+
}
126+
115127
::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
116128
const Max<DimExpr>& dim_expr) {
117129
return ConvertVariadicDimExprToAttribute(ctx, dim_expr);
@@ -150,6 +162,23 @@ std::optional<DimExpr> ConvertArrayAttributeToUnaryDimExpr(
150162
return T{operand.value()};
151163
}
152164

165+
template <typename T>
166+
std::optional<DimExpr> ConvertArrayAttributeToBinaryDimExpr(
167+
const ::pir::ArrayAttribute& attribute) {
168+
if (attribute.size() != 3) {
169+
return std::nullopt;
170+
}
171+
std::optional<DimExpr> lhs = ConvertAttributeToDimExpr(attribute.at(1));
172+
if (!lhs.has_value()) {
173+
return std::nullopt;
174+
}
175+
std::optional<DimExpr> rhs = ConvertAttributeToDimExpr(attribute.at(2));
176+
if (!rhs.has_value()) {
177+
return std::nullopt;
178+
}
179+
return T{lhs.value(), rhs.value()};
180+
}
181+
153182
template <typename T>
154183
std::optional<DimExpr> ConvertArrayAttributeToVariadicDimExpr(
155184
const ::pir::ArrayAttribute& attribute) {
@@ -175,12 +204,12 @@ std::optional<ArrayAttributeConverterT> GetArrayAttributeConverter(
175204
static std::unordered_map<std::string, ArrayAttributeConverterT> map{
176205
{GetSerializedTag<Negative<DimExpr>>(),
177206
&ConvertArrayAttributeToUnaryDimExpr<Negative<DimExpr>>},
178-
{GetSerializedTag<Reciprocal<DimExpr>>(),
179-
&ConvertArrayAttributeToUnaryDimExpr<Reciprocal<DimExpr>>},
180207
{GetSerializedTag<Add<DimExpr>>(),
181208
&ConvertArrayAttributeToVariadicDimExpr<Add<DimExpr>>},
182209
{GetSerializedTag<Mul<DimExpr>>(),
183210
&ConvertArrayAttributeToVariadicDimExpr<Mul<DimExpr>>},
211+
{GetSerializedTag<Div<DimExpr>>(),
212+
&ConvertArrayAttributeToBinaryDimExpr<Div<DimExpr>>},
184213
{GetSerializedTag<Max<DimExpr>>(),
185214
&ConvertArrayAttributeToVariadicDimExpr<Max<DimExpr>>},
186215
{GetSerializedTag<Min<DimExpr>>(),
@@ -276,9 +305,6 @@ class SubstituteDimExprHelper final {
276305
std::optional<DimExpr> SubstituteImpl(const Negative<DimExpr>& dim_expr) {
277306
return SubstituteUnary(dim_expr);
278307
}
279-
std::optional<DimExpr> SubstituteImpl(const Reciprocal<DimExpr>& dim_expr) {
280-
return SubstituteUnary(dim_expr);
281-
}
282308

283309
template <typename T>
284310
std::optional<DimExpr> SubstituteUnary(const T& dim_expr) {
@@ -298,6 +324,25 @@ class SubstituteDimExprHelper final {
298324
return SubstituteVariadic(dim_expr);
299325
}
300326

327+
std::optional<DimExpr> SubstituteImpl(const Div<DimExpr>& dim_expr) {
328+
return SubstituteBinary(dim_expr);
329+
}
330+
331+
template <typename T>
332+
std::optional<DimExpr> SubstituteBinary(const T& dim_expr) {
333+
const auto& lhs = dim_expr->lhs;
334+
const auto& rhs = dim_expr->rhs;
335+
const auto& substituted_lhs = Substitute(lhs);
336+
if (!substituted_lhs.has_value()) {
337+
return std::nullopt;
338+
}
339+
const auto& substituted_rhs = Substitute(rhs);
340+
if (!substituted_rhs.has_value()) {
341+
return std::nullopt;
342+
}
343+
return T{substituted_lhs.value(), substituted_rhs.value()};
344+
}
345+
301346
std::optional<DimExpr> SubstituteImpl(const Max<DimExpr>& dim_expr) {
302347
return SubstituteVariadic(dim_expr);
303348
}
@@ -412,12 +457,12 @@ bool IsAtomicImpl(const std::string&) { return true; }
412457

413458
bool IsAtomicImpl(const symbol::Negative<symbol::DimExpr>&) { return false; }
414459

415-
bool IsAtomicImpl(const symbol::Reciprocal<symbol::DimExpr>&) { return false; }
416-
417460
bool IsAtomicImpl(const symbol::Add<symbol::DimExpr>&) { return false; }
418461

419462
bool IsAtomicImpl(const symbol::Mul<symbol::DimExpr>&) { return false; }
420463

464+
bool IsAtomicImpl(const symbol::Div<symbol::DimExpr>&) { return false; }
465+
421466
bool IsAtomicImpl(const symbol::Max<symbol::DimExpr>&) { return false; }
422467

423468
bool IsAtomicImpl(const symbol::Min<symbol::DimExpr>&) { return false; }
@@ -484,9 +529,12 @@ void CollectSymbolNamesImpl(const symbol::Negative<symbol::DimExpr>& dim_expr,
484529
CollectSymbolNamesImplForUnary(dim_expr, ret);
485530
}
486531

487-
void CollectSymbolNamesImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
488-
std::set<std::string>* ret) {
489-
CollectSymbolNamesImplForUnary(dim_expr, ret);
532+
template <typename T>
533+
void CollectSymbolNamesImplForBinary(const T& dim_expr,
534+
std::set<std::string>* ret) {
535+
const auto& [lhs, rhs] = *dim_expr;
536+
CollectSymbolNames(lhs, ret);
537+
CollectSymbolNames(rhs, ret);
490538
}
491539

492540
template <typename T>
@@ -508,6 +556,11 @@ void CollectSymbolNamesImpl(const symbol::Mul<symbol::DimExpr>& dim_expr,
508556
CollectSymbolNamesImplForVariadic(dim_expr, ret);
509557
}
510558

559+
void CollectSymbolNamesImpl(const symbol::Div<symbol::DimExpr>& dim_expr,
560+
std::set<std::string>* ret) {
561+
CollectSymbolNamesImplForBinary(dim_expr, ret);
562+
}
563+
511564
void CollectSymbolNamesImpl(const symbol::Max<symbol::DimExpr>& dim_expr,
512565
std::set<std::string>* ret) {
513566
CollectSymbolNamesImplForVariadic(dim_expr, ret);

paddle/cinn/hlir/dialect/operator/transforms/check_infer_symbolic_util.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,6 @@ struct ShapeSignatureGenerator {
138138
[&](const symbol::Negative<symbol::DimExpr>& negative) {
139139
GetSymbolsForOneDimExpr(negative->data, symbols);
140140
},
141-
[&](const symbol::Reciprocal<symbol::DimExpr>& reciprocal) {
142-
GetSymbolsForOneDimExpr(reciprocal->data, symbols);
143-
},
144141
[&](const symbol::Add<symbol::DimExpr>& add) {
145142
for (const auto& dim_expr : *add.operands) {
146143
GetSymbolsForOneDimExpr(dim_expr, symbols);
@@ -151,6 +148,10 @@ struct ShapeSignatureGenerator {
151148
GetSymbolsForOneDimExpr(dim_expr, symbols);
152149
}
153150
},
151+
[&](const symbol::Div<symbol::DimExpr>& div) {
152+
GetSymbolsForOneDimExpr(div->lhs, symbols);
153+
GetSymbolsForOneDimExpr(div->rhs, symbols);
154+
},
154155
[&](const symbol::Max<symbol::DimExpr>& max) {
155156
for (const auto& dim_expr : *max.operands) {
156157
GetSymbolsForOneDimExpr(dim_expr, symbols);

paddle/cinn/hlir/dialect/operator/transforms/group_merge/convert_static_dim_to_dynamic_pass.cc

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,16 @@ struct StaticDimToDynamicConverter {
198198
return AppliedOnceUnaryImpl(dim_expr, symbol);
199199
}
200200

201-
bool AppliedOnceImpl(const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
201+
template <typename T>
202+
bool AppliedOnceBinaryImpl(const T& dim_expr, const std::string& symbol) {
203+
const auto& lhs = dim_expr->lhs;
204+
const auto& rhs = dim_expr->rhs;
205+
return AppliedOnce(lhs, symbol) || AppliedOnce(rhs, symbol);
206+
}
207+
208+
bool AppliedOnceImpl(const symbol::Div<symbol::DimExpr>& dim_expr,
202209
const std::string& symbol) {
203-
return AppliedOnceUnaryImpl(dim_expr, symbol);
210+
return AppliedOnceBinaryImpl(dim_expr, symbol);
204211
}
205212

206213
template <typename T>
@@ -272,6 +279,24 @@ struct StaticDimToDynamicConverter {
272279
return T{converted_operand.value()};
273280
}
274281

282+
template <typename T>
283+
std::optional<symbol::DimExpr> ConvertBinaryDimExprImpl(
284+
const T& dim_expr, int64_t c, const std::string& symbol) {
285+
const auto& lhs = dim_expr->lhs;
286+
const auto& rhs = dim_expr->rhs;
287+
const auto& converted_lhs = ConvertDimExpr(lhs, c, symbol);
288+
const auto& converted_rhs = ConvertDimExpr(rhs, c, symbol);
289+
if (!converted_lhs.has_value() && !converted_rhs.has_value())
290+
return std::nullopt;
291+
if (converted_lhs.has_value() && converted_rhs.has_value()) {
292+
return T{converted_lhs.value(), converted_rhs.value()};
293+
}
294+
if (converted_lhs.has_value()) {
295+
return T{converted_lhs.value(), rhs};
296+
}
297+
return T{lhs, converted_rhs.value()};
298+
}
299+
275300
template <typename T>
276301
std::optional<symbol::DimExpr> ConvertListDimExprImpl(
277302
const T& dim_expr, int64_t c, const std::string& symbol) {
@@ -297,24 +322,24 @@ struct StaticDimToDynamicConverter {
297322
}
298323

299324
std::optional<symbol::DimExpr> ConvertDimExprImpl(
300-
const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
325+
const symbol::Add<symbol::DimExpr>& dim_expr,
301326
int64_t c,
302327
const std::string& symbol) {
303-
return ConvertUnaryDimExprImpl(dim_expr, c, symbol);
328+
return ConvertListDimExprImpl(dim_expr, c, symbol);
304329
}
305330

306331
std::optional<symbol::DimExpr> ConvertDimExprImpl(
307-
const symbol::Add<symbol::DimExpr>& dim_expr,
332+
const symbol::Mul<symbol::DimExpr>& dim_expr,
308333
int64_t c,
309334
const std::string& symbol) {
310335
return ConvertListDimExprImpl(dim_expr, c, symbol);
311336
}
312337

313338
std::optional<symbol::DimExpr> ConvertDimExprImpl(
314-
const symbol::Mul<symbol::DimExpr>& dim_expr,
339+
const symbol::Div<symbol::DimExpr>& dim_expr,
315340
int64_t c,
316341
const std::string& symbol) {
317-
return ConvertListDimExprImpl(dim_expr, c, symbol);
342+
return ConvertBinaryDimExprImpl(dim_expr, c, symbol);
318343
}
319344

320345
std::optional<symbol::DimExpr> ConvertDimExprImpl(

0 commit comments

Comments
 (0)