Skip to content

Commit 3ed9de3

Browse files
authored
[PIR] Fix shape mismatch bug in Save&Load (#69153)
* fix shape mismatch bug in while loop * Update builtin_op.cc * Update builtin_op.cc * Update builtin_op.cc * fix * update * update * rerun ci * Update builtin_op.cc
1 parent 65e5eb4 commit 3ed9de3

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

paddle/pir/src/core/builtin_op.cc

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,27 @@ namespace pir {
2424

2525
const char *ModuleOp::attributes_name[attributes_num] = {"program"}; // NOLINT
2626

27+
bool IsDynamicShapeTypeEqual(Type type1, Type type2) {
28+
// Only support DenseTensorType now
29+
bool are_equal = false;
30+
if (type1.isa<DenseTensorType>() && type2.isa<DenseTensorType>()) {
31+
auto type_l = type1.dyn_cast<DenseTensorType>();
32+
auto type_r = type2.dyn_cast<DenseTensorType>();
33+
auto vec1 = type_l.dims();
34+
auto vec2 = type_r.dims();
35+
if (vec1.size() != vec2.size()) return false;
36+
for (auto i = 0; i < vec1.size(); ++i) {
37+
are_equal = ((vec1[i] == -1 || vec2[i] == -1) || (vec1[i] == vec2[i])) |
38+
are_equal;
39+
}
40+
return static_cast<bool>(type_l.dtype() == type_r.dtype() &&
41+
type_l.data_layout() == type_r.data_layout() &&
42+
type_l.lod() == type_r.lod() &&
43+
type_l.offset() == type_r.offset() && are_equal);
44+
}
45+
return are_equal;
46+
}
47+
2748
void PassStopGradientsDefaultly(OperationArgument &argument) { // NOLINT
2849
VLOG(10) << "Builder construction stop gradient for OpResults.";
2950
bool stop_gradient = true;
@@ -340,11 +361,12 @@ void CombineOp::VerifySig() const {
340361
input_num));
341362

342363
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type
343-
for (size_t i = 0; i < input_num; ++i) {
364+
for (uint64_t i = 0; i < input_num; ++i) {
344365
auto type = (*this)->operand(i).type();
345366
PADDLE_ENFORCE_EQ(
346-
output_type[i],
347-
type,
367+
(output_type[i] == type ||
368+
IsDynamicShapeTypeEqual(output_type[i], type)),
369+
true,
348370
common::errors::InvalidArgument("The type %s of outputs[0][%d] must be "
349371
"equal to type %s of inputs[%d].",
350372
output_type[i],

0 commit comments

Comments
 (0)