Skip to content

Commit a2dfe40

Browse files
authored
Fix vectorize size calculation (#2035)
1 parent e040676 commit a2dfe40

File tree

4 files changed

+202
-21
lines changed

4 files changed

+202
-21
lines changed

torch/csrc/jit/codegen/cuda/scheduler/registry.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,24 @@ void SchedulerRuntimeInfo::initialize(
463463
auto fusion_inp = complete_fusion_->inputs()[inp_i];
464464
auto data_ptr = tensor_arg_abstract->getPointer();
465465
input_ptrs_[fusion_inp] = (size_t)data_ptr;
466+
467+
// find and push discontiguous stride
468+
auto dtype_size = dataTypeSize(tensor_arg_abstract->getDataType());
469+
input_discontig_strides_[fusion_inp] = {};
470+
auto dims = tensor_arg_abstract->getRank();
471+
auto expected_stride = 1;
472+
for (auto dim = dims - 1; dim >= 0; dim--) {
473+
auto size = tensor_arg_abstract->getSize(dim);
474+
if (size <= 1) {
475+
continue;
476+
}
477+
auto stride = tensor_arg_abstract->getStride(dim);
478+
if (stride != expected_stride) {
479+
input_discontig_strides_[fusion_inp].push_back(stride * dtype_size);
480+
expected_stride = stride;
481+
}
482+
expected_stride *= size;
483+
}
466484
}
467485
}
468486

@@ -529,6 +547,13 @@ size_t SchedulerRuntimeInfo::getAlignmentSize(TensorView* tv) {
529547
}
530548

531549
auto alignment_size = SchedulerRuntimeInfo::computeAlignmentSize(ptrOf(tv));
550+
auto strides_it = input_discontig_strides_.find(tv);
551+
if (strides_it != input_discontig_strides_.end()) {
552+
for (auto stride : strides_it->second) {
553+
alignment_size = std::min(
554+
alignment_size, SchedulerRuntimeInfo::computeAlignmentSize(stride));
555+
}
556+
}
532557
alignment_map_[tv] = alignment_size;
533558
return alignment_size;
534559
}

torch/csrc/jit/codegen/cuda/scheduler/registry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ExpressionEvaluator;
2727
//! segmenter and schedulers.
2828
//! It is important that input id encoding should be up to date with any change
2929
//! of this class to avoid launching compiled kernels with illegal inputs.
30+
3031
class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable {
3132
public:
3233
// Max vector size we will consider, in bytes,
@@ -112,6 +113,9 @@ class TORCH_CUDA_CU_API SchedulerRuntimeInfo : public NonCopyable {
112113
// TODO: Support output tensor pointers
113114
std::unordered_map<Val*, size_t> input_ptrs_;
114115

116+
// Copy of aten input tensor strides (in bytes)
117+
std::unordered_map<Val*, std::vector<size_t>> input_discontig_strides_;
118+
115119
// Cache for getAlignmentSize
116120
std::unordered_map<TensorView*, size_t> alignment_map_;
117121
// Cache for getMaxVectorizableWidth

torch/csrc/jit/codegen/cuda/scheduler/vectorize_helper.cpp

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -80,18 +80,6 @@ size_t collectMaxVectorizeSizeWithContigMerge(
8080
size_t max_vector_size_in_byte,
8181
ExpressionEvaluator& expression_evaluator,
8282
DataType index_type) {
83-
// Maybe too conservative, but only handles fully contiguous tensors
84-
// TODO: Relax the contiguity constraint to be similar to that in index
85-
// computing. Just looking for all merged root domains in the right order,
86-
// all merged root dimensions are contiguous, all merged root dimensions are
87-
// next to eachother (exlcuding broadcast).
88-
if (std::any_of(
89-
tv->domain()->contiguity().begin(),
90-
tv->domain()->contiguity().end(),
91-
[](const auto contig) { return !contig; })) {
92-
return 1;
93-
}
94-
9583
auto dtype_size = dataTypeSize(tv->dtype(), index_type);
9684
const size_t max_vector_size = max_vector_size_in_byte / dtype_size;
9785

@@ -202,8 +190,16 @@ size_t expandVectorizationToContigMergedDomains(
202190

203191
// Merge the domains right of the break point
204192
const auto& ref_root = reference_tv->getMaybeRFactorDomain();
205-
const int num_merged_domains =
193+
const int max_num_merged_domains =
206194
static_cast<int>(ref_root.size()) - static_cast<int>(break_point);
195+
int64_t num_merged_domains = 0;
196+
while (num_merged_domains < max_num_merged_domains) {
197+
auto pos = (int64_t)ref_root.size() - 1 - num_merged_domains;
198+
if (!reference_tv->domain()->contiguity()[pos]) {
199+
break;
200+
}
201+
num_merged_domains++;
202+
}
207203

208204
// No expansion with no merged domain
209205
if (num_merged_domains == 0) {
@@ -242,14 +238,16 @@ size_t expandVectorizationToContigMergedDomains(
242238
const auto& tv_root = tv->getMaybeRFactorDomain();
243239

244240
int tv_num_merged_domains = 0;
245-
for (const auto i : c10::irange(num_merged_domains)) {
241+
for (const auto i : c10::irange(max_num_merged_domains)) {
246242
if (i == tv_root.size()) {
247243
break;
248244
}
249245
auto ref_id = ref_root.at(ref_root.size() - 1 - i);
250-
IterDomain* tv_id = tv_root.at(tv_root.size() - 1 - i);
246+
auto pos = tv_root.size() - 1 - i;
247+
IterDomain* tv_id = tv_root.at(pos);
251248
// If not mapped, stop expanding.
252-
if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT)) {
249+
if (!ca_map.areMapped(ref_id, tv_id, IdMappingMode::EXACT) ||
250+
!tv->domain()->contiguity()[pos]) {
253251
break;
254252
} else {
255253
++tv_num_merged_domains;

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 159 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19138,17 +19138,17 @@ TEST_F(NVFuserTest, FusionChannelsLastParser_CUDA) {
1913819138
// 2. use a fuzzy compare (ignore non-significant whitespaces for example)
1913919139
const std::string expected_kernel = R"(
1914019140
__global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2, Tensor<__half, 4> T7) {
19141-
int64_t i171;
19142-
i171 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
19143-
if ((i171 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) {
19141+
int64_t i165;
19142+
i165 = (((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x);
19143+
if ((i165 < (T0.size[0] * (T0.size[1] * (T0.size[2] * T0.size[3]))))) {
1914419144
__half T9[1];
1914519145
T9[0] = 0;
1914619146
T9[0]
1914719147
= T2[((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) / (T0.size[1] * (T0.size[2] * T0.size[3]))) * ((T0.size[2] * T0.size[1]) * T0.size[3])) + ((((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) % T0.size[3]) * (T0.size[2] * T0.size[1])) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) / (T0.size[2] * T0.size[3])) * T0.size[2]) + (((((((nvfuser_index_t)blockIdx.x) * 128) + ((nvfuser_index_t)threadIdx.x)) % (T0.size[1] * (T0.size[2] * T0.size[3]))) % (T0.size[2] * T0.size[3])) / T0.size[3])];
1914819148
__half T8[1];
1914919149
T8[0] = 0;
1915019150
T8[0]
19151-
= T0[i171];
19151+
= T0[i165];
1915219152
float T3[1];
1915319153
T3[0]
1915419154
= __half2float(T9[0]);
@@ -19168,7 +19168,7 @@ __global__ void CUDAGeneratedKernel(Tensor<__half, 4> T0, Tensor<__half, 4> T2,
1916819168
__half T10[1];
1916919169
T10[0]
1917019170
= __float2half(T6[0]);
19171-
T7[i171]
19171+
T7[i165]
1917219172
= T10[0];
1917319173
}
1917419174
}
@@ -26125,6 +26125,160 @@ TEST_F(NVFuserTest, FusionTrivialInputForwarding_CUDA) {
2612526125
testValidate(fusion, cg_outputs2, {t0, t1}, {t0}, __LINE__, __FILE__);
2612626126
}
2612726127

26128+
namespace {
26129+
26130+
size_t getVecSizeForPointwise(FusionExecutorCache& fec) {
26131+
auto most_recent_params =
26132+
fec.getMostRecentKernelRuntime()->getMostRecentExecutorLog().params;
26133+
auto params = dynamic_cast<PointwiseParams*>(most_recent_params.get());
26134+
if (params->vectorize) {
26135+
return params->unroll_factor;
26136+
}
26137+
return 1;
26138+
}
26139+
26140+
} // namespace
26141+
26142+
TEST_F(NVFuserTest, FusionVectorizeStrideContiguity2D_CUDA) {
26143+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
26144+
auto fusion = fusion_ptr.get();
26145+
FusionGuard fg(fusion);
26146+
26147+
TensorView* tv0 =
26148+
TensorViewBuilder().ndims(2).contiguity({false, true}).build();
26149+
fusion->addInput(tv0);
26150+
auto tv1 = set(tv0);
26151+
fusion->addOutput(tv1);
26152+
26153+
FusionExecutorCache fec(std::move(fusion_ptr));
26154+
fec.profile(true);
26155+
26156+
std::vector<std::pair<int, int>> size_and_vec{{17, 1}, {18, 2}, {32, 4}};
26157+
26158+
for (auto pair : size_and_vec) {
26159+
auto size = pair.first;
26160+
auto vec = pair.second;
26161+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
26162+
at::Tensor t0 = at::randn({1000000, size}, options).narrow(1, 0, 16);
26163+
auto cg_outputs = fec.runFusionWithInputs({t0});
26164+
26165+
TORCH_CHECK(getVecSizeForPointwise(fec) == vec);
26166+
26167+
testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
26168+
}
26169+
}
26170+
26171+
TEST_F(NVFuserTest, FusionVectorizeStrideContiguity3D_CUDA) {
26172+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
26173+
auto fusion = fusion_ptr.get();
26174+
FusionGuard fg(fusion);
26175+
26176+
TensorView* tv0 =
26177+
TensorViewBuilder().ndims(3).contiguity({false, true, true}).build();
26178+
fusion->addInput(tv0);
26179+
auto tv1 = set(tv0);
26180+
fusion->addOutput(tv1);
26181+
26182+
FusionExecutorCache fec(std::move(fusion_ptr));
26183+
fec.profile(true);
26184+
26185+
std::vector<std::pair<int, int>> size_and_vec{{17, 1}, {10, 2}, {16, 4}};
26186+
26187+
for (auto pair : size_and_vec) {
26188+
auto size = pair.first;
26189+
auto vec = pair.second;
26190+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
26191+
at::Tensor t0 = at::randn({1000000, size, 3}, options).narrow(1, 0, 8);
26192+
auto cg_outputs = fec.runFusionWithInputs({t0});
26193+
26194+
TORCH_CHECK(getVecSizeForPointwise(fec) == vec);
26195+
26196+
testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
26197+
}
26198+
}
26199+
26200+
TEST_F(NVFuserTest, FusionVectorizeStrideContiguity5D_CUDA) {
26201+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
26202+
auto fusion = fusion_ptr.get();
26203+
FusionGuard fg(fusion);
26204+
26205+
TensorView* tv0 = TensorViewBuilder()
26206+
.ndims(5)
26207+
.contiguity({false, true, false, true, true})
26208+
.build();
26209+
fusion->addInput(tv0);
26210+
auto tv1 = set(tv0);
26211+
fusion->addOutput(tv1);
26212+
26213+
FusionExecutorCache fec(std::move(fusion_ptr));
26214+
fec.profile(true);
26215+
26216+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
26217+
26218+
std::vector<std::tuple<int, int, int>> sizes_and_vec{
26219+
{9, 17, 1}, {9, 10, 2}, {9, 16, 4}};
26220+
26221+
for (auto tup : sizes_and_vec) {
26222+
auto size1 = std::get<0>(tup);
26223+
auto size2 = std::get<1>(tup);
26224+
auto vec = std::get<2>(tup);
26225+
at::Tensor t0 = at::randn({4, size1, 12345, size2, 3}, options)
26226+
.narrow(1, 0, 8)
26227+
.narrow(3, 0, 4);
26228+
auto cg_outputs = fec.runFusionWithInputs({t0});
26229+
26230+
TORCH_CHECK(getVecSizeForPointwise(fec) == vec);
26231+
26232+
testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
26233+
}
26234+
}
26235+
26236+
TEST_F(NVFuserTest, FusionVectorizeStrideContiguitySelfOverlapping_CUDA) {
26237+
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
26238+
auto fusion = fusion_ptr.get();
26239+
FusionGuard fg(fusion);
26240+
26241+
TensorView* tv0 = TensorViewBuilder()
26242+
.ndims(5)
26243+
.contiguity({false, true, false, true, true})
26244+
.build();
26245+
fusion->addInput(tv0);
26246+
auto tv1 = set(tv0);
26247+
fusion->addOutput(tv1);
26248+
26249+
FusionExecutorCache fec(std::move(fusion_ptr));
26250+
fec.profile(true);
26251+
26252+
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
26253+
26254+
std::vector<std::tuple<int, int, int, int>> sizes_strides_and_vec{
26255+
{4, 4, 4, 4},
26256+
{4, 4, 2, 2},
26257+
{4, 2, 4, 2},
26258+
{2, 4, 4, 2},
26259+
{4, 4, 1, 1},
26260+
{4, 1, 4, 1},
26261+
{1, 4, 4, 1},
26262+
{2, 2, 2, 2},
26263+
{2, 2, 1, 1},
26264+
{2, 1, 2, 1},
26265+
{1, 2, 2, 1}};
26266+
26267+
for (auto tup : sizes_strides_and_vec) {
26268+
auto size = std::get<0>(tup);
26269+
auto stride1 = std::get<1>(tup);
26270+
auto stride2 = std::get<2>(tup);
26271+
auto vec = std::get<3>(tup);
26272+
std::vector<int64_t> shape = {4, 4, 12345, size, 3};
26273+
std::vector<int64_t> stride = {stride1, stride2 * 12345, stride2, 3, 1};
26274+
at::Tensor t0 = at::empty_strided(shape, stride, options);
26275+
t0.random_();
26276+
auto cg_outputs = fec.runFusionWithInputs({t0});
26277+
TORCH_CHECK(getVecSizeForPointwise(fec) == vec);
26278+
testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
26279+
}
26280+
}
26281+
2612826282
} // namespace jit
2612926283
} // namespace torch
2613026284
#endif // #if defined(USE_CUDA)

0 commit comments

Comments
 (0)