@@ -2598,9 +2598,39 @@ bool CombineReductions::shouldRun(
25982598 return false ;
25992599}
26002600
2601- bool SegmentCandidateFinder::codeGenSupportedMerge (SegmentedEdge* edge) {
2601+ namespace {
2602+
2603+ // ! Returns true if group1 and group2 are an immediate producer-consumer pair.
2604+ bool areDirectlyConnected (SegmentedGroup* group1, SegmentedGroup* group2) {
2605+ // Check if group1 is a immediate consumer of group2
2606+ if (std::any_of (
2607+ group1->producer_edges .begin (),
2608+ group1->producer_edges .end (),
2609+ [group2](SegmentedEdge* edge) { return edge->from == group2; })) {
2610+ return true ;
2611+ }
2612+
2613+ // Check if group1 is a immediate producer of group2
2614+ if (std::any_of (
2615+ group1->consumer_edges .begin (),
2616+ group1->consumer_edges .end (),
2617+ [group2](SegmentedEdge* edge) { return edge->to == group2; })) {
2618+ return true ;
2619+ }
2620+
2621+ return false ;
2622+ }
2623+
2624+ } // namespace
2625+
2626+ bool SegmentCandidateFinder::codeGenSupportedMerge (
2627+ SegmentedGroup* group1,
2628+ SegmentedGroup* group2) {
2629+ TORCH_INTERNAL_ASSERT (
2630+ areDirectlyConnected (group1, group2),
2631+ " only support testing immediate producer-consumer groups" );
26022632 Fusion* fusion = segmented_fusion_->completeFusion ();
2603- auto h = tryMerge (fusion, runtime_info_, edge-> from , edge-> to );
2633+ auto h = tryMerge (fusion, runtime_info_, group1, group2 );
26042634 return h.has_value ();
26052635}
26062636
@@ -2827,7 +2857,7 @@ void SegmentCandidateFinder::findSegments() {
28272857
28282858 auto candidate_it = candidates.begin ();
28292859 while (candidate_it != candidates.end () &&
2830- !codeGenSupportedMerge (candidate_it->edge )) {
2860+ !codeGenSupportedMerge (group, candidate_it->group )) {
28312861 candidate_it++;
28322862 }
28332863 if (candidate_it == candidates.end ()) {
@@ -2896,7 +2926,7 @@ void SegmentCandidateFinder::finalMerge() {
28962926 for (auto consumer : all_consumers_of_producer_group) {
28972927 if (!producer_check->isConsumerOfAny (
28982928 consumer, all_consumers_of_producer_group) &&
2899- codeGenSupportedMerge (consumer_edge_map. at ( consumer) )) {
2929+ codeGenSupportedMerge (producer_group, consumer)) {
29002930 to_merge_.emplace_back (producer_group);
29012931 to_merge_.emplace_back (consumer);
29022932 producer_group->merged_ = true ;
0 commit comments