66#include < torch/csrc/jit/codegen/cuda/root_domain_map.h>
77#include < torch/csrc/jit/codegen/cuda/transform_iter.h>
88
9+ #include < tuple>
10+
911namespace torch {
1012namespace jit {
1113namespace fuser {
@@ -29,8 +31,22 @@ bool idIsALeafDomain(IterDomain* id, TensorView* tv) {
2931
3032} // namespace
3133
32- IterDomainGraph::IterDomainGraph (Fusion* fusion) {
34+ IterDomainGraph::IterDomainGraph (Fusion* fusion, bool allow_self_mapping ) {
3335 build (fusion);
36+
37+ if (!allow_self_mapping) {
38+ TORCH_INTERNAL_ASSERT (
39+ !hasSelfMapping (),
40+ " Unsupported domain mapping detected in " ,
41+ std::get<0 >(*self_mapping_info_)->toString (),
42+ " . " ,
43+ std::get<3 >(*self_mapping_info_),
44+ " domains, " ,
45+ std::get<1 >(*self_mapping_info_)->toString (),
46+ " and " ,
47+ std::get<2 >(*self_mapping_info_)->toString (),
48+ " , are mapped with each other." );
49+ }
3450}
3551
3652// ! Map corresponding inputs and outputs of swizzle op together
@@ -197,7 +213,8 @@ c10::optional<std::pair<IterDomain*, IterDomain*>> detectMappablePair(
197213// those domains should never be mapped with each other. It may be
198214// possible to lift this assumption, but it's unclear if it could
199215// matter in practice.
200- void failIfSelfMappingExists (Fusion* fusion, const IterDomainGraph& id_graph) {
216+ c10::optional<std::tuple<TensorView*, IterDomain*, IterDomain*, std::string>>
217+ findFirstSelfMapping (Fusion* fusion, const IterDomainGraph& id_graph) {
201218 for (auto tv : ir_utils::allTvs (fusion)) {
202219 // For each tensor, make sure root, rfactor and leaf domains
203220 // should not include domains that are mapped with another domain
@@ -207,44 +224,39 @@ void failIfSelfMappingExists(Fusion* fusion, const IterDomainGraph& id_graph) {
207224 // Root domains
208225 auto self_mappped_root_pair =
209226 detectMappablePair (tv->getRootDomain (), id_graph);
210- TORCH_INTERNAL_ASSERT (
211- !self_mappped_root_pair.has_value (),
212- " Unsupported domain mapping detected in " ,
213- tv->toString (),
214- " . Root domains, " ,
215- self_mappped_root_pair->first ->toString (),
216- " and " ,
217- self_mappped_root_pair->second ->toString (),
218- " , are mapped with each other." );
227+ if (self_mappped_root_pair.has_value ()) {
228+ return std::make_tuple (
229+ tv,
230+ self_mappped_root_pair->first ,
231+ self_mappped_root_pair->second ,
232+ " Root" );
233+ }
219234
220235 // Rfactor domains
221236 if (tv->hasRFactor ()) {
222237 auto self_mappped_rf_pair =
223238 detectMappablePair (tv->getRFactorDomain (), id_graph);
224- TORCH_INTERNAL_ASSERT (
225- !self_mappped_rf_pair.has_value (),
226- " Unsupported domain mapping detected in " ,
227- tv->toString (),
228- " . RFactor domains, " ,
229- self_mappped_rf_pair->first ->toString (),
230- " and " ,
231- self_mappped_rf_pair->second ->toString (),
232- " , are mapped with each other." );
239+ if (self_mappped_rf_pair.has_value ()) {
240+ return std::make_tuple (
241+ tv,
242+ self_mappped_rf_pair->first ,
243+ self_mappped_rf_pair->second ,
244+ " RFactor" );
245+ }
233246 }
234247
235248 // Leaf domains
236249 auto self_mappped_leaf_pair =
237250 detectMappablePair (tv->domain ()->domain (), id_graph);
238- TORCH_INTERNAL_ASSERT (
239- !self_mappped_leaf_pair.has_value (),
240- " Unsupported domain mapping detected in " ,
241- tv->toString (),
242- " . Leaf domains, " ,
243- self_mappped_leaf_pair->first ->toString (),
244- " and " ,
245- self_mappped_leaf_pair->second ->toString (),
246- " , are mapped with each other." );
251+ if (self_mappped_leaf_pair.has_value ()) {
252+ return std::make_tuple (
253+ tv,
254+ self_mappped_leaf_pair->first ,
255+ self_mappped_leaf_pair->second ,
256+ " Leaf" );
257+ }
247258 }
259+ return c10::nullopt ;
248260}
249261
250262} // namespace
@@ -591,8 +603,7 @@ void IterDomainGraph::build(Fusion* fusion) {
591603 }
592604 }
593605 }
594-
595- failIfSelfMappingExists (fusion, *this );
606+ self_mapping_info_ = findFirstSelfMapping (fusion, *this );
596607}
597608
598609void IterDomainGraph::initializeId (
0 commit comments