Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 48 additions & 64 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,9 @@ def __init__(self, layers, hcg, strategy):
assert (
framework.in_dygraph_mode()
), "virtual pipeline stage with interleave only support eager dygraph mode"
assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps should be evenly divisible by num_stages for pipeline with interleave"
# setup for interleave scheduler
self.num_model_chunks = layers.get_num_virtual_stages()
self.model_chunks = layers.get_model_chunks()
Expand All @@ -583,12 +586,9 @@ def _forward_step_helper(self, micro_step):
assert hasattr(self, 'output_tensors')
if not self._forward_only:
assert hasattr(self, 'output_tensor_grads')

if self.is_pipeline_first_stage():
if len(self.input_tensors[virtual_pp_rank]) == len(
self.output_tensors[virtual_pp_rank]
):
self.input_tensors[virtual_pp_rank].append(None)
assert len(self.input_tensors[virtual_pp_rank]) == (
len(self.output_tensors[virtual_pp_rank]) + 1
)
input_tensor = self.input_tensors[virtual_pp_rank][-1]
output_tensor = self._forward_step(input_tensor, virtual_pp_rank)
self.output_tensors[virtual_pp_rank].append(output_tensor)
Expand All @@ -609,9 +609,12 @@ def _backward_step_helper(self, micro_step):
assert hasattr(self, 'output_tensors')
assert hasattr(self, 'output_tensor_grads')

if self.is_pipeline_last_stage():
if len(self.output_tensor_grads[virtual_pp_rank]) == 0:
self.output_tensor_grads[virtual_pp_rank].append(None)
assert (
len(self.output_tensor_grads[virtual_pp_rank]) == 1
), f"output_tensor_grads is empty for virtual_pp_rank {virtual_pp_rank}"

assert len(self.input_tensors[virtual_pp_rank]) > 0
assert len(self.output_tensors[virtual_pp_rank]) > 0

input_tensor = self.input_tensors[virtual_pp_rank].pop(0)
output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
Expand Down Expand Up @@ -646,18 +649,17 @@ def forward_backward_pipeline(
self.output_tensor_grads = [[] for _ in range(self.num_model_chunks)]

num_steps = self.accumulate_steps * self.num_model_chunks
all_startup_steps = False
if forward_only:
# If only forward, since there is no backward during running, all steps are startup steps
startup_steps = num_steps
else:
if self.accumulate_steps == self.num_stages:
startup_steps = num_steps
all_startup_steps = True
else:
startup_steps = (self.num_stages - self.stage_id - 1) * 2
startup_steps += (self.num_model_chunks - 1) * self.num_stages
startup_steps = min(startup_steps, num_steps)
# actually startup_steps is calculated from two number:
# first_forward_cross_to_end = (self.num_stages - self.stage_id - 1) + (self.num_model_chunks - 1) * self.num_stages
# end_to_first_backward_cross = (self.num_stages - self.stage_id - 1)
# startup_steps = first_forward_cross_to_end + end_to_first_backward_cross
startup_steps = (self.num_stages - self.stage_id - 1) * 2
startup_steps += (self.num_model_chunks - 1) * self.num_stages
startup_steps = min(startup_steps, num_steps)

steady_steps = num_steps - startup_steps

Expand Down Expand Up @@ -687,11 +689,7 @@ def forward_backward_pipeline(
if self.is_pipeline_last_stage():
output_tensor = None

if (
micro_step == (startup_steps - 1)
and not forward_only
and not all_startup_steps
):
if micro_step == (startup_steps - 1) and not forward_only:
input_tensor_grad = None
recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True):
Expand All @@ -707,13 +705,16 @@ def forward_backward_pipeline(
recv_prev=recv_prev,
recv_next=recv_next,
)
# output_tensor_grad is not none if recv_next
# append output_tensor_grad no matter none or not
self.output_tensor_grads[self.num_model_chunks - 1].append(
output_tensor_grad
)
else:
input_tensor = p2p.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev
)
# append input_tensor no matter none or not
self.input_tensors[next_virtual_pp_rank].append(input_tensor)

# run 1f1b steady steps
Expand Down Expand Up @@ -752,38 +753,29 @@ def forward_backward_pipeline(

# determine whether to recv input tensor from upstream
recv_prev = True
if self.is_pipeline_first_stage(ignore_virtual=True):
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id - (self.num_stages - 1), forward=True
)
if next_forward_virtual_pp_rank == (self.num_model_chunks - 1):
# first pp stage and first virtual stage
recv_prev = False
next_forward_virtual_pp_rank += 1
else:
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id + 1, forward=True
)
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
forward_micro_step_id + 1, forward=True
)
if self.is_pipeline_first_stage(ignore_virtual=True) and (
next_forward_virtual_pp_rank == 0
):
# first pp stage and first virtual stage
recv_prev = False

# last iteration doesn't need recv from upstream
if micro_step == (steady_steps - 1):
recv_prev = False

# determine whether to recv grad from downstream
recv_next = True
if self.is_pipeline_last_stage(ignore_virtual=True):
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id - (self.num_stages - 1),
forward=False,
)
if next_backward_virtual_pp_rank == 0:
# last pp stage and last virtual stage
recv_next = False
next_backward_virtual_pp_rank -= 1
else:
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id + 1, forward=False
)
next_backward_virtual_pp_rank = self._get_virtual_pp_rank(
backward_micro_step_id + 1, forward=False
)
if self.is_pipeline_last_stage(ignore_virtual=True) and (
next_backward_virtual_pp_rank == (self.num_model_chunks - 1)
):
# last pp stage and last virtual stage
recv_next = False

(
input_tensor,
Expand All @@ -794,25 +786,17 @@ def forward_backward_pipeline(
recv_prev=recv_prev,
recv_next=recv_next,
)

if recv_prev:
self.input_tensors[next_forward_virtual_pp_rank].append(
input_tensor
)
if recv_next:
self.output_tensor_grads[next_backward_virtual_pp_rank].append(
output_tensor_grad
)
# append input_tensor no matter none or not
self.input_tensors[next_forward_virtual_pp_rank].append(
input_tensor
)
# append output_tensor_grad no matter none or not
self.output_tensor_grads[next_backward_virtual_pp_rank].append(
output_tensor_grad
)

# remaining backward steps
if not forward_only:
if all_startup_steps:
self.output_tensor_grads[self.num_model_chunks - 1].append(
p2p.recv_backward(
self.is_pipeline_last_stage(), sync_recv=False
)
)

for micro_step in range(steady_steps, num_steps):
# cooldown loop
input_tensor_grad = self._backward_step_helper(micro_step)
Expand All @@ -829,7 +813,7 @@ def forward_backward_pipeline(

if micro_step == (num_steps - 1):
recv_next = False

# append output_tensor_grad no matter none or not
self.output_tensor_grads[next_backward_virtual_pp_rank].append(
p2p.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def setUp(self):
"mp_degree": 1,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {"accumulate_steps": 2}
fleet.init(is_collective=True, strategy=strategy)
self.rank = fleet.worker_index()
self.hcg = fleet.get_hybrid_communicate_group()
Expand Down