Skip to content

Commit 86bba6e

Browse files
【Hackathon 5th No.52】 为 Paddle 新增 unsqueeze 的 spmd 切分推导规则 -part (PaddlePaddle#58296)
* add unsqueeze spmd rules * fix bugs * fix bugs * modify the code based on the first review * fix bugs
1 parent 594cc4b commit 86bba6e

File tree

6 files changed

+589
-2
lines changed

6 files changed

+589
-2
lines changed

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License. */
3030
#include "paddle/phi/infermeta/spmd_rules/softmax.h"
3131
#include "paddle/phi/infermeta/spmd_rules/split.h"
3232
#include "paddle/phi/infermeta/spmd_rules/transpose.h"
33+
#include "paddle/phi/infermeta/spmd_rules/unsqueeze.h"
3334

3435
/**
3536
* Design Notes:
@@ -71,7 +72,7 @@ PD_REGISTER_SPMD_RULE(
7172

7273
// default data parallel rule
7374
PD_REGISTER_SPMD_RULE(
74-
unsqueeze,
75+
default_data_parallel,
7576
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmd),
7677
PD_INFER_SPMD(phi::distributed::DefaultDataParallelInferSpmdReverse));
7778
PD_REGISTER_SPMD_RULE(
@@ -85,6 +86,12 @@ PD_REGISTER_SPMD_RULE(
8586
PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmd),
8687
PD_INFER_SPMD(phi::distributed::ReplicatedInferSpmdReverse));
8788

89+
// unsqueeze rule
90+
PD_REGISTER_SPMD_RULE(
91+
unsqueeze,
92+
PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmd),
93+
PD_INFER_SPMD(phi::distributed::UnsqueezeInferSpmdReverse));
94+
8895
// elementwise unary rule
8996
PD_REGISTER_SPMD_RULE(
9097
assign,
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights resized.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/phi/infermeta/spmd_rules/unsqueeze.h"
16+
#include <algorithm>
17+
#include <numeric>
18+
19+
#include "glog/logging.h"
20+
21+
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
22+
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
23+
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
24+
#include "paddle/phi/infermeta/spmd_rules/dim_trans.h"
25+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
26+
27+
namespace phi {
28+
namespace distributed {
29+
30+
using phi::distributed::auto_parallel::str_join;
31+
32+
std::vector<DimTrans*> MakeUnsqueezeDimTrans(
33+
const std::vector<int64_t>& x_shape,
34+
std::vector<int64_t>* out_shape,
35+
const std::vector<int64_t>& axis) {
36+
int64_t n = static_cast<int64_t>(x_shape.size() + axis.size());
37+
std::vector<DimTrans*> ret;
38+
ret.resize(n);
39+
out_shape->resize(n);
40+
fill(ret.begin(), ret.end(), new Singleton());
41+
fill(out_shape->begin(), out_shape->end(), 1);
42+
43+
for (int64_t i = 0, j = 0; i < n; i++) {
44+
auto it = find(axis.begin(), axis.end(), i);
45+
46+
if (it == axis.end()) {
47+
if (x_shape[j] != 1) {
48+
ret[i] = new InputDim(j);
49+
(*out_shape)[i] = x_shape[j];
50+
}
51+
52+
j++;
53+
}
54+
}
55+
56+
return ret;
57+
}
58+
59+
std::vector<DimTrans*> MakeUnsqueezeDimTransReverse(
60+
const std::vector<int64_t>& out_shape,
61+
const std::vector<int64_t>& axis,
62+
const int& x_ndim,
63+
const int& out_ndim) {
64+
std::vector<DimTrans*> ret;
65+
ret.resize(x_ndim);
66+
fill(ret.begin(), ret.end(), new Singleton());
67+
68+
for (int64_t i = 0, j = 0; i < out_ndim; i++) {
69+
auto it = find(axis.begin(), axis.end(), i);
70+
71+
if (it == axis.end()) {
72+
if (out_shape[i] != 1) {
73+
ret[j] = new InputDim(i);
74+
}
75+
76+
j++;
77+
}
78+
}
79+
80+
return ret;
81+
}
82+
83+
SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x,
84+
const std::vector<int64_t>& axis) {
85+
// Step0: Verify input args based on unsqueeze logic
86+
auto x_shape = phi::vectorize(x.dims());
87+
int x_ndim = x_shape.size();
88+
auto x_dist_attr_src = x.dist_attr();
89+
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
90+
PADDLE_ENFORCE_EQ(
91+
x_ndim,
92+
x_dims_mapping.size(),
93+
phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's "
94+
"dims_mapping size [%d] are not matched.",
95+
x_ndim,
96+
x_dims_mapping.size()));
97+
98+
// Step1: Build the transformation from
99+
// the original shape to the target shape
100+
101+
std::vector<int64_t> out_shape;
102+
std::vector<int64_t> axis_copy(axis);
103+
104+
for (int64_t i = 0; i < static_cast<int64_t>(axis_copy.size()); i++) {
105+
if (axis_copy[i] < 0) {
106+
axis_copy[i] += x_ndim + 1;
107+
}
108+
}
109+
110+
std::vector<DimTrans*> trans =
111+
MakeUnsqueezeDimTrans(x_shape, &out_shape, axis_copy);
112+
113+
// Step2: Infer the dims mapping of input (if reshard is
114+
// needed) and output from the dimension transformation.
115+
std::vector<std::vector<int64_t>> dims_mapping_vec =
116+
InferFromDimTrans(x, trans);
117+
118+
// Step3: Update the dist attributes of input
119+
// and output with the inferred dims mapping.
120+
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
121+
x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
122+
TensorDistAttr out_dist_attr(x_dist_attr_src);
123+
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
124+
125+
VLOG(4) << "UnsqueezeInferSpmd: X shape: [" << str_join(x_shape)
126+
<< "] Out shape: [" << str_join(out_shape) << "]";
127+
VLOG(4) << "Transformation from input to output:";
128+
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) {
129+
DimTrans* t = trans[i];
130+
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string();
131+
}
132+
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
133+
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0])
134+
<< "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1])
135+
<< "]\n\n";
136+
137+
CleanUp();
138+
139+
return {{x_dist_attr_dst}, {out_dist_attr}};
140+
}
141+
142+
SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x,
143+
const DistMetaTensor& out,
144+
const std::vector<int64_t>& axis) {
145+
// Step0: Verify input args based on unsqueeze logic
146+
auto x_shape = phi::vectorize(x.dims());
147+
int x_ndim = x_shape.size();
148+
auto out_shape = phi::vectorize(out.dims());
149+
int out_ndim = out_shape.size();
150+
auto out_dist_attr_src = out.dist_attr();
151+
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
152+
PADDLE_ENFORCE_EQ(
153+
out_ndim,
154+
out_dims_mapping.size(),
155+
phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's "
156+
"dims_mapping size [%d] are not matched.",
157+
out_ndim,
158+
out_dims_mapping.size()));
159+
160+
// Step1: Build the transformation from the output shape
161+
// to original shape. This function infers the dims mapping
162+
// from output to input, we first get the transformation
163+
// from output to input so that we can infer the dims mapping
164+
// with the map from output axes to input axes.
165+
166+
std::vector<int64_t> axis_copy(axis);
167+
168+
for (int64_t i = 0; i < static_cast<int64_t>(axis_copy.size()); i++) {
169+
if (axis_copy[i] < 0) {
170+
axis_copy[i] += x_ndim + 1;
171+
}
172+
}
173+
174+
std::vector<DimTrans*> trans =
175+
MakeUnsqueezeDimTransReverse(out_shape, axis_copy, x_ndim, out_ndim);
176+
177+
// Step2: Infer the dims mapping of input with
178+
// output's dims_mapping and the transformation.
179+
std::vector<std::vector<int64_t>> dims_mapping_vec =
180+
InferFromDimTrans(out, trans);
181+
182+
// Step3: Update the dist attributes of input
183+
// and output with the inferred dims mapping
184+
TensorDistAttr out_dist_attr_dst(out_dist_attr_src);
185+
out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
186+
TensorDistAttr x_dist_attr(x.dist_attr());
187+
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
188+
189+
VLOG(4) << "UnsqueezeInferSpmdReverse: Out shape: [" << str_join(out_shape)
190+
<< "] X shape: [" << str_join(x_shape) << "]";
191+
VLOG(4) << "Transformation from output to input:";
192+
for (int64_t i = 0, n = trans.size(); i < n; i++) {
193+
DimTrans* t = trans[i];
194+
VLOG(4) << "\tX axis[" << i << "]: " << t->to_string();
195+
}
196+
VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] "
197+
<< "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
198+
VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";
199+
200+
CleanUp();
201+
202+
return {{x_dist_attr}, {out_dist_attr_dst}};
203+
}
204+
205+
} // namespace distributed
206+
} // namespace phi
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
19+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
20+
#include "paddle/phi/core/distributed/type_defs.h"
21+
22+
namespace phi {
23+
namespace distributed {
24+
25+
SpmdInfo UnsqueezeInferSpmd(const DistMetaTensor& x,
26+
const std::vector<int64_t>& axis);
27+
28+
SpmdInfo UnsqueezeInferSpmdReverse(const DistMetaTensor& x,
29+
const DistMetaTensor& out,
30+
const std::vector<int64_t>& axis);
31+
} // namespace distributed
32+
} // namespace phi

test/auto_parallel/spmd_rules/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ if(WITH_DISTRIBUTE)
2020
py_test_modules(test_layer_norm_rule MODULES test_layer_norm_rule)
2121
py_test_modules(test_slice_rule MODULES test_slice_rule)
2222
py_test_modules(test_flatten_rule MODULES test_flatten_rule)
23+
py_test_modules(test_unsqueeze_rule MODULES test_unsqueeze_rule)
2324
py_test_modules(test_concat_rule MODULES test_concat_rule)
2425
# End of unittests WITH single card WITHOUT timeout
2526

test/auto_parallel/spmd_rules/test_default_data_parallel_rule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TestDefaultDataParallelSPMDRule(unittest.TestCase):
2626
def setUp(self):
2727
# After replaced all spmd rules by phi impl, we can recover the
2828
# api name to `get_spmd_rule`
29-
self.rule = core.get_phi_spmd_rule("unsqueeze")
29+
self.rule = core.get_phi_spmd_rule("default_data_parallel")
3030

3131
x_shape = [10, 10, 32, 48]
3232
y_shape = [32, 48]

0 commit comments

Comments
 (0)