-
Notifications
You must be signed in to change notification settings - Fork 28
Description
Your question
In version 1.1, the steps to implement the overlap-grad-reduce feature are as follows:
1、Within the DistributedDataParallel class, put delay weights in weight_grad_buffers, while other weight are stored in grad_buffers.
2、Add DistributedDataParallel.async_reduce_grad method, which can synchronize all grad_buffers or a specified weight_grad_buffer.
3、In the WeightGradStore.clear() method, first call model.async_reduce_grad to trigger start_grad_sync func for all weights in grad_buffers. Subsequently, perform the backward for the weights, synchronizing the gradient of each weight immediately after its gradient computation is completed.
As for how to ensure that all weights in weight_grad_buffers have executed the start_grad_sync operation within the WeightGradStore.clear() method?
As shown in the timeline below, for the last microbatch, not all weight backward computations are performed within the WeightGradStore.clear(). This means that not all delay weights will invoke the async_reduce_grad function.
