@@ -21,26 +21,20 @@ namespace scheduler_utils {
2121
2222// Returns number of "valid" dimensions. e.g. if tv has
2323// [I1, R2, I3, I4, R3{1}]
24- // where R3{1} is in dont_merge, resulting domain should be:
25- // [I1, I3*I4, R2, R3{1}] with return value 3
24+ // resulting domain should be:
25+ // [I1, I3*I4, R2* R3{1}] with return value 3
2626//
2727// if tv has
2828// [R1, I2, R3, I4, R4, R5{1}, R6{1}]
29- // where R5{1} and R6{1} are in dont_merge, resulting domain should be:
30- // [I2*I4, R1*R3, R4, R5{1}, R6{1}]
29+ // resulting domain should be:
30+ // [I2*I4, R1*R3, R4* R5{1}* R6{1}]
3131// with return value 3
32- size_t merge_3d (
33- TensorView* tv,
34- const std::unordered_set<IterDomain*>& dont_merge) {
32+ size_t merge_3d (TensorView* tv) {
3533 bool active_is_reduction = false ;
3634 bool first_dim = true ;
3735 int prev_i = -1 ;
3836
3937 for (int i = static_cast <int >(tv->nDims ()) - 1 ; i >= 0 ; i--) {
40- if (dont_merge.count (tv->axis (i))) {
41- continue ;
42- }
43-
4438 if (first_dim) {
4539 active_is_reduction = tv->axis (i)->isReduction ();
4640 prev_i = i;
@@ -67,10 +61,6 @@ size_t merge_3d(
6761
6862 for (int i = static_cast <int >(tv->nDims ()) - 2 ; i >= 0 ; i--) {
6963 auto id = tv->axis (i);
70- if (dont_merge.count (id)) {
71- continue ;
72- }
73-
7464 if (first_dim) {
7565 active_is_reduction = id->isReduction ();
7666 prev_i = i;
@@ -96,10 +86,6 @@ size_t merge_3d(
9686 prev_i = -1 ;
9787
9888 for (int i = static_cast <int >(tv->nDims ()) - 3 ; i >= 0 ; i--) {
99- if (dont_merge.count (tv->axis (i))) {
100- continue ;
101- }
102-
10389 if (first_dim) {
10490 active_is_reduction = tv->axis (i)->isReduction ();
10591 prev_i = i;
@@ -114,7 +100,7 @@ size_t merge_3d(
114100 if (prev_i == -1 ) {
115101 // Two dimensional, put merged dimensions first
116102 tv->reorder ({{-1 , 0 }, {-2 , 1 }});
117- // [outer, inner, dont_merge... ]
103+ // [outer, inner]
118104 if (tv->axis (0 )->isReduction ()) {
119105 // put reductions as second axis
120106 tv->reorder ({{0 , 1 }, {1 , 0 }});
@@ -195,13 +181,11 @@ c10::optional<size_t> mergeDims(
195181 return left;
196182}
197183
198- size_t mergeReduction (
199- TensorView* tv,
200- const std::unordered_set<IterDomain*>& dont_merge) {
184+ size_t mergeReduction (TensorView* tv) {
201185 int prev_i = -1 ;
202186 size_t num_merged = 0 ;
203187 for (int i = static_cast <int >(tv->nDims ()) - 1 ; i >= 0 ; i--) {
204- if (!tv->axis (i)->isReduction () || dont_merge. count (tv-> axis (i)) ) {
188+ if (!tv->axis (i)->isReduction ()) {
205189 continue ;
206190 }
207191 if (prev_i == -1 ) {
@@ -219,16 +203,14 @@ size_t mergeReduction(
219203 return prev_i == -1 ? 0 : num_merged + 1 ;
220204}
221205
222- size_t mergeNonReduction (
223- TensorView* tv,
224- const std::unordered_set<IterDomain*>& dont_merge) {
206+ size_t mergeNonReduction (TensorView* tv) {
225207 int prev_i = -1 ;
226208 size_t num_merged = 0 ;
227209 if (tv->nDims () == 0 ) {
228210 return 0 ;
229211 }
230212 for (int i = static_cast <int >(tv->nDims ()) - 1 ; i >= 0 ; i--) {
231- if (tv->axis (i)->isReduction () || dont_merge. count (tv-> axis (i)) ) {
213+ if (tv->axis (i)->isReduction ()) {
232214 continue ;
233215 }
234216 if (prev_i == -1 ) {
@@ -905,63 +887,21 @@ PersistentBufferSizeReturn persistentBufferSize(
905887 return persistent_buffer_size;
906888}
907889
908- std::unordered_set<IterDomain*> getTrivialReductionMap (Fusion* fusion) {
909- auto all_tvs = ir_utils::allTvs (fusion);
910- std::unordered_set<IterDomain*> mapped_to_trivial_reduction;
911- for (auto tv : all_tvs) {
912- // root domain vs domain shouldn't matter as at this point we shouldn't have
913- // any transformations.
914- for (auto id : tv->getRootDomain ()) {
915- if (id->isTrivialReduction ()) {
916- mapped_to_trivial_reduction.emplace (id);
917- }
918- }
919- }
920-
921- if (!mapped_to_trivial_reduction.empty ()) {
922- // Use the loop map as that is the most permissive
923- auto ca_map = ComputeAtMap (fusion);
924- // Make a copy we need to check mappings of all
925- auto trivial_ids = mapped_to_trivial_reduction;
926- for (auto tv : all_tvs) {
927- for (auto id : tv->getRootDomain ()) {
928- if (!id->extent ()->isOneInt ()) {
929- continue ;
930- }
931- if (std::any_of (
932- trivial_ids.begin (),
933- trivial_ids.end (),
934- [&ca_map, &id](IterDomain* trivial_id) {
935- return ca_map.areMapped (
936- id, trivial_id, IdMappingMode::PERMISSIVE);
937- })) {
938- mapped_to_trivial_reduction.emplace (id);
939- }
940- }
941- }
942- }
943- return mapped_to_trivial_reduction;
944- }
945-
946890std::pair<bool , bool > canonicalDimReduction (
947891 Fusion* fusion,
948892 TensorView* tv,
949893 bool schedule_3D) {
950- std::unordered_set<IterDomain*> mapped_to_trivial_reduction =
951- getTrivialReductionMap (fusion);
952-
953894 TORCH_INTERNAL_ASSERT (tv != nullptr );
954895
955896 if (!schedule_3D) {
956897 // We coalesce all reduction axes to the right;
957- bool has_red_axis = mergeReduction (tv, mapped_to_trivial_reduction ) > 0 ;
898+ bool has_red_axis = mergeReduction (tv) > 0 ;
958899
959- bool has_iter_axis = mergeNonReduction (tv, mapped_to_trivial_reduction ) > 0 ;
900+ bool has_iter_axis = mergeNonReduction (tv) > 0 ;
960901 return {has_iter_axis, has_red_axis};
961902 } else {
962903 TORCH_INTERNAL_ASSERT (
963- merge_3d (tv, mapped_to_trivial_reduction) == 3 ,
964- " Tried 3D merge, but result is not 3D." );
904+ merge_3d (tv) == 3 , " Tried 3D merge, but result is not 3D." );
965905 return {true , true };
966906 }
967907}
0 commit comments