Skip to content

Commit 9ef93d1

Browse files
fix bug in deserialize VOID_DATA (#71339)
1 parent 4705ebe commit 9ef93d1

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ pir::FloatAttribute deserializeAttrFromJson<pir::FloatAttribute, float>(
9090
if (attr_json->contains(VOID_DATA)) {
9191
auto string = attr_json->at(VOID_DATA).template get<std::string>();
9292
if (string == "NAN") {
93-
return pir::FloatAttribute::get(ctx, std::nanf(""));
93+
return pir::FloatAttribute::get(ctx, NAN);
9494
} else if (string == "INF") {
95-
return pir::FloatAttribute::get(ctx, FLT_MAX);
95+
return pir::FloatAttribute::get(ctx, INFINITY);
9696
} else if (string == "-INF") {
97-
return pir::FloatAttribute::get(ctx, FLT_MIN);
97+
return pir::FloatAttribute::get(ctx, -INFINITY);
9898
}
9999
}
100100

@@ -108,11 +108,11 @@ pir::DoubleAttribute deserializeAttrFromJson<pir::DoubleAttribute, double>(
108108
if (attr_json->contains(VOID_DATA)) {
109109
auto string = attr_json->at(VOID_DATA).template get<std::string>();
110110
if (string == "NAN") {
111-
return pir::DoubleAttribute::get(ctx, std::nanf(""));
111+
return pir::DoubleAttribute::get(ctx, NAN);
112112
} else if (string == "INF") {
113-
return pir::DoubleAttribute::get(ctx, DBL_MAX);
113+
return pir::DoubleAttribute::get(ctx, INFINITY);
114114
} else if (string == "-INF") {
115-
return pir::DoubleAttribute::get(ctx, DBL_MIN);
115+
return pir::DoubleAttribute::get(ctx, -INFINITY);
116116
}
117117
}
118118
double data = attr_json->at(DATA).template get<double>();
@@ -640,12 +640,6 @@ pir::Type AttrTypeReader::ReadBuiltInType(const std::string type_name,
640640
} else if (type_name == pir::IndexType::name()) {
641641
VLOG(8) << "Parse IndexType ... ";
642642
return pir::deserializeTypeFromJson<pir::IndexType>(type_json, ctx);
643-
} else if (type_name == pir::Float8E4M3FNType::name()) {
644-
VLOG(8) << "Parse IndexType ... ";
645-
return pir::deserializeTypeFromJson<pir::Float8E4M3FNType>(type_json, ctx);
646-
} else if (type_name == pir::Float8E5M2Type::name()) {
647-
VLOG(8) << "Parse IndexType ... ";
648-
return pir::deserializeTypeFromJson<pir::Float8E5M2Type>(type_json, ctx);
649643
} else if (type_name == pir::Complex64Type::name()) {
650644
VLOG(8) << "Parse Complex64Type ... ";
651645
return pir::deserializeTypeFromJson<pir::Complex64Type>(type_json, ctx);

0 commit comments

Comments
 (0)