-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Feature/backward #3068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/backward #3068
Changes from all commits
b1b4364
ecf23ce
b1b13f8
00615eb
a2dc961
e32e306
831d4e1
f77c63b
fa7cbfd
0ac79a3
292f2ab
05d9aff
fa6a46a
03f418c
5297bcb
9475972
f9fab14
3d18737
70bd07a
63636d6
04db418
28c0281
099bb53
3dd5fd0
84198f7
4461f3c
b1d8419
d2583bd
b9f2bb3
5713266
d4ab70a
a0669ea
7088654
404cc05
65d2678
46d766e
e1d1067
8bf0ca0
d0b25ac
72839a7
29d50ad
74cd9a7
7087a04
b2e1c48
658588a
d6e0368
e1cd719
71bd439
0da5cce
52054af
0e337be
1197420
302046a
1de465b
dc06eaa
39cd39e
be52868
a2e2cd7
2198963
42e2fa5
48812cd
213fdad
f5636da
bd14660
ca16c0d
bc146e8
80baf86
e2fd2bd
737ea05
9cc9907
051d6c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/framework/backward.h" | ||
#include <list> | ||
#include "paddle/framework/net.h" | ||
#include "paddle/framework/op_registry.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
static bool AllInSet(const std::vector<std::string>& names, | ||
const std::string& suffix, | ||
const std::unordered_set<std::string>& set) { | ||
for (auto& name : names) { | ||
if (set.find(name + suffix) == set.end()) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
static std::shared_ptr<OperatorBase> NOP() { | ||
auto net_op = std::make_shared<NetOp>(); | ||
net_op->type_ = "@NOP@"; | ||
net_op->CompleteAddOp(); | ||
return net_op; | ||
} | ||
|
||
// Get backward operator from a forward operator, recursively implementation. | ||
// | ||
// no_grad_names the gradient variable names without gradient calculating. | ||
// | ||
// uniq_id is a unique index used inside recursively calling BackwardRecursive. | ||
// use `uid = uniq_id++;` to get the unique index, and pass `uniq_id` through | ||
// recursive calling. | ||
// | ||
// returns The backward operator. For simple situation, it is a simple | ||
// operator. For complex situation, it is a NetOp. | ||
// | ||
// See Backward.h for details | ||
static std::shared_ptr<OperatorBase> BackwardRecursive( | ||
const OperatorBase& forwardOp, | ||
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id); | ||
std::shared_ptr<OperatorBase> BackwardRecursive( | ||
const OperatorBase& forwardOp, | ||
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) { | ||
// If all input gradients of forwarding operator do not need to calculate, | ||
// just return an NOP. Not return null ptr because NOP does not take | ||
// too much time for calculation, but it is useful for simplifying logic. | ||
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(), | ||
no_grad_names)) { | ||
return NOP(); | ||
} | ||
|
||
// All output gradients of forwarding operator do not need to calculate. Then | ||
// all input gradients cannot be computed at all, and we put them into | ||
// `no_grad_names` set. Return an NOP. | ||
if (AllInSet(forwardOp.outputs_, OperatorBase::GRAD_VAR_SUFFIX(), | ||
no_grad_names)) { | ||
for (auto& name : forwardOp.inputs_) { | ||
// Mark all input is not need | ||
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); | ||
} | ||
return NOP(); | ||
} | ||
|
||
// Returned gradient network | ||
auto net = std::make_shared<NetOp>(); | ||
|
||
if (forwardOp.IsNetOp()) { | ||
// Because forwardOp is a net op, it can static_cast. | ||
auto& forwardNet = static_cast<const NetOp&>(forwardOp); | ||
|
||
// Map from output gradient variable name to operator's indices in backward | ||
// net. That operator generates that variable. | ||
std::unordered_map<std::string, std::vector<size_t>> dup_output_ops; | ||
|
||
size_t local_op_id = 0; | ||
// reversely travel forwardNet | ||
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); | ||
++it, ++local_op_id) { | ||
auto fwd = *it; | ||
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); | ||
net->AddOp(bwd); | ||
for (auto& out : bwd->outputs_) { | ||
dup_output_ops[out].emplace_back(local_op_id); | ||
} | ||
} | ||
// Get unique ID for this method. | ||
auto uid = uniq_id++; | ||
// TODO(dzh): more comment | ||
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>; | ||
std::list<Pos> insert_position; | ||
for (auto& dup_output_op : dup_output_ops) { | ||
const std::string& name = dup_output_op.first; | ||
auto& dup_op = dup_output_op.second; | ||
if (dup_op.size() == 1) continue; | ||
std::vector<std::string> dup_outputs; | ||
|
||
for (size_t i = 0; i < dup_op.size(); ++i) { | ||
auto op_offset = dup_op[i]; | ||
dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + | ||
std::to_string(i)); | ||
net->ops_[op_offset]->Rename(name, dup_outputs.back()); | ||
} | ||
insert_position.push_back( | ||
{dup_op.back(), | ||
OpRegistry::CreateOp( | ||
"add", {dup_outputs}, {name}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个add op现在应该还没实现? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的。这个Op的实现不影响Backward算法的实现和单测。 |
||
{{"input_format", | ||
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})}); | ||
} | ||
|
||
insert_position.sort( | ||
[](const Pos& l, const Pos& r) { return l.first > r.first; }); | ||
|
||
for (auto& pos : insert_position) { | ||
net->InsertOp(pos.first + 1, pos.second); | ||
} | ||
|
||
} else { | ||
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); | ||
for (std::string& grad_input : grad_op->inputs_) { | ||
if (no_grad_names.count(grad_input)) { | ||
std::string prefix = grad_input.substr( | ||
0, grad_input.size() - OperatorBase::GRAD_VAR_SUFFIX().size()); | ||
grad_input = prefix + OperatorBase::ZERO_VAR_SUFFIX(); | ||
|
||
// If part of input gradient of that operator is not calculated, fill | ||
// zero variables to that input gradient. | ||
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix}, | ||
{grad_input}, {})); | ||
} | ||
} | ||
|
||
for (std::string& grad_output : grad_op->outputs_) { | ||
if (no_grad_names.count(grad_output)) { | ||
grad_output = OperatorBase::EMPTY_VAR_NAME(); | ||
} | ||
} | ||
|
||
if (net->ops_.empty()) { // Current no aux op is added to network | ||
return grad_op; | ||
} | ||
net->AddOp(grad_op); | ||
} | ||
net->type_ = "@GENERATED_BACKWARD@"; | ||
net->CompleteAddOp(); | ||
return net; | ||
} | ||
|
||
// See header for comments | ||
std::shared_ptr<OperatorBase> Backward( | ||
const OperatorBase& forwardOp, | ||
const std::unordered_set<std::string>& no_grad_vars) { | ||
std::unordered_set<std::string> no_grad_names; | ||
no_grad_names.reserve(no_grad_vars.size()); | ||
|
||
for (auto& name : no_grad_vars) { | ||
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX()); | ||
} | ||
size_t uid = 0; | ||
return BackwardRecursive(forwardOp, no_grad_names, uid); | ||
} | ||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include <unordered_set> | ||
#include "operator.h" | ||
namespace paddle { | ||
namespace framework { | ||
|
||
// Create the backward operator from a forward operator. | ||
// TODO(yuyang18): Add more API reference comment. | ||
extern std::shared_ptr<OperatorBase> Backward( | ||
const OperatorBase& forwardOp, | ||
const std::unordered_set<std::string>& no_grad_vars); | ||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## Operator/expression 's Backward | ||
|
||
### Motivation | ||
|
||
In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation lineage, the operator/ expression's Backward feature will generate the backward pass respect to forward pass. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. computation lineage ==》 computation graph? I can find any definition about computation lineage |
||
|
||
### Implement : gradient operator registry | ||
|
||
| | forward operator | backward operator | | ||
| ---------------------- | ---------------- | -------------------------------- | | ||
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients | | ||
| **Operator::outputs_** | Outputs | InputGradients | | ||
|
||
Inputs/Outputs means the input/output of the operator, InputGradients/OutputGradients is the gradient respect to forward opeartor. Forward operator and Backward operator are isomorphic, save their corresponding needs into member attribute. | ||
|
||
We use a global hash map record the gradient operators available, follow the philosophy of minimum core, make operator pluggable unit. Each gradient is an operator and it needs to regist itself. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need not emphasize that we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. regist ==> register |
||
|
||
grad_op_builder(fengjiayi) | ||
|
||
### Implement : Backward network | ||
|
||
given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`. | ||
|
||
1. bla bla bla (yuyang) | ||
|
||
2. NetOp | ||
|
||
when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively and ensure them done. During the process, we need to collect the `OutputGradients` name. | ||
|
||
We share variable in the same scope, as a result, duplicate operator `OutputGradients` will overwirte then duplicate variable. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. overwirte => overwrite, then => the |
||
|
||
![./images/duplicate_op]() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator instead. | ||
|
||
![./images/duplicate_op2]() | ||
|
||
Then collect the sub graph OutputGradients/InputGradients as the NetOp's and return it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to use
static
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we do not export them to global symbols.