@@ -36,6 +36,8 @@ class TORCH_CUDA_CU_API FullOp : public Expr {
3636
3737 FullOp (const FullOp* src, IrCloner* ir_cloner);
3838
39+ Expr* shallowCopy () const override ;
40+
3941 bool sameAs (const Statement* other) const override ;
4042
4143 DataType dtype () const {
@@ -64,6 +66,8 @@ class TORCH_CUDA_CU_API ARangeOp : public Expr {
6466
6567 ARangeOp (const ARangeOp* src, IrCloner* ir_cloner);
6668
69+ Expr* shallowCopy () const override ;
70+
6771 bool sameAs (const Statement* other) const override ;
6872
6973 DataType dtype () const {
@@ -127,6 +131,8 @@ class TORCH_CUDA_CU_API EyeOp : public Expr {
127131
128132 EyeOp (const EyeOp* src, IrCloner* ir_cloner);
129133
134+ Expr* shallowCopy () const override ;
135+
130136 bool sameAs (const Statement* other) const override ;
131137
132138 DataType dtype () const {
@@ -172,6 +178,8 @@ class TORCH_CUDA_CU_API UnaryOp : public Expr {
172178
173179 UnaryOp (const UnaryOp* src, IrCloner* ir_cloner);
174180
181+ Expr* shallowCopy () const override ;
182+
175183 Val* out () const {
176184 return out_;
177185 }
@@ -201,6 +209,8 @@ class TORCH_CUDA_CU_API BinaryOp : public Expr {
201209
202210 BinaryOp (const BinaryOp* src, IrCloner* ir_cloner);
203211
212+ Expr* shallowCopy () const override ;
213+
204214 Val* out () const {
205215 return out_;
206216 }
@@ -239,6 +249,8 @@ class TORCH_CUDA_CU_API RNGOp : public Expr {
239249
240250 RNGOp (const RNGOp* src, IrCloner* ir_cloner);
241251
252+ Expr* shallowCopy () const override ;
253+
242254 RNGOpType getRNGOpType () const {
243255 return rng_op_type_;
244256 }
@@ -298,6 +310,8 @@ class TORCH_CUDA_CU_API BroadcastOp : public Expr {
298310
299311 BroadcastOp (const BroadcastOp* src, IrCloner* ir_cloner);
300312
313+ Expr* shallowCopy () const override ;
314+
301315 Val* out () const {
302316 return out_;
303317 }
@@ -346,6 +360,8 @@ class TORCH_CUDA_CU_API ReductionOp : public Expr {
346360
347361 ReductionOp (const ReductionOp* src, IrCloner* ir_cloner);
348362
363+ Expr* shallowCopy () const override ;
364+
349365 Val* out () const {
350366 return out_;
351367 }
@@ -394,6 +410,8 @@ class TORCH_CUDA_CU_API GroupedReductionOp : public Expr {
394410
395411 GroupedReductionOp (const GroupedReductionOp* src, IrCloner* ir_cloner);
396412
413+ Expr* shallowCopy () const override ;
414+
397415 // ! Number of expressions grouped horizontally. It does not reflect
398416 // ! iteration grouping.
399417 size_t numExprs () const {
@@ -580,6 +598,8 @@ class TORCH_CUDA_CU_API WelfordOp : public Expr {
580598
581599 WelfordOp (const WelfordOp* src, IrCloner* ir_cloner);
582600
601+ Expr* shallowCopy () const override ;
602+
583603 Val* out () const {
584604 return output ().avg ();
585605 }
@@ -675,6 +695,8 @@ class TORCH_CUDA_CU_API GroupedWelfordOp : public Expr {
675695
676696 GroupedWelfordOp (const GroupedWelfordOp* src, IrCloner* ir_cloner);
677697
698+ Expr* shallowCopy () const override ;
699+
678700 // ! Number of expressions grouped horizontally. It does not reflect
679701 // ! iteration grouping. As horizontal grouping is not supported,
680702 // ! this always returns 1.
@@ -798,6 +820,8 @@ class TORCH_CUDA_CU_API MmaOp : public Expr {
798820
799821 MmaOp (const MmaOp* src, IrCloner* ir_cloner);
800822
823+ Expr* shallowCopy () const override ;
824+
801825 Val* out () const {
802826 return out_;
803827 }
@@ -856,6 +880,8 @@ class TORCH_CUDA_CU_API TransposeOp : public Expr {
856880
857881 TransposeOp (const TransposeOp* src, IrCloner* ir_cloner);
858882
883+ Expr* shallowCopy () const override ;
884+
859885 TensorView* out () const {
860886 return out_;
861887 }
@@ -886,6 +912,8 @@ class TORCH_CUDA_CU_API ExpandOp : public Expr {
886912
887913 ExpandOp (const ExpandOp* src, IrCloner* ir_cloner);
888914
915+ Expr* shallowCopy () const override ;
916+
889917 TensorView* out () const {
890918 return out_;
891919 }
@@ -916,6 +944,8 @@ class TORCH_CUDA_CU_API TernaryOp : public Expr {
916944
917945 TernaryOp (const TernaryOp* src, IrCloner* ir_cloner);
918946
947+ Expr* shallowCopy () const override ;
948+
919949 Val* out () const {
920950 return out_;
921951 }
@@ -959,6 +989,8 @@ class TORCH_CUDA_CU_API ShiftOp : public Expr {
959989
960990 ShiftOp (const ShiftOp* src, IrCloner* ir_cloner);
961991
992+ Expr* shallowCopy () const override ;
993+
962994 Val* out () const {
963995 return out_;
964996 }
@@ -1008,6 +1040,8 @@ class TORCH_CUDA_CU_API GatherOp : public Expr {
10081040
10091041 GatherOp (const GatherOp* src, IrCloner* ir_cloner);
10101042
1043+ Expr* shallowCopy () const override ;
1044+
10111045 Val* out () const {
10121046 return out_;
10131047 }
@@ -1054,6 +1088,8 @@ class TORCH_CUDA_CU_API ViewAsScalar : public Expr {
10541088
10551089 ViewAsScalar (const ViewAsScalar* src, IrCloner* ir_cloner);
10561090
1091+ Expr* shallowCopy () const override ;
1092+
10571093 Val* out () const {
10581094 return out_;
10591095 }
@@ -1087,6 +1123,8 @@ class TORCH_CUDA_CU_API ViewOp : public Expr {
10871123
10881124 ViewOp (const ViewOp* src, IrCloner* ir_cloner);
10891125
1126+ Expr* shallowCopy () const override ;
1127+
10901128 TensorView* out () const {
10911129 return out_;
10921130 }
@@ -1112,6 +1150,8 @@ class TORCH_CUDA_CU_API LoadStoreOp : public Expr {
11121150
11131151 LoadStoreOp (const LoadStoreOp* src, IrCloner* ir_cloner);
11141152
1153+ Expr* shallowCopy () const override ;
1154+
11151155 Val* out () const {
11161156 return out_;
11171157 }
@@ -1691,6 +1731,8 @@ class TORCH_CUDA_CU_API Split : public Expr {
16911731
16921732 Split (const Split* src, IrCloner* ir_cloner);
16931733
1734+ Expr* shallowCopy () const override ;
1735+
16941736 IterDomain* outer () const {
16951737 return outer_;
16961738 }
@@ -1751,6 +1793,8 @@ class TORCH_CUDA_CU_API Merge : public Expr {
17511793
17521794 Merge (const Merge* src, IrCloner* ir_cloner);
17531795
1796+ Expr* shallowCopy () const override ;
1797+
17541798 IterDomain* out () const {
17551799 return out_;
17561800 }
@@ -1783,6 +1827,8 @@ class TORCH_CUDA_CU_API Swizzle2D : public Expr {
17831827
17841828 Swizzle2D (const Swizzle2D* src, IrCloner* ir_cloner);
17851829
1830+ Expr* shallowCopy () const override ;
1831+
17861832 IterDomain* outX () const {
17871833 return out_x_;
17881834 }
0 commit comments