@@ -24,6 +24,27 @@ namespace pir {
2424
2525const char *ModuleOp::attributes_name[attributes_num] = {" program" }; // NOLINT
2626
27+ bool IsDynamicShapeTypeEqual (Type type1, Type type2) {
28+ // Only support DenseTensorType now
29+ bool are_equal = false ;
30+ if (type1.isa <DenseTensorType>() && type2.isa <DenseTensorType>()) {
31+ auto type_l = type1.dyn_cast <DenseTensorType>();
32+ auto type_r = type2.dyn_cast <DenseTensorType>();
33+ auto vec1 = type_l.dims ();
34+ auto vec2 = type_r.dims ();
35+ if (vec1.size () != vec2.size ()) return false ;
36+ for (auto i = 0 ; i < vec1.size (); ++i) {
37+ are_equal = ((vec1[i] == -1 || vec2[i] == -1 ) || (vec1[i] == vec2[i])) |
38+ are_equal;
39+ }
40+ return static_cast <bool >(type_l.dtype () == type_r.dtype () &&
41+ type_l.data_layout () == type_r.data_layout () &&
42+ type_l.lod () == type_r.lod () &&
43+ type_l.offset () == type_r.offset () && are_equal);
44+ }
45+ return are_equal;
46+ }
47+
2748void PassStopGradientsDefaultly (OperationArgument &argument) { // NOLINT
2849 VLOG (10 ) << " Builder construction stop gradient for OpResults." ;
2950 bool stop_gradient = true ;
@@ -340,11 +361,12 @@ void CombineOp::VerifySig() const {
340361 input_num));
341362
342363 // forall i in inputs.size(): inputs[i].type == outputs[0][i].type
343- for (size_t i = 0 ; i < input_num; ++i) {
364+ for (uint64_t i = 0 ; i < input_num; ++i) {
344365 auto type = (*this )->operand (i).type ();
345366 PADDLE_ENFORCE_EQ (
346- output_type[i],
347- type,
367+ (output_type[i] == type ||
368+ IsDynamicShapeTypeEqual (output_type[i], type)),
369+ true ,
348370 common::errors::InvalidArgument (" The type %s of outputs[0][%d] must be "
349371 " equal to type %s of inputs[%d]." ,
350372 output_type[i],
0 commit comments