1818
1919#include " paddle/common/errors.h"
2020#include " paddle/fluid/framework/phi_utils.h"
21+ #include " paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
2122#include " paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
2223#include " paddle/fluid/pir/dialect/operator/ir/manual_op.h"
2324#include " paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
@@ -57,43 +58,43 @@ enum class AttrType {
5758 DOUBLE,
5859
5960 ARRAY,
61+ STRING,
62+ TENSOR_NAME,
63+ DATA_TYPE,
6064 INT_ARRAY,
65+ PLACE,
66+ TensorDist,
6167
6268 SCALAR,
63- DATA_TYPE,
6469 DATA_LAYOUT,
65- PLACE,
66-
67- STRING,
68-
69- TENSOR_NAME,
70-
7170 NUM_ATTR_TYPES,
7271};
7372
7473static inline AttrType GetAttributeType (const pir::Attribute& attr) {
7574 if (attr.isa <pir::BoolAttribute>()) {
7675 return AttrType::BOOL;
77- } else if (attr.isa <pir::FloatAttribute>()) {
78- return AttrType::FLOAT;
79- } else if (attr.isa <pir::DoubleAttribute>()) {
80- return AttrType::DOUBLE;
8176 } else if (attr.isa <pir::Int32Attribute>()) {
8277 return AttrType::INT32;
8378 } else if (attr.isa <pir::Int64Attribute>()) {
8479 return AttrType::INT64;
80+ } else if (attr.isa <pir::FloatAttribute>()) {
81+ return AttrType::FLOAT;
82+ } else if (attr.isa <pir::DoubleAttribute>()) {
83+ return AttrType::DOUBLE;
8584 } else if (attr.isa <pir::ArrayAttribute>()) {
8685 return AttrType::ARRAY;
8786 } else if (attr.isa <pir::StrAttribute>()) {
8887 return AttrType::STRING;
89- } else if (attr.isa <paddle::dialect::IntArrayAttribute >()) {
90- return AttrType::INT_ARRAY ;
88+ } else if (attr.isa <pir::TensorNameAttribute >()) {
89+ return AttrType::TENSOR_NAME ;
9190 } else if (attr.isa <paddle::dialect::DataTypeAttribute>()) {
9291 return AttrType::DATA_TYPE;
92+ } else if (attr.isa <paddle::dialect::IntArrayAttribute>()) {
93+ return AttrType::INT_ARRAY;
9394 } else if (attr.isa <paddle::dialect::PlaceAttribute>()) {
9495 return AttrType::PLACE;
95- } else if (attr.isa <pir::TensorNameAttribute >()) {
96- return AttrType::TENSOR_NAME ;
96+ } else if (attr.isa <paddle::dialect::TensorDistAttribute >()) {
97+ return AttrType::TensorDist ;
9798 } else {
9899 PADDLE_THROW (common::errors::Unimplemented (
99100 " Unsupported ir Attribute type when casting it into "
@@ -110,14 +111,6 @@ static std::function<T(const pir::Attribute& attr)> GetAttrCast(
110111 [](const pir::Attribute& attr) {
111112 return T{attr.dyn_cast <pir::BoolAttribute>().data ()};
112113 }},
113- {AttrType::FLOAT,
114- [](const pir::Attribute& attr) {
115- return T{attr.dyn_cast <pir::FloatAttribute>().data ()};
116- }},
117- {AttrType::DOUBLE,
118- [](const pir::Attribute& attr) {
119- return T{attr.dyn_cast <pir::DoubleAttribute>().data ()};
120- }},
121114 {AttrType::INT32,
122115 [](const pir::Attribute& attr) {
123116 return T{attr.dyn_cast <pir::Int32Attribute>().data ()};
@@ -126,28 +119,13 @@ static std::function<T(const pir::Attribute& attr)> GetAttrCast(
126119 [](const pir::Attribute& attr) {
127120 return T{attr.dyn_cast <pir::Int64Attribute>().data ()};
128121 }},
129- {AttrType::INT_ARRAY,
130- [](const pir::Attribute& attr) {
131- return T{attr.dyn_cast <paddle::dialect::IntArrayAttribute>()
132- .data ()
133- .GetData ()};
134- }},
135- {AttrType::STRING,
136- [](const pir::Attribute& attr) {
137- return T{attr.dyn_cast <pir::StrAttribute>().AsString ()};
138- }},
139- {AttrType::DATA_TYPE,
140- [](const pir::Attribute& attr) {
141- return T{
142- attr.dyn_cast <paddle::dialect::DataTypeAttribute>().data ()};
143- }},
144- {AttrType::PLACE,
122+ {AttrType::FLOAT,
145123 [](const pir::Attribute& attr) {
146- return T{attr.dyn_cast <paddle::dialect::PlaceAttribute >().data ()};
124+ return T{attr.dyn_cast <pir::FloatAttribute >().data ()};
147125 }},
148- {AttrType::TENSOR_NAME ,
126+ {AttrType::DOUBLE ,
149127 [](const pir::Attribute& attr) {
150- return T{attr.dyn_cast <pir::TensorNameAttribute >().data ()};
128+ return T{attr.dyn_cast <pir::DoubleAttribute >().data ()};
151129 }},
152130 {AttrType::ARRAY,
153131 [](const pir::Attribute& attr) {
@@ -211,7 +189,33 @@ static std::function<T(const pir::Attribute& attr)> GetAttrCast(
211189 " vector." ));
212190 }
213191 }},
214- };
192+ {AttrType::STRING,
193+ [](const pir::Attribute& attr) {
194+ return T{attr.dyn_cast <pir::StrAttribute>().AsString ()};
195+ }},
196+
197+ {AttrType::TENSOR_NAME,
198+ [](const pir::Attribute& attr) {
199+ return T{attr.dyn_cast <pir::TensorNameAttribute>().data ()};
200+ }},
201+ {AttrType::DATA_TYPE,
202+ [](const pir::Attribute& attr) {
203+ return T{
204+ attr.dyn_cast <paddle::dialect::DataTypeAttribute>().data ()};
205+ }},
206+ {AttrType::INT_ARRAY,
207+ [](const pir::Attribute& attr) {
208+ return T{attr.dyn_cast <paddle::dialect::IntArrayAttribute>()
209+ .data ()
210+ .GetData ()};
211+ }},
212+ {AttrType::PLACE,
213+ [](const pir::Attribute& attr) {
214+ return T{attr.dyn_cast <paddle::dialect::PlaceAttribute>().data ()};
215+ }},
216+ {AttrType::TensorDist, [](const pir::Attribute& attr) {
217+ return T{attr.dyn_cast <paddle::dialect::TensorDistAttribute>()};
218+ }}};
215219 return kAttrCastMap [attr_type];
216220}
217221
0 commit comments