@@ -1620,89 +1620,6 @@ BroadcastMultipleInformation getBroadcastMultiples(
16201620 return bcast_info;
16211621}
16221622
1623- size_t collectMaxVectorizeSizeWithContigMerge (
1624- TensorView* tv,
1625- IterDomain* leaf_merged_domain,
1626- size_t max_vector_size_in_byte,
1627- ExpressionEvaluator& expression_evaluator,
1628- DataType index_type) {
1629- // Maybe too conservative, but only handles fully contiguous tensors
1630- // TODO: Relax the contiguity constraint to be similar to that in index
1631- // computing. Just looking for all merged root domains in the right order,
1632- // all merged root dimensions are contiguous, all merged root dimensions are
1633- // next to eachother (exlcuding broadcast).
1634- if (std::any_of (
1635- tv->domain ()->contiguity ().begin (),
1636- tv->domain ()->contiguity ().end (),
1637- [](const auto contig) { return !contig; })) {
1638- return 1 ;
1639- }
1640-
1641- auto dtype_size = dataTypeSize (tv->dtype (), index_type);
1642- const size_t max_vector_size = max_vector_size_in_byte / dtype_size;
1643-
1644- // Assume no halo-related expression appears in the fusion. No
1645- // broadcast is merged, so indexability can be assumed to be true.
1646- ContigIDs contigIds (
1647- {leaf_merged_domain},
1648- tv->getMaybeRFactorDomain (),
1649- tv->domain ()->contiguity (),
1650- {},
1651- {},
1652- true ,
1653- true );
1654-
1655- auto innermost_root_id = tv->getMaybeRFactorDomain ().back ();
1656- auto indexed_id = contigIds.rootToIndexedID ().at (innermost_root_id);
1657-
1658- size_t merged_size = 1 ;
1659- // If the indexed ID is a contig merged domain, i.e., it is
1660- // different from innermost_root_id, we accumulate the extents of
1661- // all the root domains covered by the contig indexed ID. Otherwise,
1662- // just look at the extent of the innermost root ID.
1663- if (indexed_id != innermost_root_id) {
1664- const auto & within_root = contigIds.withinContigIDs ().at (indexed_id);
1665- for (auto root_id : tv->getMaybeRFactorDomain ()) {
1666- if (within_root.find (root_id) == within_root.end ()) {
1667- continue ;
1668- }
1669- auto maybe_dimension_size =
1670- expression_evaluator.evaluate (root_id->extent ());
1671- TORCH_INTERNAL_ASSERT (
1672- maybe_dimension_size.has_value (),
1673- " Unknown extent of tv: " ,
1674- tv->toString (),
1675- " , id: " ,
1676- root_id->toString ());
1677- merged_size *= maybe_dimension_size->as <int64_t >();
1678- }
1679- } else {
1680- auto maybe_dimension_size =
1681- expression_evaluator.evaluate (innermost_root_id->extent ());
1682- TORCH_INTERNAL_ASSERT (
1683- maybe_dimension_size.has_value (),
1684- " Unknown extent of tv: " ,
1685- tv->toString (),
1686- " , id: " ,
1687- innermost_root_id->toString ());
1688- merged_size = maybe_dimension_size->as <int64_t >();
1689- }
1690-
1691- size_t vector_size = 1 ;
1692- size_t next_vector_size = vector_size * 2 ;
1693-
1694- // Try until vector size exceeds the max allowed size
1695- while (next_vector_size <= max_vector_size) {
1696- if (merged_size % next_vector_size != 0 ) {
1697- break ;
1698- }
1699- vector_size = next_vector_size;
1700- next_vector_size *= 2 ;
1701- }
1702-
1703- return vector_size;
1704- }
1705-
17061623namespace matmul_utils {
17071624
17081625void scheduleWarpTileWithReduction (TensorView* tv, MatMulTileOptions tile) {
@@ -2260,183 +2177,6 @@ void BoundedDirectionalTransformPropagator::bothWays(
22602177 propagate (from, pos, included_tvs, *options);
22612178}
22622179
2263- // Grab all values and expressions used to make the merged_domain and remove
2264- // them from the fusion
2265- void cleanUpInnermostMergedDomains (
2266- const std::vector<IterDomain*>& root_domain,
2267- IterDomain* merged_domain) {
2268- TORCH_INTERNAL_ASSERT (merged_domain != nullptr );
2269- TORCH_INTERNAL_ASSERT (!root_domain.empty ());
2270-
2271- std::unordered_set<Val*> root_set ({root_domain.begin (), root_domain.end ()});
2272-
2273- auto vals = DependencyCheck::getAllValsBetween (root_set, {merged_domain});
2274-
2275- for (auto it = vals.rbegin (); it != vals.rend (); ++it) {
2276- TORCH_INTERNAL_ASSERT ((*it)->isA <IterDomain>());
2277- auto id = (*it)->as <IterDomain>();
2278- if (root_set.find (id) != root_set.end ()) {
2279- continue ;
2280- }
2281- Fusion* fusion = id->container ()->as <Fusion>();
2282- auto id_def = id->definition ();
2283- TORCH_INTERNAL_ASSERT (
2284- id_def->isA <Merge>(),
2285- " Invalid ID: " ,
2286- id->toString (),
2287- " . Expected definition of a Merge expression: " ,
2288- (id_def != nullptr ? id_def->toString () : " nullptr" ));
2289- fusion->removeExpr (id_def);
2290- fusion->removeVal (id);
2291- }
2292- }
2293-
2294- // Merge innermost domains for finding the widest vectorizable
2295- // size. Return the merged domain or nullptr if no merge is done.
2296- IterDomain* mergeInnermostDomains (
2297- const std::vector<IterDomain*>& domain,
2298- int num_merged_domains) {
2299- const auto ndims = domain.size ();
2300- IterDomain* merged_id = nullptr ;
2301- bool is_merge_done = false ;
2302- for (const auto i : c10::irange (num_merged_domains)) {
2303- auto id = domain.at (ndims - 1 - i);
2304- // broadcast and trivial reductions are ignored
2305- if (id->isBroadcast () || id->isTrivialReduction ()) {
2306- continue ;
2307- }
2308- if (merged_id == nullptr ) {
2309- merged_id = id;
2310- } else {
2311- auto id_inner = merged_id;
2312- auto id_outer = id;
2313- merged_id = IterDomain::merge (id_outer, id_inner);
2314- is_merge_done = true ;
2315- }
2316- }
2317- return is_merge_done ? merged_id : nullptr ;
2318- }
2319-
2320- // ! Attempt to expand vectorized domains to contig merged domains. Break point
2321- // ! identifies the point in which you can't propagate contiguous merges. For
2322- // ! example in pointwise this is the point where we want to split the
2323- // ! parallelization to take advantage of broadcast, and for reduction
2324- // ! schedulers it's the point where we switch from a reduction domain to an
2325- // ! iter domain (or vice versa).
2326- size_t expandVectorizationToContigMergedDomains (
2327- Fusion* fusion,
2328- SchedulerRuntimeInfo& runtime_info,
2329- const std::vector<TensorView*> vectorizable_inputs_outputs,
2330- TensorView* reference_tv,
2331- int break_point,
2332- size_t default_word_size) {
2333- size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte;
2334- size_t common_alignment_size =
2335- SchedulerRuntimeInfo::max_alignment_size_in_byte;
2336-
2337- for (auto inp_out : vectorizable_inputs_outputs) {
2338- auto dtype_size = dataTypeSize (
2339- inp_out->dtype (), indexModeToDtype (runtime_info.getIndexMode ()));
2340-
2341- max_expand_size = std::min (
2342- max_expand_size,
2343- SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size);
2344- max_expand_size = std::min (
2345- max_expand_size, runtime_info.getMaxVectorizableWidth (inp_out));
2346- common_alignment_size =
2347- std::min (common_alignment_size, runtime_info.getAlignmentSize (inp_out));
2348- }
2349-
2350- // If there's no possibility to increase vector size of provided tensors,
2351- // then don't bother doing a more complex analysis to try and do so, just
2352- // return early.
2353- if (max_expand_size == default_word_size) {
2354- return default_word_size;
2355- }
2356-
2357- auto ca_map = ComputeAtMap (fusion);
2358-
2359- // Merge the domains right of the break point
2360- const auto & ref_root = reference_tv->getMaybeRFactorDomain ();
2361- const int num_merged_domains =
2362- static_cast <int >(ref_root.size ()) - static_cast <int >(break_point);
2363-
2364- // No expansion with no merged domain
2365- if (num_merged_domains == 0 ) {
2366- return default_word_size;
2367- }
2368-
2369- // Merge the domains but don't modify TensorDomain
2370- auto merged_domain = mergeInnermostDomains (ref_root, num_merged_domains);
2371-
2372- // No expansion is done if no merge is done.
2373- if (merged_domain == nullptr ) {
2374- return default_word_size;
2375- }
2376-
2377- // Find the vectorizable word size with the merged domains
2378- size_t word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge (
2379- reference_tv,
2380- merged_domain,
2381- common_alignment_size,
2382- runtime_info.expressionEvaluator (),
2383- indexModeToDtype (runtime_info.getIndexMode ()));
2384-
2385- cleanUpInnermostMergedDomains (ref_root, merged_domain);
2386-
2387- // Stop if the reference doesn't get a larger word size.
2388- if (word_size <= default_word_size) {
2389- return default_word_size;
2390- }
2391-
2392- // Check the other TVs and take the minimum of the valid word sizes
2393- for (const auto tv : vectorizable_inputs_outputs) {
2394- if (tv == reference_tv) {
2395- continue ;
2396- }
2397-
2398- const auto & tv_root = tv->getMaybeRFactorDomain ();
2399-
2400- int tv_num_merged_domains = 0 ;
2401- for (const auto i : c10::irange (num_merged_domains)) {
2402- if (i == tv_root.size ()) {
2403- break ;
2404- }
2405- auto ref_id = ref_root.at (ref_root.size () - 1 - i);
2406- IterDomain* tv_id = tv_root.at (tv_root.size () - 1 - i);
2407- // If not mapped, stop expanding.
2408- if (!ca_map.areMapped (ref_id, tv_id, IdMappingMode::EXACT)) {
2409- break ;
2410- } else {
2411- ++tv_num_merged_domains;
2412- }
2413- }
2414-
2415- size_t tv_word_size = 1 ;
2416- if (tv_num_merged_domains > 1 ) {
2417- auto tv_merged_domain =
2418- mergeInnermostDomains (tv_root, tv_num_merged_domains);
2419- if (tv_merged_domain == nullptr ) {
2420- tv_word_size = runtime_info.getInnerDimVectorizableWidth (tv);
2421- } else {
2422- tv_word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge (
2423- tv,
2424- tv_merged_domain,
2425- common_alignment_size,
2426- runtime_info.expressionEvaluator (),
2427- indexModeToDtype (runtime_info.getIndexMode ()));
2428- cleanUpInnermostMergedDomains (tv_root, tv_merged_domain);
2429- }
2430- } else {
2431- tv_word_size = runtime_info.getInnerDimVectorizableWidth (tv);
2432- }
2433-
2434- word_size = std::min (word_size, tv_word_size);
2435- }
2436-
2437- return word_size;
2438- }
2439-
24402180DisjointSets<IterDomain*> disjointViewSets (Fusion* fusion) {
24412181 // Start from the exact iter domain graph of the fusion
24422182 IterDomainGraph id_graph (fusion);
0 commit comments