1414
1515import logging
1616
17+ import paddle
1718from paddle .base import core
1819from paddle .distributed .auto_parallel .static .cost import calc_time_by_cost_model
20+ from paddle .framework import (
21+ _current_expected_place_ as _get_device ,
22+ )
1923
2024from ...utils .log_utils import get_logger
2125from ..pass_base import register_pass
2226from ..pass_utils import (
2327 AutoParallelStreamType ,
2428 _add_event_dependency ,
2529 _program_for_fthenb_and_1f1b ,
26- _split_program_into_forward_backward_optimize ,
30+ forward_complete_op_role ,
2731 split_program ,
2832)
2933from .pipeline_pass_base import PipelinePassBase
3034
35+ RECV_FORWARD = "recv_forward"
3136FORWARD = "forward"
3237BACKWARD = "backward"
38+ SEND_BACKWARD = "send_backward"
3339OPT = "optimizer"
3440
3541logger = get_logger (logging .INFO )
3642
3743
3844@register_pass ("pipeline_scheduler_1F1B" )
3945class Pipeline1F1BPass (PipelinePassBase ):
46+
4047 def __init__ (self ):
4148 super ().__init__ ()
4249 self .jobs_in_stable_phase = [BACKWARD , FORWARD ]
50+ self .jobs_in_stable_phase_in_pir = [
51+ BACKWARD ,
52+ RECV_FORWARD ,
53+ SEND_BACKWARD ,
54+ FORWARD ,
55+ ]
4356 self .set_attr ("enable_backward_forward_overlap" , 0 )
4457
4558 # Backward-forward overlapping splits and rearranges jobs for pattern Bi-Fj.
@@ -168,6 +181,9 @@ def _backward_forward_overlap(self, backward_program, forward_program):
168181 )
169182
170183 def _create_job_list (self ):
184+ if self ._in_pir_mode :
185+ return self ._create_job_list_in_pir ()
186+
171187 num_micro_batches = self .get_attr ("num_micro_batches" )
172188 pp_stage = self .get_attr ("pp_stage" )
173189 pp_degree = self .get_attr ("pp_degree" )
@@ -212,6 +228,61 @@ def _create_job_list(self):
212228 job_list .append (opt_job )
213229 return job_list
214230
231+ def _create_job_list_in_pir (self ):
232+ num_micro_batches = self .get_attr ("num_micro_batches" )
233+ pp_stage = self .get_attr ("pp_stage" )
234+ pp_degree = self .get_attr ("pp_degree" )
235+
236+ job_list = []
237+ assert (
238+ pp_degree <= num_micro_batches
239+ ), "Num of micro batches should larger than or equal to pp degree."
240+
241+ micro_batch_in_warmup = pp_degree - pp_stage
242+ micro_batch_in_1f1b = num_micro_batches - micro_batch_in_warmup
243+
244+ forward_micro_batch_id = 0
245+ for i in range (micro_batch_in_warmup ):
246+ recv_fwd_job = core .Job (RECV_FORWARD )
247+ recv_fwd_job .set_micro_batch_id (forward_micro_batch_id )
248+ job_list .append (recv_fwd_job )
249+
250+ forward_job = core .Job (FORWARD )
251+ forward_job .set_micro_batch_id (forward_micro_batch_id )
252+ job_list .append (forward_job )
253+ forward_micro_batch_id += 1
254+
255+ backward_micro_batch_id = 0
256+ for i in range (micro_batch_in_1f1b ):
257+ for job_type in self .jobs_in_stable_phase_in_pir :
258+ job = core .Job (job_type )
259+ micro_batch_id = (
260+ forward_micro_batch_id
261+ if job_type .startswith (FORWARD )
262+ or job_type .startswith (RECV_FORWARD )
263+ else backward_micro_batch_id
264+ )
265+ job .set_micro_batch_id (micro_batch_id )
266+ job_list .append (job )
267+ forward_micro_batch_id += 1
268+ backward_micro_batch_id += 1
269+
270+ for i in range (micro_batch_in_warmup ):
271+ backward_job = core .Job (BACKWARD )
272+ backward_job .set_micro_batch_id (backward_micro_batch_id )
273+ job_list .append (backward_job )
274+
275+ send_bwd_job = core .Job (SEND_BACKWARD )
276+ send_bwd_job .set_micro_batch_id (backward_micro_batch_id )
277+ job_list .append (send_bwd_job )
278+
279+ backward_micro_batch_id += 1
280+
281+ opt_job = core .Job (OPT )
282+ opt_job .set_micro_batch_id (0 )
283+ job_list .append (opt_job )
284+ return job_list
285+
215286 def _multistreaming_for_overlapping (self , programs , job_type ):
216287 num_programs = len (programs )
217288 higher_stream_priority = - 1
@@ -343,16 +414,17 @@ def _partial_pir_programs(self, program):
343414 not enable_send_recv_overlap
344415 ), "PIR does not support 1F1B with enable_send_recv_overlap yet."
345416
346- types = [FORWARD , BACKWARD , OPT ]
347- sub_program_list = _split_program_into_forward_backward_optimize (
348- program , enable_send_recv_overlap
349- )
417+ types = [RECV_FORWARD , FORWARD , BACKWARD , SEND_BACKWARD , OPT ]
418+ prog_splitter = ProgramSplitter (program , types )
419+ sub_program_list = prog_splitter .split_programs ()
350420
351421 for i in range (len (types )):
352422 logger .debug (
353423 f"type = { types [i ]} , sub_programs = { sub_program_list [i ]} \n "
354424 )
355- logger .debug (f"jobs_in_stable_phase = { self .jobs_in_stable_phase } " )
425+ logger .debug (
426+ f"jobs_in_stable_phase = { self .jobs_in_stable_phase_in_pir } "
427+ )
356428 return types , sub_program_list
357429
358430 def _split_program_for_overlapping (self , job_type , program , split_points ):
@@ -376,3 +448,188 @@ def is_comm_op_valid_to_overlap(self, op):
376448 and op .dist_attr .execution_stream
377449 == AutoParallelStreamType .CALC_STREAM .value
378450 )
451+
452+
453+ class ProgramSplitter :
454+ def __init__ (self , main_program , job_types ):
455+ assert job_types == [
456+ RECV_FORWARD ,
457+ FORWARD ,
458+ BACKWARD ,
459+ SEND_BACKWARD ,
460+ OPT ,
461+ ]
462+ self ._overlap_send_recv (main_program )
463+ forward_complete_op_role (main_program )
464+ self .job_types = job_types
465+ self .complete_ops = main_program .global_block ().ops
466+ self .programs = self ._clone_programs (main_program )
467+ self .ops_dict = {
468+ key : prog .global_block ().ops for key , prog in self .programs .items ()
469+ }
470+ self .blocks_dict = {
471+ key : prog .global_block () for key , prog in self .programs .items ()
472+ }
473+
474+ self .cur_place = self ._get_cur_place ()
475+
476+ def _overlap_send_recv (self , program ):
477+ # TODO(liym27): This function should not be in ProgramSplitter, move it to pipeline_pass_base.py after vpp fixed.
478+ for block in program .blocks :
479+ for op in block .ops :
480+ if op .name () == "pd_op.send_v2" :
481+ op .set_bool_attr ("dynamic_shape" , False )
482+ op .set_bool_attr ("use_calc_stream" , True )
483+ ring_id = op .attrs ()["ring_id" ]
484+ op .set_execution_stream ("send_recv_stream" )
485+ op .set_scheduling_priority (0 )
486+ elif op .name () == "pd_op.recv_v2" :
487+ op .set_bool_attr ("dynamic_shape" , False )
488+ op .set_bool_attr ("use_calc_stream" , True )
489+ op .set_execution_stream ("send_recv_stream" )
490+ op .set_scheduling_priority (0 )
491+
492+ def _clone_programs (self , program ):
493+ prog_dict = {}
494+ for job_type in self .job_types :
495+ prog_dict [job_type ] = program .clone ()
496+ return prog_dict
497+
498+ def _get_cur_place (self ):
499+ place = _get_device ()
500+ if isinstance (place , paddle .framework .CUDAPlace ):
501+ place = paddle .framework .CUDAPlace (
502+ paddle .distributed .ParallelEnv ().dev_id
503+ )
504+ cur_place = paddle .base .libpaddle .Place ()
505+ cur_place .set_place (place )
506+ return cur_place
507+
508+ def split_programs (self ):
509+ region = "opt"
510+ for op_idx in range (len (self .complete_ops ) - 1 , - 1 , - 1 ):
511+ op = self .complete_ops [op_idx ]
512+ if op .op_role != - 1 :
513+ if op .op_role == 1 :
514+ region = "bwd"
515+ elif op .op_role == 0 :
516+ region = "fwd"
517+ elif op .op_role == 2 :
518+ region = "opt"
519+
520+ if region == "opt" :
521+ self ._erase_op_from_other_programs (op_idx , OPT )
522+ elif region == "bwd" and op .name () == "pd_op.send_v2" :
523+ self ._handle_func (op_idx , SEND_BACKWARD , self .job_types [4 :])
524+ self ._erase_op_from_other_programs (op_idx , SEND_BACKWARD )
525+ elif region == "bwd" and op .name () != "pd_op.send_v2" :
526+ self ._handle_func (op_idx , BACKWARD , self .job_types [3 :])
527+ self ._erase_op_from_other_programs (op_idx , BACKWARD )
528+ elif region == "fwd" and op .name () != "pd_op.recv_v2" :
529+ self ._handle_func (op_idx , FORWARD , self .job_types [2 :])
530+ self ._erase_op_from_other_programs (op_idx , FORWARD )
531+ elif region == "fwd" and op .name () == "pd_op.recv_v2" :
532+ self ._handle_func (op_idx , RECV_FORWARD , self .job_types [1 :])
533+ self ._erase_op_from_other_programs (op_idx , RECV_FORWARD )
534+ progs = []
535+ for job_type in self .job_types :
536+ progs .append (self .programs [job_type ])
537+ return progs
538+
539+ def _erase_op_from_other_programs (self , op_idx , keep_job_type ):
540+ for job_type in self .job_types :
541+ if job_type != keep_job_type :
542+ self .ops_dict [job_type ][op_idx ].erase ()
543+
544+ def _handle_func (self , op_idx , cur_job_type , suffixed_job_types ):
545+ for idx in range (self .complete_ops [op_idx ].num_results ()):
546+ if self ._result_is_used (suffixed_job_types , op_idx , idx ):
547+ var_name = self ._get_or_create_var_name (
548+ self .ops_dict [cur_job_type ], op_idx , idx
549+ )
550+ for job_type in suffixed_job_types :
551+ if self ._result_is_used ([job_type ], op_idx , idx ):
552+ self ._add_dependency_if_necessary (
553+ cur_job_type , job_type , op_idx , idx , var_name
554+ )
555+ self ._add_kwarg_and_replace (
556+ self .blocks_dict [job_type ],
557+ self .ops_dict [job_type ],
558+ op_idx ,
559+ idx ,
560+ var_name ,
561+ )
562+
563+ def _add_dependency_if_necessary (
564+ self , cur_job_type , next_job_type , op_idx , rst_idx , var_name
565+ ):
566+ if not (
567+ cur_job_type == BACKWARD and next_job_type == SEND_BACKWARD
568+ ) and not (cur_job_type == RECV_FORWARD and next_job_type == FORWARD ):
569+ return
570+
571+ first_used_idx = None
572+ first_used_op = None
573+ for used_op in (
574+ self .ops_dict [next_job_type ][op_idx ].result (rst_idx ).all_used_ops ()
575+ ):
576+ used_idx = self .ops_dict [next_job_type ].index (used_op )
577+ if first_used_idx is None or used_idx < first_used_idx :
578+ first_used_idx = used_idx
579+ first_used_op = used_op
580+ self ._add_dependency (
581+ self .ops_dict [cur_job_type ][op_idx ], first_used_op , var_name
582+ )
583+
584+ def _add_dependency (self , recorder_op , waiter_op , name ):
585+ '''
586+ Add the extra event dependency of the two operators.
587+ This function mainly aims for the cross-programs in pipeline parallelism,
588+ especial for the 'send_v2' 'recv_v2' etc.
589+ '''
590+ if not recorder_op .has_attr ("force_record_event" ):
591+ recorder_op .set_bool_attr ("force_record_event" , True )
592+ recorder_op .set_str_attr ("event_to_record" , name )
593+ waiter_op .set_str_array_attr ("events_to_wait" , [name ])
594+
595+ def _result_is_used (self , job_types , op_idx , rst_idx ):
596+ is_used = False
597+ for job_type in job_types :
598+ is_used = (
599+ is_used
600+ or self .ops_dict [job_type ][op_idx ].result (rst_idx ).use_empty ()
601+ is False
602+ )
603+ return is_used
604+
605+ def _get_or_create_var_name (self , cur_sub_ops , op_idx , rst_idx ):
606+ var_name = None
607+ # case1: get var_name in current sub-program
608+ op = cur_sub_ops [op_idx ]
609+ if op .name () == "pd_op.data" or op .name () == "builtin.parameter" :
610+ var_name = op .result (rst_idx ).name
611+ else :
612+ # case2: get var_name from shadow_output in complete program
613+ result_var = self .complete_ops [op_idx ].result (rst_idx )
614+ shadow_output_op = None
615+ for used_op in result_var .all_used_ops ():
616+ if used_op .name () == "builtin.shadow_output" :
617+ shadow_output_op = used_op
618+ if shadow_output_op is not None :
619+ var_name = shadow_output_op .attrs ()["output_name" ]
620+
621+ if var_name is None :
622+ # case3: create var_name in current sub-program
623+ paddle .pir .set_insertion_point_after (op )
624+ var_name = (
625+ f"var_{ op_idx } _{ self .complete_ops [op_idx ].name ()} _{ rst_idx } "
626+ )
627+ paddle ._C_ops .set_persistable_value (op .result (rst_idx ), var_name )
628+ return var_name
629+
630+ def _add_kwarg_and_replace (self , block , ops , op_idx , rst_idx , var_name ):
631+ ori_result = ops [op_idx ].result (rst_idx )
632+ new_result_var = block .add_kwarg (var_name , ori_result .type ())
633+ new_result_var .place_attr = self .cur_place
634+ new_result_var .persistable = ori_result .persistable
635+ ops [op_idx ].result (rst_idx ).replace_all_uses_with (new_result_var )
0 commit comments