@@ -382,6 +382,77 @@ LogicalResult mlir::scf::peelAndCanonicalizeForLoop(RewriterBase &rewriter,
382382 return success ();
383383}
384384
385+ // / Canonicalize AffineMinOp operations in the context of for loops with a known
386+ // / range. Call `canonicalizeAffineMinOp` and add the following constraints to
387+ // / the constraint system (along with the missing dimensions):
388+ // /
389+ // / * iv >= lb
390+ // / * iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
391+ // /
392+ // / Note: Due to limitations of FlatAffineConstraints, only constant step sizes
393+ // / are currently supported.
394+ LogicalResult mlir::scf::canonicalizeAffineMinOpInLoop (
395+ AffineMinOp minOp, RewriterBase &rewriter,
396+ function_ref<LogicalResult(Value, Value &, Value &, Value &)> loopMatcher) {
397+ FlatAffineValueConstraints constraints;
398+ DenseSet<Value> allIvs;
399+
400+ // Find all iteration variables among `minOp`'s operands add constrain them.
401+ for (Value operand : minOp.operands ()) {
402+ // Skip duplicate ivs.
403+ if (llvm::find (allIvs, operand) != allIvs.end ())
404+ continue ;
405+
406+ // If `operand` is an iteration variable: Find corresponding loop
407+ // bounds and step.
408+ Value iv = operand;
409+ Value lb, ub, step;
410+ if (failed (loopMatcher (operand, lb, ub, step)))
411+ continue ;
412+ allIvs.insert (iv);
413+
414+ // FlatAffineConstraints does not support semi-affine expressions.
415+ // Therefore, only constant step values are supported.
416+ auto stepInt = getConstantIntValue (step);
417+ if (!stepInt)
418+ continue ;
419+
420+ unsigned dimIv = constraints.addDimId (iv);
421+ unsigned dimLb = constraints.addDimId (lb);
422+ unsigned dimUb = constraints.addDimId (ub);
423+
424+ // If loop lower/upper bounds are constant: Add EQ constraint.
425+ Optional<int64_t > lbInt = getConstantIntValue (lb);
426+ Optional<int64_t > ubInt = getConstantIntValue (ub);
427+ if (lbInt)
428+ constraints.addBound (FlatAffineConstraints::EQ, dimLb, *lbInt);
429+ if (ubInt)
430+ constraints.addBound (FlatAffineConstraints::EQ, dimUb, *ubInt);
431+
432+ // iv >= lb (equiv.: iv - lb >= 0)
433+ SmallVector<int64_t > ineqLb (constraints.getNumCols (), 0 );
434+ ineqLb[dimIv] = 1 ;
435+ ineqLb[dimLb] = -1 ;
436+ constraints.addInequality (ineqLb);
437+
438+ // iv < lb + step * ((ub - lb - 1) floorDiv step) + 1
439+ AffineExpr exprLb = lbInt ? rewriter.getAffineConstantExpr (*lbInt)
440+ : rewriter.getAffineDimExpr (dimLb);
441+ AffineExpr exprUb = ubInt ? rewriter.getAffineConstantExpr (*ubInt)
442+ : rewriter.getAffineDimExpr (dimUb);
443+ AffineExpr ivUb =
444+ exprLb + 1 + (*stepInt * ((exprUb - exprLb - 1 ).floorDiv (*stepInt)));
445+ auto map = AffineMap::get (
446+ /* dimCount=*/ constraints.getNumDimIds (),
447+ /* symbolCount=*/ constraints.getNumSymbolIds (), /* result=*/ ivUb);
448+
449+ if (failed (constraints.addBound (FlatAffineConstraints::UB, dimIv, map)))
450+ return failure ();
451+ }
452+
453+ return canonicalizeAffineMinOp (rewriter, minOp, constraints);
454+ }
455+
385456static constexpr char kPeeledLoopLabel [] = " __peeled_loop__" ;
386457static constexpr char kPartialIterationLabel [] = " __partial_iteration__" ;
387458
@@ -423,6 +494,39 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
423494 // / the direct parent.
424495 bool skipPartial;
425496};
497+
498+ // / Canonicalize AffineMinOp operations in the context of scf.for and
499+ // / scf.parallel loops with a known range.
500+ struct AffineMinSCFCanonicalizationPattern
501+ : public OpRewritePattern<AffineMinOp> {
502+ using OpRewritePattern<AffineMinOp>::OpRewritePattern;
503+
504+ LogicalResult matchAndRewrite (AffineMinOp minOp,
505+ PatternRewriter &rewriter) const override {
506+ auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
507+ if (scf::ForOp forOp = scf::getForInductionVarOwner (iv)) {
508+ lb = forOp.lowerBound ();
509+ ub = forOp.upperBound ();
510+ step = forOp.step ();
511+ return success ();
512+ }
513+ if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner (iv)) {
514+ for (unsigned idx = 0 ; idx < parOp.getNumLoops (); ++idx) {
515+ if (parOp.getInductionVars ()[idx] == iv) {
516+ lb = parOp.lowerBound ()[idx];
517+ ub = parOp.upperBound ()[idx];
518+ step = parOp.step ()[idx];
519+ return success ();
520+ }
521+ }
522+ return failure ();
523+ }
524+ return failure ();
525+ };
526+
527+ return scf::canonicalizeAffineMinOpInLoop (minOp, rewriter, loopMatcher);
528+ }
529+ };
426530} // namespace
427531
428532namespace {
@@ -456,8 +560,24 @@ struct ForLoopPeeling : public SCFForLoopPeelingBase<ForLoopPeeling> {
456560 });
457561 }
458562};
563+
564+ struct AffineMinSCFCanonicalization
565+ : public AffineMinSCFCanonicalizationBase<AffineMinSCFCanonicalization> {
566+ void runOnFunction () override {
567+ FuncOp funcOp = getFunction ();
568+ MLIRContext *ctx = funcOp.getContext ();
569+ RewritePatternSet patterns (ctx);
570+ patterns.add <AffineMinSCFCanonicalizationPattern>(ctx);
571+ if (failed (applyPatternsAndFoldGreedily (funcOp, std::move (patterns))))
572+ signalPassFailure ();
573+ }
574+ };
459575} // namespace
460576
577+ std::unique_ptr<Pass> mlir::createAffineMinSCFCanonicalizationPass () {
578+ return std::make_unique<AffineMinSCFCanonicalization>();
579+ }
580+
461581std::unique_ptr<Pass> mlir::createParallelLoopSpecializationPass () {
462582 return std::make_unique<ParallelLoopSpecialization>();
463583}
@@ -469,3 +589,8 @@ std::unique_ptr<Pass> mlir::createForLoopSpecializationPass() {
469589std::unique_ptr<Pass> mlir::createForLoopPeelingPass () {
470590 return std::make_unique<ForLoopPeeling>();
471591}
592+
593+ void mlir::scf::populateSCFLoopBodyCanonicalizationPatterns (
594+ RewritePatternSet &patterns) {
595+ patterns.insert <AffineMinSCFCanonicalizationPattern>(patterns.getContext ());
596+ }
0 commit comments