2020#include " paddle/fluid/eager/grad_node_info.h"
2121#include " paddle/fluid/eager/utils.h"
2222#include " paddle/phi/api/lib/utils/allocator.h"
23+ #include " paddle/phi/core/enforce.h"
2324#include " paddle/phi/core/kernel_registry.h"
2425#include " test/cpp/eager/data_structure_tests/grad_node_test.h"
2526#include " test/cpp/eager/test_utils.h"
@@ -53,7 +54,10 @@ TEST(EagerUtils, AutoGradMeta) {
5354
5455 AutogradMeta* unsafe_autograd_meta_after =
5556 EagerUtils::unsafe_autograd_meta (et0);
56- CHECK_NOTNULL (unsafe_autograd_meta_after);
57+ PADDLE_ENFORCE_NOT_NULL (
58+ unsafe_autograd_meta_after,
59+ phi::errors::PreconditionNotMet (
60+ " Unsafe autograd meta after should not be null." ));
5761
5862 // NOTE: Since autograd_meta will be copied make sure it's not null
5963 std::vector<paddle::Tensor> ets = {et0, et1};
@@ -62,8 +66,12 @@ TEST(EagerUtils, AutoGradMeta) {
6266 std::vector<AutogradMeta*> autograd_metas = EagerUtils::autograd_meta (&ets);
6367 std::vector<AutogradMeta*> unsafe_autograd_metas =
6468 EagerUtils::unsafe_autograd_meta (ets);
65- CHECK_NOTNULL (unsafe_autograd_metas[0 ]);
66- CHECK_NOTNULL (unsafe_autograd_metas[1 ]);
69+ PADDLE_ENFORCE_NOT_NULL (unsafe_autograd_metas[0 ],
70+ phi::errors::PreconditionNotMet (
71+ " Unsafe autograd metas should not be null." ));
72+ PADDLE_ENFORCE_NOT_NULL (unsafe_autograd_metas[1 ],
73+ phi::errors::PreconditionNotMet (
74+ " Unsafe autograd metas should not be null." ));
6775
6876 // Set Autograd Meta
6977 autograd_meta0->SetSingleOutRankWithSlot (0 , 1 );
@@ -72,32 +80,76 @@ TEST(EagerUtils, AutoGradMeta) {
7280
7381 // OutRankInfo()
7482 std::pair<size_t , size_t > out_rank_info0 = EagerUtils::OutRankInfo (et0);
75- CHECK_EQ (static_cast <int >(out_rank_info0.first ), 0 );
76- CHECK_EQ (static_cast <int >(out_rank_info0.second ), 1 );
83+ PADDLE_ENFORCE_EQ (
84+ static_cast <int >(out_rank_info0.first ),
85+ 0UL ,
86+ phi::errors::InvalidArgument (" The first element of out rank info "
87+ " mismatch. Expected 0 but received %d." ,
88+ static_cast <int >(out_rank_info0.first )));
89+ PADDLE_ENFORCE_EQ (
90+ static_cast <int >(out_rank_info0.second ),
91+ 1UL ,
92+ phi::errors::InvalidArgument (" The second element of out rank info "
93+ " mismatch. Expected 1 but received %d." ,
94+ static_cast <int >(out_rank_info0.second )));
7795
7896 // grad_node()
7997 std::shared_ptr<GradNodeBase> grad_node0 = EagerUtils::grad_node (et0);
80- CHECK_NOTNULL (grad_node0.get ());
98+ PADDLE_ENFORCE_NOT_NULL (
99+ grad_node0.get (),
100+ phi::errors::PreconditionNotMet (" Grad of node should not be null." ));
81101
82102 EagerUtils::SetHistory (autograd_meta1, test_node);
83103 EagerUtils::SetHistory (autograd_meta1, test_node);
84104 std::shared_ptr<GradNodeBase> grad_node1 = EagerUtils::grad_node (et1);
85- CHECK_NOTNULL (grad_node1.get ());
105+ PADDLE_ENFORCE_NOT_NULL (
106+ grad_node1.get (),
107+ phi::errors::PreconditionNotMet (" Grad of node should not be null." ));
86108
87109 // SetOutRankWithSlot()
88110 EagerUtils::SetOutRankWithSlot (autograd_meta1, 0 );
89111 std::pair<size_t , size_t > out_rank_info1 = EagerUtils::OutRankInfo (et1);
90- CHECK_EQ (static_cast <int >(out_rank_info1.first ), 0 );
91- CHECK_EQ (static_cast <int >(out_rank_info1.second ), 0 );
112+ PADDLE_ENFORCE_EQ (
113+ static_cast <int >(out_rank_info1.first ),
114+ 0UL ,
115+ phi::errors::InvalidArgument (" The first element of out rank info "
116+ " mismatch. Expected 0 but received %d." ,
117+ static_cast <int >(out_rank_info1.first )));
118+ PADDLE_ENFORCE_EQ (
119+ static_cast <int >(out_rank_info1.second ),
120+ 0UL ,
121+ phi::errors::InvalidArgument (" The second element of out rank info "
122+ " mismatch. Expected 0 but received %d." ,
123+ static_cast <int >(out_rank_info1.second )));
92124
93125 EagerUtils::SetOutRankWithSlot (&autograd_metas, 0 );
94126 std::pair<size_t , size_t > out_rank_info2 = EagerUtils::OutRankInfo (et0);
95- CHECK_EQ (static_cast <int >(out_rank_info2.first ), 0 );
96- CHECK_EQ (static_cast <int >(out_rank_info2.second ), 0 );
127+ PADDLE_ENFORCE_EQ (
128+ static_cast <int >(out_rank_info2.first ),
129+ 0UL ,
130+ phi::errors::InvalidArgument (" The first element of out rank info "
131+ " mismatch. Expected 0 but received %d." ,
132+ static_cast <int >(out_rank_info2.first )));
133+ PADDLE_ENFORCE_EQ (
134+ static_cast <int >(out_rank_info2.second ),
135+ 0UL ,
136+ phi::errors::InvalidArgument (" The second element of out rank info "
137+ " mismatch. Expected 0 but received %d." ,
138+ static_cast <int >(out_rank_info2.second )));
97139
98140 std::pair<size_t , size_t > out_rank_info3 = EagerUtils::OutRankInfo (et1);
99- CHECK_EQ (static_cast <int >(out_rank_info3.first ), 0 );
100- CHECK_EQ (static_cast <int >(out_rank_info3.second ), 1 );
141+ PADDLE_ENFORCE_EQ (
142+ static_cast <int >(out_rank_info3.first ),
143+ 0UL ,
144+ phi::errors::InvalidArgument (" The first element of out rank info "
145+ " mismatch. Expected 0 but received %d." ,
146+ static_cast <int >(out_rank_info3.first )));
147+ PADDLE_ENFORCE_EQ (
148+ static_cast <int >(out_rank_info3.second ),
149+ 1UL ,
150+ phi::errors::InvalidArgument (" The second element of out rank info "
151+ " mismatch. Expected 1 but received %d." ,
152+ static_cast <int >(out_rank_info3.second )));
101153}
102154
103155template <typename T>
@@ -122,7 +174,12 @@ TEST(EagerUtils, ComputeRequireGrad) {
122174 auto auto_grad1 = std::make_shared<egr::AutogradMeta>();
123175 auto auto_grad2 = std::make_shared<egr::AutogradMeta>();
124176 auto auto_grad3 = std::make_shared<egr::AutogradMeta>();
125- CHECK_EQ (auto_grad0->NumericStopGradient (), -1 );
177+ PADDLE_ENFORCE_EQ (
178+ auto_grad0->NumericStopGradient (),
179+ -1 ,
180+ phi::errors::InvalidArgument (" The NumericStopGradient of auto grad "
181+ " mismatch. Expected -1 but received %d." ,
182+ auto_grad0->NumericStopGradient ()));
126183 VLOG (6 ) << " Single Test ComputeRequireGrad" ;
127184 auto_grad0->SetStopGradient (true );
128185 CHECK (egr::EagerUtils::ComputeRequireGrad (true , auto_grad0.get ()) == false );
@@ -150,7 +207,12 @@ TEST(EagerUtils, PassStopGradient) {
150207 auto auto_grad1 = std::make_shared<egr::AutogradMeta>();
151208 auto auto_grad2 = std::make_shared<egr::AutogradMeta>();
152209 auto auto_grad3 = std::make_shared<egr::AutogradMeta>();
153- CHECK_EQ (auto_grad0->NumericStopGradient (), -1 );
210+ PADDLE_ENFORCE_EQ (
211+ auto_grad0->NumericStopGradient (),
212+ -1 ,
213+ phi::errors::InvalidArgument (" The NumericStopGradient of auto grad "
214+ " mismatch. Expected -1 but received %d." ,
215+ auto_grad0->NumericStopGradient ()));
154216 VLOG (6 ) << " Test PassStopGradient" ;
155217 egr::EagerUtils::PassStopGradient (false , auto_grad0.get ());
156218 CHECK (auto_grad0->StopGradient () == false );
@@ -176,10 +238,21 @@ TEST(EagerUtils, TrySyncToVar) {
176238
177239 const float * ptr = framework_tensor.data <float >();
178240 VLOG (6 ) << " Check Value for SyncToVarsSingle" ;
179- CHECK_EQ (framework_tensor.numel (), tensor.numel ());
241+ PADDLE_ENFORCE_EQ (framework_tensor.numel (),
242+ tensor.numel (),
243+ phi::errors::InvalidArgument (
244+ " The numel of framework tensor and numel of "
245+ " tensor should be the same, but received %d and %d." ,
246+ framework_tensor.numel (),
247+ tensor.numel ()));
180248
181249 for (int i = 0 ; i < framework_tensor.numel (); i++) {
182- CHECK_EQ (ptr[i], 5 .0f );
250+ PADDLE_ENFORCE_EQ (
251+ ptr[i],
252+ 5 .0f ,
253+ phi::errors::InvalidArgument (" The numel of framework tensor mismatch. "
254+ " Expected 5.0 but received %f." ,
255+ ptr[i]));
183256 }
184257}
185258
@@ -196,10 +269,22 @@ TEST(EagerUtils, TrySyncToVars) {
196269 const auto & framework_tensor = var->Get <phi::DenseTensor>();
197270
198271 const float * ptr = framework_tensor.data <float >();
199- CHECK_EQ (framework_tensor.numel (), tensors[0 ].numel ());
272+ PADDLE_ENFORCE_EQ (
273+ framework_tensor.numel (),
274+ tensors[0 ].numel (),
275+ phi::errors::InvalidArgument (
276+ " The numel of framework tensor and numel "
277+ " of tensor should be the same, but received %d and %d." ,
278+ framework_tensor.numel (),
279+ tensors[0 ].numel ()));
200280
201281 for (int i = 0 ; i < framework_tensor.numel (); i++) {
202- CHECK_EQ (ptr[i], 1.0 );
282+ PADDLE_ENFORCE_EQ (ptr[i],
283+ 1.0 ,
284+ phi::errors::InvalidArgument (
285+ " The numel of framework tensor mismatch. Expected "
286+ " 1.0 but received %f." ,
287+ ptr[i]));
203288 }
204289 }
205290
@@ -209,10 +294,22 @@ TEST(EagerUtils, TrySyncToVars) {
209294
210295 const float * ptr = framework_tensor.data <float >();
211296 VLOG (6 ) << " Check Value for SyncToVarsMultiple" ;
212- CHECK_EQ (framework_tensor.numel (), tensors[0 ].numel ());
297+ PADDLE_ENFORCE_EQ (
298+ framework_tensor.numel (),
299+ tensors[0 ].numel (),
300+ phi::errors::InvalidArgument (
301+ " The numel of framework tensor and numel "
302+ " of tensor should be the same, but received %d and %d." ,
303+ framework_tensor.numel (),
304+ tensors[0 ].numel ()));
213305
214306 for (int i = 0 ; i < framework_tensor.numel (); i++) {
215- CHECK_EQ (ptr[i], 2.0 );
307+ PADDLE_ENFORCE_EQ (ptr[i],
308+ 2.0 ,
309+ phi::errors::InvalidArgument (
310+ " The numel of framework tensor mismatch. Expected "
311+ " 2.0 but received %f." ,
312+ ptr[i]));
216313 }
217314 }
218315}
@@ -221,7 +318,11 @@ TEST(EagerUtils, CreateVars) {
221318 VLOG (6 ) << " Check CreateVars" ;
222319 std::vector<std::shared_ptr<egr::EagerVariable>> outs =
223320 egr::EagerUtils::CreateVars (2 );
224- CHECK_EQ (outs.size (), size_t (2 ));
321+ PADDLE_ENFORCE_EQ (
322+ outs.size (),
323+ 2UL ,
324+ phi::errors::InvalidArgument (
325+ " Size of outs mismatch. Expected 2 but received %d." , outs.size ()));
225326 CHECK (outs[0 ]->Var ().IsInitialized () == false );
226327}
227328
0 commit comments