@@ -123,32 +123,30 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
123123#if defined(PADDLE_WITH_ASCEND_CL)
124124// return true if found_inf_or_nan or return false;
125125template <typename T>
126- bool CheckNumerics (const framework::ExecutionContext& exe_ctx,
127- aclrtStream stream, const paddle::framework::Tensor* in) {
126+ bool ContainsNan (const framework::ExecutionContext& exe_ctx, aclrtStream stream ,
127+ const paddle::framework::Tensor* in) {
128128 auto & dev_ctx =
129129 exe_ctx.template device_context <paddle::platform::NPUDeviceContext>();
130130 using Tensor = paddle::framework::Tensor;
131131 Tensor out (in->type ());
132- out.Resize (in->dims ());
133- out.mutable_data <T>(dev_ctx.GetPlace ());
134-
135- bool found_inf_data = false ;
136-
137- try {
138- const auto & runner =
139- NpuOpRunner (" CheckNumerics" , {*in}, {out},
140- {{" message" , std::string (" check_numberics" )}});
141- runner.Run (stream);
142- dev_ctx.Wait ();
143- } catch (platform::EnforceNotMet& exception) {
144- LOG (WARNING) << " [check_nan_and_inf] detected contains NaN or INF!!!" ;
145- found_inf_data = true ;
146- } catch (...) {
147- LOG (WARNING) << " [check_nan_and_inf] detected contains NaN or INF!!!" ;
148- found_inf_data = true ;
132+
133+ Tensor mean (in->type ());
134+ mean.Resize ({1 });
135+ mean.mutable_data <T>(dev_ctx.GetPlace ());
136+ std::vector<int > axes;
137+ for (int i = 0 ; i < in->dims ().size (); ++i) {
138+ axes.push_back (i);
149139 }
140+ const auto & runner_mean = NpuOpRunner (" ReduceMeanD" , {*in}, {mean},
141+ {{" axes" , axes}, {" keep_dims" , false }});
142+
143+ std::vector<T> vec;
144+ TensorToVector (mean, exe_ctx.device_context (), &vec);
150145
151- return found_inf_data;
146+ if (std::isnan (static_cast <float >(vec[0 ]))) {
147+ return true ;
148+ }
149+ return false ;
152150}
153151#endif
154152
@@ -216,22 +214,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
216214 framework::Tensor tmp;
217215 tmp.mutable_data <float >({8 }, ctx.GetPlace ());
218216
219- bool check_numerics = false ;
217+ bool has_nan = false ;
220218
221219 auto d_type = in->type ();
222220 switch (d_type) {
223221 case framework::proto::VarType::FP16:
224222 case framework::proto::VarType::FP32: {
225- VLOG (4 ) << " prepare to FoundNanInf " ;
226- check_numerics = CheckNumerics <T>(ctx, dev_ctx->stream (), in);
227- VLOG (4 ) << " check_numerics :" << check_numerics ;
223+ VLOG (4 ) << " prepare to check nan " ;
224+ has_nan = ContainsNan <T>(ctx, dev_ctx->stream (), in);
225+ VLOG (4 ) << " ContainsNan :" << has_nan ;
228226 break ;
229227 }
230228 default :
231229 break ;
232230 }
233231
234- if (check_numerics ) {
232+ if (has_nan ) {
235233 T inf = static_cast <T>(std::numeric_limits<float >::infinity ());
236234 VLOG (4 ) << " fill input data constant inf" ;
237235 auto dims = in->dims ();
0 commit comments