Skip to content

Commit 15f2f6d

Browse files
authored
Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (pytorch#1988)
1 parent 8f1c7f5 commit 15f2f6d

File tree

8 files changed

+20
-17
lines changed

8 files changed

+20
-17
lines changed

torch/csrc/jit/codegen/cuda/lower2device.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ void GpuLower::lower(Fusion* fusion, DataType index_type) {
257257
compute_at_map_->validateAndPropagatePType();
258258

259259
// Used in parallel dimension map
260-
concretized_broadcast_domains_.build(fusion_);
260+
concretized_broadcast_domains_ =
261+
std::make_shared<const ConcretizedBroadcastDomains>(fusion_);
261262

262263
parallelDimensionMap().build(fusion_);
263264
if (isDebugDumpEnabled(DebugDumpOption::ParallelDimensions)) {

torch/csrc/jit/codegen/cuda/lower2device.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
6262
//! Query if lowering is in progress
6363
static bool hasCurrent();
6464

65-
ConcretizedBroadcastDomains& concretizedBroadcastDomains() {
65+
std::shared_ptr<const ConcretizedBroadcastDomains>
66+
concretizedBroadcastDomains() {
6667
return concretized_broadcast_domains_;
6768
}
6869

@@ -194,7 +195,8 @@ class TORCH_CUDA_CU_API GpuLower : public NonCopyable {
194195
// would be safer to wrap all of these in unique pointers and remove the build
195196
// interface and default constructor. That way they couldn't be accessed
196197
// without being initialized.
197-
ConcretizedBroadcastDomains concretized_broadcast_domains_;
198+
std::shared_ptr<const ConcretizedBroadcastDomains>
199+
concretized_broadcast_domains_;
198200
ThreadPredicateMap thread_pred_map_;
199201
PredicateElimination pred_elimination_;
200202
std::shared_ptr<ComputeAtMap> compute_at_map_;

torch/csrc/jit/codegen/cuda/lower_sync_information.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void validateParallelizationOfTensor(TensorView* tv) {
2626
// It doesn't matter if this axis is a non-concretized broadcast
2727
// TODO: merging broadcast and non-broadcast
2828
if (axis->isBroadcast() &&
29-
!GpuLower::current()->concretizedBroadcastDomains().isConcretized(
29+
!GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
3030
axis)) {
3131
continue;
3232
}
@@ -195,7 +195,7 @@ void SyncMap::build(Fusion* fusion) {
195195
(!parallel_bcast_doms.get(consumer_ptype) ||
196196
!GpuLower::current()
197197
->concretizedBroadcastDomains()
198-
.isConcretized(consumer_axis))) {
198+
->isConcretized(consumer_axis))) {
199199
continue;
200200
}
201201

@@ -421,7 +421,7 @@ void SyncMap::build(Fusion* fusion) {
421421
.redundant_types;
422422

423423
if (p_id->isBroadcast() &&
424-
GpuLower::current()->concretizedBroadcastDomains().isConcretized(
424+
GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
425425
p_id) &&
426426
producer->getMemoryType() == MemoryType::Shared &&
427427
redundant_preds.hasTID()) {
@@ -436,7 +436,7 @@ void SyncMap::build(Fusion* fusion) {
436436
(!parallel_bcast_doms.get(producer_ptype) ||
437437
!GpuLower::current()
438438
->concretizedBroadcastDomains()
439-
.isConcretized(p_id))) {
439+
->isConcretized(p_id))) {
440440
continue;
441441
}
442442

torch/csrc/jit/codegen/cuda/lower_thread_predicate.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ void ThreadPredicateMap::updateBitSet(const Expr* expr) {
237237
id_reductions.set(id->getParallelType());
238238
}
239239
if (id->isBroadcast() &&
240-
GpuLower::current()->concretizedBroadcastDomains().isConcretized(
240+
GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
241241
id)) {
242242
id_bcasts.set(id->getParallelType());
243243
}
@@ -575,7 +575,8 @@ ParallelTypeBitmap ThreadPredicateMap::getParallelBroadcastDomains(
575575

576576
for (auto id : iter_domains) {
577577
if (!id->isBroadcast() ||
578-
!GpuLower::current()->concretizedBroadcastDomains().isConcretized(id)) {
578+
!GpuLower::current()->concretizedBroadcastDomains()->isConcretized(
579+
id)) {
579580
continue;
580581
}
581582
if (id->isBlockDim() || (!output_smem && id->isThreadDim())) {

torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ namespace jit {
1010
namespace fuser {
1111
namespace cuda {
1212

13-
void ConcretizedBroadcastDomains::build(Fusion* fusion) {
13+
ConcretizedBroadcastDomains::ConcretizedBroadcastDomains(Fusion* fusion) {
1414
exact_map_ = std::make_unique<ExactRootDomainMap>(fusion);
1515

1616
// Initialize the origin map with input broadcast domains

torch/csrc/jit/codegen/cuda/lower_trivial_broadcast.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ namespace cuda {
2323
//! domains are marked as concretized.
2424
class TORCH_CUDA_CU_API ConcretizedBroadcastDomains : private IterVisitor {
2525
public:
26-
void build(Fusion* fusion);
26+
ConcretizedBroadcastDomains() = delete;
27+
ConcretizedBroadcastDomains(Fusion* fusion);
2728

2829
//! Is a domain concretized?
2930
bool isConcretized(IterDomain* id) const;

torch/csrc/jit/codegen/cuda/scheduler/registry.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -789,8 +789,7 @@ static bool checkPatternEquivalence(
789789
// being broadcasted to one size multiple times or different sizes. This is a
790790
// hard to optimize problem and likely indicates we shouldn't be fusing.
791791
bool hasNonUniqueBcast(Fusion* fusion) {
792-
ConcretizedBroadcastDomains concretize_info;
793-
concretize_info.build(fusion);
792+
ConcretizedBroadcastDomains concretize_info(fusion);
794793

795794
for (auto tv : ir_utils::allTvs(fusion)) {
796795
for (auto id : tv->getRootDomain()) {

torch/csrc/jit/codegen/cuda/test/test_gpu.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20889,9 +20889,9 @@ TEST_F(NVFuserTest, FusionBroadcastConcretization1_CUDA) {
2088920889
}
2089020890

2089120891
GpuLower gpulw(&fusion);
20892-
TORCH_CHECK(!gpulw.concretizedBroadcastDomains().isConcretized(
20892+
TORCH_CHECK(!gpulw.concretizedBroadcastDomains()->isConcretized(
2089320893
loweredTv(tv4, gpulw)->axis(1)));
20894-
TORCH_CHECK(gpulw.concretizedBroadcastDomains().isConcretized(
20894+
TORCH_CHECK(gpulw.concretizedBroadcastDomains()->isConcretized(
2089520895
loweredTv(tv7, gpulw)->axis(1)));
2089620896

2089720897
auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
@@ -21079,8 +21079,7 @@ TEST_F(NVFuserTest, FusionBroadcastConcretization5_CUDA) {
2107921079
auto tvs3 = Welford(tv17, {1});
2108021080
fusion.addOutput(tvs3.avg);
2108121081

21082-
ConcretizedBroadcastDomains bcast_concretization_info;
21083-
bcast_concretization_info.build(&fusion);
21082+
ConcretizedBroadcastDomains bcast_concretization_info(&fusion);
2108421083

2108521084
TORCH_CHECK(
2108621085
bcast_concretization_info.maybeNonUniquelyConcretized(tv5->axis(1)),

0 commit comments

Comments
 (0)