@@ -1937,52 +1937,55 @@ std::vector<Val*> Index::getNonGlobalProducerStridedIndices(
19371937 return strided_inds;
19381938}
19391939
1940- std::vector<Val*> Index::getLinearIndex (
1941- TensorView* consumer_tv,
1942- const std::vector<kir::ForLoop*>& loops) {
1940+ template <typename func_t >
1941+ auto evaluateWithOverridenContiguity (
1942+ TensorView* tv,
1943+ bool contiguity,
1944+ const func_t & functor) -> decltype(functor()) {
19431945 // Use domain guard to ignore the contiguity of
19441946 // consumer tv.
1945- TensorDomain* consumer_tv_no_contiguity_domain = nullptr ;
1946- auto contiguity_vector =
1947- std::vector< bool >(consumer_tv ->getMaybeRFactorDomain ().size (), true );
1948- if (consumer_tv ->hasRFactor ()) {
1949- consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
1950- consumer_tv ->getRootDomain (),
1951- consumer_tv ->getRFactorDomain (),
1952- consumer_tv ->domain ()->domain (),
1947+ TensorDomain* domain_with_specified_contiguity = nullptr ;
1948+ std::vector< bool > contiguity_vector (
1949+ tv ->getMaybeRFactorDomain ().size (), contiguity );
1950+ if (tv ->hasRFactor ()) {
1951+ domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
1952+ tv ->getRootDomain (),
1953+ tv ->getRFactorDomain (),
1954+ tv ->domain ()->domain (),
19531955 contiguity_vector);
19541956 } else {
1955- consumer_tv_no_contiguity_domain = IrBuilder::create<TensorDomain>(
1956- consumer_tv->getRootDomain (),
1957- consumer_tv->domain ()->domain (),
1958- contiguity_vector);
1957+ domain_with_specified_contiguity = IrBuilder::create<TensorDomain>(
1958+ tv->getRootDomain (), tv->domain ()->domain (), contiguity_vector);
19591959 }
19601960
1961- ir_utils::TVDomainGuard domain_guard (
1962- consumer_tv, consumer_tv_no_contiguity_domain);
1961+ ir_utils::TVDomainGuard domain_guard (tv, domain_with_specified_contiguity);
19631962
1964- // TODO:
1965- // More optimization on the underlying tensor layout
1966- // will be done in a follow up.
1967- return getGlobalConsumerStridedIndices (consumer_tv, loops);
1963+ return functor ();
19681964}
19691965
1970- std::vector<Val*> Index::getGlobalConsumerStridedIndices (
1971- const TensorView* consumer_tv,
1966+ std::vector<Val*> Index::getLinearLogicalIndex (
1967+ TensorView* consumer_tv,
19721968 const std::vector<kir::ForLoop*>& loops) {
1973- FUSER_PERF_SCOPE (" GpuLower::Lower::getGlobalConsumerIndex" );
1974-
1975- auto gpu_lower = GpuLower::current ();
1976-
1977- auto index_from_id_graph = getTensorIndexFromIdGraph (loops, consumer_tv);
1969+ return evaluateWithOverridenContiguity (consumer_tv, true , [&]() {
1970+ return getGlobalConsumerStridedIndices (consumer_tv, loops);
1971+ });
1972+ }
19781973
1979- auto consumer_indexing = index_from_id_graph.index ;
1974+ std::vector<Val*> Index::getPerDimLogicalIndex (
1975+ TensorView* consumer_tv,
1976+ const std::vector<kir::ForLoop*>& loops) {
1977+ return evaluateWithOverridenContiguity (consumer_tv, false , [&]() {
1978+ IndexFromIdGraph index_from_id_graph =
1979+ getTensorIndexFromIdGraph (loops, consumer_tv);
1980+ return getRootIndices (consumer_tv, loops, index_from_id_graph);
1981+ });
1982+ }
19801983
1984+ std::vector<Val*> Index::getStrides (const TensorView* tv) {
19811985 // Indices should now be mapped onto IterDomains in consumer, so just grab
19821986 // and use them.
1983- auto root_dom = consumer_tv ->getMaybeRFactorDomain ();
1987+ auto root_dom = tv ->getMaybeRFactorDomain ();
19841988
1985- // TODO: Abstract stride logic to reuse with producer indexing
19861989 std::vector<Val*> strides (
19871990 root_dom.size (), GpuLower::current ()->kernel ()->oneVal ());
19881991 {
@@ -1993,39 +1996,21 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
19931996 continue ;
19941997 }
19951998 std::stringstream ss;
1996- ss << " T" << consumer_tv ->name () << " .stride[" << stride_i++ << " ]" ;
1999+ ss << " T" << tv ->name () << " .stride[" << stride_i++ << " ]" ;
19972000 strides[i] =
19982001 SimplifyingIrBuilder::create<NamedScalar>(ss.str (), DataType::Int);
19992002 }
20002003 }
20012004
2002- TORCH_INTERNAL_ASSERT (
2003- root_dom.size () == consumer_tv->domain ()->contiguity ().size ());
2005+ TORCH_INTERNAL_ASSERT (root_dom.size () == tv->domain ()->contiguity ().size ());
20042006 Val* cur_contig_stride = GpuLower::current ()->kernel ()->oneVal ();
20052007 for (const auto i : c10::irange (root_dom.size ())) {
20062008 auto dim = root_dom.size () - i - 1 ;
20072009 if (root_dom[dim]->isReduction () || root_dom[dim]->isStride ()) {
20082010 continue ;
20092011 }
20102012
2011- Val* root_ind = nullptr ;
2012- if (consumer_indexing.indexMap ().find (root_dom[dim]) !=
2013- consumer_indexing.indexMap ().end ()) {
2014- root_ind = consumer_indexing.indexMap ().at (root_dom[dim]);
2015- } else if (root_dom[dim]->isBroadcast ()) {
2016- root_ind = GpuLower::current ()->kernel ()->zeroVal ();
2017- }
2018-
2019- TORCH_INTERNAL_ASSERT (
2020- root_ind != nullptr ,
2021- " Couldn't find root mapping for " ,
2022- consumer_tv->toString (),
2023- " dim: " ,
2024- dim,
2025- " id: " ,
2026- root_dom[dim]->toString ());
2027-
2028- if (consumer_tv->domain ()->contiguity ()[dim]) {
2013+ if (tv->domain ()->contiguity ()[dim]) {
20292014 // If contig, used the stored stride which may be the previous
20302015 // dimensions stride * previous dimensions size
20312016 strides[dim] = cur_contig_stride;
@@ -2041,12 +2026,18 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
20412026 strides[dim], getHaloExtentOfRootAxis (root_dom[dim]));
20422027 }
20432028 }
2029+ return strides;
2030+ }
20442031
2045- auto vectorize_shift =
2046- loops.empty () ? nullptr : loops.back ()->vectorize_shift ();
2032+ std::vector<Val*> Index::getRootIndices (
2033+ const TensorView* tv,
2034+ const std::vector<kir::ForLoop*>& loops,
2035+ const IndexFromIdGraph& index_from_id_graph) {
2036+ auto gpu_lower = GpuLower::current ();
2037+ auto root_dom = tv->getMaybeRFactorDomain ();
2038+ auto indexing = index_from_id_graph.index ;
20472039
2048- // Global striding
2049- std::vector<Val*> strided_inds (
2040+ std::vector<Val*> root_inds (
20502041 root_dom.size (), GpuLower::current ()->kernel ()->zeroVal ());
20512042 for (const auto i : c10::irange (root_dom.size ())) {
20522043 // See a comment in indexing to root domains in getGlobalProducerIndex.
@@ -2057,35 +2048,55 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
20572048 }
20582049
20592050 TORCH_INTERNAL_ASSERT (
2060- consumer_indexing.indexMap ().find (root_dom[i]) !=
2061- consumer_indexing.indexMap ().end (),
2051+ indexing.indexMap ().find (root_dom[i]) != indexing.indexMap ().end (),
20622052 " Couldn't find root mapping for " ,
2063- consumer_tv ->toString (),
2053+ tv ->toString (),
20642054 " dim: " ,
20652055 i,
20662056 " id: " ,
20672057 root_dom[i]->toString ());
20682058
2069- auto root_ind = consumer_indexing .indexMap ().at (root_dom[i]);
2059+ auto root_ind = indexing .indexMap ().at (root_dom[i]);
20702060
20712061 // index hoist must be done before the adjustments for halo
20722062 root_ind = hoistConsumerIndex (
20732063 root_dom[i],
2074- consumer_tv ,
2075- consumer_indexing ,
2064+ tv ,
2065+ indexing ,
20762066 index_from_id_graph.resolved_loop_domains ,
20772067 index_from_id_graph.initial_concrete_index_map ,
20782068 loops,
20792069 root_ind);
20802070
20812071 root_ind = SimplifyingIrBuilder::addExpr (
20822072 root_ind, getGlobalConsumerOffsetWithPartialSplit (root_dom[i]));
2073+ root_inds[i] = root_ind;
2074+ }
2075+ return root_inds;
2076+ }
20832077
2084- if (root_ind->isZeroInt ()) {
2078+ std::vector<Val*> Index::getGlobalConsumerStridedIndices (
2079+ const TensorView* consumer_tv,
2080+ const std::vector<kir::ForLoop*>& loops) {
2081+ FUSER_PERF_SCOPE (" GpuLower::Lower::getGlobalConsumerIndex" );
2082+
2083+ auto index_from_id_graph = getTensorIndexFromIdGraph (loops, consumer_tv);
2084+ auto consumer_indexing = index_from_id_graph.index ;
2085+ auto strides = getStrides (consumer_tv);
2086+ auto root_inds = getRootIndices (consumer_tv, loops, index_from_id_graph);
2087+
2088+ // Global striding
2089+ auto vectorize_shift =
2090+ loops.empty () ? nullptr : loops.back ()->vectorize_shift ();
2091+ std::vector<Val*> strided_inds (
2092+ root_inds.size (), GpuLower::current ()->kernel ()->zeroVal ());
2093+ for (const auto i : c10::irange (root_inds.size ())) {
2094+ if (root_inds[i]->isZeroInt ()) {
20852095 continue ;
20862096 } else {
2087- auto strided_ind = SimplifyingIrBuilder::mulExpr (root_ind, strides[i]);
2088- if (i == root_dom.size () - 1 && vectorize_shift != nullptr ) {
2097+ auto strided_ind =
2098+ SimplifyingIrBuilder::mulExpr (root_inds[i], strides[i]);
2099+ if (i == strides.size () - 1 && vectorize_shift != nullptr ) {
20892100 strided_inds[i] =
20902101 SimplifyingIrBuilder::addExpr (strided_ind, vectorize_shift);
20912102 } else {
0 commit comments