@@ -22304,26 +22304,40 @@ TEST_F(NVFuserTest, FusionTrivialReductionForwarding3_CUDA) {
2230422304 auto tv2 = add(tv1, IrBuilder::create<Double>(1));
2230522305 fusion.addOutput(tv2);
2230622306
22307- // Similar pattern as FusionTrivialReductionForwarding2 but no
22308- // trivial reduciton at the root domain
22307+ // Similar pattern as FusionTrivialReductionForwarding2 but trivial
22308+ // reduciton at non- root domain
2230922309
2231022310 // Create a trivial reduction by splitting with a factor of 1
2231122311 tv1->split(1, 1, false);
2231222312 // Merging with a trivial reduction
2231322313 tv1->merge(0, 1);
22314+ auto tv1_merge_out_id = tv1->axis(0);
2231422315 tv1->split(0, 5);
2231522316
2231622317 tv2->split(0, 5);
2231722318
22318- // While the merge of tv1 is done with a trivial reduciton, it's not
22319- // a root domain, so forwarding is not enabled. BestEffortReplay
22320- // should only map the first axis of each tensor.
22319+ // The merge of tv1 is done with a non-root trivial
22320+ // reduciton. BestEffortReplay should forward the merge.
2232122321
2232222322 PairwiseRootDomainMap root_map(tv1, tv2);
2232322323 auto p2c = BestEffortReplay::replayCasP(tv2, tv1, 2, root_map).getReplay();
22324- TORCH_CHECK(p2c.size() == 1, "Expected only one mapping found");
22325- TORCH_CHECK(p2c.begin()->first == tv1->getRootDomain().at(0));
22326- TORCH_CHECK(p2c.begin()->second == tv2->getRootDomain().at(0));
22324+
22325+ // The two tensors should look like:
22326+ // tv1: [I1*1//5, 5, I2//1]
22327+ // tv2: [I1//5, 5]
22328+ //
22329+ // BestEffortRepaly should forward the merge of (I1 * 1) and create
22330+ // mappings of:
22331+ // I1*1//5 -> I1//5
22332+ // 5 -> 5
22333+ // I1*1 -> I1
22334+
22335+ TORCH_CHECK(p2c.size() == 3, "Unexpected number of mappings");
22336+ TORCH_CHECK(p2c.count(tv1->axis(0)) && p2c[tv1->axis(0)] == tv2->axis(0));
22337+ TORCH_CHECK(p2c.count(tv1->axis(1)) && p2c[tv1->axis(1)] == tv2->axis(1));
22338+ TORCH_CHECK(
22339+ p2c.count(tv1_merge_out_id) &&
22340+ p2c[tv1_merge_out_id] == tv2->getRootDomain()[0]);
2232722341}
2232822342
2232922343TEST_F(NVFuserTest, FusionTrivialReductionForwarding4_CUDA) {
@@ -26125,6 +26139,39 @@ TEST_F(NVFuserTest, FusionTrivialInputForwarding_CUDA) {
2612526139 testValidate(fusion, cg_outputs2, {t0, t1}, {t0}, __LINE__, __FILE__);
2612626140}
2612726141
26142+ // Simplified repro of issue #2008
26143+ TEST_F(NVFuserTest, FusionReplayTrivialReductionAndBroadcast2_CUDA) {
26144+ auto fusion_ptr = std::make_unique<Fusion>();
26145+ Fusion& fusion = *fusion_ptr;
26146+ FusionGuard fg(fusion_ptr.get());
26147+
26148+ std::vector<int64_t> shape({10, 1, 1});
26149+
26150+ auto tv0 = makeConcreteTensor(shape);
26151+ fusion.addInput(tv0);
26152+
26153+ auto tv1 = add(tv0, IrBuilder::create<Double>(1));
26154+ auto tv2 = sum(tv1, {1, 2});
26155+ auto tv3 = broadcast(tv2, {false, true, true});
26156+ fusion.addOutput(tv3);
26157+
26158+ tv0->merge(-2, -1)->merge(-2, -1)->split(0, 4);
26159+
26160+ MaxRootDomainInfoSpanningTree tree(tv0);
26161+ TransformPropagator tp(tv0);
26162+ tree.traverse(&tp);
26163+
26164+ auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
26165+ at::Tensor t0 = at::randn(shape, options);
26166+ std::vector<IValue> aten_inputs({t0});
26167+
26168+ FusionExecutor fe;
26169+ fe.compileFusion(fusion_ptr.get(), aten_inputs);
26170+ auto outputs = fe.runFusion(aten_inputs);
26171+
26172+ testValidate(&fusion, outputs, aten_inputs, {t0 + 1}, __LINE__, __FILE__);
26173+ }
26174+
2612826175namespace {
2612926176
2613026177size_t getVecSizeForPointwise(FusionExecutorCache& fec) {
0 commit comments