Skip to content

Commit e040676

Browse files
authored
Use withPredicate to replace setPredicate to maintain Exprs immutable (pytorch#2025)
1 parent 197221b commit e040676

File tree

15 files changed

+616
-115
lines changed

15 files changed

+616
-115
lines changed

torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,12 @@ void Expr::setPredicate(kir::Predicate* predicate) {
341341
predicate_ = predicate;
342342
}
343343

344+
Expr* Expr::withPredicate(kir::Predicate* predicate) {
345+
auto result = shallowCopy();
346+
result->setPredicate(predicate);
347+
return result;
348+
}
349+
344350
kir::Predicate* Expr::writePredicate() const {
345351
TORCH_INTERNAL_ASSERT(
346352
container()->isA<kir::Kernel>(), "Function invalid for fusion.");
@@ -353,6 +359,19 @@ void Expr::setWritePredicate(kir::Predicate* write_predicate) {
353359
write_predicate_ = write_predicate;
354360
}
355361

362+
Expr* Expr::withWritePredicate(kir::Predicate* predicate) {
363+
auto result = shallowCopy();
364+
result->setWritePredicate(predicate);
365+
return result;
366+
}
367+
368+
void Expr::copyPredicatesFrom(const Expr* expr) {
369+
if (container()->isA<kir::Kernel>()) {
370+
predicate_ = expr->predicate_;
371+
write_predicate_ = expr->write_predicate_;
372+
}
373+
}
374+
356375
} // namespace cuda
357376
} // namespace fuser
358377
} // namespace jit

torch/csrc/jit/codegen/cuda/ir_base_nodes.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,10 @@ class TORCH_CUDA_CU_API Expr : public Statement {
426426

427427
Expr(const Expr* src, IrCloner* ir_cloner);
428428

429+
// Creates a new instance of the expression with all its field copied.
430+
// Note that unlike IrCloner, this function only do a shallow copy
431+
virtual Expr* shallowCopy() const = 0;
432+
429433
c10::optional<ExprType> getExprType() const override {
430434
return etype_;
431435
}
@@ -466,16 +470,27 @@ class TORCH_CUDA_CU_API Expr : public Statement {
466470
// TODO: Protect based on being in kernel container
467471
kir::Predicate* predicate() const;
468472

473+
// Creates a shallow copy the expression with the given predicate attached.
469474
// TODO: Protect based on being in kernel container
470-
void setPredicate(kir::Predicate* predicate);
475+
Expr* withPredicate(kir::Predicate* predicate);
471476

472477
// TODO: Protect based on being in kernel container
473478
kir::Predicate* writePredicate() const;
474479

480+
// Creates a shallow copy the expression with the given write-predicate
481+
// attached.
475482
// TODO: Protect based on being in kernel container
476-
void setWritePredicate(kir::Predicate* write_predicate);
483+
Expr* withWritePredicate(kir::Predicate* write_predicate);
477484

478485
protected:
486+
// TODO: Protect based on being in kernel container
487+
void setPredicate(kir::Predicate* predicate);
488+
489+
// TODO: Protect based on being in kernel container
490+
void setWritePredicate(kir::Predicate* write_predicate);
491+
492+
void copyPredicatesFrom(const Expr* expr);
493+
479494
// TODO: Add Fusion passkey
480495
void addInput(Val* input) {
481496
TORCH_INTERNAL_ASSERT(input != nullptr);

torch/csrc/jit/codegen/cuda/ir_internal_nodes.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class TORCH_CUDA_CU_API FullOp : public Expr {
3636

3737
FullOp(const FullOp* src, IrCloner* ir_cloner);
3838

39+
Expr* shallowCopy() const override;
40+
3941
bool sameAs(const Statement* other) const override;
4042

4143
DataType dtype() const {
@@ -64,6 +66,8 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr {
6466

6567
ARangeOp(const ARangeOp* src, IrCloner* ir_cloner);
6668

69+
Expr* shallowCopy() const override;
70+
6771
bool sameAs(const Statement* other) const override;
6872

6973
DataType dtype() const {
@@ -127,6 +131,8 @@ class TORCH_CUDA_CU_API EyeOp : public Expr {
127131

128132
EyeOp(const EyeOp* src, IrCloner* ir_cloner);
129133

134+
Expr* shallowCopy() const override;
135+
130136
bool sameAs(const Statement* other) const override;
131137

132138
DataType dtype() const {
@@ -172,6 +178,8 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr {
172178

173179
UnaryOp(const UnaryOp* src, IrCloner* ir_cloner);
174180

181+
Expr* shallowCopy() const override;
182+
175183
Val* out() const {
176184
return out_;
177185
}
@@ -201,6 +209,8 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr {
201209

202210
BinaryOp(const BinaryOp* src, IrCloner* ir_cloner);
203211

212+
Expr* shallowCopy() const override;
213+
204214
Val* out() const {
205215
return out_;
206216
}
@@ -239,6 +249,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
239249

240250
RNGOp(const RNGOp* src, IrCloner* ir_cloner);
241251

252+
Expr* shallowCopy() const override;
253+
242254
RNGOpType getRNGOpType() const {
243255
return rng_op_type_;
244256
}
@@ -298,6 +310,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr {
298310

299311
BroadcastOp(const BroadcastOp* src, IrCloner* ir_cloner);
300312

313+
Expr* shallowCopy() const override;
314+
301315
Val* out() const {
302316
return out_;
303317
}
@@ -346,6 +360,8 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr {
346360

347361
ReductionOp(const ReductionOp* src, IrCloner* ir_cloner);
348362

363+
Expr* shallowCopy() const override;
364+
349365
Val* out() const {
350366
return out_;
351367
}
@@ -394,6 +410,8 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {
394410

395411
GroupedReductionOp(const GroupedReductionOp* src, IrCloner* ir_cloner);
396412

413+
Expr* shallowCopy() const override;
414+
397415
//! Number of expressions grouped horizontally. It does not reflect
398416
//! iteration grouping.
399417
size_t numExprs() const {
@@ -580,6 +598,8 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr {
580598

581599
WelfordOp(const WelfordOp* src, IrCloner* ir_cloner);
582600

601+
Expr* shallowCopy() const override;
602+
583603
Val* out() const {
584604
return output().avg();
585605
}
@@ -675,6 +695,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr {
675695

676696
GroupedWelfordOp(const GroupedWelfordOp* src, IrCloner* ir_cloner);
677697

698+
Expr* shallowCopy() const override;
699+
678700
//! Number of expressions grouped horizontally. It does not reflect
679701
//! iteration grouping. As horizontal grouping is not supported,
680702
//! this always returns 1.
@@ -798,6 +820,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
798820

799821
MmaOp(const MmaOp* src, IrCloner* ir_cloner);
800822

823+
Expr* shallowCopy() const override;
824+
801825
Val* out() const {
802826
return out_;
803827
}
@@ -856,6 +880,8 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr {
856880

857881
TransposeOp(const TransposeOp* src, IrCloner* ir_cloner);
858882

883+
Expr* shallowCopy() const override;
884+
859885
TensorView* out() const {
860886
return out_;
861887
}
@@ -886,6 +912,8 @@ class TORCH_CUDA_CU_API ExpandOp : public Expr {
886912

887913
ExpandOp(const ExpandOp* src, IrCloner* ir_cloner);
888914

915+
Expr* shallowCopy() const override;
916+
889917
TensorView* out() const {
890918
return out_;
891919
}
@@ -916,6 +944,8 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr {
916944

917945
TernaryOp(const TernaryOp* src, IrCloner* ir_cloner);
918946

947+
Expr* shallowCopy() const override;
948+
919949
Val* out() const {
920950
return out_;
921951
}
@@ -959,6 +989,8 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr {
959989

960990
ShiftOp(const ShiftOp* src, IrCloner* ir_cloner);
961991

992+
Expr* shallowCopy() const override;
993+
962994
Val* out() const {
963995
return out_;
964996
}
@@ -1008,6 +1040,8 @@ class TORCH_CUDA_CU_API GatherOp : public Expr {
10081040

10091041
GatherOp(const GatherOp* src, IrCloner* ir_cloner);
10101042

1043+
Expr* shallowCopy() const override;
1044+
10111045
Val* out() const {
10121046
return out_;
10131047
}
@@ -1054,6 +1088,8 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr {
10541088

10551089
ViewAsScalar(const ViewAsScalar* src, IrCloner* ir_cloner);
10561090

1091+
Expr* shallowCopy() const override;
1092+
10571093
Val* out() const {
10581094
return out_;
10591095
}
@@ -1087,6 +1123,8 @@ class TORCH_CUDA_CU_API ViewOp : public Expr {
10871123

10881124
ViewOp(const ViewOp* src, IrCloner* ir_cloner);
10891125

1126+
Expr* shallowCopy() const override;
1127+
10901128
TensorView* out() const {
10911129
return out_;
10921130
}
@@ -1112,6 +1150,8 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr {
11121150

11131151
LoadStoreOp(const LoadStoreOp* src, IrCloner* ir_cloner);
11141152

1153+
Expr* shallowCopy() const override;
1154+
11151155
Val* out() const {
11161156
return out_;
11171157
}
@@ -1691,6 +1731,8 @@ class TORCH_CUDA_CU_API Split : public Expr {
16911731

16921732
Split(const Split* src, IrCloner* ir_cloner);
16931733

1734+
Expr* shallowCopy() const override;
1735+
16941736
IterDomain* outer() const {
16951737
return outer_;
16961738
}
@@ -1751,6 +1793,8 @@ class TORCH_CUDA_CU_API Merge : public Expr {
17511793

17521794
Merge(const Merge* src, IrCloner* ir_cloner);
17531795

1796+
Expr* shallowCopy() const override;
1797+
17541798
IterDomain* out() const {
17551799
return out_;
17561800
}
@@ -1783,6 +1827,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
17831827

17841828
Swizzle2D(const Swizzle2D* src, IrCloner* ir_cloner);
17851829

1830+
Expr* shallowCopy() const override;
1831+
17861832
IterDomain* outX() const {
17871833
return out_x_;
17881834
}

0 commit comments

Comments
 (0)