@@ -123,30 +123,32 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
123123#if defined(PADDLE_WITH_ASCEND_CL)
124124// return true if found_inf_or_nan or return false;
125125template <typename T>
126- bool ContainsNan (const framework::ExecutionContext& exe_ctx, aclrtStream stream ,
127- const paddle::framework::Tensor* in) {
126+ bool CheckNumerics (const framework::ExecutionContext& exe_ctx,
127+ aclrtStream stream, const paddle::framework::Tensor* in) {
128128 auto & dev_ctx =
129129 exe_ctx.template device_context <paddle::platform::NPUDeviceContext>();
130130 using Tensor = paddle::framework::Tensor;
131131 Tensor out (in->type ());
132-
133- Tensor mean (in->type ());
134- mean.Resize ({1 });
135- mean.mutable_data <T>(dev_ctx.GetPlace ());
136- std::vector<int > axes;
137- for (int i = 0 ; i < in->dims ().size (); ++i) {
138- axes.push_back (i);
132+ out.Resize (in->dims ());
133+ out.mutable_data <T>(dev_ctx.GetPlace ());
134+
135+ bool found_inf_data = false ;
136+
137+ try {
138+ const auto & runner =
139+ NpuOpRunner (" CheckNumerics" , {*in}, {out},
140+ {{" message" , std::string (" check_numberics" )}});
141+ runner.Run (stream);
142+ dev_ctx.Wait ();
143+ } catch (platform::EnforceNotMet& exception) {
144+ LOG (WARNING) << " [check_nan_and_inf] detected contains NaN or INF!!!" ;
145+ found_inf_data = true ;
146+ } catch (...) {
147+ LOG (WARNING) << " [check_nan_and_inf] detected contains NaN or INF!!!" ;
148+ found_inf_data = true ;
139149 }
140- const auto & runner_mean = NpuOpRunner (" ReduceMeanD" , {*in}, {mean},
141- {{" axes" , axes}, {" keep_dims" , false }});
142-
143- std::vector<T> vec;
144- TensorToVector (mean, exe_ctx.device_context (), &vec);
145150
146- if (std::isnan (static_cast <float >(vec[0 ]))) {
147- return true ;
148- }
149- return false ;
151+ return found_inf_data;
150152}
151153#endif
152154
@@ -214,22 +216,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
214216 framework::Tensor tmp;
215217 tmp.mutable_data <float >({8 }, ctx.GetPlace ());
216218
217- bool has_nan = false ;
219+ bool check_numerics = false ;
218220
219221 auto d_type = in->type ();
220222 switch (d_type) {
221223 case framework::proto::VarType::FP16:
222224 case framework::proto::VarType::FP32: {
223- VLOG (4 ) << " prepare to check nan " ;
224- has_nan = ContainsNan <T>(ctx, dev_ctx->stream (), in);
225- VLOG (4 ) << " ContainsNan :" << has_nan ;
225+ VLOG (4 ) << " prepare to FoundNanInf " ;
226+ check_numerics = CheckNumerics <T>(ctx, dev_ctx->stream (), in);
227+ VLOG (4 ) << " check_numerics :" << check_numerics ;
226228 break ;
227229 }
228230 default :
229231 break ;
230232 }
231233
232- if (has_nan ) {
234+ if (check_numerics ) {
233235 T inf = static_cast <T>(std::numeric_limits<float >::infinity ());
234236 VLOG (4 ) << " fill input data constant inf" ;
235237 auto dims = in->dims ();
0 commit comments