@@ -30,11 +30,6 @@ std::string GetSerializedTag<Negative<DimExpr>>() {
3030 return " Negative" ;
3131}
3232
33- template <>
34- std::string GetSerializedTag<Reciprocal<DimExpr>>() {
35- return " Reciprocal" ;
36- }
37-
3833template <>
3934std::string GetSerializedTag<Add<DimExpr>>() {
4035 return " Add" ;
@@ -45,6 +40,11 @@ std::string GetSerializedTag<Mul<DimExpr>>() {
4540 return " Mul" ;
4641}
4742
43+ template <>
44+ std::string GetSerializedTag<Div<DimExpr>>() {
45+ return " Div" ;
46+ }
47+
4848template <>
4949std::string GetSerializedTag<Max<DimExpr>>() {
5050 return " Max" ;
@@ -80,13 +80,20 @@ ::pir::Attribute ConvertUnaryDimExprToAttributeImpl(::pir::IrContext* ctx,
8080 return pir::ArrayAttribute::get (ctx, attr_vecs);
8181}
8282
83- ::pir::Attribute ConvertDimExprToAttributeImpl (
84- ::pir::IrContext* ctx, const Negative<DimExpr>& dim_expr) {
85- return ConvertUnaryDimExprToAttributeImpl (ctx, dim_expr);
83+ template <typename T>
84+ ::pir::Attribute ConvertBinaryDimExprToAttributeImpl (::pir::IrContext* ctx,
85+ const T& dim_expr) {
86+ std::vector<::pir::Attribute> attr_vecs{};
87+ attr_vecs.push_back (pir::StrAttribute::get (ctx, GetSerializedTag<T>()));
88+ const auto & lhs = dim_expr->lhs ;
89+ const auto & rhs = dim_expr->rhs ;
90+ attr_vecs.push_back (ConvertDimExprToAttribute (ctx, lhs));
91+ attr_vecs.push_back (ConvertDimExprToAttribute (ctx, rhs));
92+ return pir::ArrayAttribute::get (ctx, attr_vecs);
8693}
8794
8895::pir::Attribute ConvertDimExprToAttributeImpl (
89- ::pir::IrContext* ctx, const Reciprocal <DimExpr>& dim_expr) {
96+ ::pir::IrContext* ctx, const Negative <DimExpr>& dim_expr) {
9097 return ConvertUnaryDimExprToAttributeImpl (ctx, dim_expr);
9198}
9299
@@ -112,6 +119,11 @@ ::pir::Attribute ConvertDimExprToAttributeImpl(::pir::IrContext* ctx,
112119 return ConvertVariadicDimExprToAttribute (ctx, dim_expr);
113120}
114121
122+ ::pir::Attribute ConvertDimExprToAttributeImpl (::pir::IrContext* ctx,
123+ const Div<DimExpr>& dim_expr) {
124+ return ConvertBinaryDimExprToAttributeImpl (ctx, dim_expr);
125+ }
126+
115127::pir::Attribute ConvertDimExprToAttributeImpl (::pir::IrContext* ctx,
116128 const Max<DimExpr>& dim_expr) {
117129 return ConvertVariadicDimExprToAttribute (ctx, dim_expr);
@@ -150,6 +162,23 @@ std::optional<DimExpr> ConvertArrayAttributeToUnaryDimExpr(
150162 return T{operand.value ()};
151163}
152164
165+ template <typename T>
166+ std::optional<DimExpr> ConvertArrayAttributeToBinaryDimExpr (
167+ const ::pir::ArrayAttribute& attribute) {
168+ if (attribute.size () != 3 ) {
169+ return std::nullopt ;
170+ }
171+ std::optional<DimExpr> lhs = ConvertAttributeToDimExpr (attribute.at (1 ));
172+ if (!lhs.has_value ()) {
173+ return std::nullopt ;
174+ }
175+ std::optional<DimExpr> rhs = ConvertAttributeToDimExpr (attribute.at (2 ));
176+ if (!rhs.has_value ()) {
177+ return std::nullopt ;
178+ }
179+ return T{lhs.value (), rhs.value ()};
180+ }
181+
153182template <typename T>
154183std::optional<DimExpr> ConvertArrayAttributeToVariadicDimExpr (
155184 const ::pir::ArrayAttribute& attribute) {
@@ -175,12 +204,12 @@ std::optional<ArrayAttributeConverterT> GetArrayAttributeConverter(
175204 static std::unordered_map<std::string, ArrayAttributeConverterT> map{
176205 {GetSerializedTag<Negative<DimExpr>>(),
177206 &ConvertArrayAttributeToUnaryDimExpr<Negative<DimExpr>>},
178- {GetSerializedTag<Reciprocal<DimExpr>>(),
179- &ConvertArrayAttributeToUnaryDimExpr<Reciprocal<DimExpr>>},
180207 {GetSerializedTag<Add<DimExpr>>(),
181208 &ConvertArrayAttributeToVariadicDimExpr<Add<DimExpr>>},
182209 {GetSerializedTag<Mul<DimExpr>>(),
183210 &ConvertArrayAttributeToVariadicDimExpr<Mul<DimExpr>>},
211+ {GetSerializedTag<Div<DimExpr>>(),
212+ &ConvertArrayAttributeToBinaryDimExpr<Div<DimExpr>>},
184213 {GetSerializedTag<Max<DimExpr>>(),
185214 &ConvertArrayAttributeToVariadicDimExpr<Max<DimExpr>>},
186215 {GetSerializedTag<Min<DimExpr>>(),
@@ -276,9 +305,6 @@ class SubstituteDimExprHelper final {
276305 std::optional<DimExpr> SubstituteImpl (const Negative<DimExpr>& dim_expr) {
277306 return SubstituteUnary (dim_expr);
278307 }
279- std::optional<DimExpr> SubstituteImpl (const Reciprocal<DimExpr>& dim_expr) {
280- return SubstituteUnary (dim_expr);
281- }
282308
283309 template <typename T>
284310 std::optional<DimExpr> SubstituteUnary (const T& dim_expr) {
@@ -298,6 +324,25 @@ class SubstituteDimExprHelper final {
298324 return SubstituteVariadic (dim_expr);
299325 }
300326
327+ std::optional<DimExpr> SubstituteImpl (const Div<DimExpr>& dim_expr) {
328+ return SubstituteBinary (dim_expr);
329+ }
330+
331+ template <typename T>
332+ std::optional<DimExpr> SubstituteBinary (const T& dim_expr) {
333+ const auto & lhs = dim_expr->lhs ;
334+ const auto & rhs = dim_expr->rhs ;
335+ const auto & substituted_lhs = Substitute (lhs);
336+ if (!substituted_lhs.has_value ()) {
337+ return std::nullopt ;
338+ }
339+ const auto & substituted_rhs = Substitute (rhs);
340+ if (!substituted_rhs.has_value ()) {
341+ return std::nullopt ;
342+ }
343+ return T{substituted_lhs.value (), substituted_rhs.value ()};
344+ }
345+
301346 std::optional<DimExpr> SubstituteImpl (const Max<DimExpr>& dim_expr) {
302347 return SubstituteVariadic (dim_expr);
303348 }
@@ -412,12 +457,12 @@ bool IsAtomicImpl(const std::string&) { return true; }
412457
413458bool IsAtomicImpl (const symbol::Negative<symbol::DimExpr>&) { return false ; }
414459
415- bool IsAtomicImpl (const symbol::Reciprocal<symbol::DimExpr>&) { return false ; }
416-
417460bool IsAtomicImpl (const symbol::Add<symbol::DimExpr>&) { return false ; }
418461
419462bool IsAtomicImpl (const symbol::Mul<symbol::DimExpr>&) { return false ; }
420463
464+ bool IsAtomicImpl (const symbol::Div<symbol::DimExpr>&) { return false ; }
465+
421466bool IsAtomicImpl (const symbol::Max<symbol::DimExpr>&) { return false ; }
422467
423468bool IsAtomicImpl (const symbol::Min<symbol::DimExpr>&) { return false ; }
@@ -484,9 +529,12 @@ void CollectSymbolNamesImpl(const symbol::Negative<symbol::DimExpr>& dim_expr,
484529 CollectSymbolNamesImplForUnary (dim_expr, ret);
485530}
486531
487- void CollectSymbolNamesImpl (const symbol::Reciprocal<symbol::DimExpr>& dim_expr,
488- std::set<std::string>* ret) {
489- CollectSymbolNamesImplForUnary (dim_expr, ret);
532+ template <typename T>
533+ void CollectSymbolNamesImplForBinary (const T& dim_expr,
534+ std::set<std::string>* ret) {
535+ const auto & [lhs, rhs] = *dim_expr;
536+ CollectSymbolNames (lhs, ret);
537+ CollectSymbolNames (rhs, ret);
490538}
491539
492540template <typename T>
@@ -508,6 +556,11 @@ void CollectSymbolNamesImpl(const symbol::Mul<symbol::DimExpr>& dim_expr,
508556 CollectSymbolNamesImplForVariadic (dim_expr, ret);
509557}
510558
559+ void CollectSymbolNamesImpl (const symbol::Div<symbol::DimExpr>& dim_expr,
560+ std::set<std::string>* ret) {
561+ CollectSymbolNamesImplForBinary (dim_expr, ret);
562+ }
563+
511564void CollectSymbolNamesImpl (const symbol::Max<symbol::DimExpr>& dim_expr,
512565 std::set<std::string>* ret) {
513566 CollectSymbolNamesImplForVariadic (dim_expr, ret);
0 commit comments