@@ -73,12 +73,24 @@ TEST(Hook_intermidiate, Sigmoid) {
7373 const paddle::experimental::Tensor&)>
7474 hook = &hook_function;
7575
76+ VLOG (6 ) << " Make ReduceHook function" ;
77+ auto reduce_hook = [&](void ) -> void {
78+ auto * t_ptr = std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl ())
79+ ->data <float >();
80+ for (int i = 0 ; i < tensor.numel (); i++) {
81+ t_ptr[i] = 100.0 ; // set to 100.0
82+ }
83+ };
84+
7685 VLOG (6 ) << " Retain Grad for Tensor" ;
7786 egr_utils_api::RetainGradForTensor (tensor);
7887
7988 VLOG (6 ) << " Register GradientHook for Tensor" ;
8089 egr_utils_api::RegisterGradientHookForTensor (tensor, hook);
8190
91+ VLOG (6 ) << " Register ReduceHook for Tensor" ;
92+ egr_utils_api::RegisterReduceHookForTensor (tensor, reduce_hook);
93+
8294 VLOG (6 ) << " Runing Forward" ;
8395 auto output_tensor = sigmoid_dygraph_function (tensor, {});
8496 VLOG (6 ) << " Finish Forward" ;
@@ -92,6 +104,13 @@ TEST(Hook_intermidiate, Sigmoid) {
92104 VLOG (6 ) << " Finish Backward" ;
93105
94106 eager_test::CompareGradTensorWithValue<float >(tensor, 0.25 + 3 );
107+
108+ VLOG (6 ) << " Checking ReduceHook results" ;
109+ for (int i = 0 ; i < tensor.numel (); i++) {
110+ CHECK_EQ (std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl ())
111+ ->data <float >()[i],
112+ static_cast <float >(100 .0f ));
113+ }
95114 VLOG (6 ) << " After Tests" ;
96115}
97116
@@ -118,8 +137,17 @@ TEST(Hook_intermidiate, ElementwiseAdd) {
118137 const paddle::experimental::Tensor&)>
119138 hook = &hook_function;
120139
140+ auto reduce_hook = [&](void ) -> void {
141+ auto * t_ptr =
142+ std::dynamic_pointer_cast<pten::DenseTensor>(Y.impl ())->data <float >();
143+ for (int i = 0 ; i < Y.numel (); i++) {
144+ t_ptr[i] = 100.0 ; // set to 100.0
145+ }
146+ };
147+
121148 egr_utils_api::RetainGradForTensor (Y);
122149 egr_utils_api::RegisterGradientHookForTensor (Y, hook);
150+ egr_utils_api::RegisterReduceHookForTensor (Y, reduce_hook);
123151
124152 auto output_tensor = elementwise_add_dygraph_function (X, Y, {});
125153
@@ -130,6 +158,13 @@ TEST(Hook_intermidiate, ElementwiseAdd) {
130158
131159 eager_test::CompareGradTensorWithValue<float >(X, 1.0 );
132160 eager_test::CompareGradTensorWithValue<float >(Y, 4.0 );
161+
162+ // Checking ReduceHook results
163+ for (int i = 0 ; i < Y.numel (); i++) {
164+ CHECK_EQ (std::dynamic_pointer_cast<pten::DenseTensor>(Y.impl ())
165+ ->data <float >()[i],
166+ static_cast <float >(100 .0f ));
167+ }
133168}
134169
135170TEST (Hook_intermidiate, Matmul_v2) {
@@ -155,8 +190,17 @@ TEST(Hook_intermidiate, Matmul_v2) {
155190 const paddle::experimental::Tensor&)>
156191 hook = &hook_function;
157192
193+ auto reduce_hook = [&](void ) -> void {
194+ auto * t_ptr =
195+ std::dynamic_pointer_cast<pten::DenseTensor>(Y.impl ())->data <float >();
196+ for (int i = 0 ; i < Y.numel (); i++) {
197+ t_ptr[i] = 100.0 ; // set to 100.0
198+ }
199+ };
200+
158201 egr_utils_api::RetainGradForTensor (Y);
159202 egr_utils_api::RegisterGradientHookForTensor (Y, hook);
203+ egr_utils_api::RegisterReduceHookForTensor (Y, reduce_hook);
160204
161205 auto output_tensor = matmul_v2_dygraph_function (
162206 X, Y, {{" trans_x" , false }, {" trans_y" , false }});
@@ -168,8 +212,14 @@ TEST(Hook_intermidiate, Matmul_v2) {
168212
169213 eager_test::CompareGradTensorWithValue<float >(X, 2.0 * 20 );
170214 eager_test::CompareGradTensorWithValue<float >(Y, 3.0 * 4 + 3 );
171- }
172215
216+ // Checking ReduceHook results
217+ for (int i = 0 ; i < Y.numel (); i++) {
218+ CHECK_EQ (std::dynamic_pointer_cast<pten::DenseTensor>(Y.impl ())
219+ ->data <float >()[i],
220+ static_cast <float >(100 .0f ));
221+ }
222+ }
173223} // namespace egr
174224
175225USE_OP (sigmoid);
0 commit comments