Skip to content

Commit 0aa87d1

Browse files
【pir cinn】modify some bug of Dpa model (#70123)
* modify DPA model * modify ir_printer
1 parent c1cdbc7 commit 0aa87d1

File tree

7 files changed

+75
-59
lines changed

7 files changed

+75
-59
lines changed

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
2424
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
2525
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
26+
#include "paddle/fluid/pir/utils/general_functions.h"
2627
#include "paddle/pir/include/core/builtin_dialect.h"
2728
#include "paddle/pir/include/core/builtin_op.h"
2829
#include "paddle/pir/include/pass/pass.h"
@@ -522,11 +523,11 @@ class PowOpPattern : public pir::OpRewritePattern<paddle::dialect::PowOp> {
522523
void Rewrite(paddle::dialect::PowOp op,
523524
pir::PatternRewriter &rewriter) const override {
524525
auto factor = op->attribute("y").dyn_cast<pir::FloatAttribute>().data();
525-
auto full_op =
526-
rewriter.Build<paddle::dialect::FullOp>(std::vector<int64_t>({1}),
527-
factor,
528-
phi::DataType::FLOAT32,
529-
phi::CPUPlace());
526+
auto full_op = rewriter.Build<paddle::dialect::FullOp>(
527+
std::vector<int64_t>({1}),
528+
factor,
529+
pir::GetValueDtype(op->operand_source(0)),
530+
phi::CPUPlace());
530531

531532
auto elementwise_pow = rewriter.Build<paddle::dialect::ElementwisePowOp>(
532533
op->operand_source(0), full_op->result(0));

paddle/fluid/pir/utils/general_functions.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,35 @@ bool ValueIsPersistable(const pir::Value& value) {
279279
return true;
280280
}
281281

282+
phi::DataType GetTensorDtype(pir::Type type) {
283+
if (!type) {
284+
PADDLE_THROW(
285+
common::errors::InvalidArgument("The type of value is nullptr."));
286+
}
287+
if (auto dense_tensor_type = type.dyn_cast<pir::DenseTensorType>()) {
288+
return paddle::dialect::TransToPhiDataType(dense_tensor_type.dtype());
289+
} else if (auto sparse_coo_tensor_type =
290+
type.dyn_cast<paddle::dialect::SparseCooTensorType>()) {
291+
return paddle::dialect::TransToPhiDataType(sparse_coo_tensor_type.dtype());
292+
} else if (auto sparse_csr_tensor_type =
293+
type.dyn_cast<paddle::dialect::SparseCsrTensorType>()) {
294+
return paddle::dialect::TransToPhiDataType(sparse_csr_tensor_type.dtype());
295+
} else if (auto select_rows =
296+
type.dyn_cast<paddle::dialect::SelectedRowsType>()) {
297+
return paddle::dialect::TransToPhiDataType(select_rows.dtype());
298+
} else if (auto dense_array =
299+
type.dyn_cast<paddle::dialect::DenseTensorArrayType>()) {
300+
return paddle::dialect::TransToPhiDataType(dense_array.dtype());
301+
} else {
302+
PADDLE_THROW(common::errors::InvalidArgument(
303+
"Currently, we can only get phi::DataType from DenseTensorType and "
304+
"SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or "
305+
"SparseCsrTensorType."));
306+
}
307+
}
308+
309+
phi::DataType GetValueDtype(const pir::Value& val) {
310+
return GetTensorDtype(val.type());
311+
}
312+
282313
} // namespace pir

paddle/fluid/pir/utils/general_functions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/phi/common/data_type.h"
2121
#include "paddle/phi/common/place.h"
2222
#include "paddle/pir/include/core/type.h"
23+
#include "paddle/pir/include/core/value.h"
2324
#include "paddle/pir/include/pass/pass.h"
2425

2526
namespace paddle {
@@ -247,4 +248,7 @@ std::vector<Value> GetUsedExternalValue(const Block& block);
247248
*/
248249
bool ValueIsPersistable(const pir::Value& value);
249250

251+
phi::DataType GetTensorDtype(pir::Type type);
252+
phi::DataType GetValueDtype(const pir::Value& val);
253+
250254
} // namespace pir

paddle/fluid/pybind/pir.cc

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pybind/pir.h"
16+
#include "paddle/fluid/pybind/pir_utils.h"
1617

1718
#include <Python.h>
1819
#include <algorithm>
@@ -210,35 +211,6 @@ std::string GetValueInfo(Value v) {
210211
return ss.str();
211212
}
212213

213-
phi::DataType GetTensorDtype(Type type) {
214-
if (!type) {
215-
PADDLE_THROW(
216-
common::errors::InvalidArgument("The type of value is nullptr."));
217-
}
218-
if (auto dense_tensor_type = type.dyn_cast<DenseTensorType>()) {
219-
return dialect::TransToPhiDataType(dense_tensor_type.dtype());
220-
} else if (auto sparse_coo_tensor_type =
221-
type.dyn_cast<SparseCooTensorType>()) {
222-
return dialect::TransToPhiDataType(sparse_coo_tensor_type.dtype());
223-
} else if (auto sparse_csr_tensor_type =
224-
type.dyn_cast<SparseCsrTensorType>()) {
225-
return dialect::TransToPhiDataType(sparse_csr_tensor_type.dtype());
226-
} else if (auto select_rows = type.dyn_cast<SelectedRowsType>()) {
227-
return dialect::TransToPhiDataType(select_rows.dtype());
228-
} else if (auto dense_array = type.dyn_cast<DenseTensorArrayType>()) {
229-
return dialect::TransToPhiDataType(dense_array.dtype());
230-
} else {
231-
PADDLE_THROW(common::errors::InvalidArgument(
232-
"Currently, we can only get phi::DataType from DenseTensorType and "
233-
"SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or "
234-
"SparseCsrTensorType."));
235-
}
236-
}
237-
238-
phi::DataType GetValueDtype(Value value) {
239-
return GetTensorDtype(value.type());
240-
}
241-
242214
py::object Clone(const Program &self, IrMapping *p_mapper = nullptr) {
243215
IrMapping mapper;
244216
if (p_mapper == nullptr) {
@@ -274,7 +246,7 @@ pir::Value AppendDataOp(pir::Block *block,
274246
paddle::dialect::IntArrayAttribute::get(
275247
ctx, phi::IntArray(phi::vectorize(GetValueDims(value))))},
276248
{"dtype",
277-
paddle::dialect::DataTypeAttribute::get(ctx, GetValueDtype(value))},
249+
paddle::dialect::DataTypeAttribute::get(ctx, pir::GetValueDtype(value))},
278250
{"place", PlaceAttribute::get(ctx, phi::Place())}};
279251
std::vector<pir::Type> output_types{value.type()};
280252
pir::Operation *operation =
@@ -1427,7 +1399,7 @@ void BindValue(py::module *m) {
14271399
})
14281400
.def_property(
14291401
"dtype",
1430-
[](Value self) { return GetValueDtype(self); },
1402+
[](Value self) { return pir::GetValueDtype(self); },
14311403
[](Value self, phi::DataType dtype) {
14321404
PADDLE_THROW(common::errors::InvalidArgument(
14331405
"can't set dtype when building static graph"));
@@ -2299,7 +2271,7 @@ static void inline CreateVariableIfNotExist(
22992271
phi::DeviceContextPool &pool = phi::DeviceContextPool::Instance();
23002272
const phi::DeviceContext *dev_ctx = nullptr;
23012273
dev_ctx = pool.Get(exe->GetPlace());
2302-
dev_ctx->Alloc(tensor_temp, GetValueDtype(value));
2274+
dev_ctx->Alloc(tensor_temp, pir::GetValueDtype(value));
23032275
}
23042276
}
23052277
return;

paddle/pir/src/core/ir_printer.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,10 +389,9 @@ void IrPrinter::PrintOpReturnType(const Operation& op) {
389389

390390
void IrPrinter::AddValueAlias(Value v, const std::string& alias) {
391391
const void* key = v.impl();
392-
PADDLE_ENFORCE_EQ(aliases_.find(key),
393-
aliases_.end(),
394-
common::errors::InvalidArgument("Value already has alias"));
395-
aliases_[key] = alias;
392+
if (aliases_.find(key) == aliases_.end()) {
393+
aliases_[key] = alias;
394+
}
396395
}
397396

398397
class CustomPrinter : public IrPrinter {

python/paddle/autograd/backward_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,11 @@ def copy(self, new_block):
330330
def _check_vjp_dynamic_shape(op, inputs):
331331
for items in inputs:
332332
for item in items:
333-
if item.initialized() and -1 in item.shape:
333+
if (
334+
item.is_dense_tensor_type()
335+
and item.initialized()
336+
and -1 in item.shape
337+
):
334338
return True
335339

336340

python/paddle/autograd/ir_backward.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -696,24 +696,29 @@ def append_yield(
696696

697697
if op.name() == "cf.tuple_push":
698698
stackop = op.operand_source(0).get_defining_op()
699-
with dynamic_shape_prim_vjp_guard(op, inputs):
700-
copy_out = paddle.framework.core.call_vjp(
701-
op,
702-
inputs,
703-
outputs,
704-
output_grads,
705-
input_grad_stopgradients,
706-
)
699+
if stackop.result(2).use_empty():
700+
with dynamic_shape_prim_vjp_guard(op, inputs):
701+
copy_out = paddle.framework.core.call_vjp(
702+
op,
703+
inputs,
704+
outputs,
705+
output_grads,
706+
input_grad_stopgradients,
707+
)
707708

708-
pop_op = bwd_block.ops[-1]
709-
while_tuple_ops.append(pop_op)
710-
while_tuple_ops.append(op)
711-
while_tuple_ops.append(stackop)
712-
bwd_ops = [pop_op]
713-
for output, copy_output in zip(inputs[1:], copy_out[1:]):
714-
control_flow_value_to_copyvalue_map[output[0]] = (
715-
copy_output[0]
716-
)
709+
pop_op = bwd_block.ops[-1]
710+
while_tuple_ops.append(pop_op)
711+
while_tuple_ops.append(op)
712+
while_tuple_ops.append(stackop)
713+
bwd_ops = [pop_op]
714+
for output, copy_output in zip(
715+
inputs[1:], copy_out[1:]
716+
):
717+
control_flow_value_to_copyvalue_map[output[0]] = (
718+
copy_output[0]
719+
)
720+
else:
721+
bwd_ops = [stackop.result(2).first_use().owner()]
717722
else:
718723
# all(zero_flag) support this op has no contribution for grad
719724
# should be delete (prune sub_graph)

0 commit comments

Comments
 (0)