Skip to content

Commit 06b177c

Browse files
authored
[Eager Hook] Support ReduceHook in GradNodeAccumulation (#39674)
* [Eager] Support GradientHook before running seperate GradNode * Fix CI issue * Support eager ReduceHook in accumulation_node * Fix CI issue * Add some tests to fix coverage CI issue
1 parent 5df3cd6 commit 06b177c

File tree

9 files changed

+134
-38
lines changed

9 files changed

+134
-38
lines changed

paddle/fluid/eager/accumulation/accumulation_node.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,14 @@ operator()(
7979
return {{accumulated_grad}};
8080
}
8181

82+
void GradNodeAccumulation::RegisterReduceHook(
83+
const std::function<void(void)>& hook) {
84+
reduce_hooks_.emplace_back(hook);
85+
}
86+
87+
void GradNodeAccumulation::ApplyReduceHooks() {
88+
for (auto& hook : reduce_hooks_) {
89+
hook();
90+
}
91+
}
8292
} // namespace egr

paddle/fluid/eager/accumulation/accumulation_node.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,25 @@ class GradNodeAccumulation : public GradNodeBase {
3535

3636
paddle::experimental::Tensor* Grad() { return &accumulated_grad; }
3737

38+
/**
39+
* Register ReduceHook
40+
* **/
41+
void RegisterReduceHook(const std::function<void(void)>& hook);
42+
43+
/**
44+
* Apply ReduceHook here
45+
* **/
46+
inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; }
47+
void ApplyReduceHooks();
48+
3849
private:
3950
paddle::experimental::Tensor accumulated_grad;
4051

4152
std::function<paddle::experimental::Tensor(
4253
const paddle::experimental::Tensor&)>
4354
retain_grad_hook_;
55+
56+
std::vector<std::function<void(void)>> reduce_hooks_;
4457
};
4558

4659
} // namespace egr

paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,6 @@ operator()(
171171
&out);
172172
}
173173

174-
// Apply Reduce Hooks
175-
if (ReduceHooksRegistered()) {
176-
ApplyReduceHooks();
177-
}
178174
return {{out}};
179175
}
180176

paddle/fluid/eager/api/utils/hook_utils.cc

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,21 @@ void RegisterGradientHookForTensor(
3535

3636
void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
3737
const std::function<void(void)>& hook) {
38-
// Find grad_node and out_rank from AutogradMeta
39-
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor);
40-
41-
grad_node->RegisterReduceHook(hook);
38+
if (IsLeafTensor(tensor)) {
39+
VLOG(6) << "Register ReduceHook for leaf tensor";
40+
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(tensor);
41+
PADDLE_ENFORCE(
42+
grad_node.get() != nullptr,
43+
paddle::platform::errors::Fatal("Detected NULL grad_node,"
44+
"Leaf tensor should have had grad_node "
45+
"with type: GradNodeAccumulation"));
46+
auto accumulation_grad_node =
47+
std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node);
48+
accumulation_grad_node->RegisterReduceHook(hook);
49+
} else {
50+
PADDLE_THROW(paddle::platform::errors::Fatal(
51+
"Only can register reduce hook for leaf Tensor."));
52+
}
4253
}
4354

4455
void RetainGradForTensor(const paddle::experimental::Tensor& tensor) {

paddle/fluid/eager/grad_node_info.cc

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,6 @@ void GradNodeBase::RegisterGradientHook(
214214
gradient_hooks_.emplace_back(std::make_tuple(slot_id, rank, hook));
215215
}
216216

217-
void GradNodeBase::RegisterReduceHook(const std::function<void(void)>& hook) {
218-
reduce_hooks_.emplace_back(hook);
219-
}
220-
221217
std::vector<std::vector<paddle::experimental::Tensor>>
222218
GradNodeBase::ApplyGradientHooks(
223219
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors) {
@@ -267,9 +263,4 @@ GradNodeBase::ApplyGradientHooks(
267263
return outs;
268264
}
269265

270-
void GradNodeBase::ApplyReduceHooks() {
271-
for (auto& hook : reduce_hooks_) {
272-
hook();
273-
}
274-
}
275266
} // namespace egr

paddle/fluid/eager/grad_node_info.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,22 +133,19 @@ class GradNodeBase {
133133
* **/
134134
void SetDefaultGradInOutMeta();
135135
/**
136-
* Register GradientHook or ReduceHook
136+
* Register GradientHook
137137
* **/
138138
void RegisterGradientHook(size_t slot_id, size_t rank,
139139
const std::function<paddle::experimental::Tensor(
140140
const paddle::experimental::Tensor&)>& hook);
141-
void RegisterReduceHook(const std::function<void(void)>& hook);
142141

143142
/**
144-
* Apply GradientHook or ReduceHook
143+
* Apply GradientHook
145144
* **/
146145
inline bool GradientHooksRegistered() { return gradient_hooks_.size() != 0; }
147-
inline bool ReduceHooksRegistered() { return reduce_hooks_.size() != 0; }
148146

149147
std::vector<std::vector<paddle::experimental::Tensor>> ApplyGradientHooks(
150148
const std::vector<std::vector<paddle::experimental::Tensor>>& tensors);
151-
void ApplyReduceHooks();
152149

153150
private:
154151
// TODO(jiabin): Use SmallVector instead after merge PR from develop
@@ -173,7 +170,6 @@ class GradNodeBase {
173170
/* hook */ std::function<paddle::experimental::Tensor(
174171
const paddle::experimental::Tensor&)>>>
175172
gradient_hooks_;
176-
std::vector<std::function<void(void)>> reduce_hooks_;
177173
};
178174

179175
class Edge {

paddle/fluid/eager/tests/data_structure_tests/accumulation_node_test.cc

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ TEST(AccumulationNode, Tensor) {
6161
// AccumulationNode
6262
GradNodeAccumulation node = GradNodeAccumulation();
6363

64-
// Hook
64+
// Hook, RetainGrad
6565
std::function<paddle::experimental::Tensor(
6666
const paddle::experimental::Tensor&)>
6767
hook = [&grad_et](const paddle::experimental::Tensor& t) {
@@ -88,4 +88,46 @@ TEST(AccumulationNode, Tensor) {
8888
std::dynamic_pointer_cast<pten::DenseTensor>(grad_et.impl())
8989
->data<paddle::platform::float16>();
9090
CHECK_EQ(ret_grad_et_ptr[0], paddle::platform::float16(30.0f));
91+
92+
// Reduce Hook case 1: Call RegisterReduceHook and run operator()
93+
VLOG(6) << "Test Reduce Hook";
94+
auto reduce_hook_1 = [&](void) -> void {
95+
auto* grad_et_ptr =
96+
std::dynamic_pointer_cast<pten::DenseTensor>(grad_et.impl())
97+
->data<paddle::platform::float16>();
98+
grad_et_ptr[0] = 36.0;
99+
VLOG(6) << "Running Reduce Hook";
100+
};
101+
102+
node.RegisterReduceHook(reduce_hook_1);
103+
104+
// operator()
105+
paddle::experimental::Tensor _ret = node({{et0}})[0][0];
106+
107+
// Check operator() result, should be 36.0
108+
auto* _ret_ptr = std::dynamic_pointer_cast<pten::DenseTensor>(_ret.impl())
109+
->data<paddle::platform::float16>();
110+
CHECK_EQ(_ret_ptr[0], paddle::platform::float16(36.0f));
111+
112+
// Check Retain Grad, should be 36.0
113+
auto* _ret_grad_et_ptr =
114+
std::dynamic_pointer_cast<pten::DenseTensor>(grad_et.impl())
115+
->data<paddle::platform::float16>();
116+
CHECK_EQ(_ret_grad_et_ptr[0], paddle::platform::float16(36.0f));
117+
118+
// Reduce Hook case 2: Call RegisterReduceHook and ApplyReduceHooks directly
119+
VLOG(6) << "Test Reduce Hook";
120+
auto reduce_hook_2 = [&](void) -> void {
121+
auto* ret_et0_ptr = std::dynamic_pointer_cast<pten::DenseTensor>(et0.impl())
122+
->data<paddle::platform::float16>();
123+
ret_et0_ptr[0] = 100.0; // set to 100.0
124+
VLOG(6) << "Running Reduce Hook";
125+
};
126+
node.RegisterReduceHook(reduce_hook_2);
127+
node.ApplyReduceHooks();
128+
129+
// Check ApplyReduceHooks result
130+
CHECK_EQ(std::dynamic_pointer_cast<pten::DenseTensor>(et0.impl())
131+
->data<paddle::platform::float16>()[0],
132+
paddle::platform::float16(100.0f));
91133
}

paddle/fluid/eager/tests/data_structure_tests/grad_node_info_test.cc

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,19 +119,6 @@ TEST(GradNodeInfo, GradNodeBase) {
119119
std::dynamic_pointer_cast<pten::DenseTensor>(grad_hook_res[0][0].impl())
120120
->data<float>()[0],
121121
11.0);
122-
123-
VLOG(6) << "Test Reduce Hook";
124-
auto reduce_hook = [&](void) -> void {
125-
auto* et_ptr =
126-
std::dynamic_pointer_cast<pten::DenseTensor>(et1.impl())->data<float>();
127-
et_ptr[0] = 100.0;
128-
VLOG(6) << "Running Reduce Hook";
129-
};
130-
grad_test_node0->RegisterReduceHook(reduce_hook);
131-
grad_test_node0->ApplyReduceHooks();
132-
CHECK_EQ(std::dynamic_pointer_cast<pten::DenseTensor>(et1.impl())
133-
->data<float>()[0],
134-
100.0);
135122
}
136123

137124
TEST(GradNodeInfo, Edge) {

paddle/fluid/eager/tests/task_tests/hook_test_intermidiate.cc

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,12 +73,24 @@ TEST(Hook_intermidiate, Sigmoid) {
7373
const paddle::experimental::Tensor&)>
7474
hook = &hook_function;
7575

76+
VLOG(6) << "Make ReduceHook function";
77+
auto reduce_hook = [&](void) -> void {
78+
auto* t_ptr = std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl())
79+
->data<float>();
80+
for (int i = 0; i < tensor.numel(); i++) {
81+
t_ptr[i] = 100.0; // set to 100.0
82+
}
83+
};
84+
7685
VLOG(6) << "Retain Grad for Tensor";
7786
egr_utils_api::RetainGradForTensor(tensor);
7887

7988
VLOG(6) << "Register GradientHook for Tensor";
8089
egr_utils_api::RegisterGradientHookForTensor(tensor, hook);
8190

91+
VLOG(6) << "Register ReduceHook for Tensor";
92+
egr_utils_api::RegisterReduceHookForTensor(tensor, reduce_hook);
93+
8294
VLOG(6) << "Runing Forward";
8395
auto output_tensor = sigmoid_dygraph_function(tensor, {});
8496
VLOG(6) << "Finish Forward";
@@ -92,6 +104,13 @@ TEST(Hook_intermidiate, Sigmoid) {
92104
VLOG(6) << "Finish Backward";
93105

94106
eager_test::CompareGradTensorWithValue<float>(tensor, 0.25 + 3);
107+
108+
VLOG(6) << "Checking ReduceHook results";
109+
for (int i = 0; i < tensor.numel(); i++) {
110+
CHECK_EQ(std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl())
111+
->data<float>()[i],
112+
static_cast<float>(100.0f));
113+
}
95114
VLOG(6) << "After Tests";
96115
}
97116

@@ -118,8 +137,17 @@ TEST(Hook_intermidiate, ElementwiseAdd) {
118137
const paddle::experimental::Tensor&)>
119138
hook = &hook_function;
120139

140+
auto reduce_hook = [&](void) -> void {
141+
auto* t_ptr =
142+
std::dynamic_pointer_cast<pten::DenseTensor>(Y.impl())->data<float>();
143+
for (int i = 0; i < Y.numel(); i++) {
144+
t_ptr[i] = 100.0; // set to 100.0
145+
}
146+
};
147+
121148
egr_utils_api::RetainGradForTensor(Y);
122149
egr_utils_api::RegisterGradientHookForTensor(Y, hook);
150+
egr_utils_api::RegisterReduceHookForTensor(Y, reduce_hook);
123151

124152
auto output_tensor = elementwise_add_dygraph_function(X, Y, {});
125153

@@ -130,6 +158,13 @@ TEST(Hook_intermidiate, ElementwiseAdd) {
130158

131159
eager_test::CompareGradTensorWithValue<float>(X, 1.0);
132160
eager_test::CompareGradTensorWithValue<float>(Y, 4.0);
161+
162+
// Checking ReduceHook results
163+
for (int i = 0; i < Y.numel(); i++) {
164+
CHECK_EQ(std::dynamic_pointer_cast<pten::DenseTensor>(Y.impl())
165+
->data<float>()[i],
166+
static_cast<float>(100.0f));
167+
}
133168
}
134169

135170
TEST(Hook_intermidiate, Matmul_v2) {
@@ -155,8 +190,17 @@ TEST(Hook_intermidiate, Matmul_v2) {
155190
const paddle::experimental::Tensor&)>
156191
hook = &hook_function;
157192

193+
auto reduce_hook = [&](void) -> void {
194+
auto* t_ptr =
195+
std::dynamic_pointer_cast<pten::DenseTensor>(Y.impl())->data<float>();
196+
for (int i = 0; i < Y.numel(); i++) {
197+
t_ptr[i] = 100.0; // set to 100.0
198+
}
199+
};
200+
158201
egr_utils_api::RetainGradForTensor(Y);
159202
egr_utils_api::RegisterGradientHookForTensor(Y, hook);
203+
egr_utils_api::RegisterReduceHookForTensor(Y, reduce_hook);
160204

161205
auto output_tensor = matmul_v2_dygraph_function(
162206
X, Y, {{"trans_x", false}, {"trans_y", false}});
@@ -168,8 +212,14 @@ TEST(Hook_intermidiate, Matmul_v2) {
168212

169213
eager_test::CompareGradTensorWithValue<float>(X, 2.0 * 20);
170214
eager_test::CompareGradTensorWithValue<float>(Y, 3.0 * 4 + 3);
171-
}
172215

216+
// Checking ReduceHook results
217+
for (int i = 0; i < Y.numel(); i++) {
218+
CHECK_EQ(std::dynamic_pointer_cast<pten::DenseTensor>(Y.impl())
219+
->data<float>()[i],
220+
static_cast<float>(100.0f));
221+
}
222+
}
173223
} // namespace egr
174224

175225
USE_OP(sigmoid);

0 commit comments

Comments
 (0)