1313# limitations under the License.
1414
1515import logging
16- import os
1716
1817import paddle
1918from paddle .base import core
2827 AutoParallelStreamType ,
2928 _add_event_dependency ,
3029 _program_for_fthenb_and_1f1b ,
31- _split_program_into_forward_backward_optimize ,
3230 forward_complete_op_role ,
3331 split_program ,
3432)
@@ -179,14 +177,6 @@ def _backward_forward_overlap(self, backward_program, forward_program):
179177 )
180178
181179 def _create_job_list (self ):
182-
183- if os .getenv ("FLAGS_enable_p2p_comm_opt" , 0 ) in [
184- 'True' ,
185- 'true' ,
186- '1' ,
187- ]:
188- return self ._create_job_list_send_recv ()
189-
190180 num_micro_batches = self .get_attr ("num_micro_batches" )
191181 pp_stage = self .get_attr ("pp_stage" )
192182 pp_degree = self .get_attr ("pp_degree" )
@@ -201,18 +191,24 @@ def _create_job_list(self):
201191
202192 forward_micro_batch_id = 0
203193 for i in range (micro_batch_in_warmup ):
194+ recv_fwd_job = core .Job (RECV_FORWARD )
195+ recv_fwd_job .set_micro_batch_id (forward_micro_batch_id )
196+ job_list .append (recv_fwd_job )
197+
204198 forward_job = core .Job (FORWARD )
205199 forward_job .set_micro_batch_id (forward_micro_batch_id )
206200 job_list .append (forward_job )
207201 forward_micro_batch_id += 1
208202
209203 backward_micro_batch_id = 0
204+ jobs_in_stable_phase = [BACKWARD , RECV_FORWARD , SEND_BACKWARD , FORWARD ]
210205 for i in range (micro_batch_in_1f1b ):
211- for job_type in self . jobs_in_stable_phase :
206+ for job_type in jobs_in_stable_phase :
212207 job = core .Job (job_type )
213208 micro_batch_id = (
214209 forward_micro_batch_id
215210 if job_type .startswith (FORWARD )
211+ or job_type .startswith (RECV_FORWARD )
216212 else backward_micro_batch_id
217213 )
218214 job .set_micro_batch_id (micro_batch_id )
@@ -224,6 +220,11 @@ def _create_job_list(self):
224220 backward_job = core .Job (BACKWARD )
225221 backward_job .set_micro_batch_id (backward_micro_batch_id )
226222 job_list .append (backward_job )
223+
224+ send_bwd_job = core .Job (SEND_BACKWARD )
225+ send_bwd_job .set_micro_batch_id (backward_micro_batch_id )
226+ job_list .append (send_bwd_job )
227+
227228 backward_micro_batch_id += 1
228229
229230 opt_job = core .Job (OPT )
@@ -362,20 +363,9 @@ def _partial_pir_programs(self, program):
362363 not enable_send_recv_overlap
363364 ), "PIR does not support 1F1B with enable_send_recv_overlap yet."
364365
365- if os .getenv ("FLAGS_enable_p2p_comm_opt" , 0 ) in [
366- 'True' ,
367- 'true' ,
368- '1' ,
369- ]:
370- types = [RECV_FORWARD , FORWARD , BACKWARD , SEND_BACKWARD , OPT ]
371- prog_splitter = ProgramSplitter (program , types )
372- sub_program_list = prog_splitter ._split_programs ()
373-
374- else :
375- types = [FORWARD , BACKWARD , OPT ]
376- sub_program_list = _split_program_into_forward_backward_optimize (
377- program , enable_send_recv_overlap
378- )
366+ types = [RECV_FORWARD , FORWARD , BACKWARD , SEND_BACKWARD , OPT ]
367+ prog_splitter = ProgramSplitter (program , types )
368+ sub_program_list = prog_splitter ._split_programs ()
379369
380370 for i in range (len (types )):
381371 logger .debug (
@@ -406,62 +396,6 @@ def is_comm_op_valid_to_overlap(self, op):
406396 == AutoParallelStreamType .CALC_STREAM .value
407397 )
408398
409- def _create_job_list_send_recv (self ):
410- num_micro_batches = self .get_attr ("num_micro_batches" )
411- pp_stage = self .get_attr ("pp_stage" )
412- pp_degree = self .get_attr ("pp_degree" )
413-
414- job_list = []
415- assert (
416- pp_degree <= num_micro_batches
417- ), "Num of micro batches should larger than or equal to pp degree."
418-
419- micro_batch_in_warmup = pp_degree - pp_stage
420- micro_batch_in_1f1b = num_micro_batches - micro_batch_in_warmup
421-
422- forward_micro_batch_id = 0
423- for i in range (micro_batch_in_warmup ):
424- recv_fwd_job = core .Job (RECV_FORWARD )
425- recv_fwd_job .set_micro_batch_id (forward_micro_batch_id )
426- job_list .append (recv_fwd_job )
427-
428- forward_job = core .Job (FORWARD )
429- forward_job .set_micro_batch_id (forward_micro_batch_id )
430- job_list .append (forward_job )
431- forward_micro_batch_id += 1
432-
433- backward_micro_batch_id = 0
434- jobs_in_stable_phase = [BACKWARD , RECV_FORWARD , SEND_BACKWARD , FORWARD ]
435- for i in range (micro_batch_in_1f1b ):
436- for job_type in jobs_in_stable_phase :
437- job = core .Job (job_type )
438- micro_batch_id = (
439- forward_micro_batch_id
440- if job_type .startswith (FORWARD )
441- or job_type .startswith (RECV_FORWARD )
442- else backward_micro_batch_id
443- )
444- job .set_micro_batch_id (micro_batch_id )
445- job_list .append (job )
446- forward_micro_batch_id += 1
447- backward_micro_batch_id += 1
448-
449- for i in range (micro_batch_in_warmup ):
450- backward_job = core .Job (BACKWARD )
451- backward_job .set_micro_batch_id (backward_micro_batch_id )
452- job_list .append (backward_job )
453-
454- send_bwd_job = core .Job (SEND_BACKWARD )
455- send_bwd_job .set_micro_batch_id (backward_micro_batch_id )
456- job_list .append (send_bwd_job )
457-
458- backward_micro_batch_id += 1
459-
460- opt_job = core .Job (OPT )
461- opt_job .set_micro_batch_id (0 )
462- job_list .append (opt_job )
463- return job_list
464-
465399
466400class ProgramSplitter :
467401
0 commit comments