Skip to content

Commit f67d7b6

Browse files
committed
rm FLAGS_enable_p2p_comm_opt
1 parent 829d5de commit f67d7b6

File tree

2 files changed

+16
-120
lines changed

2 files changed

+16
-120
lines changed

python/paddle/distributed/passes/pass_utils.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@
4646
]
4747

4848
logger = get_logger(logging.INFO)
49-
from paddle.distributed.utils.stream_utils import (
50-
ExecutionStreamType,
51-
)
5249

5350

5451
# NOTE: Here stream is just a presentation with different name,
@@ -242,27 +239,6 @@ def _get_required_vars_of_program(program):
242239
"""
243240
Get all vars in the program that are non-persistable and not in op's no_need_buffer.
244241
"""
245-
if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
246-
"FLAGS_enable_pir_api"
247-
]:
248-
return _get_required_vars_of_program_in_pir(program)
249-
else:
250-
return _get_required_vars_of_program_in_old_ir(program)
251-
252-
253-
def _get_required_vars_of_program_in_pir(program):
254-
required_vars = set(program.list_vars())
255-
no_need_buffer_vars = core.get_no_need_buffer_values(program)
256-
required_vars -= no_need_buffer_vars
257-
persistable_vars = set()
258-
for var in required_vars:
259-
if var.persistable:
260-
persistable_vars.add(var)
261-
required_vars -= persistable_vars
262-
return required_vars
263-
264-
265-
def _get_required_vars_of_program_in_old_ir(program):
266242
required_vars = set()
267243
for block in program.blocks:
268244
for op in block.ops:
@@ -603,21 +579,7 @@ def _pir_overlap_send_recv(program):
603579
elif op.name() == "pd_op.recv_v2":
604580
op.set_bool_attr("dynamic_shape", False)
605581
op.set_bool_attr("use_calc_stream", True)
606-
if os.getenv("FLAGS_enable_p2p_comm_opt", 0) in [
607-
'True',
608-
'true',
609-
'1',
610-
]:
611-
# and os.getenv("FLAGS_1f1b", 0) in [
612-
# 'True',
613-
# 'true',
614-
# '1',
615-
# ]:
616-
op.set_execution_stream(
617-
ExecutionStreamType.DefaultStream.value
618-
)
619-
else:
620-
op.set_execution_stream("recv_stream")
582+
op.set_execution_stream("recv_stream")
621583
op.set_scheduling_priority(0)
622584

623585

python/paddle/distributed/passes/pipeline_scheduler_pass/pipeline_1f1b.py

Lines changed: 15 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import os
1716

1817
import paddle
1918
from paddle.base import core
@@ -28,7 +27,6 @@
2827
AutoParallelStreamType,
2928
_add_event_dependency,
3029
_program_for_fthenb_and_1f1b,
31-
_split_program_into_forward_backward_optimize,
3230
forward_complete_op_role,
3331
split_program,
3432
)
@@ -179,14 +177,6 @@ def _backward_forward_overlap(self, backward_program, forward_program):
179177
)
180178

181179
def _create_job_list(self):
182-
183-
if os.getenv("FLAGS_enable_p2p_comm_opt", 0) in [
184-
'True',
185-
'true',
186-
'1',
187-
]:
188-
return self._create_job_list_send_recv()
189-
190180
num_micro_batches = self.get_attr("num_micro_batches")
191181
pp_stage = self.get_attr("pp_stage")
192182
pp_degree = self.get_attr("pp_degree")
@@ -201,18 +191,24 @@ def _create_job_list(self):
201191

202192
forward_micro_batch_id = 0
203193
for i in range(micro_batch_in_warmup):
194+
recv_fwd_job = core.Job(RECV_FORWARD)
195+
recv_fwd_job.set_micro_batch_id(forward_micro_batch_id)
196+
job_list.append(recv_fwd_job)
197+
204198
forward_job = core.Job(FORWARD)
205199
forward_job.set_micro_batch_id(forward_micro_batch_id)
206200
job_list.append(forward_job)
207201
forward_micro_batch_id += 1
208202

209203
backward_micro_batch_id = 0
204+
jobs_in_stable_phase = [BACKWARD, RECV_FORWARD, SEND_BACKWARD, FORWARD]
210205
for i in range(micro_batch_in_1f1b):
211-
for job_type in self.jobs_in_stable_phase:
206+
for job_type in jobs_in_stable_phase:
212207
job = core.Job(job_type)
213208
micro_batch_id = (
214209
forward_micro_batch_id
215210
if job_type.startswith(FORWARD)
211+
or job_type.startswith(RECV_FORWARD)
216212
else backward_micro_batch_id
217213
)
218214
job.set_micro_batch_id(micro_batch_id)
@@ -224,6 +220,11 @@ def _create_job_list(self):
224220
backward_job = core.Job(BACKWARD)
225221
backward_job.set_micro_batch_id(backward_micro_batch_id)
226222
job_list.append(backward_job)
223+
224+
send_bwd_job = core.Job(SEND_BACKWARD)
225+
send_bwd_job.set_micro_batch_id(backward_micro_batch_id)
226+
job_list.append(send_bwd_job)
227+
227228
backward_micro_batch_id += 1
228229

229230
opt_job = core.Job(OPT)
@@ -362,20 +363,9 @@ def _partial_pir_programs(self, program):
362363
not enable_send_recv_overlap
363364
), "PIR does not support 1F1B with enable_send_recv_overlap yet."
364365

365-
if os.getenv("FLAGS_enable_p2p_comm_opt", 0) in [
366-
'True',
367-
'true',
368-
'1',
369-
]:
370-
types = [RECV_FORWARD, FORWARD, BACKWARD, SEND_BACKWARD, OPT]
371-
prog_splitter = ProgramSplitter(program, types)
372-
sub_program_list = prog_splitter._split_programs()
373-
374-
else:
375-
types = [FORWARD, BACKWARD, OPT]
376-
sub_program_list = _split_program_into_forward_backward_optimize(
377-
program, enable_send_recv_overlap
378-
)
366+
types = [RECV_FORWARD, FORWARD, BACKWARD, SEND_BACKWARD, OPT]
367+
prog_splitter = ProgramSplitter(program, types)
368+
sub_program_list = prog_splitter._split_programs()
379369

380370
for i in range(len(types)):
381371
logger.debug(
@@ -406,62 +396,6 @@ def is_comm_op_valid_to_overlap(self, op):
406396
== AutoParallelStreamType.CALC_STREAM.value
407397
)
408398

409-
def _create_job_list_send_recv(self):
410-
num_micro_batches = self.get_attr("num_micro_batches")
411-
pp_stage = self.get_attr("pp_stage")
412-
pp_degree = self.get_attr("pp_degree")
413-
414-
job_list = []
415-
assert (
416-
pp_degree <= num_micro_batches
417-
), "Num of micro batches should larger than or equal to pp degree."
418-
419-
micro_batch_in_warmup = pp_degree - pp_stage
420-
micro_batch_in_1f1b = num_micro_batches - micro_batch_in_warmup
421-
422-
forward_micro_batch_id = 0
423-
for i in range(micro_batch_in_warmup):
424-
recv_fwd_job = core.Job(RECV_FORWARD)
425-
recv_fwd_job.set_micro_batch_id(forward_micro_batch_id)
426-
job_list.append(recv_fwd_job)
427-
428-
forward_job = core.Job(FORWARD)
429-
forward_job.set_micro_batch_id(forward_micro_batch_id)
430-
job_list.append(forward_job)
431-
forward_micro_batch_id += 1
432-
433-
backward_micro_batch_id = 0
434-
jobs_in_stable_phase = [BACKWARD, RECV_FORWARD, SEND_BACKWARD, FORWARD]
435-
for i in range(micro_batch_in_1f1b):
436-
for job_type in jobs_in_stable_phase:
437-
job = core.Job(job_type)
438-
micro_batch_id = (
439-
forward_micro_batch_id
440-
if job_type.startswith(FORWARD)
441-
or job_type.startswith(RECV_FORWARD)
442-
else backward_micro_batch_id
443-
)
444-
job.set_micro_batch_id(micro_batch_id)
445-
job_list.append(job)
446-
forward_micro_batch_id += 1
447-
backward_micro_batch_id += 1
448-
449-
for i in range(micro_batch_in_warmup):
450-
backward_job = core.Job(BACKWARD)
451-
backward_job.set_micro_batch_id(backward_micro_batch_id)
452-
job_list.append(backward_job)
453-
454-
send_bwd_job = core.Job(SEND_BACKWARD)
455-
send_bwd_job.set_micro_batch_id(backward_micro_batch_id)
456-
job_list.append(send_bwd_job)
457-
458-
backward_micro_batch_id += 1
459-
460-
opt_job = core.Job(OPT)
461-
opt_job.set_micro_batch_id(0)
462-
job_list.append(opt_job)
463-
return job_list
464-
465399

466400
class ProgramSplitter:
467401

0 commit comments

Comments
 (0)