@@ -83,9 +83,9 @@ void IRElementwiseSchedule(ir::IRSchedule &ir_sch, // NOLINT
8383 if (size <= target.max_num_threads ()) {
8484 ir_sch.Bind (loop, " threadIdx.x" );
8585 } else {
86- auto splited = ir_sch.Split (loop, {-1 , target.max_num_threads ()});
87- ir_sch.Bind (splited [0 ], " blockIdx.x" );
88- ir_sch.Bind (splited [1 ], " threadIdx.x" );
86+ auto split = ir_sch.Split (loop, {-1 , target.max_num_threads ()});
87+ ir_sch.Bind (split [0 ], " blockIdx.x" );
88+ ir_sch.Bind (split [1 ], " threadIdx.x" );
8989 }
9090 };
9191 target.arch .Match (
@@ -117,9 +117,9 @@ void IRInjectiveSchedule(ir::IRSchedule &ir_sch, // NOLINT
117117 if (size <= target.max_num_threads ()) {
118118 ir_sch.Bind (loop, " threadIdx.x" );
119119 } else {
120- auto splited = ir_sch.Split (loop, {-1 , target.max_num_threads ()});
121- ir_sch.Bind (splited [0 ], " blockIdx.x" );
122- ir_sch.Bind (splited [1 ], " threadIdx.x" );
120+ auto split = ir_sch.Split (loop, {-1 , target.max_num_threads ()});
121+ ir_sch.Bind (split [0 ], " blockIdx.x" );
122+ ir_sch.Bind (split [1 ], " threadIdx.x" );
123123 }
124124 };
125125 target.arch .Match (
@@ -172,10 +172,10 @@ void IRScheduleInjectiveCPU(ir::IRSchedule &ir_sch, // NOLINT
172172 auto loops = ir_sch.GetLoops(all_blocks[0]);
173173 int last_shape = ir::GetLoopExtent(loops.back());
174174 factor = GetVectorizeFactor(last_shape, factor);
175- auto splited = ir_sch.Split(loops.back(), {-1, factor});
176- ir_sch.Vectorize(splited [1], factor);
175+ auto split = ir_sch.Split(loops.back(), {-1, factor});
176+ ir_sch.Vectorize(split [1], factor);
177177 if (dims == 1) {
178- ir_sch.Parallel(splited [0]);
178+ ir_sch.Parallel(split [0]);
179179 }
180180 } */
181181 VLOG (3 ) << " After IRScheduleInjectiveCPU, new ir is : "
@@ -195,9 +195,9 @@ void IRGpuScheduleInjective(ir::IRSchedule &ir_sch, // NOLINT
195195 int prod_size = std::accumulate (
196196 output_shape.begin (), output_shape.end (), 1 , std::multiplies<int >());
197197 if (prod_size > num_thread) {
198- auto splited = ir_sch.Split (fused, {-1 , num_thread});
199- ir_sch.Bind (splited [0 ], " blockIdx.x" );
200- ir_sch.Bind (splited [1 ], " threadIdx.x" );
198+ auto split = ir_sch.Split (fused, {-1 , num_thread});
199+ ir_sch.Bind (split [0 ], " blockIdx.x" );
200+ ir_sch.Bind (split [1 ], " threadIdx.x" );
201201 } else {
202202 ir_sch.Bind (fused, " threadIdx.x" );
203203 }
@@ -242,9 +242,9 @@ std::vector<cinn::common::CINNValue> IRGpuScheduleMatMul(
242242 auto loops = ir_sch.GetLoops (init_block);
243243 if (loops.size () == 1 ) {
244244 if (ir::GetLoopExtent (loops[0 ]) > num_thread) {
245- auto splited = ir_sch.Split (loops[0 ], {-1 , num_thread});
246- ir_sch.Bind (splited [0 ], " blockIdx.x" );
247- ir_sch.Bind (splited [1 ], " threadIdx.x" );
245+ auto split = ir_sch.Split (loops[0 ], {-1 , num_thread});
246+ ir_sch.Bind (split [0 ], " blockIdx.x" );
247+ ir_sch.Bind (split [1 ], " threadIdx.x" );
248248 } else {
249249 ir_sch.Bind (loops[0 ], " threadIdx.x" );
250250 }
@@ -273,7 +273,7 @@ void IRCudaScheduleMul(ir::IRSchedule &ir_sch, // NOLINT
273273 2U ,
274274 ::common::errors::InvalidArgument (
275275 " The size of loops should be greater than 2." ));
276- auto splited = ir_sch.Split (loops[1 ], {-1 , 2 });
276+ auto split = ir_sch.Split (loops[1 ], {-1 , 2 });
277277 all_blocks = ir_sch.GetAllBlocks ();
278278 loops = ir_sch.GetLoops (all_blocks.back ());
279279 ir_sch.Bind (loops[0 ], " blockIdx.x" );
@@ -349,15 +349,14 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, // NOLINT
349349
350350 if (tsize > target.max_num_threads ()) {
351351 // split [-1, 256]
352- auto splited = ir_sch.Split (ir_sch.GetLoops (block_name)[0 ],
353- {-1 , target.max_num_threads () / 4 });
354- ir_sch.Bind (splited [0 ], " blockIdx.x" );
355- ir_sch.Bind (splited [1 ], " threadIdx.x" );
352+ auto split = ir_sch.Split (ir_sch.GetLoops (block_name)[0 ],
353+ {-1 , target.max_num_threads () / 4 });
354+ ir_sch.Bind (split [0 ], " blockIdx.x" );
355+ ir_sch.Bind (split [1 ], " threadIdx.x" );
356356 } else {
357- auto splited =
358- ir_sch.Split (ir_sch.GetLoops (block_name)[0 ], {1 , tsize});
359- ir_sch.Bind (splited[0 ], " blockIdx.x" );
360- ir_sch.Bind (splited[1 ], " threadIdx.x" );
357+ auto split = ir_sch.Split (ir_sch.GetLoops (block_name)[0 ], {1 , tsize});
358+ ir_sch.Bind (split[0 ], " blockIdx.x" );
359+ ir_sch.Bind (split[1 ], " threadIdx.x" );
361360 }
362361 }
363362 } else {
@@ -373,15 +372,15 @@ void IRCudaSplitSchedule(ir::IRSchedule &ir_sch, // NOLINT
373372 auto tsize = first_loop.As <ir::For>()->extent .as_int32 ();
374373 if (tsize > target.max_num_threads ()) {
375374 // split [-1, 256]
376- auto splited = ir_sch.Split (ir_sch.GetLoops (block_names[idx])[0 ],
377- {-1 , target.max_num_threads () / 4 });
378- ir_sch.Bind (splited [0 ], " blockIdx.x" );
379- ir_sch.Bind (splited [1 ], " threadIdx.x" );
375+ auto split = ir_sch.Split (ir_sch.GetLoops (block_names[idx])[0 ],
376+ {-1 , target.max_num_threads () / 4 });
377+ ir_sch.Bind (split [0 ], " blockIdx.x" );
378+ ir_sch.Bind (split [1 ], " threadIdx.x" );
380379 } else {
381- auto splited =
380+ auto split =
382381 ir_sch.Split (ir_sch.GetLoops (block_names[idx])[0 ], {1 , tsize});
383- ir_sch.Bind (splited [0 ], " blockIdx.x" );
384- ir_sch.Bind (splited [1 ], " threadIdx.x" );
382+ ir_sch.Bind (split [0 ], " blockIdx.x" );
383+ ir_sch.Bind (split [1 ], " threadIdx.x" );
385384 }
386385 }
387386 }
@@ -1180,9 +1179,9 @@ void IRPoolScheduleGPU(ir::IRSchedule &ir_sch, // NOLINT
11801179 // Blocks were changed after Fuse, so we have to get all blocks again.
11811180 all_blocks = ir_sch.GetAllBlocks ();
11821181 loops = ir_sch.GetLoops (all_blocks[0 ]);
1183- auto splited = ir_sch.Split (loops[0 ], {-1 , 1024 });
1184- ir_sch.Bind (splited [0 ], " blockIdx.x" );
1185- ir_sch.Bind (splited [1 ], " threadIdx.x" );
1182+ auto split = ir_sch.Split (loops[0 ], {-1 , 1024 });
1183+ ir_sch.Bind (split [0 ], " blockIdx.x" );
1184+ ir_sch.Bind (split [1 ], " threadIdx.x" );
11861185 VLOG (3 ) << " End IRPoolScheduleGPU: " << ir_sch.GetModule ().GetExprs ().at (0 );
11871186}
11881187
@@ -1198,14 +1197,14 @@ void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, // NOLINT
11981197 auto loops = ir_sch.GetLoops (all_blocks[1 ]);
11991198 if (loops.size () > 1 ) {
12001199 auto fused = ir_sch.Fuse (all_blocks[0 ], {0 , 1 });
1201- auto splited = ir_sch.Split (fused, {-1 , 32 });
1200+ auto split = ir_sch.Split (fused, {-1 , 32 });
12021201 all_blocks = ir_sch.GetAllBlocks ();
12031202 fused = ir_sch.Fuse (all_blocks[1 ], {0 , 1 });
1204- splited = ir_sch.Split (fused, {-1 , 32 });
1205- ir_sch.Bind (splited [0 ], " blockIdx.x" );
1206- ir_sch.Bind (splited [1 ], " threadIdx.y" );
1203+ split = ir_sch.Split (fused, {-1 , 32 });
1204+ ir_sch.Bind (split [0 ], " blockIdx.x" );
1205+ ir_sch.Bind (split [1 ], " threadIdx.y" );
12071206 all_blocks = ir_sch.GetAllBlocks ();
1208- ir_sch.SimpleComputeAt (all_blocks[0 ], splited [1 ]);
1207+ ir_sch.SimpleComputeAt (all_blocks[0 ], split [1 ]);
12091208 all_blocks = ir_sch.GetAllBlocks ();
12101209 ir_sch.SetBuffer (all_blocks[0 ], " local" , true );
12111210 loops = ir_sch.GetLoops (all_blocks[0 ]);
@@ -1218,15 +1217,15 @@ void IRGlobalPoolScheduleGPU(ir::IRSchedule &ir_sch, // NOLINT
12181217 ir_sch.Bind (loops[2 ], " threadIdx.x" );
12191218 } else {
12201219 loops = ir_sch.GetLoops (all_blocks[0 ]);
1221- auto splited = ir_sch.Split (loops[0 ], {-1 , 32 });
1220+ auto split = ir_sch.Split (loops[0 ], {-1 , 32 });
12221221 all_blocks = ir_sch.GetAllBlocks ();
12231222 loops = ir_sch.GetLoops (all_blocks[1 ]);
1224- splited = ir_sch.Split (loops[0 ], {-1 , 32 });
1225- ir_sch.Bind (splited [0 ], " blockIdx.x" );
1226- ir_sch.Bind (splited [1 ], " threadIdx.y" );
1223+ split = ir_sch.Split (loops[0 ], {-1 , 32 });
1224+ ir_sch.Bind (split [0 ], " blockIdx.x" );
1225+ ir_sch.Bind (split [1 ], " threadIdx.y" );
12271226 all_blocks = ir_sch.GetAllBlocks ();
1228- splited = ir_sch.GetLoops (all_blocks[1 ]);
1229- ir_sch.SimpleComputeAt (all_blocks[0 ], splited [1 ]);
1227+ split = ir_sch.GetLoops (all_blocks[1 ]);
1228+ ir_sch.SimpleComputeAt (all_blocks[0 ], split [1 ]);
12301229 all_blocks = ir_sch.GetAllBlocks ();
12311230 ir_sch.SetBuffer (all_blocks[0 ], " local" , true );
12321231 loops = ir_sch.GetLoops (all_blocks[0 ]);
0 commit comments