Skip to content

Commit 992e17c

Browse files
authored
test the groups the same order as they are merged (pytorch#1949)
1 parent 208262b commit 992e17c

File tree

2 files changed

+35
-5
lines changed

2 files changed

+35
-5
lines changed

torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2598,9 +2598,39 @@ bool CombineReductions::shouldRun(
25982598
return false;
25992599
}
26002600

2601-
bool SegmentCandidateFinder::codeGenSupportedMerge(SegmentedEdge* edge) {
2601+
namespace {
2602+
2603+
//! Returns true if group1 and group2 are an immediate producer-consumer pair.
2604+
bool areDirectlyConnected(SegmentedGroup* group1, SegmentedGroup* group2) {
2605+
// Check if group1 is a immediate consumer of group2
2606+
if (std::any_of(
2607+
group1->producer_edges.begin(),
2608+
group1->producer_edges.end(),
2609+
[group2](SegmentedEdge* edge) { return edge->from == group2; })) {
2610+
return true;
2611+
}
2612+
2613+
// Check if group1 is a immediate producer of group2
2614+
if (std::any_of(
2615+
group1->consumer_edges.begin(),
2616+
group1->consumer_edges.end(),
2617+
[group2](SegmentedEdge* edge) { return edge->to == group2; })) {
2618+
return true;
2619+
}
2620+
2621+
return false;
2622+
}
2623+
2624+
} // namespace
2625+
2626+
bool SegmentCandidateFinder::codeGenSupportedMerge(
2627+
SegmentedGroup* group1,
2628+
SegmentedGroup* group2) {
2629+
TORCH_INTERNAL_ASSERT(
2630+
areDirectlyConnected(group1, group2),
2631+
"only support testing immediate producer-consumer groups");
26022632
Fusion* fusion = segmented_fusion_->completeFusion();
2603-
auto h = tryMerge(fusion, runtime_info_, edge->from, edge->to);
2633+
auto h = tryMerge(fusion, runtime_info_, group1, group2);
26042634
return h.has_value();
26052635
}
26062636

@@ -2827,7 +2857,7 @@ void SegmentCandidateFinder::findSegments() {
28272857

28282858
auto candidate_it = candidates.begin();
28292859
while (candidate_it != candidates.end() &&
2830-
!codeGenSupportedMerge(candidate_it->edge)) {
2860+
!codeGenSupportedMerge(group, candidate_it->group)) {
28312861
candidate_it++;
28322862
}
28332863
if (candidate_it == candidates.end()) {
@@ -2896,7 +2926,7 @@ void SegmentCandidateFinder::finalMerge() {
28962926
for (auto consumer : all_consumers_of_producer_group) {
28972927
if (!producer_check->isConsumerOfAny(
28982928
consumer, all_consumers_of_producer_group) &&
2899-
codeGenSupportedMerge(consumer_edge_map.at(consumer))) {
2929+
codeGenSupportedMerge(producer_group, consumer)) {
29002930
to_merge_.emplace_back(producer_group);
29012931
to_merge_.emplace_back(consumer);
29022932
producer_group->merged_ = true;

torch/csrc/jit/codegen/cuda/fusion_segmenter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ class TORCH_CUDA_CU_API SegmentCandidateFinder {
488488

489489
SegmentedGroup* mergeNodes();
490490

491-
bool codeGenSupportedMerge(SegmentedEdge* edge);
491+
bool codeGenSupportedMerge(SegmentedGroup* group1, SegmentedGroup* group2);
492492

493493
void findSegments();
494494

0 commit comments

Comments
 (0)