Skip to content

Commit 0b8e83f

Browse files
naoyamcsarofeen
andauthored
Allow non-root trivial reductions (pytorch#2037)
* Allow non-root trivial reductions Fixes pytorch#2008 Co-authored-by: Christian Sarofeen <[email protected]>
1 parent a2dfe40 commit 0b8e83f

File tree

4 files changed

+93
-27
lines changed

4 files changed

+93
-27
lines changed

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,16 +1416,8 @@ class TORCH_CUDA_CU_API IterDomain : public Val {
14161416
}
14171417

14181418
//! Check if IterDomain is a reduction axis with size of 1, i.e.
1419-
//! a "squeeze" operator.
1420-
//!
1421-
//! NOTE: Detection of trivial reduction here is not
1422-
//! comprehensive. See detectTrivialReductionDerivedDomains for more
1423-
//! comprehensive analysis. We typically use this for root domain trivial
1424-
//! reduction checks. So we ship to the correct scheduler. It may
1425-
//! not be incredibly robust, but it makes sense to keep it for now.
1426-
bool isTrivialReduction() const {
1427-
return isReduction() && extent()->isOneInt();
1428-
}
1419+
//! a "squeeze" operator, or solely derived from such axes.
1420+
bool isTrivialReduction() const;
14291421

14301422
//! Split for stride by a given factor. It effectively does an inner
14311423
//! split by the factor and sets the inner domain as a Stride

torch/csrc/jit/codegen/cuda/ir_nodes.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,37 @@ IterDomain* IterDomain::cloneWithoutRFactor() const {
17201720
return cloned;
17211721
}
17221722

1723+
bool IterDomain::isTrivialReduction() const {
1724+
if (!isReduction()) {
1725+
return false;
1726+
}
1727+
1728+
if (extent()->isOneInt()) {
1729+
return true;
1730+
}
1731+
1732+
// If this domain is an output of an expression, i.e., not a root
1733+
// domain, check if all root domains are trivial reductions. This is
1734+
// almost the same as the analysis done in TrivialReductionInfo, but
1735+
// is limited within a single tensor, whereas TrivialReductionInfo
1736+
// does more expensive analysis potentially traversing through
1737+
// rfactor domains
1738+
if (definition()) {
1739+
// Note: There's no const version of IterVisitor.
1740+
auto id_inputs = InputsOf::output(fusion(), const_cast<IterDomain*>(this));
1741+
if (std::all_of(
1742+
ir_utils::filterByType<IterDomain>(id_inputs).begin(),
1743+
ir_utils::filterByType<IterDomain>(id_inputs).end(),
1744+
[](IterDomain* root_id) {
1745+
return root_id->isReduction() && root_id->extent()->isOneInt();
1746+
})) {
1747+
return true;
1748+
}
1749+
}
1750+
1751+
return false;
1752+
}
1753+
17231754
std::vector<IterDomain*> IterDomain::clone(
17241755
const std::vector<IterDomain*>& domains) {
17251756
std::vector<IterDomain*> cloned_domains;
@@ -1744,7 +1775,11 @@ IterDomain* IterDomain::merge(IterDomain* outer, IterDomain* inner) {
17441775
outer->isReduction() == inner->isReduction() ||
17451776
(!outer->isReduction() && inner->isTrivialReduction()) ||
17461777
(outer->isTrivialReduction() && !inner->isReduction()),
1747-
"Merging IterDomains requires that their iteration types match.");
1778+
"Merging IterDomains requires that their iteration types match. ",
1779+
"Outer: ",
1780+
outer->toString(),
1781+
", Inner: ",
1782+
inner->toString());
17481783
TORCH_CHECK(
17491784
(outer->isGather() && inner->isGather()) ||
17501785
(!outer->isGather() && !inner->isGather()),

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22304,26 +22304,40 @@ TEST_F(NVFuserTest, FusionTrivialReductionForwarding3_CUDA) {
2230422304
auto tv2 = add(tv1, IrBuilder::create<Double>(1));
2230522305
fusion.addOutput(tv2);
2230622306

22307-
// Similar pattern as FusionTrivialReductionForwarding2 but no
22308-
// trivial reduciton at the root domain
22307+
// Similar pattern as FusionTrivialReductionForwarding2 but trivial
22308+
// reduciton at non-root domain
2230922309

2231022310
// Create a trivial reduction by splitting with a factor of 1
2231122311
tv1->split(1, 1, false);
2231222312
// Merging with a trivial reduction
2231322313
tv1->merge(0, 1);
22314+
auto tv1_merge_out_id = tv1->axis(0);
2231422315
tv1->split(0, 5);
2231522316

2231622317
tv2->split(0, 5);
2231722318

22318-
// While the merge of tv1 is done with a trivial reduciton, it's not
22319-
// a root domain, so forwarding is not enabled. BestEffortReplay
22320-
// should only map the first axis of each tensor.
22319+
// The merge of tv1 is done with a non-root trivial
22320+
// reduciton. BestEffortReplay should forward the merge.
2232122321

2232222322
PairwiseRootDomainMap root_map(tv1, tv2);
2232322323
auto p2c = BestEffortReplay::replayCasP(tv2, tv1, 2, root_map).getReplay();
22324-
TORCH_CHECK(p2c.size() == 1, "Expected only one mapping found");
22325-
TORCH_CHECK(p2c.begin()->first == tv1->getRootDomain().at(0));
22326-
TORCH_CHECK(p2c.begin()->second == tv2->getRootDomain().at(0));
22324+
22325+
// The two tensors should look like:
22326+
// tv1: [I1*1//5, 5, I2//1]
22327+
// tv2: [I1//5, 5]
22328+
//
22329+
// BestEffortRepaly should forward the merge of (I1 * 1) and create
22330+
// mappings of:
22331+
// I1*1//5 -> I1//5
22332+
// 5 -> 5
22333+
// I1*1 -> I1
22334+
22335+
TORCH_CHECK(p2c.size() == 3, "Unexpected number of mappings");
22336+
TORCH_CHECK(p2c.count(tv1->axis(0)) && p2c[tv1->axis(0)] == tv2->axis(0));
22337+
TORCH_CHECK(p2c.count(tv1->axis(1)) && p2c[tv1->axis(1)] == tv2->axis(1));
22338+
TORCH_CHECK(
22339+
p2c.count(tv1_merge_out_id) &&
22340+
p2c[tv1_merge_out_id] == tv2->getRootDomain()[0]);
2232722341
}
2232822342

2232922343
TEST_F(NVFuserTest, FusionTrivialReductionForwarding4_CUDA) {
@@ -26125,6 +26139,39 @@ TEST_F(NVFuserTest, FusionTrivialInputForwarding_CUDA) {
2612526139
testValidate(fusion, cg_outputs2, {t0, t1}, {t0}, __LINE__, __FILE__);
2612626140
}
2612726141

26142+
// Simplified repro of issue #2008
26143+
TEST_F(NVFuserTest, FusionReplayTrivialReductionAndBroadcast2_CUDA) {
26144+
auto fusion_ptr = std::make_unique<Fusion>();
26145+
Fusion& fusion = *fusion_ptr;
26146+
FusionGuard fg(fusion_ptr.get());
26147+
26148+
std::vector<int64_t> shape({10, 1, 1});
26149+
26150+
auto tv0 = makeConcreteTensor(shape);
26151+
fusion.addInput(tv0);
26152+
26153+
auto tv1 = add(tv0, IrBuilder::create<Double>(1));
26154+
auto tv2 = sum(tv1, {1, 2});
26155+
auto tv3 = broadcast(tv2, {false, true, true});
26156+
fusion.addOutput(tv3);
26157+
26158+
tv0->merge(-2, -1)->merge(-2, -1)->split(0, 4);
26159+
26160+
MaxRootDomainInfoSpanningTree tree(tv0);
26161+
TransformPropagator tp(tv0);
26162+
tree.traverse(&tp);
26163+
26164+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
26165+
at::Tensor t0 = at::randn(shape, options);
26166+
std::vector<IValue> aten_inputs({t0});
26167+
26168+
FusionExecutor fe;
26169+
fe.compileFusion(fusion_ptr.get(), aten_inputs);
26170+
auto outputs = fe.runFusion(aten_inputs);
26171+
26172+
testValidate(&fusion, outputs, aten_inputs, {t0 + 1}, __LINE__, __FILE__);
26173+
}
26174+
2612826175
namespace {
2612926176

2613026177
size_t getVecSizeForPointwise(FusionExecutorCache& fec) {

torch/csrc/jit/codegen/cuda/transform_iter.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -762,14 +762,6 @@ struct ProducerForwardingInfo {
762762
(outer->isTrivialReduction() && !inner->isReduction())) {
763763
auto compliment_id = inner->isTrivialReduction() ? inner : outer;
764764
auto forwarded_id = inner->isTrivialReduction() ? outer : inner;
765-
// Only allow forwarding when the trivial reduction domain is
766-
// an root domain
767-
if (std::find(
768-
producer->getMaybeRFactorDomain().begin(),
769-
producer->getMaybeRFactorDomain().end(),
770-
compliment_id) == producer->getMaybeRFactorDomain().end()) {
771-
continue;
772-
}
773765
forwarding_map.emplace(std::make_pair(forwarded_id, merge->out()));
774766
compliment_map.emplace(std::make_pair(
775767
forwarded_id, std::vector<IterDomain*>{compliment_id}));

0 commit comments

Comments
 (0)