@@ -802,8 +802,8 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
802802 pir::IrContext *ctx = pir::IrContext::Instance ();
803803 ctx->GetOrRegisterDialect <cinn::dialect::OperatorDialect>();
804804 ctx->GetOrRegisterDialect <pir::shape::ShapeDialect>();
805- auto pass_manager =
806- std::make_shared<:: pir::PassManager>(:: pir:: IrContext::Instance (), 2 );
805+ auto pass_manager = std::make_shared<::pir::PassManager>(
806+ :: pir::IrContext::Instance (), config_.pm_opt_level_ );
807807 if (!config_.glog_info_disabled ()) {
808808 pass_manager->EnablePrintStatistics ();
809809 }
@@ -882,12 +882,20 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
882882 std::make_unique<pir::PassManager::IRPrinterOption>(
883883 ir_printing_conditions, ir_printing_conditions));
884884 }
885+ // set attr
886+ for (const auto &pass : pass_pm.passes ()) {
887+ if (pass->name () == " matmul_add_act_fuse_pass" ||
888+ pass->name () == " conv2d_add_act_fuse_pass" ||
889+ pass->name () == " conv2d_add_fuse_pass" ) {
890+ pass->Set (" use_cutlass" , new bool (config_.use_cutlass_ ));
891+ }
892+ }
885893 pass_pm.Run (pir_program_.get ());
886894
887895 // Apply some basic passes required by the framework
888896 ::pir::PassManager basic_pass_pm (::pir::IrContext::Instance (),
889897 config_.pm_opt_level_ );
890-
898+ basic_pass_pm. AddPass (:: pir::CreateCommonSubexpressionEliminationPass ());
891899 auto params_sync_among_devices_pass =
892900 ::pir::CreateParamsSyncAmongDevicesPass ();
893901 params_sync_among_devices_pass->SetNotOwned (pir::Pass::kPlaceAttr , &place_);
@@ -918,6 +926,9 @@ void AnalysisPredictor::OptimizeInferencePirProgram() {
918926 paddle::dialect::PdOpLowerToKernelPass (pir_program_.get (), place_);
919927
920928 ::pir::PassManager lowered_pm (::pir::IrContext::Instance (), 3 );
929+ auto remove_shadow_feed_pass = ::pir::CreateRemoveShadowFeedPass ();
930+ remove_shadow_feed_pass->Set (" used_for_inference" , new bool (true ));
931+ lowered_pm.AddPass (std::move (remove_shadow_feed_pass));
921932 if (FLAGS_pir_apply_inplace_pass) {
922933 lowered_pm.AddPass (::pir::CreateInplacePass ());
923934 }
@@ -1081,9 +1092,10 @@ bool AnalysisPredictor::PrepareProgram(
10811092 executor_->CreateVariables (*inference_program_, 0 , false , sub_scope_);
10821093
10831094 if (config_.new_ir_enabled ()) {
1084- if (pir_program_ != nullptr ) {
1085- PADDLE_FATAL (" pir_program_ must be nullptr" );
1086- }
1095+ PADDLE_ENFORCE_EQ (
1096+ pir_program_,
1097+ nullptr ,
1098+ platform::errors::Fatal (" Here, pir_program must be a nullptr!" ));
10871099 pir_program_ = paddle::TranslateLegacyProgramToProgram (*inference_program_);
10881100 OptimizeInferencePirProgram ();
10891101 }
0 commit comments