Skip to content

Commit fdc38b2

Browse files
authored
[DRR] change namespace pir::drr:: to paddle::drr:: (#60432)
1 parent 51dc031 commit fdc38b2

35 files changed

+597
-529
lines changed

paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ namespace cinn {
3131
namespace dialect {
3232
namespace ir {
3333

34-
class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
34+
class SumOpPattern : public paddle::drr::DrrPatternBase<SumOpPattern> {
3535
public:
36-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
36+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
3737
// Source Pattern
38-
pir::drr::SourcePattern pattern = ctx->SourcePattern();
38+
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
3939
const auto &full_int_array =
4040
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
4141
{{"value", pattern.Attr("axis_info")},
@@ -48,7 +48,7 @@ class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
4848
pattern.Tensor("ret") = sum(pattern.Tensor("arg0"), full_int_array());
4949

5050
// Result patterns
51-
pir::drr::ResultPattern res = pattern.ResultPattern();
51+
paddle::drr::ResultPattern res = pattern.ResultPattern();
5252
const auto &cinn_reduce_sum =
5353
res.Op(cinn::dialect::ReduceSumOp::name(),
5454
{{"dim", pattern.Attr("axis_info")},
@@ -57,11 +57,11 @@ class SumOpPattern : public pir::drr::DrrPatternBase<SumOpPattern> {
5757
}
5858
};
5959

60-
class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
60+
class MaxOpPattern : public paddle::drr::DrrPatternBase<MaxOpPattern> {
6161
public:
62-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
62+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
6363
// Source Pattern
64-
pir::drr::SourcePattern pattern = ctx->SourcePattern();
64+
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
6565
const auto &full_int_array =
6666
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
6767
{{"value", pattern.Attr("axis_info")},
@@ -73,7 +73,7 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
7373
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());
7474

7575
// Result patterns
76-
pir::drr::ResultPattern res = pattern.ResultPattern();
76+
paddle::drr::ResultPattern res = pattern.ResultPattern();
7777
const auto &cinn_reduce_max =
7878
res.Op(cinn::dialect::ReduceMaxOp::name(),
7979
{{"dim", pattern.Attr("axis_info")},
@@ -82,11 +82,11 @@ class MaxOpPattern : public pir::drr::DrrPatternBase<MaxOpPattern> {
8282
}
8383
};
8484

85-
class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
85+
class MinOpPattern : public paddle::drr::DrrPatternBase<MinOpPattern> {
8686
public:
87-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
87+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
8888
// Source Pattern
89-
pir::drr::SourcePattern pattern = ctx->SourcePattern();
89+
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
9090
const auto &full_int_array =
9191
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
9292
{{"value", pattern.Attr("axis_info")},
@@ -98,7 +98,7 @@ class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
9898
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());
9999

100100
// Result patterns
101-
pir::drr::ResultPattern res = pattern.ResultPattern();
101+
paddle::drr::ResultPattern res = pattern.ResultPattern();
102102
const auto &cinn_reduce_max =
103103
res.Op(cinn::dialect::ReduceMinOp::name(),
104104
{{"dim", pattern.Attr("axis_info")},
@@ -107,11 +107,11 @@ class MinOpPattern : public pir::drr::DrrPatternBase<MinOpPattern> {
107107
}
108108
};
109109

110-
class ProdOpPattern : public pir::drr::DrrPatternBase<ProdOpPattern> {
110+
class ProdOpPattern : public paddle::drr::DrrPatternBase<ProdOpPattern> {
111111
public:
112-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
112+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
113113
// Source Pattern
114-
pir::drr::SourcePattern pattern = ctx->SourcePattern();
114+
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
115115
const auto &full_int_array =
116116
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
117117
{{"value", pattern.Attr("axis_info")},
@@ -123,7 +123,7 @@ class ProdOpPattern : public pir::drr::DrrPatternBase<ProdOpPattern> {
123123
pattern.Tensor("ret") = pd_max(pattern.Tensor("arg0"), full_int_array());
124124

125125
// Result patterns
126-
pir::drr::ResultPattern res = pattern.ResultPattern();
126+
paddle::drr::ResultPattern res = pattern.ResultPattern();
127127
const auto &cinn_reduce_max =
128128
res.Op(cinn::dialect::ReduceProdOp::name(),
129129
{{"dim", pattern.Attr("axis_info")},
@@ -552,11 +552,11 @@ class SplitWithNumOpPattern
552552
}
553553
};
554554

555-
class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
555+
class UniformOpPattern : public paddle::drr::DrrPatternBase<UniformOpPattern> {
556556
public:
557-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
557+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
558558
// Source Pattern
559-
pir::drr::SourcePattern pattern = ctx->SourcePattern();
559+
paddle::drr::SourcePattern pattern = ctx->SourcePattern();
560560
const auto &full_int_array =
561561
pattern.Op(paddle::dialect::FullIntArrayOp::name(),
562562
{{"value", pattern.Attr("axis_info")},
@@ -585,7 +585,7 @@ class UniformOpPattern : public pir::drr::DrrPatternBase<UniformOpPattern> {
585585
// int64_t[] shape, float min, float max, int seed, DataType dtype, int
586586
// diag_num, int diag_step, float diag_val)
587587
// Result patterns
588-
pir::drr::ResultPattern res = pattern.ResultPattern();
588+
paddle::drr::ResultPattern res = pattern.ResultPattern();
589589
const auto &cinn_uniform =
590590
res.Op(cinn::dialect::UniformRandomOp::name(),
591591
{{"shape", pattern.Attr("axis_info")},

paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,22 +27,22 @@
2727
{op_header}
2828
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
2929
30-
namespace pir {{
30+
namespace paddle {{
3131
namespace drr {{
3232
3333
void OperationFactory::Register{dialect}GeneratedOpCreator() {{
3434
{body}
3535
}}
3636
3737
}} // namespace drr
38-
}} // namespace pir
38+
}} // namespace paddle
3939
4040
"""
4141

4242
NORMAL_FUNCTION_TEMPLATE = """
4343
RegisterOperationCreator(
4444
"{op_name}",
45-
[](const std::vector<Value>& inputs,
45+
[](const std::vector<pir::Value>& inputs,
4646
const pir::AttributeMap& attrs,
4747
pir::PatternRewriter& rewriter) {{
4848
return rewriter.Build<{namespace}::{op_class_name}>(
@@ -53,7 +53,7 @@
5353
MUTABLE_ATTR_FUNCTION_TEMPLATE = """
5454
RegisterOperationCreator(
5555
"{op_name}",
56-
[](const std::vector<Value>& inputs,
56+
[](const std::vector<pir::Value>& inputs,
5757
const pir::AttributeMap& attrs,
5858
pir::PatternRewriter& rewriter) {{
5959
// mutable_attr is tensor

paddle/fluid/pir/drr/README.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ Taking PASS to eliminate redundant CastOp as an example, the code example develo
1010
~~~ c++
1111
// 1. Inherit specialized template class from DrPatternBase
1212
class RemoveRedundentCastPattern
13-
: public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> {
13+
: public paddle::drr::DrrPatternBase<RemoveRedundentCastPattern> {
1414
// 2. Overload operator()
15-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
15+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
1616
// 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute
1717
auto pat = ctx->SourcePattern();
1818

@@ -55,7 +55,7 @@ Developers only need to define `SourcePattern`, `Constrains` and `ResultPattern`
5555
<tr>
5656
<td rowspan="1">DrrPatternBase</td>
5757
<td> <pre> virtual void operator()(
58-
pir::drr::DrrPatternContext* ctx) const </pre></td>
58+
paddle::drr::DrrPatternContext* ctx) const </pre></td>
5959
<td> Implement the entry function of DRR PASS </td>
6060
<td> ctx: Context parameters required to create Patten</td>
6161
</tr>
@@ -165,11 +165,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
165165
## 3 Example
166166
Example 1: Matmul + Add -> FusedGemmEpilogue
167167
~~~ c++
168-
class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
168+
class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern> {
169169
public:
170-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
170+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
171171
// Define SourcePattern
172-
pir::drr::SourcePattern pat = ctx->SourcePattern();
172+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
173173
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
174174
{{"transpose_x", pat.Attr("trans_x")},
175175
{"transpose_y", pat.Attr("trans_y")}});
@@ -179,10 +179,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
179179
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
180180
181181
// Define ResultPattern
182-
pir::drr::ResultPattern res = pat.ResultPattern();
182+
paddle::drr::ResultPattern res = pat.ResultPattern();
183183
// Define Constrain
184184
const auto &act_attr =
185-
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
185+
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
186186
return "none";
187187
});
188188
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
@@ -199,11 +199,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
199199
Example 2: Full + Expand -> Full
200200
~~~ c++
201201
class FoldExpandToConstantPattern
202-
: public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> {
202+
: public paddle::drr::DrrPatternBase<FoldExpandToConstantPattern> {
203203
public:
204-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
204+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
205205
// Define SourcePattern
206-
pir::drr::SourcePattern pat = ctx->SourcePattern();
206+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
207207
const auto &full1 = pat.Op(paddle::dialect::FullOp::name(),
208208
{{"shape", pat.Attr("shape_1")},
209209
{"value", pat.Attr("value_1")},
@@ -218,7 +218,7 @@ class FoldExpandToConstantPattern
218218
pat.Tensor("ret") = expand(full1(), full_int_array1());
219219

220220
// Define ResultPattern
221-
pir::drr::ResultPattern res = pat.ResultPattern();
221+
paddle::drr::ResultPattern res = pat.ResultPattern();
222222
const auto &full2 = res.Op(paddle::dialect::FullOp::name(),
223223
{{"shape", pat.Attr("expand_shape_value")},
224224
{"value", pat.Attr("value_1")},

paddle/fluid/pir/drr/README_cn.md

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P
1010
~~~ c++
1111
// 1. 继承 DrrPatternBase 的特化模板类
1212
class RemoveRedundentCastPattern
13-
: public pir::drr::DrrPatternBase<RemoveRedundentCastPattern> {
13+
: public paddle::drr::DrrPatternBase<RemoveRedundentCastPattern> {
1414
// 2. 重载 operator()
15-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
15+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
1616
// 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern
1717
auto pat = ctx->SourcePattern();
1818

@@ -56,7 +56,7 @@ DRR PASS 包含以下三个部分:
5656
<tr>
5757
<td rowspan="1">DrrPatternBase</td>
5858
<td> <pre> virtual void operator()(
59-
pir::drr::DrrPatternContext* ctx) const </pre></td>
59+
paddle::drr::DrrPatternContext* ctx) const </pre></td>
6060
<td> 实现 DRR PASS 的入口函数 </td>
6161
<td> ctx: 创建 Patten 所需要的 Context 参数</td>
6262
</tr>
@@ -168,11 +168,11 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
168168
## 3 使用示例
169169
Example 1: Matmul + Add -> FusedGemmEpilogue
170170
~~~ c++
171-
class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
171+
class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern> {
172172
public:
173-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
173+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
174174
// 定义 Source Pattern
175-
pir::drr::SourcePattern pat = ctx->SourcePattern();
175+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
176176
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
177177
{{"transpose_x", pat.Attr("trans_x")},
178178
{"transpose_y", pat.Attr("trans_y")}});
@@ -182,10 +182,10 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
182182
pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias"));
183183
184184
// 定义 Result Pattern
185-
pir::drr::ResultPattern res = pat.ResultPattern();
185+
paddle::drr::ResultPattern res = pat.ResultPattern();
186186
// 定义 Constrain
187187
const auto &act_attr =
188-
res.Attr([](const pir::drr::MatchContext &match_ctx) -> std::any {
188+
res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any {
189189
return "none";
190190
});
191191
const auto &fused_gemm_epilogue = res.Op(paddle::dialect::FusedGemmEpilogueOp::name(),
@@ -202,11 +202,11 @@ class FusedLinearPattern : public pir::drr::DrrPatternBase<FusedLinearPattern> {
202202
Example 2: Full + Expand -> Full
203203
~~~ c++
204204
class FoldExpandToConstantPattern
205-
: public pir::drr::DrrPatternBase<FoldExpandToConstantPattern> {
205+
: public paddle::drr::DrrPatternBase<FoldExpandToConstantPattern> {
206206
public:
207-
void operator()(pir::drr::DrrPatternContext *ctx) const override {
207+
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
208208
// 定义 Source Pattern
209-
pir::drr::SourcePattern pat = ctx->SourcePattern();
209+
paddle::drr::SourcePattern pat = ctx->SourcePattern();
210210
const auto &full1 = pat.Op(paddle::dialect::FullOp::name(),
211211
{{"shape", pat.Attr("shape_1")},
212212
{"value", pat.Attr("value_1")},
@@ -221,7 +221,7 @@ class FoldExpandToConstantPattern
221221
pat.Tensor("ret") = expand(full1(), full_int_array1());
222222

223223
// 定义 Result Pattern Constrains: 本 Pass 无额外约束规则
224-
pir::drr::ResultPattern res = pat.ResultPattern();
224+
paddle::drr::ResultPattern res = pat.ResultPattern();
225225
const auto &full2 = res.Op(paddle::dialect::FullOp::name(),
226226
{{"shape", pat.Attr("expand_shape_value")},
227227
{"value", pat.Attr("value_1")},

paddle/fluid/pir/drr/api/drr_pattern_base.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include "paddle/fluid/pir/drr/api/drr_pattern_context.h"
1818
#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h"
1919

20-
namespace pir {
20+
namespace paddle {
2121
namespace drr {
2222

2323
template <typename DrrPattern>
@@ -26,7 +26,7 @@ class DrrPatternBase {
2626
virtual ~DrrPatternBase() = default;
2727

2828
// Define the Drr Pattern.
29-
virtual void operator()(pir::drr::DrrPatternContext* ctx) const = 0;
29+
virtual void operator()(paddle::drr::DrrPatternContext* ctx) const = 0;
3030

3131
std::unique_ptr<DrrRewritePattern> Build(
3232
pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const {
@@ -39,4 +39,4 @@ class DrrPatternBase {
3939
};
4040

4141
} // namespace drr
42-
} // namespace pir
42+
} // namespace paddle

paddle/fluid/pir/drr/api/drr_pattern_context.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#include "paddle/fluid/pir/drr/pattern_graph.h"
1818
#include "paddle/phi/core/enforce.h"
1919

20-
namespace pir {
20+
namespace paddle {
2121
namespace drr {
2222

2323
DrrPatternContext::DrrPatternContext() {
@@ -28,6 +28,7 @@ DrrPatternContext::DrrPatternContext() {
2828
drr::SourcePattern DrrPatternContext::SourcePattern() {
2929
return drr::SourcePattern(this);
3030
}
31+
3132
const Op& DrrPatternContext::SourceOpPattern(
3233
const std::string& op_type,
3334
const std::unordered_map<std::string, Attribute>& attributes) {
@@ -167,4 +168,4 @@ void Tensor::operator=(const Tensor& other) const { // NOLINT
167168
}
168169

169170
} // namespace drr
170-
} // namespace pir
171+
} // namespace paddle

paddle/fluid/pir/drr/api/drr_pattern_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
#include "paddle/fluid/pir/drr/api/match_context.h"
2626

27-
namespace pir {
27+
namespace paddle {
2828
namespace drr {
2929

3030
class Op;
@@ -334,4 +334,4 @@ class SourcePattern {
334334
};
335335

336336
} // namespace drr
337-
} // namespace pir
337+
} // namespace paddle

paddle/fluid/pir/drr/api/match_context.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "paddle/fluid/pir/drr/ir_operation.h"
2020
#include "paddle/fluid/pir/drr/match_context_impl.h"
2121

22-
namespace pir {
22+
namespace paddle {
2323
namespace drr {
2424

2525
MatchContext::MatchContext(std::shared_ptr<const MatchContextImpl> impl)
@@ -46,4 +46,4 @@ template std::vector<int64_t> MatchContext::Attr<std::vector<int64_t>>(
4646
const std::string&) const;
4747

4848
} // namespace drr
49-
} // namespace pir
49+
} // namespace paddle

paddle/fluid/pir/drr/api/match_context.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
#include "paddle/fluid/pir/drr/api/tensor_interface.h"
2121
#include "paddle/fluid/pir/drr/ir_operation.h"
2222

23-
namespace pir {
23+
namespace paddle {
2424
namespace drr {
2525

2626
class TensorInterface;
@@ -40,4 +40,4 @@ class MatchContext final {
4040
};
4141

4242
} // namespace drr
43-
} // namespace pir
43+
} // namespace paddle

0 commit comments

Comments
 (0)