-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add basic hook classes for dygraph & implement reduce hook #28584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic hook classes for dygraph & implement reduce hook #28584
Conversation
|
Thanks for your contribution! |
| } | ||
|
|
||
| private: | ||
| std::vector<std::unique_ptr<GradAccumulatorPostHook>> hooks_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be it can call 'leaf_var_hooks_' , and can be better distinguished from 'backward_hooks_' . After all, both of them are hooks for backward. Isn't 'backward_hooks_' here for Allreduce/Reduce only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- my opinion: the class name
LeafVarHookPackagealreaady hold theleaf varinfo, thehooksinLeafVarHookPackageareleaf_var_hooks_, using long member name cause information redundancy and also make the interface name longer, such asLeafVarHookPackage.add_leaf_var_hook() backward_hooks_mean the hooks ofwhole backward process, because it relay on leaf var, so we can only put it here now, may be we should addAccumulateGrad dummy OpNodeand movebackward_hooks_outside, I wiil perfect the comments here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And backward_hooks_ may not only used for Allreduce/Reduce, we should keep scalability here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
| << ref_cnt_; | ||
| // After all tmp gradient being accumulated to grad var, run hooks | ||
| if (AccumulateCompleted() && HasPostHooks()) { | ||
| CallBackwardPostHooks(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here call backward_hooks_, how about when AccumulateCompleted, first call_hooks_ , then gradient_accumulation between batch, last call backward_hooks_ .
So We must have two function: CallPostHooks, and CallBackwardPostHooks. And this can changed after this PR merged.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I agree
| } | ||
|
|
||
| private: | ||
| std::vector<std::unique_ptr<GradAccumulatorPostHook>> hooks_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
LGTM for |
PR types
New features
PR changes
Others
Describe
Add basic hook classes for dygraph & implement reduce hook
执行逻辑设计
由前向VarBase拿到前向VariableWrapper, 通过VariableWrapper的接口注册LeafGradHook
反向执行Engine准备执行环境时将hook关联到GradientAccumulator
当反向执行梯度累加完成时,执行关联的hook
简单hook示例