Skip to content

Commit b7a206e

Browse files
authored
Move scheduler vectorize utilities into their own file (pytorch#1959)
1 parent d9420e4 commit b7a206e

File tree

7 files changed

+291
-273
lines changed

7 files changed

+291
-273
lines changed

build_variables.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ libtorch_cuda_core_sources = [
730730
"torch/csrc/jit/codegen/cuda/scheduler/reduction_utils.cpp",
731731
"torch/csrc/jit/codegen/cuda/scheduler/registry.cpp",
732732
"torch/csrc/jit/codegen/cuda/scheduler/utils.cpp",
733+
"torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp",
733734
"torch/csrc/jit/codegen/cuda/type_inference.cpp",
734735
"torch/csrc/jit/codegen/cuda/type_promotion.cpp",
735736
"torch/csrc/jit/codegen/cuda/fusion_segmenter.cpp",

torch/csrc/jit/codegen/cuda/scheduler/normalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getPersistentHeuristics(
909909
}
910910

911911
// Try expanding vectorization to contig merged domains
912-
vectorize_factor = scheduler_utils::expandVectorizationToContigMergedDomains(
912+
vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains(
913913
fusion,
914914
runtime_info,
915915
vectorizable_inputs_outputs,

torch/csrc/jit/codegen/cuda/scheduler/pointwise.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ std::shared_ptr<PointwiseParams> getPointwiseHeuristics(
344344
// TODO: This is an expensive function that shouldn't be in heuristics without
345345
// caching.
346346
auto expanded_vector_word_size =
347-
scheduler_utils::expandVectorizationToContigMergedDomains(
347+
vectorize_helper::expandVectorizationToContigMergedDomains(
348348
fusion,
349349
runtime_info,
350350
vectorizable_inputs_outputs,

torch/csrc/jit/codegen/cuda/scheduler/reduction.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ TORCH_CUDA_CU_API std::shared_ptr<ReductionParams> getReductionHeuristics(
954954
}
955955

956956
// Try expanding vectorization to contig merged domains
957-
vectorize_factor = scheduler_utils::expandVectorizationToContigMergedDomains(
957+
vectorize_factor = vectorize_helper::expandVectorizationToContigMergedDomains(
958958
fusion,
959959
runtime_info,
960960
vectorizable_inputs_outputs,

torch/csrc/jit/codegen/cuda/scheduler/utils.cpp

Lines changed: 0 additions & 260 deletions
Original file line numberDiff line numberDiff line change
@@ -1620,89 +1620,6 @@ BroadcastMultipleInformation getBroadcastMultiples(
16201620
return bcast_info;
16211621
}
16221622

1623-
size_t collectMaxVectorizeSizeWithContigMerge(
1624-
TensorView* tv,
1625-
IterDomain* leaf_merged_domain,
1626-
size_t max_vector_size_in_byte,
1627-
ExpressionEvaluator& expression_evaluator,
1628-
DataType index_type) {
1629-
// Maybe too conservative, but only handles fully contiguous tensors
1630-
// TODO: Relax the contiguity constraint to be similar to that in index
1631-
// computing. Just looking for all merged root domains in the right order,
1632-
// all merged root dimensions are contiguous, all merged root dimensions are
1633-
// next to eachother (exlcuding broadcast).
1634-
if (std::any_of(
1635-
tv->domain()->contiguity().begin(),
1636-
tv->domain()->contiguity().end(),
1637-
[](const auto contig) { return !contig; })) {
1638-
return 1;
1639-
}
1640-
1641-
auto dtype_size = dataTypeSize(tv->dtype(), index_type);
1642-
const size_t max_vector_size = max_vector_size_in_byte / dtype_size;
1643-
1644-
// Assume no halo-related expression appears in the fusion. No
1645-
// broadcast is merged, so indexability can be assumed to be true.
1646-
ContigIDs contigIds(
1647-
{leaf_merged_domain},
1648-
tv->getMaybeRFactorDomain(),
1649-
tv->domain()->contiguity(),
1650-
{},
1651-
{},
1652-
true,
1653-
true);
1654-
1655-
auto innermost_root_id = tv->getMaybeRFactorDomain().back();
1656-
auto indexed_id = contigIds.rootToIndexedID().at(innermost_root_id);
1657-
1658-
size_t merged_size = 1;
1659-
// If the indexed ID is a contig merged domain, i.e., it is
1660-
// different from innermost_root_id, we accumulate the extents of
1661-
// all the root domains covered by the contig indexed ID. Otherwise,
1662-
// just look at the extent of the innermost root ID.
1663-
if (indexed_id != innermost_root_id) {
1664-
const auto& within_root = contigIds.withinContigIDs().at(indexed_id);
1665-
for (auto root_id : tv->getMaybeRFactorDomain()) {
1666-
if (within_root.find(root_id) == within_root.end()) {
1667-
continue;
1668-
}
1669-
auto maybe_dimension_size =
1670-
expression_evaluator.evaluate(root_id->extent());
1671-
TORCH_INTERNAL_ASSERT(
1672-
maybe_dimension_size.has_value(),
1673-
"Unknown extent of tv: ",
1674-
tv->toString(),
1675-
", id: ",
1676-
root_id->toString());
1677-
merged_size *= maybe_dimension_size->as<int64_t>();
1678-
}
1679-
} else {
1680-
auto maybe_dimension_size =
1681-
expression_evaluator.evaluate(innermost_root_id->extent());
1682-
TORCH_INTERNAL_ASSERT(
1683-
maybe_dimension_size.has_value(),
1684-
"Unknown extent of tv: ",
1685-
tv->toString(),
1686-
", id: ",
1687-
innermost_root_id->toString());
1688-
merged_size = maybe_dimension_size->as<int64_t>();
1689-
}
1690-
1691-
size_t vector_size = 1;
1692-
size_t next_vector_size = vector_size * 2;
1693-
1694-
// Try until vector size exceeds the max allowed size
1695-
while (next_vector_size <= max_vector_size) {
1696-
if (merged_size % next_vector_size != 0) {
1697-
break;
1698-
}
1699-
vector_size = next_vector_size;
1700-
next_vector_size *= 2;
1701-
}
1702-
1703-
return vector_size;
1704-
}
1705-
17061623
namespace matmul_utils {
17071624

17081625
void scheduleWarpTileWithReduction(TensorView* tv, MatMulTileOptions tile) {
@@ -2260,183 +2177,6 @@ void BoundedDirectionalTransformPropagator::bothWays(
22602177
propagate(from, pos, included_tvs, *options);
22612178
}
22622179

2263-
// Grab all values and expressions used to make the merged_domain and remove
2264-
// them from the fusion
2265-
void cleanUpInnermostMergedDomains(
2266-
const std::vector<IterDomain*>& root_domain,
2267-
IterDomain* merged_domain) {
2268-
TORCH_INTERNAL_ASSERT(merged_domain != nullptr);
2269-
TORCH_INTERNAL_ASSERT(!root_domain.empty());
2270-
2271-
std::unordered_set<Val*> root_set({root_domain.begin(), root_domain.end()});
2272-
2273-
auto vals = DependencyCheck::getAllValsBetween(root_set, {merged_domain});
2274-
2275-
for (auto it = vals.rbegin(); it != vals.rend(); ++it) {
2276-
TORCH_INTERNAL_ASSERT((*it)->isA<IterDomain>());
2277-
auto id = (*it)->as<IterDomain>();
2278-
if (root_set.find(id) != root_set.end()) {
2279-
continue;
2280-
}
2281-
Fusion* fusion = id->container()->as<Fusion>();
2282-
auto id_def = id->definition();
2283-
TORCH_INTERNAL_ASSERT(
2284-
id_def->isA<Merge>(),
2285-
"Invalid ID: ",
2286-
id->toString(),
2287-
". Expected definition of a Merge expression: ",
2288-
(id_def != nullptr ? id_def->toString() : "nullptr"));
2289-
fusion->removeExpr(id_def);
2290-
fusion->removeVal(id);
2291-
}
2292-
}
2293-
2294-
// Merge innermost domains for finding the widest vectorizable
2295-
// size. Return the merged domain or nullptr if no merge is done.
2296-
IterDomain* mergeInnermostDomains(
2297-
const std::vector<IterDomain*>& domain,
2298-
int num_merged_domains) {
2299-
const auto ndims = domain.size();
2300-
IterDomain* merged_id = nullptr;
2301-
bool is_merge_done = false;
2302-
for (const auto i : c10::irange(num_merged_domains)) {
2303-
auto id = domain.at(ndims - 1 - i);
2304-
// broadcast and trivial reductions are ignored
2305-
if (id->isBroadcast() || id->isTrivialReduction()) {
2306-
continue;
2307-
}
2308-
if (merged_id == nullptr) {
2309-
merged_id = id;
2310-
} else {
2311-
auto id_inner = merged_id;
2312-
auto id_outer = id;
2313-
merged_id = IterDomain::merge(id_outer, id_inner);
2314-
is_merge_done = true;
2315-
}
2316-
}
2317-
return is_merge_done ? merged_id : nullptr;
2318-
}
2319-
2320-
//! Attempt to expand vectorized domains to contig merged domains. Break point
2321-
//! identifies the point in which you can't propagate contiguous merges. For
2322-
//! example in pointwise this is the point where we want to split the
2323-
//! parallelization to take advantage of broadcast, and for reduction
2324-
//! schedulers it's the point where we switch from a reduction domain to an
2325-
//! iter domain (or vice versa).
2326-
size_t expandVectorizationToContigMergedDomains(
2327-
Fusion* fusion,
2328-
SchedulerRuntimeInfo& runtime_info,
2329-
const std::vector<TensorView*> vectorizable_inputs_outputs,
2330-
TensorView* reference_tv,
2331-
int break_point,
2332-
size_t default_word_size) {
2333-
size_t max_expand_size = SchedulerRuntimeInfo::max_alignment_size_in_byte;
2334-
size_t common_alignment_size =
2335-
SchedulerRuntimeInfo::max_alignment_size_in_byte;
2336-
2337-
for (auto inp_out : vectorizable_inputs_outputs) {
2338-
auto dtype_size = dataTypeSize(
2339-
inp_out->dtype(), indexModeToDtype(runtime_info.getIndexMode()));
2340-
2341-
max_expand_size = std::min(
2342-
max_expand_size,
2343-
SchedulerRuntimeInfo::max_alignment_size_in_byte / dtype_size);
2344-
max_expand_size = std::min(
2345-
max_expand_size, runtime_info.getMaxVectorizableWidth(inp_out));
2346-
common_alignment_size =
2347-
std::min(common_alignment_size, runtime_info.getAlignmentSize(inp_out));
2348-
}
2349-
2350-
// If there's no possibility to increase vector size of provided tensors,
2351-
// then don't bother doing a more complex analysis to try and do so, just
2352-
// return early.
2353-
if (max_expand_size == default_word_size) {
2354-
return default_word_size;
2355-
}
2356-
2357-
auto ca_map = ComputeAtMap(fusion);
2358-
2359-
// Merge the domains right of the break point
2360-
const auto& ref_root = reference_tv->getMaybeRFactorDomain();
2361-
const int num_merged_domains =
2362-
static_cast<int>(ref_root.size()) - static_cast<int>(break_point);
2363-
2364-
// No expansion with no merged domain
2365-
if (num_merged_domains == 0) {
2366-
return default_word_size;
2367-
}
2368-
2369-
// Merge the domains but don't modify TensorDomain
2370-
auto merged_domain = mergeInnermostDomains(ref_root, num_merged_domains);
2371-
2372-
// No expansion is done if no merge is done.
2373-
if (merged_domain == nullptr) {
2374-
return default_word_size;
2375-
}
2376-
2377-
// Find the vectorizable word size with the merged domains
2378-
size_t word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge(
2379-
reference_tv,
2380-
merged_domain,
2381-
common_alignment_size,
2382-
runtime_info.expressionEvaluator(),
2383-
indexModeToDtype(runtime_info.getIndexMode()));
2384-
2385-
cleanUpInnermostMergedDomains(ref_root, merged_domain);
2386-
2387-
// Stop if the reference doesn't get a larger word size.
2388-
if (word_size <= default_word_size) {
2389-
return default_word_size;
2390-
}
2391-
2392-
// Check the other TVs and take the minimum of the valid word sizes
2393-
for (const auto tv : vectorizable_inputs_outputs) {
2394-
if (tv == reference_tv) {
2395-
continue;
2396-
}
2397-
2398-
const auto& tv_root = tv->getMaybeRFactorDomain();
2399-
2400-
int tv_num_merged_domains = 0;
2401-
for (const auto i : c10::irange(num_merged_domains)) {
2402-
if (i == tv_root.size()) {
2403-
break;
2404-
}
2405-
auto ref_id = ref_root.at(ref_root.size() - 1 - i);
2406-
IterDomain* tv_id = tv_root.at(tv_root.size() - 1 - i);
2407-
// If not mapped, stop expanding.
2408-
if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT)) {
2409-
break;
2410-
} else {
2411-
++tv_num_merged_domains;
2412-
}
2413-
}
2414-
2415-
size_t tv_word_size = 1;
2416-
if (tv_num_merged_domains > 1) {
2417-
auto tv_merged_domain =
2418-
mergeInnermostDomains(tv_root, tv_num_merged_domains);
2419-
if (tv_merged_domain == nullptr) {
2420-
tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv);
2421-
} else {
2422-
tv_word_size = scheduler_utils::collectMaxVectorizeSizeWithContigMerge(
2423-
tv,
2424-
tv_merged_domain,
2425-
common_alignment_size,
2426-
runtime_info.expressionEvaluator(),
2427-
indexModeToDtype(runtime_info.getIndexMode()));
2428-
cleanUpInnermostMergedDomains(tv_root, tv_merged_domain);
2429-
}
2430-
} else {
2431-
tv_word_size = runtime_info.getInnerDimVectorizableWidth(tv);
2432-
}
2433-
2434-
word_size = std::min(word_size, tv_word_size);
2435-
}
2436-
2437-
return word_size;
2438-
}
2439-
24402180
DisjointSets<IterDomain*> disjointViewSets(Fusion* fusion) {
24412181
// Start from the exact iter domain graph of the fusion
24422182
IterDomainGraph id_graph(fusion);

0 commit comments

Comments
 (0)