Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ message MpConfig {
message PpConfig {
optional bool dp_comm_overlap = 1 [ default = false ];
optional bool delay_scale_loss = 2 [ default = false ];
optional bool enable_timer = 3 [ default = false ];
}

message HybridConfig {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from paddle import framework

from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
from ..utils import timer_helper as timer
from ..utils.hybrid_parallel_util import (
broadcast_dp_parameters,
broadcast_mp_parameters,
Expand Down Expand Up @@ -73,13 +74,24 @@ def __init__(self, layers, hcg, strategy):
self._dp_comm_overlap = self._strategy.hybrid_configs[
"pp_configs"
].dp_comm_overlap
self._enable_timer = self._strategy.hybrid_configs[
"pp_configs"
].enable_timer
self._dp_comm_buffers = []

if self._dp_comm_overlap:
assert self.use_data_parallel and self.num_stages > 1

if self._enable_timer:
if not timer.is_timer_initialized():
timer.set_timers()
self.timers = timer.get_timers()

p2p.initialize_p2p_groups(
hcg, self._using_cache, self._enable_partial_send_recv
hcg,
self._using_cache,
self._enable_partial_send_recv,
self._enable_timer,
)

self.global_rank = self._hcg.get_global_rank()
Expand Down Expand Up @@ -153,6 +165,12 @@ def register_allreduce_overlap_hook(self, model, comm_group, acc_steps):
for param in parameters:
param._register_backward_hook(self.bw_hook_func(buffer, param))

def timer_printer(self):
if not self._enable_timer:
return
all_flag_names = self.timers.timers.keys()
self.timers.log(all_flag_names)

def forward_backward_pipeline(self, data, scaler=None):
# use the 1f1b scheduling strategy.
# this strategy is inspired by:
Expand Down Expand Up @@ -236,9 +254,17 @@ def forward_backward_pipeline(self, data, scaler=None):
for buffer in self._dp_comm_buffers:
buffer.scale_and_split_grads()

if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
self._layers.allreduce_shared_weight_gradients()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").stop()
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
self.timer_printer()
return train_loss

def _prepare_training(self, data, optimizer, lr_scheduler):
Expand Down Expand Up @@ -334,6 +360,8 @@ def eval_batch(self, data, compute_loss=False):
return self.train_loss

def _forward_step(self, input_tensor, chunk_id=None):
if self._enable_timer:
self.timers("forward_step").start()
if self.is_pipeline_first_stage():
input_tensor = self._load_micro_batch(self.micro_batch_id)

Expand Down Expand Up @@ -365,9 +393,13 @@ def _forward_step(self, input_tensor, chunk_id=None):
# Only increase micro batch id at virtual first/last pp stage.
# The micro batch id is used to load data, therefore, only increase it when load data.
self.micro_batch_id += 1
if self._enable_timer:
self.timers("forward_step").stop()
return output_tensor

def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
if self._enable_timer:
self.timers("backward_step").start()
with paddle.amp.auto_cast(enable=False):
if self.is_pipeline_last_stage():
assert output_tensor_grad is None
Expand Down Expand Up @@ -397,6 +429,8 @@ def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
)
else:
input_tensor_grad = input_tensor.grad
if self._enable_timer:
self.timers("backward_step").stop()
return input_tensor_grad

def _check_data_vaild(self, data):
Expand Down Expand Up @@ -807,16 +841,25 @@ def forward_backward_pipeline(
for buffer in self._dp_comm_buffers:
buffer.scale_and_split_grads()

if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start()
self._layers.allreduce_shared_weight_gradients()
if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").stop()

if compute_loss:
# return loss if compute loss
if self._enable_timer:
self.timers("broadcast_final_loss").start()
with paddle.amp.auto_cast(enable=False):
train_loss = self._broadcast_final_loss()
if self._enable_timer:
self.timers("broadcast_final_loss").stop()
else:
# else just return all intermediate output tensor for all micro steps
train_loss = self.output_tensors

self.timer_printer()
return train_loss

def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import paddle
from paddle import framework

from ...utils import timer_helper as timer
from ...utils.log_util import logger
from .utils import number_2_dtype, paddle_2_number

_hcg = None
_use_cache = False
_enable_partial_send_recv = True
_timers = None

_xpu_comm_group_started = False

Expand All @@ -50,11 +52,15 @@ def _xpu_comm_group_end():
_xpu_comm_group_started = False


def initialize_p2p_groups(hcg, use_cache=True, enable_partial_send_recv=True):
global _hcg, _use_cache, _enable_partial_send_recv
def initialize_p2p_groups(
hcg, use_cache=True, enable_partial_send_recv=True, enable_timer=False
):
global _hcg, _use_cache, _enable_partial_send_recv, _timers
_hcg = hcg
_use_cache = use_cache
_enable_partial_send_recv = enable_partial_send_recv
if enable_timer:
_timers = timer.get_timers()
(
send_next_group,
send_prev_group,
Expand Down Expand Up @@ -683,6 +689,9 @@ def _p2p_helper(


def recv_forward(pp_first_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_forward").start()
if pp_first_stage:
input_tensor = None
else:
Expand All @@ -697,10 +706,15 @@ def recv_forward(pp_first_stage, sync_recv=True):
recv_next=False,
sync_recv=sync_recv,
)
if _timers is not None:
_timers("recv_forward").stop()
return input_tensor


def recv_backward(pp_last_stage, sync_recv=True):
global _timers
if _timers is not None:
_timers("recv_backward").start()
if pp_last_stage:
output_tensor_grad = None
else:
Expand All @@ -711,10 +725,15 @@ def recv_backward(pp_last_stage, sync_recv=True):
recv_next=True,
sync_recv=sync_recv,
)
if _timers is not None:
_timers("recv_backward").stop()
return output_tensor_grad


def send_forward(output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward").start()
if not pp_last_stage:
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
Expand All @@ -727,19 +746,29 @@ def send_forward(output_tensor, pp_last_stage):
recv_prev=False,
recv_next=False,
)
if _timers is not None:
_timers("send_forward").stop()


def send_backward(input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward").start()
if not pp_first_stage:
_p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=False,
)
if _timers is not None:
_timers("send_backward").stop()


def send_forward_recv_backward(output_tensor, pp_last_stage):
global _timers
if _timers is not None:
_timers("send_forward_recv_backward").start()
if pp_last_stage:
output_tensor_grad = None
else:
Expand All @@ -749,10 +778,15 @@ def send_forward_recv_backward(output_tensor, pp_last_stage):
recv_prev=False,
recv_next=True,
)
if _timers is not None:
_timers("send_forward_recv_backward").stop()
return output_tensor_grad


def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
global _timers
if _timers is not None:
_timers("send_backward_recv_forward").start()
if pp_first_stage:
input_tensor = None
else:
Expand All @@ -762,13 +796,18 @@ def send_backward_recv_forward(input_tensor_grad, pp_first_stage):
recv_prev=True,
recv_next=False,
)
if _timers is not None:
_timers("send_backward_recv_forward").stop()
return input_tensor


def send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, recv_prev, recv_next
):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").start()
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
Expand All @@ -783,11 +822,16 @@ def send_forward_backward_recv_forward_backward(
recv_next=recv_next,
sync_recv=False,
)
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").stop()
return input_tensor, output_tensor_grad


def send_forward_recv_forward(output_tensor, recv_prev):
# always have to send dytpe info to downstream
global _timers
if _timers is not None:
_timers("send_forward_recv_forward").start()
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
Expand All @@ -803,16 +847,22 @@ def send_forward_recv_forward(output_tensor, recv_prev):
recv_next=False,
sync_recv=False,
)

if _timers is not None:
_timers("send_forward_recv_forward").stop()
return input_tensor


def send_backward_recv_backward(input_tensor_grad, recv_next):
global _timers
if _timers is not None:
_timers("send_backward_recv_backward").start()
_, output_tensor_grad = _p2p_helper(
tensor_send_next=None,
tensor_send_prev=input_tensor_grad,
recv_prev=False,
recv_next=recv_next,
sync_recv=False,
)
if _timers is not None:
_timers("send_backward_recv_backward").stop()
return output_tensor_grad
Loading