Skip to content

Commit 5c60a90

Browse files
modify DPA model
1 parent c6b2115 commit 5c60a90

File tree

7 files changed

+72
-58
lines changed

7 files changed

+72
-58
lines changed

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
2424
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
2525
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
26+
#include "paddle/fluid/pir/utils/general_functions.h"
2627
#include "paddle/pir/include/core/builtin_dialect.h"
2728
#include "paddle/pir/include/core/builtin_op.h"
2829
#include "paddle/pir/include/pass/pass.h"
@@ -522,11 +523,11 @@ class PowOpPattern : public pir::OpRewritePattern<paddle::dialect::PowOp> {
522523
void Rewrite(paddle::dialect::PowOp op,
523524
pir::PatternRewriter &rewriter) const override {
524525
auto factor = op->attribute("y").dyn_cast<pir::FloatAttribute>().data();
525-
auto full_op =
526-
rewriter.Build<paddle::dialect::FullOp>(std::vector<int64_t>({1}),
527-
factor,
528-
phi::DataType::FLOAT32,
529-
phi::CPUPlace());
526+
auto full_op = rewriter.Build<paddle::dialect::FullOp>(
527+
std::vector<int64_t>({1}),
528+
factor,
529+
pir::GetValueDtype(op->operand_source(0)),
530+
phi::CPUPlace());
530531

531532
auto elementwise_pow = rewriter.Build<paddle::dialect::ElementwisePowOp>(
532533
op->operand_source(0), full_op->result(0));

paddle/fluid/pir/utils/general_functions.cc

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,35 @@ bool ValueIsPersistable(const pir::Value& value) {
279279
return true;
280280
}
281281

282+
phi::DataType GetTensorDtype(pir::Type type) {
283+
if (!type) {
284+
PADDLE_THROW(
285+
common::errors::InvalidArgument("The type of value is nullptr."));
286+
}
287+
if (auto dense_tensor_type = type.dyn_cast<pir::DenseTensorType>()) {
288+
return paddle::dialect::TransToPhiDataType(dense_tensor_type.dtype());
289+
} else if (auto sparse_coo_tensor_type =
290+
type.dyn_cast<paddle::dialect::SparseCooTensorType>()) {
291+
return paddle::dialect::TransToPhiDataType(sparse_coo_tensor_type.dtype());
292+
} else if (auto sparse_csr_tensor_type =
293+
type.dyn_cast<paddle::dialect::SparseCsrTensorType>()) {
294+
return paddle::dialect::TransToPhiDataType(sparse_csr_tensor_type.dtype());
295+
} else if (auto select_rows =
296+
type.dyn_cast<paddle::dialect::SelectedRowsType>()) {
297+
return paddle::dialect::TransToPhiDataType(select_rows.dtype());
298+
} else if (auto dense_array =
299+
type.dyn_cast<paddle::dialect::DenseTensorArrayType>()) {
300+
return paddle::dialect::TransToPhiDataType(dense_array.dtype());
301+
} else {
302+
PADDLE_THROW(common::errors::InvalidArgument(
303+
"Currently, we can only get phi::DataType from DenseTensorType and "
304+
"SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or "
305+
"SparseCsrTensorType."));
306+
}
307+
}
308+
309+
phi::DataType GetValueDtype(const pir::Value& val) {
310+
return GetTensorDtype(val.type());
311+
}
312+
282313
} // namespace pir

paddle/fluid/pir/utils/general_functions.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "paddle/phi/common/data_type.h"
2121
#include "paddle/phi/common/place.h"
2222
#include "paddle/pir/include/core/type.h"
23+
#include "paddle/pir/include/core/value.h"
2324
#include "paddle/pir/include/pass/pass.h"
2425

2526
namespace paddle {
@@ -247,4 +248,7 @@ std::vector<Value> GetUsedExternalValue(const Block& block);
247248
*/
248249
bool ValueIsPersistable(const pir::Value& value);
249250

251+
phi::DataType GetTensorDtype(pir::Type type);
252+
phi::DataType GetValueDtype(const pir::Value& val);
253+
250254
} // namespace pir

paddle/fluid/pybind/pir.cc

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pybind/pir.h"
16+
#include "paddle/fluid/pybind/pir_utils.h"
1617

1718
#include <Python.h>
1819
#include <algorithm>
@@ -207,35 +208,6 @@ std::string GetValueInfo(Value v) {
207208
return ss.str();
208209
}
209210

210-
phi::DataType GetTensorDtype(Type type) {
211-
if (!type) {
212-
PADDLE_THROW(
213-
common::errors::InvalidArgument("The type of value is nullptr."));
214-
}
215-
if (auto dense_tensor_type = type.dyn_cast<DenseTensorType>()) {
216-
return dialect::TransToPhiDataType(dense_tensor_type.dtype());
217-
} else if (auto sparse_coo_tensor_type =
218-
type.dyn_cast<SparseCooTensorType>()) {
219-
return dialect::TransToPhiDataType(sparse_coo_tensor_type.dtype());
220-
} else if (auto sparse_csr_tensor_type =
221-
type.dyn_cast<SparseCsrTensorType>()) {
222-
return dialect::TransToPhiDataType(sparse_csr_tensor_type.dtype());
223-
} else if (auto select_rows = type.dyn_cast<SelectedRowsType>()) {
224-
return dialect::TransToPhiDataType(select_rows.dtype());
225-
} else if (auto dense_array = type.dyn_cast<DenseTensorArrayType>()) {
226-
return dialect::TransToPhiDataType(dense_array.dtype());
227-
} else {
228-
PADDLE_THROW(common::errors::InvalidArgument(
229-
"Currently, we can only get phi::DataType from DenseTensorType and "
230-
"SelectedRowsType, DenseTensorArrayType,SparseCooTensorType or "
231-
"SparseCsrTensorType."));
232-
}
233-
}
234-
235-
phi::DataType GetValueDtype(Value value) {
236-
return GetTensorDtype(value.type());
237-
}
238-
239211
py::object Clone(const Program &self, IrMapping *p_mapper = nullptr) {
240212
IrMapping mapper;
241213
if (p_mapper == nullptr) {
@@ -271,7 +243,7 @@ pir::Value AppendDataOp(pir::Block *block,
271243
paddle::dialect::IntArrayAttribute::get(
272244
ctx, phi::IntArray(phi::vectorize(GetValueDims(value))))},
273245
{"dtype",
274-
paddle::dialect::DataTypeAttribute::get(ctx, GetValueDtype(value))},
246+
paddle::dialect::DataTypeAttribute::get(ctx, pir::GetValueDtype(value))},
275247
{"place", PlaceAttribute::get(ctx, phi::Place())}};
276248
std::vector<pir::Type> output_types{value.type()};
277249
pir::Operation *operation =
@@ -1369,7 +1341,7 @@ void BindValue(py::module *m) {
13691341
})
13701342
.def_property(
13711343
"dtype",
1372-
[](Value self) { return GetValueDtype(self); },
1344+
[](Value self) { return pir::GetValueDtype(self); },
13731345
[](Value self, phi::DataType dtype) {
13741346
PADDLE_THROW(common::errors::InvalidArgument(
13751347
"can't set dtype when building static graph"));
@@ -2241,7 +2213,7 @@ static void inline CreateVariableIfNotExist(
22412213
phi::DeviceContextPool &pool = phi::DeviceContextPool::Instance();
22422214
const phi::DeviceContext *dev_ctx = nullptr;
22432215
dev_ctx = pool.Get(exe->GetPlace());
2244-
dev_ctx->Alloc(tensor_temp, GetValueDtype(value));
2216+
dev_ctx->Alloc(tensor_temp, pir::GetValueDtype(value));
22452217
}
22462218
}
22472219
return;

paddle/pir/src/core/ir_printer.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,6 @@ void IrPrinter::PrintOpReturnType(const Operation& op) {
389389

390390
void IrPrinter::AddValueAlias(Value v, const std::string& alias) {
391391
const void* key = v.impl();
392-
PADDLE_ENFORCE_EQ(aliases_.find(key),
393-
aliases_.end(),
394-
common::errors::InvalidArgument("Value already has alias"));
395392
aliases_[key] = alias;
396393
}
397394

python/paddle/autograd/backward_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,11 @@ def copy(self, new_block):
328328
def _check_vjp_dynamic_shape(op, inputs):
329329
for items in inputs:
330330
for item in items:
331-
if item.initialized() and -1 in item.shape:
331+
if (
332+
item.is_dense_tensor_type()
333+
and item.initialized()
334+
and -1 in item.shape
335+
):
332336
return True
333337

334338

python/paddle/autograd/ir_backward.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -696,24 +696,29 @@ def append_yield(
696696

697697
if op.name() == "cf.tuple_push":
698698
stackop = op.operand_source(0).get_defining_op()
699-
with dynamic_shape_prim_vjp_guard(op, inputs):
700-
copy_out = paddle.framework.core.call_vjp(
701-
op,
702-
inputs,
703-
outputs,
704-
output_grads,
705-
input_grad_stopgradients,
706-
)
699+
if stackop.result(2).use_empty():
700+
with dynamic_shape_prim_vjp_guard(op, inputs):
701+
copy_out = paddle.framework.core.call_vjp(
702+
op,
703+
inputs,
704+
outputs,
705+
output_grads,
706+
input_grad_stopgradients,
707+
)
707708

708-
pop_op = bwd_block.ops[-1]
709-
while_tuple_ops.append(pop_op)
710-
while_tuple_ops.append(op)
711-
while_tuple_ops.append(stackop)
712-
bwd_ops = [pop_op]
713-
for output, copy_output in zip(inputs[1:], copy_out[1:]):
714-
control_flow_value_to_copyvalue_map[output[0]] = (
715-
copy_output[0]
716-
)
709+
pop_op = bwd_block.ops[-1]
710+
while_tuple_ops.append(pop_op)
711+
while_tuple_ops.append(op)
712+
while_tuple_ops.append(stackop)
713+
bwd_ops = [pop_op]
714+
for output, copy_output in zip(
715+
inputs[1:], copy_out[1:]
716+
):
717+
control_flow_value_to_copyvalue_map[output[0]] = (
718+
copy_output[0]
719+
)
720+
else:
721+
bwd_ops = [stackop.result(2).first_use().owner()]
717722
else:
718723
# all(zero_flag) support this op has no contribution for grad
719724
# should be delete (prune sub_graph)

0 commit comments

Comments
 (0)