@@ -58,15 +58,49 @@ class DomainMap : public pointwise_utils::DomainMap {
5858 domain_map.findReferenceFor (grouped_inputs_outputs[1 ]) != nullptr ;
5959 }
6060
61- int getPosMappedTo (TensorView* tv, IterDomain* id) const {
61+ int getInnerLeafDim (TensorView* tv, IterDomain* root_dim) const {
62+ // Find the root id mapped to `root_dim`
63+ const auto & root_dom = tv->getRootDomain ();
64+ IterDomain* mapped_id = nullptr ;
65+ for (auto i : c10::irange (root_dom.size ())) {
66+ if (ca_map_.idGraph ().permissiveNodes ().permissiveAreMapped (
67+ root_dom[i], root_dim)) {
68+ mapped_id = root_dom[i];
69+ break ;
70+ }
71+ }
72+ TORCH_INTERNAL_ASSERT (
73+ mapped_id != nullptr ,
74+ " Can not find ID mapped to " ,
75+ root_dim,
76+ " in tensor " ,
77+ tv);
78+ // Project the root id to leaf id
79+ while (!mapped_id->uses ().empty ()) {
80+ TORCH_INTERNAL_ASSERT (mapped_id->uses ().size () == 1 );
81+ auto expr = mapped_id->uses ()[0 ];
82+ if (expr->isA <Split>()) {
83+ mapped_id = expr->as <Split>()->inner ();
84+ } else {
85+ auto merge = expr->as <Merge>();
86+ TORCH_INTERNAL_ASSERT (
87+ mapped_id == merge->inner (),
88+ " Can not find ID mapped to " ,
89+ root_dim,
90+ " in tensor " ,
91+ tv);
92+ mapped_id = merge->out ();
93+ }
94+ }
95+ // Find the position of the leaf id
6296 const auto & dom = tv->domain ()->domain ();
6397 for (auto i : c10::irange (dom.size ())) {
64- if (areExactMapped (id, tv-> axis (i)) ) {
98+ if (dom[i] == mapped_id ) {
6599 return i;
66100 }
67101 }
68102 TORCH_INTERNAL_ASSERT (
69- false , " Can not find ID mapped to " , id , " in tensor " , tv);
103+ false , " Can not find ID mapped to " , root_dim , " in tensor " , tv);
70104 }
71105
72106 // Group inputs and outputs of a fusion by its inner most domain. For example
@@ -240,22 +274,37 @@ void maybeBuildVirtualInnerDims(
240274 // both virtual innermost dim.
241275 // 2. The satisfied one did not merge in anything. For example,
242276 // T0[I0{1024*1024}, I1{2}]
277+ // If this is the case, this means that we need to split the large
278+ // inner-most dimension to satisfy the small innermost dimension
243279 int64_t large_dim;
244280 int64_t split_factor;
281+ bool split_inner_most;
245282 if (merged_size1 < params.tile_size1 ) {
246283 if (params.dims_merged_with_2 .empty ()) {
247284 // case 2
248- return ;
285+ split_inner_most = true ;
286+ large_dim = inner_most2;
287+ split_factor = params.tile_size2 ;
288+ } else {
289+ // case 1
290+ split_inner_most = false ;
291+ large_dim = params.dims_merged_with_2 .back ();
292+ auto prev_merged_size2 = merged_size2 / shape_in_ref1[large_dim];
293+ split_factor = ceilDiv (params.tile_size2 , prev_merged_size2);
249294 }
250- large_dim = params.dims_merged_with_2 .back ();
251- split_factor = ceilDiv (params.tile_size1 , merged_size1);
252295 } else {
253296 if (params.dims_merged_with_1 .empty ()) {
254297 // case 2
255- return ;
298+ split_inner_most = true ;
299+ large_dim = inner_most1;
300+ split_factor = params.tile_size1 ;
301+ } else {
302+ // case 1
303+ split_inner_most = false ;
304+ large_dim = params.dims_merged_with_1 .back ();
305+ auto prev_merged_size1 = merged_size1 / shape_in_ref1[large_dim];
306+ split_factor = ceilDiv (params.tile_size1 , prev_merged_size1);
256307 }
257- large_dim = params.dims_merged_with_1 .back ();
258- split_factor = ceilDiv (params.tile_size2 , merged_size2);
259308 }
260309 params.split_before_tiling .push_back ({large_dim, split_factor});
261310 // adjust all dims to after-split
@@ -271,12 +320,16 @@ void maybeBuildVirtualInnerDims(
271320 }
272321 // Give the split-out dim to the unsatisfied one, so that both are satisfied.
273322 if (merged_size1 < params.tile_size1 ) {
274- params.dims_merged_with_2 .pop_back ();
275- params.dims_merged_with_2 .push_back (large_dim + 1 );
323+ if (!split_inner_most) {
324+ params.dims_merged_with_2 .pop_back ();
325+ params.dims_merged_with_2 .push_back (large_dim + 1 );
326+ }
276327 params.dims_merged_with_1 .push_back (large_dim);
277328 } else {
278- params.dims_merged_with_1 .pop_back ();
279- params.dims_merged_with_1 .push_back (large_dim + 1 );
329+ if (!split_inner_most) {
330+ params.dims_merged_with_1 .pop_back ();
331+ params.dims_merged_with_1 .push_back (large_dim + 1 );
332+ }
280333 params.dims_merged_with_2 .push_back (large_dim);
281334 }
282335}
@@ -369,12 +422,6 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
369422 if (n_elems < device_multiprocessor_count * kMaxTileSize * kMaxTileSize ) {
370423 params->tile_size1 = 8 ;
371424 params->tile_size2 = 8 ;
372- // TODO: I was trying the following but I got silent wrong result
373- // params->tile_size1 = 8;
374- // params->tile_size2 = 4;
375- // This should not happen, because the correctness should be irrevalent to
376- // schedulers. We don't have to use tile size (8, 4), but we need to fix our
377- // bug in codegen.
378425 }
379426
380427 // Expand inner-most dims to virtual inner-most dims so that the inner-most
@@ -383,9 +430,9 @@ std::shared_ptr<TransposeParams> getTransposeHeuristics(
383430 auto inner_most_id2 = scheduler_utils::innerMostRootDim (reference2);
384431
385432 auto inner_most_pos1_in_ref1 =
386- domain_map.getPosMappedTo (reference1, inner_most_id1);
433+ domain_map.getInnerLeafDim (reference1, inner_most_id1);
387434 auto inner_most_pos2_in_ref1 =
388- domain_map.getPosMappedTo (reference1, inner_most_id2);
435+ domain_map.getInnerLeafDim (reference1, inner_most_id2);
389436
390437 // See note [Supporting small transpose dimensions]
391438 maybeBuildVirtualInnerDims (
@@ -643,9 +690,9 @@ void scheduleTranspose(Fusion* fusion, TransposeParams params) {
643690
644691 // merge with inner most dims to get virtual inner most dims
645692 size_t inner_most_pos1_in_ref1 =
646- domain_map.getPosMappedTo (reference1, inner_most_id1);
693+ domain_map.getInnerLeafDim (reference1, inner_most_id1);
647694 size_t inner_most_pos2_in_ref1 =
648- domain_map.getPosMappedTo (reference1, inner_most_id2);
695+ domain_map.getInnerLeafDim (reference1, inner_most_id2);
649696 if (merged1.has_value ()) {
650697 if (inner_most_pos1_in_ref1 < *merged1) {
651698 reference1->reorder (
0 commit comments