@@ -268,6 +268,21 @@ def set_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs):
268268 thus a sub_program's vars might be used as the op's input of the later sub_program,
269269 and these vars cannot be gc after executing current sub_program.
270270 """
271+ if paddle .base .framework .get_flags ("FLAGS_enable_pir_api" )[
272+ "FLAGS_enable_pir_api"
273+ ]:
274+ return _set_skip_gc_vars_in_pir (
275+ num_micro_batches , job_types , sub_programs , jobs
276+ )
277+ else :
278+ return _set_skip_gc_vars_in_old_ir (
279+ num_micro_batches , job_types , sub_programs , jobs
280+ )
281+
282+
283+ def _set_skip_gc_vars_in_old_ir (
284+ num_micro_batches , job_types , sub_programs , jobs
285+ ):
271286 assert num_micro_batches >= 1 , "num_micro_batches needs to be >= 1"
272287 type_to_program = dict (zip (job_types , sub_programs ))
273288
@@ -300,32 +315,54 @@ def set_skip_gc_vars(num_micro_batches, job_types, sub_programs, jobs):
300315 return type_to_program
301316
302317
303- def set_pir_skip_gc_vars (num_micro_batches , job_types , sub_programs , jobs ):
318+ def _set_skip_gc_vars_in_pir (num_micro_batches , job_types , sub_programs , jobs ):
304319 assert num_micro_batches >= 1 , "num_micro_batches needs to be >= 1"
305- type_to_var_names = {}
306320 type_to_program = dict (zip (job_types , sub_programs ))
307- for job_type , program in type_to_program .items ():
308- type_to_var_names [job_type ] = set ()
309- ops = program .global_block ().ops
310- for op in ops :
311- if op .name () == "builtin.shadow_output" :
312- # if a value is renamed by shadow_output,
313- # it will be used by other sub_programs
314- type_to_var_names [job_type ].add (op .attrs ()["output_name" ])
315- if job_type in ["backward" , "backward_w" ]:
316- assert (
317- len (type_to_var_names [job_type ]) == 0
318- ), f"The { job_type } sub_program can't have skip_gc_vars. But it is { type_to_var_names [job_type ]} ."
319321
322+ # step1: Get all required vars of every sub_program that are non-persistable and not in op's no_need_buffer.
323+ type_to_required_vars = {}
320324 no_need_buffer_vars = core .get_no_need_buffer_values (type_to_program )
325+ for job_type , program in type_to_program .items ():
326+ required_vars = set ()
327+ persistable_vars = set ()
328+ for key in program .global_block ().kwargs ():
329+ required_vars .add (key )
330+ for op in program .global_block ().ops :
331+ for var in op .operands_source ():
332+ if var .has_name :
333+ required_vars .add (var .name )
334+ if var .persistable :
335+ persistable_vars .add (var .name )
336+ for var in op .results ():
337+ if var .has_name :
338+ required_vars .add (var .name )
339+ if var .persistable :
340+ persistable_vars .add (var .name )
341+ if job_type in no_need_buffer_vars :
342+ required_vars -= no_need_buffer_vars [job_type ]
343+ required_vars -= persistable_vars
344+ type_to_required_vars [job_type ] = required_vars
321345
322- for job_type , var_set in no_need_buffer_vars . items ():
323- if len ( var_set ) > 0 :
324- type_to_var_names [ job_type ] = type_to_var_names [ job_type ] - var_set
325-
326- for job in jobs :
346+ # step2: Set `skip_gc_vars` for each job
347+ suffixed_required_vars = [ set () for i in range ( num_micro_batches )]
348+ num_jobs = len ( jobs )
349+ for job_id in reversed ( range ( num_jobs )):
350+ job = jobs [ job_id ]
327351 job_type = job .type ()
328- job .set_skip_gc_vars (type_to_var_names [job_type ])
352+ required_vars = type_to_required_vars [job_type ]
353+ micro_batch_id = job .micro_batch_id ()
354+ skip_gc_vars = required_vars & suffixed_required_vars [micro_batch_id ]
355+ logger .debug (
356+ f"Skip gc vars for { job_type } -({ micro_batch_id } ): { skip_gc_vars } "
357+ )
358+
359+ if job_type in ["send_backward" , "backward_w" ]:
360+ assert (
361+ len (skip_gc_vars ) == 0
362+ ), f"When enabling pipeline parallelism strategy, the skip_gc_vars for { job_type } subprogram must be empty, but it is { skip_gc_vars } ."
363+
364+ job .set_skip_gc_vars (skip_gc_vars )
365+ suffixed_required_vars [micro_batch_id ] |= required_vars
329366
330367 return type_to_program
331368
0 commit comments