Skip to content

Commit 9d5cd38

Browse files
[AutoParallel] Add op_role_guard for PIR (#68796)
* c++ end * python api * unitest * style * optimize graph compile time * code style * style * revise name * revise name * update style * unitest cmake --------- Co-authored-by: winter-wang <[email protected]>
1 parent 7a5db03 commit 9d5cd38

File tree

9 files changed

+198
-10
lines changed

9 files changed

+198
-10
lines changed

paddle/fluid/pir/dialect/operator/ir/api_builder.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class ApiBuilder {
7575
// pop the insertion point and set it to the current insertion point.
7676
void LoadInsertionPoint();
7777

78+
void SetOpRole(int op_role) { builder_->set_op_role(op_role); }
79+
int GetOpRole() const { return builder_->op_role(); }
80+
7881
private:
7982
ApiBuilder();
8083

paddle/fluid/pybind/pir.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2132,6 +2132,9 @@ void BindUtils(pybind11::module *m) {
21322132
[]() { ApiBuilder::Instance().ResetInsertionPointToStart(); });
21332133
m->def("reset_insertion_point_to_end",
21342134
[]() { ApiBuilder::Instance().ResetInsertionPointToEnd(); });
2135+
m->def("set_op_role",
2136+
[](int op_role) { ApiBuilder::Instance().SetOpRole(op_role); });
2137+
m->def("get_op_role", []() { return ApiBuilder::Instance().GetOpRole(); });
21352138
m->def("register_paddle_dialect", []() {
21362139
pir::IrContext::Instance()
21372140
->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();

paddle/pir/include/core/builder.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,9 @@ class Builder {
115115
common::errors::PreconditionNotMet("argument of block is nullptr"));
116116
set_insertion_point(block, block->end());
117117
}
118-
118+
/// Set/Get the op_role
119+
void set_op_role(int op_role) { op_role_ = op_role; }
120+
int op_role() const { return op_role_; }
119121
IrContext *ir_context() const { return context_; }
120122

121123
Block *block() const { return insertion_point_.first; }
@@ -172,6 +174,10 @@ class Builder {
172174
InsertionPoint insertion_point_;
173175

174176
bool forbid_insert_without_position_;
177+
178+
// by now the op_role is used by autoparallel for distinguish the op in fw,
179+
// bw, opt region.
180+
int op_role_ = -1;
175181
};
176182

177183
template <typename OpTy, typename... Args>

paddle/pir/src/core/builder.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@
2323
namespace pir {
2424
/// Create an operation given the fields represented as an OperationState.
2525
Operation *Builder::Build(OperationArgument &&argument) {
26-
return Insert(Operation::Create(std::move(argument)));
26+
Operation *op = Insert(Operation::Create(std::move(argument)));
27+
// TODO(ljz): Generalize here to be a hook function in the future.
28+
// we add op_role attribute only when it is not equal to -1.
29+
if (op_role_ != -1) {
30+
op->set_attribute("op_role", Int32Attribute::get(context_, op_role_));
31+
}
32+
return op;
2733
}
2834

2935
/// Creates an operation with the given fields.

python/paddle/base/framework.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8461,3 +8461,19 @@ def set_op_roles(block, op_role, always_forward_ops):
84618461
if paddle.framework.in_pir_mode() and is_dist_block(block):
84628462
always_forward_ops = ["pd_op.data", "builtin.parameter"]
84638463
set_op_roles(block, op_role, always_forward_ops)
8464+
8465+
8466+
# set op when op_role when it is add by apibuilder
8467+
# pir_op_role_guard could not distinguish "always_forward_ops", therefore if
8468+
# there would be always_forward_ops in your region, you should use "auto_complete_op_role"
8469+
@signature_safe_contextmanager
8470+
def pir_op_role_guard(op_role: int - 1) -> Generator[None, None, None]:
8471+
8472+
if paddle.framework.in_pir_mode():
8473+
original_op_rope = pir.get_op_role()
8474+
pir.set_op_role(op_role)
8475+
try:
8476+
yield
8477+
finally:
8478+
if paddle.framework.in_pir_mode():
8479+
pir.set_op_role(original_op_rope)

python/paddle/distributed/auto_parallel/static/pir_pass.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import paddle
2121
from paddle import pir
2222
from paddle.autograd.backward_utils import ValueDict
23-
from paddle.base.framework import auto_complete_op_role
23+
from paddle.base.framework import pir_op_role_guard
2424
from paddle.base.log_helper import get_logger
2525
from paddle.distributed.fleet.meta_optimizers.common import OpRole
2626
from paddle.distributed.passes.pass_base import PassContext, new_pass
@@ -65,7 +65,7 @@ def reshard_single_value(program, op, operand, attr):
6565
if prev_var.is_dist() and prev_var.dist_attr() != attr:
6666
operand_attr = attr.as_tensor_dist_attr()
6767
paddle.pir.set_insertion_point(op)
68-
with auto_complete_op_role(program, op.op_role):
68+
with pir_op_role_guard(op.op_role):
6969
# fold reshard
7070
if prev_var.get_defining_op().name() == 'dist_op.reshard':
7171
prev_reshard = prev_var.get_defining_op()
@@ -101,7 +101,7 @@ def reshard_combine_value(program, op, operand, attr):
101101
)
102102

103103
paddle.pir.set_insertion_point(op)
104-
with auto_complete_op_role(program, op.op_role):
104+
with pir_op_role_guard(op.op_role):
105105
combine_value = paddle._C_ops.builtin_combine(reshard_vars)
106106
return combine_value
107107

@@ -136,7 +136,7 @@ def apply_partition_pass(program):
136136

137137
# reshard input
138138
paddle.pir.set_insertion_point(op)
139-
with auto_complete_op_role(program, ref_op_role):
139+
with pir_op_role_guard(ref_op_role):
140140
reshard_var = paddle._C_ops.reshard_v2(prev_var, operand_attr)
141141
operand.set_source(reshard_var)
142142

@@ -151,7 +151,7 @@ def apply_partition_pass(program):
151151
old_dist_attr = result.dist_attr()
152152
result.update_dist_attr(result_attr)
153153

154-
with auto_complete_op_role(program, ref_op_role):
154+
with pir_op_role_guard(ref_op_role):
155155
prev_op = prev_var.get_defining_op()
156156

157157
# reshard output to assign out input
@@ -171,7 +171,7 @@ def apply_partition_pass(program):
171171

172172
reshard_var_2 = reshard_var_1
173173
if old_dist_attr != reshard_var_1.dist_attr():
174-
with auto_complete_op_role(program, ref_op_role):
174+
with pir_op_role_guard(ref_op_role):
175175
reshard_var_2 = paddle._C_ops.reshard_v2(
176176
result, old_dist_attr
177177
)
@@ -201,7 +201,7 @@ def apply_partition_pass(program):
201201
var.update_dist_attr(attr.as_tensor_dist_attr())
202202

203203
# insert reshard
204-
with auto_complete_op_role(program, op.op_role):
204+
with pir_op_role_guard(op.op_role):
205205
reshard_var = paddle._C_ops.reshard_v2(var, old_dist_attr)
206206
var.replace_all_uses_with(reshard_var)
207207
reshard_var.get_defining_op().operand(0).set_source(var)
@@ -266,7 +266,7 @@ def reshard_op_pass(dist_program, params_grads=[]):
266266
paddle.pir.set_insertion_point(op)
267267
ref_op_role = op.op_role
268268

269-
with auto_complete_op_role(dist_program, ref_op_role):
269+
with pir_op_role_guard(ref_op_role):
270270
out_value = reshard_func.reshard(
271271
src_dist_attr,
272272
dst_dist_attr,

python/paddle/pir/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
create_shaped_type,
2727
fake_value,
2828
get_current_insertion_point,
29+
get_op_role,
2930
is_fake_value,
3031
parse_program,
3132
register_dist_dialect,
@@ -35,6 +36,7 @@
3536
set_insertion_point,
3637
set_insertion_point_after,
3738
set_insertion_point_to_block_end,
39+
set_op_role,
3840
translate_to_pir,
3941
translate_to_pir_with_param_map,
4042
)

test/auto_parallel/pir/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
55
py_test_modules(test_ir_dist_attr MODULES test_ir_dist_attr ENVS
66
FLAGS_enable_pir_api=1)
77
py_test_modules(test_static_pir_program MODULES test_static_pir_program)
8+
py_test_modules(test_op_role MODULES test_op_role)
9+
set_tests_properties(test_op_role PROPERTIES ENVIRONMENT
10+
"FLAGS_enable_pir_api=1")
811
py_test_modules(test_pir_elementwise_spmd MODULES test_elementwise_spmd_rule
912
ENVS FLAGS_enable_pir_api=1)
1013
py_test_modules(test_pir_relu_spmd MODULES test_relu_spmd_rule ENVS
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
import paddle.distributed as dist
19+
from paddle.base.framework import auto_complete_op_role, pir_op_role_guard
20+
from paddle.distributed import Replicate, Shard
21+
from paddle.distributed.auto_parallel.static.mix_to_dist_pass import (
22+
apply_mix2dist_pass,
23+
)
24+
from paddle.distributed.auto_parallel.static.pir_pass import (
25+
ReshardPasses,
26+
apply_partition_pass,
27+
)
28+
29+
30+
class TestOpRole(unittest.TestCase):
31+
def test_signle(self):
32+
paddle.enable_static()
33+
with paddle.pir_utils.IrGuard():
34+
main_program = paddle.base.Program()
35+
with paddle.base.program_guard(main_program):
36+
37+
# op_role = -1
38+
x0 = paddle.static.data(name='x0', shape=[1, 128, 512])
39+
x1 = paddle.nn.functional.relu(x0)
40+
x2 = paddle.nn.functional.relu(x1)
41+
42+
with pir_op_role_guard(1):
43+
y0 = paddle.static.data(name='y0', shape=[1, 128, 512])
44+
y1 = paddle.nn.functional.relu(y0)
45+
z0 = paddle.add(y1, x2)
46+
z0 = z0 * 3.0
47+
with pir_op_role_guard(3):
48+
z1 = paddle.nn.functional.relu(z0)
49+
z2 = paddle.add(y0, z1)
50+
z4 = paddle.split(z0, num_or_sections=[8, 100, 20], axis=1)
51+
52+
with pir_op_role_guard(0):
53+
z3 = paddle.add(y1, z2)
54+
55+
# op_role = -1
56+
z4 = paddle.add(y0, z3)
57+
58+
# check global shape
59+
std_ops = [
60+
"pd_op.data:-1",
61+
"pd_op.data:1",
62+
"pd_op.relu:-1",
63+
"pd_op.relu:-1",
64+
"pd_op.relu:1",
65+
"pd_op.add:1",
66+
"pd_op.full:1",
67+
"pd_op.scale:1",
68+
"pd_op.relu:3",
69+
"pd_op.add:3",
70+
"pd_op.full_int_array:3",
71+
"pd_op.full:3",
72+
"pd_op.split:3",
73+
"builtin.split:3",
74+
"pd_op.add:0",
75+
"pd_op.add:-1",
76+
]
77+
78+
cur_ops = [
79+
f"{op.name()}:{op.op_role}"
80+
for op in main_program.global_block().ops
81+
]
82+
self.assertEqual(cur_ops, std_ops)
83+
84+
def test_dist(self):
85+
paddle.enable_static()
86+
87+
mesh = dist.ProcessMesh([[0, 1], [2, 3]], dim_names=["x", "y"])
88+
89+
with paddle.pir_utils.IrGuard():
90+
main_program = paddle.base.Program()
91+
with paddle.base.program_guard(main_program):
92+
93+
with auto_complete_op_role(main_program, 0):
94+
x0 = paddle.static.data(name='x0', shape=[1, 128, 512])
95+
x0 = dist.shard_tensor(
96+
x0, mesh, [Shard(1), Replicate()], stop_gradient=False
97+
)
98+
x1 = x0 / 2.0
99+
100+
with pir_op_role_guard(3):
101+
x2 = dist.reshard(x1, mesh, [Shard(2), Replicate()])
102+
with pir_op_role_guard(1):
103+
x3 = dist.reshard(
104+
x2, mesh, [Replicate(), Replicate()]
105+
)
106+
x4 = dist.reshard(x3, mesh, [Shard(1), Replicate()])
107+
108+
x5 = dist.reshard(x4, mesh, [Replicate(), Replicate()])
109+
110+
apply_mix2dist_pass(main_program)
111+
apply_partition_pass(main_program)
112+
ReshardPasses.apply_reshard_pass(main_program, [])
113+
114+
std_ops = [
115+
'pd_op.data:0',
116+
'pd_op.full:0',
117+
'pd_op.scale:0',
118+
'pd_op.all_gather:3',
119+
'pd_op.full:3',
120+
'pd_op.split_with_num:3',
121+
'pd_op.full:3',
122+
'pd_op.concat:3',
123+
'pd_op.full_int_array:3',
124+
'pd_op.full_int_array:3',
125+
'pd_op.slice:3',
126+
'pd_op.all_gather:1',
127+
'pd_op.full:1',
128+
'pd_op.split_with_num:1',
129+
'pd_op.full:1',
130+
'pd_op.concat:1',
131+
'pd_op.full_int_array:3',
132+
'pd_op.full_int_array:3',
133+
'pd_op.slice:3',
134+
'pd_op.all_gather:0',
135+
'pd_op.full:0',
136+
'pd_op.split_with_num:0',
137+
'pd_op.full:0',
138+
'pd_op.concat:0',
139+
]
140+
141+
cur_ops = [
142+
f"{op.name()}:{op.op_role}"
143+
for op in main_program.global_block().ops
144+
]
145+
self.assertEqual(cur_ops, std_ops)
146+
147+
148+
if __name__ == "__main__":
149+
unittest.main()

0 commit comments

Comments
 (0)