@@ -3637,11 +3637,117 @@ bool RmsNormOpInferSymbolicShape(
36373637 return true ;
36383638}
36393639
3640- // bool RnnOpInferSymbolicShape(pir::Operation *op,
3641- // pir::InferSymbolicShapeContext *infer_context) {
3642- // // pass
3643- // return true;
3644- // }
3640+ bool RnnOpInferSymbolicShape (pir::Operation *op,
3641+ pir::InferSymbolicShapeContext *infer_context) {
3642+ const auto &x_shape_or_data =
3643+ infer_context->GetShapeOrDataForValue (op->operand_source (0 ));
3644+ const auto &pre_state_shape_or_data_list =
3645+ infer_context->GetShapeOrDataForValue (op->operand_source (1 ))
3646+ .dyn_cast <symbol::TensorListShapeOrDataDimExprs>();
3647+ const auto &sequence_length_shape_or_data =
3648+ infer_context->GetShapeOrDataForValue (op->operand_source (3 ));
3649+
3650+ const std::string &mode = op->attribute <pir::StrAttribute>(" mode" ).AsString ();
3651+ const bool &is_bidirec =
3652+ op->attribute <pir::BoolAttribute>(" is_bidirec" ).data ();
3653+ const int &hidden_size =
3654+ op->attribute <pir::Int32Attribute>(" hidden_size" ).data ();
3655+
3656+ const auto &x_shape = x_shape_or_data.shape ();
3657+ PADDLE_ENFORCE_EQ (x_shape.size (),
3658+ 3 ,
3659+ common::errors::InvalidArgument (
3660+ " The rank of Input in RNN must be 3. But "
3661+ " received Input's rank is %d." ,
3662+ x_shape.size ()));
3663+
3664+ if (!sequence_length_shape_or_data.isa <symbol::NullShapeOrDataDimExpr>()) {
3665+ const auto &sequence_length_shape = sequence_length_shape_or_data.shape ();
3666+ infer_context->AddEqualCstr (x_shape[1 ], sequence_length_shape[0 ]);
3667+ }
3668+
3669+ PADDLE_ENFORCE_EQ (pre_state_shape_or_data_list[0 ].shape ().size (),
3670+ 3 ,
3671+ common::errors::InvalidArgument (
3672+ " The rank of PreState in RNN must be 3. But "
3673+ " the received rank is %d." ,
3674+ pre_state_shape_or_data_list[0 ].shape ().size ()));
3675+ for (size_t i = 0 ; i < 3 ; ++i) {
3676+ details::BuildCstrEqForTensorListAlongAxis (
3677+ infer_context, pre_state_shape_or_data_list, i);
3678+ }
3679+ size_t i = 0 ;
3680+ for (; i < pre_state_shape_or_data_list.size (); ++i) {
3681+ infer_context->AddEqualCstr (x_shape[1 ],
3682+ pre_state_shape_or_data_list[i].shape ()[1 ]);
3683+ }
3684+ size_t num_state = mode == " LSTM" ? 2 : 1 ;
3685+ PADDLE_ENFORCE_EQ (i,
3686+ num_state,
3687+ common::errors::InvalidArgument (
3688+ " The number of tensors in PreState of %s should be %d, "
3689+ " but received %d." ,
3690+ mode,
3691+ 2 ,
3692+ i));
3693+ std::vector<symbol::DimExpr> out_shape = x_shape;
3694+ out_shape[2 ] = is_bidirec
3695+ ? symbol::DimExpr (static_cast <int64_t >(hidden_size) * 2 )
3696+ : symbol::DimExpr (static_cast <int64_t >(hidden_size));
3697+ infer_context->SetShapeOrDataForValue (
3698+ op->result (0 ),
3699+ symbol::ShapeOrDataDimExprs{
3700+ symbol::TensorShapeOrDataDimExprs (out_shape)});
3701+
3702+ size_t state_num = pre_state_shape_or_data_list.size ();
3703+ symbol::TensorListShapeOrDataDimExprs state_shape_or_data_list;
3704+ for (size_t i = 0 ; i < state_num; ++i) {
3705+ state_shape_or_data_list.emplace_back (
3706+ pre_state_shape_or_data_list[i].shape ());
3707+ }
3708+ infer_context->SetShapeOrDataForValue (
3709+ op->result (2 ), symbol::ShapeOrDataDimExprs{state_shape_or_data_list});
3710+
3711+ int gate_num = 4 ;
3712+ if (mode == " RNN_RELU" || mode == " RNN_TANH" ) {
3713+ gate_num = 1 ;
3714+ } else if (mode == " GRU" ) {
3715+ gate_num = 3 ;
3716+ }
3717+ const int &num_layers =
3718+ op->attribute <pir::Int32Attribute>(" num_layers" ).data ();
3719+
3720+ int hidden_date_idx = num_layers - 1 ;
3721+ if (mode == " LSTM" ) {
3722+ hidden_date_idx += (gate_num + 2 ) * num_layers;
3723+ } else if (mode == " GRU" ) {
3724+ hidden_date_idx += (gate_num + 1 ) * num_layers;
3725+ } else {
3726+ hidden_date_idx += gate_num * num_layers;
3727+ }
3728+ symbol::DimExpr block_size =
3729+ symbol::DimExpr (static_cast <int64_t >(num_state)) * x_shape[0 ] *
3730+ x_shape[1 ] * symbol::DimExpr (hidden_size);
3731+ std::vector<symbol::DimExpr> reserve_shape = {symbol::DimExpr (hidden_size),
3732+ block_size};
3733+ infer_context->SetShapeOrDataForValue (
3734+ op->result (3 ),
3735+ symbol::ShapeOrDataDimExprs{
3736+ symbol::TensorShapeOrDataDimExprs (reserve_shape)});
3737+
3738+ symbol::DimExpr dropout_state_shape = infer_context->GetNextSymName ();
3739+
3740+ infer_context->SetShapeOrDataForValue (
3741+ op->result (1 ),
3742+ symbol::ShapeOrDataDimExprs{
3743+ symbol::TensorShapeOrDataDimExprs ({dropout_state_shape})});
3744+ return true ;
3745+ }
3746+
3747+ bool Rnn_OpInferSymbolicShape (pir::Operation *op,
3748+ pir::InferSymbolicShapeContext *infer_context) {
3749+ return RnnOpInferSymbolicShape (op, infer_context);
3750+ }
36453751
36463752// bool RoiPoolOpInferSymbolicShape(pir::Operation *op,
36473753// pir::InferSymbolicShapeContext
0 commit comments