Skip to content

Commit c8eb839

Browse files
committed
add cf saveload and flag (PaddlePaddle#68628)
1 parent f210620 commit c8eb839

File tree

7 files changed

+124
-39
lines changed

7 files changed

+124
-39
lines changed

paddle/common/flags.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1965,3 +1965,7 @@ PHI_DEFINE_EXPORTED_bool(fused_multi_transformer_op_use_mbfmha,
19651965
PHI_DEFINE_EXPORTED_int64(multi_block_attention_min_partition_size,
19661966
1024,
19671967
"The minimum partition size for flash decoding");
1968+
1969+
PHI_DEFINE_EXPORTED_bool(save_cf_stack_op,
1970+
false,
1971+
"Save cf stack op for higher-order derivatives.");

paddle/fluid/pir/serialize_deserialize/include/deserialize_utils.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "paddle/phi/common/data_type.h"
3030
#include "paddle/pir/include/core/builtin_attribute.h"
3131
#include "paddle/pir/include/core/builtin_type.h"
32+
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"
3233
#include "paddle/utils/flat_hash_map.h"
3334

3435
namespace pir {
@@ -66,6 +67,10 @@ class AttrTypeReader {
6667
static pir::Attribute ReadPaddleDistAttr(const std::string attr_name,
6768
Json* attr_json,
6869
pir::IrContext* ctx);
70+
71+
static pir::Type ReadControlFlowType(const std::string type_name,
72+
Json* type_json,
73+
pir::IrContext* ctx);
6974
};
7075

7176
template <typename T>
@@ -237,6 +242,9 @@ pir::Type parseType(Json* type_json) {
237242
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
238243
paddle::dialect::DistDialect::name()) {
239244
return AttrTypeReader::ReadPaddleDistType(name.second, type_json, ctx);
245+
} else if (DECOMPRESS_DIALECT_ID(name.first) ==
246+
pir::ControlFlowDialect::name()) {
247+
return AttrTypeReader::ReadControlFlowType(name.second, type_json, ctx);
240248
} else {
241249
PADDLE_ENFORCE(
242250
false,
@@ -695,4 +703,25 @@ pir::Type AttrTypeReader::ReadPaddleDistType(const std::string type_name,
695703
}
696704
}
697705

706+
pir::Type AttrTypeReader::ReadControlFlowType(const std::string type_name,
707+
Json* type_json,
708+
pir::IrContext* ctx) {
709+
if (type_name == pir::StackType::name()) {
710+
VLOG(8) << "Parse StackType ... ";
711+
return pir::deserializeTypeFromJson<pir::StackType>(type_json, ctx);
712+
} else if (type_name == pir::InletType::name()) {
713+
VLOG(8) << "Parse InletType ... ";
714+
return pir::deserializeTypeFromJson<pir::InletType>(type_json, ctx);
715+
} else if (type_name == pir::OutletType::name()) {
716+
VLOG(8) << "Parse OutletType ... ";
717+
return pir::deserializeTypeFromJson<pir::OutletType>(type_json, ctx);
718+
} else {
719+
PADDLE_ENFORCE(
720+
false,
721+
common::errors::InvalidArgument(
722+
"Unknown Type %s for parse controlflow dialect type", type_name));
723+
return pir::Type();
724+
}
725+
}
726+
698727
} // namespace pir

paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "paddle/phi/common/data_type.h"
2929
#include "paddle/pir/include/core/builtin_attribute.h"
3030
#include "paddle/pir/include/core/builtin_type.h"
31+
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"
3132

3233
namespace pir {
3334
#define COMPRESS_DIALECT_NAME(attr_template) \
@@ -53,6 +54,8 @@ class AttrTypeWriter {
5354
static Json WritePaddleDistType(const pir::Type& type);
5455

5556
static Json WritePaddleDistAttr(const pir::Attribute& attr);
57+
58+
static Json WriteControlFlowType(const pir::Type& type);
5659
};
5760
/** serializeTypeToJson is a template function to serialize
5861
* a pir type to a json object. a pir type may have value or no value
@@ -245,6 +248,9 @@ Json writeType(const pir::Type& type) {
245248
} else if (type.dialect().name() == paddle::dialect::DistDialect::name()) {
246249
VLOG(6) << "write PaddleDistType ... ";
247250
return AttrTypeWriter::WritePaddleDistType(type);
251+
} else if (type.dialect().name() == pir::ControlFlowDialect::name()) {
252+
VLOG(6) << "write ControlFlowDialect ... ";
253+
return AttrTypeWriter::WriteControlFlowType(type);
248254
} else {
249255
PADDLE_ENFORCE(
250256
false,
@@ -723,4 +729,26 @@ Json AttrTypeWriter::WritePaddleDistAttr(const pir::Attribute& attr) {
723729
return Json::object();
724730
}
725731

732+
Json AttrTypeWriter::WriteControlFlowType(const pir::Type& type) {
733+
Json type_json = Json::object();
734+
if (type.isa<pir::StackType>()) {
735+
VLOG(8) << "Write StackType ... ";
736+
return pir::serializeTypeToJson<pir::StackType>(
737+
type.dyn_cast<pir::StackType>());
738+
} else if (type.isa<pir::InletType>()) {
739+
VLOG(8) << "Write InletType ... ";
740+
return pir::serializeTypeToJson<pir::InletType>(
741+
type.dyn_cast<pir::InletType>());
742+
} else if (type.isa<pir::OutletType>()) {
743+
VLOG(8) << "Write OutletType ... ";
744+
return pir::serializeTypeToJson<pir::OutletType>(
745+
type.dyn_cast<pir::OutletType>());
746+
} else {
747+
PADDLE_ENFORCE(false,
748+
common::errors::InvalidArgument(
749+
"Unknown Type when write controlflow dialect type"));
750+
}
751+
return type_json;
752+
}
753+
726754
} // namespace pir

paddle/fluid/pir/serialize_deserialize/src/ir_serialize.cc

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pir/serialize_deserialize/include/ir_serialize.h"
16+
#include "paddle/common/flags.h"
1617
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
1718
#include "paddle/fluid/pir/serialize_deserialize/include/serialize_utils.h"
1819
#include "paddle/pir/include/core/dialect.h"
1920
#include "paddle/pir/include/core/operation.h"
2021
#include "paddle/pir/include/dialect/control_flow/ir/cf_dialect.h"
2122
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
2223

24+
COMMON_DECLARE_bool(save_cf_stack_op);
2325
namespace pir {
2426

2527
Json ProgramWriter::GetProgramJson(const pir::Program* program) {
@@ -99,29 +101,31 @@ Json ProgramWriter::WriteBlock(pir::Block* block,
99101
Json ops_json = Json::array();
100102

101103
/* delete cf.stack_create / cf.tuple_push */
102-
std::vector<pir::Operation*> delete_ops;
103-
for (auto op : block->ops()) {
104-
if (op->isa<pir::StackCreateOp>()) {
105-
delete_ops.push_back(op);
106-
}
107-
}
108-
VLOG(6) << "program before delete stack op :" << *(block->parent_program());
109-
for (auto op : delete_ops) {
110-
VLOG(0) << "Delete cf.stack_create / cf.tuple_push.";
111-
auto stack_op = op->dyn_cast<pir::StackCreateOp>();
112-
if (stack_op.inlet().HasOneUse()) {
113-
auto tuple_push_op = stack_op.tuple_push_op();
114-
auto block_in = tuple_push_op->GetParent();
115-
block_in->erase(*tuple_push_op);
104+
if (!FLAGS_save_cf_stack_op) {
105+
std::vector<pir::Operation*> delete_ops;
106+
for (auto op : block->ops()) {
107+
if (op->isa<pir::StackCreateOp>()) {
108+
delete_ops.push_back(op);
109+
}
116110
}
117-
if (stack_op.outlet().HasOneUse()) {
118-
auto tuple_pop_op = stack_op.tuple_pop_op();
119-
auto block_in = tuple_pop_op->GetParent();
120-
block_in->erase(*tuple_pop_op);
111+
VLOG(6) << "program before delete stack op :" << *(block->parent_program());
112+
for (auto op : delete_ops) {
113+
VLOG(0) << "Delete cf.stack_create / cf.tuple_push.";
114+
auto stack_op = op->dyn_cast<pir::StackCreateOp>();
115+
if (stack_op.inlet().HasOneUse()) {
116+
auto tuple_push_op = stack_op.tuple_push_op();
117+
auto block_in = tuple_push_op->GetParent();
118+
block_in->erase(*tuple_push_op);
119+
}
120+
if (stack_op.outlet().HasOneUse()) {
121+
auto tuple_pop_op = stack_op.tuple_pop_op();
122+
auto block_in = tuple_pop_op->GetParent();
123+
block_in->erase(*tuple_pop_op);
124+
}
125+
block->erase(*op);
121126
}
122-
block->erase(*op);
127+
VLOG(6) << "program after delete stack op :" << *(block->parent_program());
123128
}
124-
VLOG(6) << "program after delete stack op :" << *(block->parent_program());
125129
for (auto op : block->ops()) {
126130
auto op_json = WriteOp(*op);
127131
ops_json.emplace_back(op_json);

paddle/fluid/pir/serialize_deserialize/src/patch_util.cc

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "paddle/phi/common/data_type.h"
2525
#include "paddle/pir/include/core/builtin_attribute.h"
2626
#include "paddle/pir/include/core/builtin_type.h"
27+
#include "paddle/pir/include/dialect/control_flow/ir/cf_type.h"
2728

2829
namespace pir {
2930

@@ -168,9 +169,15 @@ std::string GetTypeName(const YAML::Node &action) {
168169

169170
Json GetTypeJson(const YAML::Node &action) {
170171
Json json;
171-
std::string dialect = DialectIdMap::Instance()->GetCompressDialectId(
172-
pir::BuiltinDialect::name()) +
173-
".";
172+
std::string builtin_dialect = DialectIdMap::Instance()->GetCompressDialectId(
173+
pir::BuiltinDialect::name()) +
174+
".";
175+
std::string op_dialect = DialectIdMap::Instance()->GetCompressDialectId(
176+
paddle::dialect::OperatorDialect::name()) +
177+
".";
178+
std::string cf_dialect = DialectIdMap::Instance()->GetCompressDialectId(
179+
pir::ControlFlowDialect::name()) +
180+
".";
174181
std::string type_name = "";
175182
if (action.IsScalar()) {
176183
type_name = action.as<std::string>();
@@ -181,54 +188,54 @@ Json GetTypeJson(const YAML::Node &action) {
181188
}
182189
if (type_name == "pir::BoolType") {
183190
VLOG(8) << "Get BoolType name.";
184-
json[ID] = dialect + pir::BoolType::name();
191+
json[ID] = builtin_dialect + pir::BoolType::name();
185192
} else if (type_name == "pir::BFloat16Type") {
186193
VLOG(8) << "Get BFloat16Type name.";
187-
json[ID] = dialect + pir::BFloat16Type::name();
194+
json[ID] = builtin_dialect + pir::BFloat16Type::name();
188195
} else if (type_name == "pir::Float16Type") {
189196
VLOG(8) << "Get Float16Type name.";
190-
json[ID] = dialect + pir::Float16Type::name();
197+
json[ID] = builtin_dialect + pir::Float16Type::name();
191198
} else if (type_name == "pir::Float32Type") {
192199
VLOG(8) << "Get Float32Type name.";
193-
json[ID] = dialect + pir::Float32Type::name();
200+
json[ID] = builtin_dialect + pir::Float32Type::name();
194201
} else if (type_name == "pir::Float64Type") {
195202
VLOG(8) << "Get Float64Type name.";
196-
json[ID] = dialect + pir::Float64Type::name();
203+
json[ID] = builtin_dialect + pir::Float64Type::name();
197204
} else if (type_name == "pir::Int8Type") {
198205
VLOG(8) << "Get Int8Type name.";
199-
json[ID] = dialect + pir::Int8Type::name();
206+
json[ID] = builtin_dialect + pir::Int8Type::name();
200207
} else if (type_name == "pir::UInt8Type") {
201208
VLOG(8) << "Get UInt8Type name.";
202-
json[ID] = dialect + pir::UInt8Type::name();
209+
json[ID] = builtin_dialect + pir::UInt8Type::name();
203210
} else if (type_name == "pir::Int16Type") {
204211
VLOG(8) << "Get Int16Type name.";
205-
json[ID] = dialect + pir::Int16Type::name();
212+
json[ID] = builtin_dialect + pir::Int16Type::name();
206213
} else if (type_name == "pir::Int32Type") {
207214
VLOG(8) << "Get Int32Type name.";
208-
json[ID] = dialect + pir::Int32Type::name();
215+
json[ID] = builtin_dialect + pir::Int32Type::name();
209216
} else if (type_name == "pir::Int64Type") {
210217
VLOG(8) << "Get Int64Type name.";
211-
json[ID] = dialect + pir::Int64Type::name();
218+
json[ID] = builtin_dialect + pir::Int64Type::name();
212219
} else if (type_name == "pir::IndexType") {
213220
VLOG(8) << "Get IndexType name.";
214-
json[ID] = dialect + pir::IndexType::name();
221+
json[ID] = builtin_dialect + pir::IndexType::name();
215222
} else if (type_name == "pir::Complex64Type") {
216223
VLOG(8) << "Get Complex64Type name.";
217-
json[ID] = dialect + pir::Complex64Type::name();
224+
json[ID] = builtin_dialect + pir::Complex64Type::name();
218225
} else if (type_name == "pir::Complex128Type") {
219226
VLOG(8) << "Get Complex128Type name.";
220-
json[ID] = dialect + pir::Complex128Type::name();
227+
json[ID] = builtin_dialect + pir::Complex128Type::name();
221228
} else if (type_name == "pir::VectorType") {
222229
VLOG(8) << "Get VectorType name.";
223-
json[ID] = dialect + pir::VectorType::name();
230+
json[ID] = builtin_dialect + pir::VectorType::name();
224231
json[DATA] = Json::array();
225232
for (size_t i = 0; i < action["default"].size(); i++) {
226233
YAML::Node array_value = action["default"][i];
227234
json[DATA].push_back(BuildTypeJsonPatch(array_value));
228235
}
229236
} else if (type_name == "pir::DenseTensorType") {
230237
VLOG(8) << "Get DenseTensorType name.";
231-
json[ID] = dialect + pir::DenseTensorType::name();
238+
json[ID] = builtin_dialect + pir::DenseTensorType::name();
232239
Json content = Json::array();
233240
YAML::Node tensor_value = action["default"];
234241
content.push_back(BuildTypeJsonPatch(tensor_value[0]));
@@ -242,6 +249,15 @@ Json GetTypeJson(const YAML::Node &action) {
242249

243250
content.push_back(tensor_value[4].as<int>()); // offset
244251
json[DATA] = content;
252+
} else if (type_name == "pir::StackType") {
253+
VLOG(8) << "Get StackType name.";
254+
json[ID] = cf_dialect + pir::StackType::name();
255+
} else if (type_name == "pir::InletType") {
256+
VLOG(8) << "Get InletType name.";
257+
json[ID] = cf_dialect + pir::InletType::name();
258+
} else if (type_name == "pir::OutletType") {
259+
VLOG(8) << "Get OutletType name.";
260+
json[ID] = cf_dialect + pir::OutletType::name();
245261
}
246262
return json;
247263
}

paddle/pir/include/core/storage_manager_support.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ class StorageHelperBase : public BaseT {
6565
using InterfaceList =
6666
typename Filter<TypeInterfaceBase, std::tuple<TraitOrInterface...>>::Type;
6767

68-
static ConcreteT dyn_cast_impl(BaseT type) {
68+
template <typename T>
69+
static ConcreteT dyn_cast_impl(T type) {
6970
if (type && type.type_id() == TypeId::get<ConcreteT>()) {
7071
return ConcreteT(type.storage());
7172
}

paddle/pir/include/dialect/control_flow/ir/cf_type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,16 +30,19 @@ class IR_API StackType
3030
: public Type::TypeBase<StackType, ContainerType, TypeStorage> {
3131
public:
3232
using Base::Base;
33+
static std::string name() { return "t_stack"; }
3334
};
3435

3536
class IR_API InletType : public Type::TypeBase<InletType, Type, TypeStorage> {
3637
public:
3738
using Base::Base;
39+
static std::string name() { return "t_inlet"; }
3840
};
3941

4042
class IR_API OutletType : public Type::TypeBase<OutletType, Type, TypeStorage> {
4143
public:
4244
using Base::Base;
45+
static std::string name() { return "t_outlet"; }
4346
};
4447

4548
} // namespace pir

0 commit comments

Comments
 (0)