Skip to content

Commit 393a0b1

Browse files
authored
[NPU] refine nan check (#34508)
1 parent a6f55e4 commit 393a0b1

File tree

1 file changed

+23
-25
lines changed

1 file changed

+23
-25
lines changed

paddle/fluid/operators/collective/c_allreduce_op.h

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -123,32 +123,30 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
123123
#if defined(PADDLE_WITH_ASCEND_CL)
124124
// return true if found_inf_or_nan or return false;
125125
template <typename T>
126-
bool CheckNumerics(const framework::ExecutionContext& exe_ctx,
127-
aclrtStream stream, const paddle::framework::Tensor* in) {
126+
bool ContainsNan(const framework::ExecutionContext& exe_ctx, aclrtStream stream,
127+
const paddle::framework::Tensor* in) {
128128
auto& dev_ctx =
129129
exe_ctx.template device_context<paddle::platform::NPUDeviceContext>();
130130
using Tensor = paddle::framework::Tensor;
131131
Tensor out(in->type());
132-
out.Resize(in->dims());
133-
out.mutable_data<T>(dev_ctx.GetPlace());
134-
135-
bool found_inf_data = false;
136-
137-
try {
138-
const auto& runner =
139-
NpuOpRunner("CheckNumerics", {*in}, {out},
140-
{{"message", std::string("check_numberics")}});
141-
runner.Run(stream);
142-
dev_ctx.Wait();
143-
} catch (platform::EnforceNotMet& exception) {
144-
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
145-
found_inf_data = true;
146-
} catch (...) {
147-
LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!";
148-
found_inf_data = true;
132+
133+
Tensor mean(in->type());
134+
mean.Resize({1});
135+
mean.mutable_data<T>(dev_ctx.GetPlace());
136+
std::vector<int> axes;
137+
for (int i = 0; i < in->dims().size(); ++i) {
138+
axes.push_back(i);
149139
}
140+
const auto& runner_mean = NpuOpRunner("ReduceMeanD", {*in}, {mean},
141+
{{"axes", axes}, {"keep_dims", false}});
142+
143+
std::vector<T> vec;
144+
TensorToVector(mean, exe_ctx.device_context(), &vec);
150145

151-
return found_inf_data;
146+
if (std::isnan(static_cast<float>(vec[0]))) {
147+
return true;
148+
}
149+
return false;
152150
}
153151
#endif
154152

@@ -216,22 +214,22 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
216214
framework::Tensor tmp;
217215
tmp.mutable_data<float>({8}, ctx.GetPlace());
218216

219-
bool check_numerics = false;
217+
bool has_nan = false;
220218

221219
auto d_type = in->type();
222220
switch (d_type) {
223221
case framework::proto::VarType::FP16:
224222
case framework::proto::VarType::FP32: {
225-
VLOG(4) << "prepare to FoundNanInf";
226-
check_numerics = CheckNumerics<T>(ctx, dev_ctx->stream(), in);
227-
VLOG(4) << "check_numerics:" << check_numerics;
223+
VLOG(4) << "prepare to check nan";
224+
has_nan = ContainsNan<T>(ctx, dev_ctx->stream(), in);
225+
VLOG(4) << "ContainsNan:" << has_nan;
228226
break;
229227
}
230228
default:
231229
break;
232230
}
233231

234-
if (check_numerics) {
232+
if (has_nan) {
235233
T inf = static_cast<T>(std::numeric_limits<float>::infinity());
236234
VLOG(4) << "fill input data constant inf";
237235
auto dims = in->dims();

0 commit comments

Comments
 (0)