Skip to content

Commit af2fc67

Browse files
committed
Lazy init for MP. Add reshard infer shape.
1 parent 9018371 commit af2fc67

File tree

6 files changed

+123
-16
lines changed

6 files changed

+123
-16
lines changed

paddle/phi/core/distributed/auto_parallel/dist_tensor.cc

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -174,17 +174,25 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
174174
// uninitialized tensor only with dist_tensor_meta_.
175175
if (IsCurRankInMesh(process_mesh)) {
176176
if (!dist_attr_.is_replicated()) {
177-
value_ = std::make_shared<DenseTensor>();
178-
// 1. create replicated global tensor
179-
TensorDistAttr replicated_dist_attr(
180-
common::vectorize(global_value->dims()));
181-
replicated_dist_attr.set_process_mesh(process_mesh);
182-
DistTensor replicated_tensor(global_value, replicated_dist_attr);
183-
184-
// 2. reshard from replicated to other state
185-
auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr_);
186-
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value->place());
187-
func->Eval(dev_ctx, replicated_tensor, dist_attr_, this);
177+
if (global_value->initialized()) {
178+
value_ = std::make_shared<DenseTensor>();
179+
// 1. create replicated global tensor
180+
TensorDistAttr replicated_dist_attr(
181+
common::vectorize(global_value->dims()));
182+
replicated_dist_attr.set_process_mesh(process_mesh);
183+
DistTensor replicated_tensor(global_value, replicated_dist_attr);
184+
185+
// 2. reshard from replicated to other state
186+
auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr_);
187+
auto* dev_ctx =
188+
DeviceContextPool::Instance().Get(global_value->place());
189+
func->Eval(dev_ctx, replicated_tensor, dist_attr_, this);
190+
} else {
191+
// For lazy init, the global value is an uninitialized tensor.
192+
// Just infer the local shape of the dist tensor.
193+
value_ = global_value;
194+
value_->Resize(ReshardInferShape(global_value, dist_attr_));
195+
}
188196
} else {
189197
value_ = global_value;
190198
}

paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,5 +180,28 @@ phi::DeviceContext* GetDistTensorDeviceContext(
180180
return phi::DeviceContextPool::Instance().Get(place);
181181
}
182182

183+
phi::DDim ReshardInferShape(
184+
const std::shared_ptr<phi::DenseTensor>& global_value,
185+
const TensorDistAttr& dist_attr) {
186+
phi::DDim out_dim = global_value->dims();
187+
auto coord_id = GetCurRankCoordInMesh(dist_attr.process_mesh());
188+
for (int tensor_axis = 0; tensor_axis < global_value->dims().size();
189+
++tensor_axis) {
190+
if (dist_attr.is_shard(-1, tensor_axis)) {
191+
for (int mesh_axis = 0; mesh_axis < dist_attr.process_mesh().ndim();
192+
++mesh_axis) {
193+
if (dist_attr.is_shard(mesh_axis, tensor_axis)) {
194+
// handle the shard axis
195+
int64_t global_shape = out_dim[tensor_axis];
196+
int64_t mesh_size = dist_attr.process_mesh().dim_size(mesh_axis);
197+
auto balance_shard = BalancedSplit(global_shape, mesh_size);
198+
out_dim[tensor_axis] = balance_shard[coord_id[mesh_axis]];
199+
}
200+
}
201+
}
202+
}
203+
return out_dim;
204+
}
205+
183206
} // namespace distributed
184207
} // namespace phi

paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces);
7171
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
7272
const std::vector<int64_t>& process_ids);
7373

74+
phi::DDim ReshardInferShape(
75+
const std::shared_ptr<phi::DenseTensor>& global_value,
76+
const TensorDistAttr& dist_attr);
77+
7478
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
7579
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
7680
do { \

python/paddle/nn/initializer/constant.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,20 @@ def forward(self, var, block=None):
7272
if self._force_cpu:
7373
place = core.CPUPlace()
7474
if in_dygraph_mode():
75-
_C_ops.full_(
76-
var, var.shape, float(self._value), var.dtype, place
77-
)
75+
if isinstance(var, framework.EagerParamBase) and var.is_dist():
76+
out_var = _C_ops.full(
77+
var._local_shape, float(self._value), var.dtype, place
78+
)
79+
out_var = (
80+
paddle.distributed.auto_parallel.api.dtensor_from_local(
81+
out_var, var.process_mesh, var.placements
82+
)
83+
)
84+
out_var._share_underline_tensor_to(var)
85+
else:
86+
_C_ops.full_(
87+
var, var.shape, float(self._value), var.dtype, place
88+
)
7889
return None
7990
else:
8091
return _C_ops.full(

test/auto_parallel/semi_auto_parallel_lazy_init.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
import logging
1515
import os
1616

1717
import paddle
@@ -34,6 +34,11 @@ def __init__(self):
3434
self._mesh_bias = dist.ProcessMesh([1], dim_names=["x"])
3535
self._placements_weight = [dist.Replicate()]
3636
self._placements_bias = [dist.Replicate()]
37+
elif self._placements_type == "MP":
38+
self._mesh_weight = dist.ProcessMesh([0, 1], dim_names=["x"])
39+
self._mesh_bias = dist.ProcessMesh([0, 1], dim_names=["x"])
40+
self._placements_weight = [dist.Shard(1)]
41+
self._placements_bias = [dist.Shard(0)]
3742

3843
def test_different_xavier(self):
3944
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
@@ -53,6 +58,31 @@ def test_different_xavier(self):
5358
linear.bias = dist.shard_tensor(
5459
linear.bias, self._mesh_bias, self._placements_bias
5560
)
61+
for param in linear.parameters():
62+
param.initialize()
63+
logging.info(param)
64+
65+
def test_constant(self):
66+
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
67+
weight_attr = paddle.framework.ParamAttr(
68+
initializer=paddle.nn.initializer.Constant(2.0)
69+
)
70+
bias_attr = paddle.framework.ParamAttr(
71+
initializer=paddle.nn.initializer.Constant(1.0)
72+
)
73+
with LazyGuard():
74+
linear = paddle.nn.Linear(
75+
10, 10, weight_attr=weight_attr, bias_attr=bias_attr
76+
)
77+
linear.weight = dist.shard_tensor(
78+
linear.weight, self._mesh_weight, self._placements_weight
79+
)
80+
linear.bias = dist.shard_tensor(
81+
linear.bias, self._mesh_bias, self._placements_bias
82+
)
83+
for param in linear.parameters():
84+
param.initialize()
85+
logging.info(param)
5686

5787
def test_placements(self):
5888
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
@@ -67,6 +97,7 @@ def test_placements(self):
6797
for param in linear.parameters():
6898
assert not param._is_initialized()
6999
param.initialize()
100+
logging.info(param)
70101

71102
if self._placements_type == "DP":
72103
assert linear.weight._is_initialized()
@@ -93,10 +124,40 @@ def test_placements(self):
93124
else:
94125
assert not linear.weight._is_initialized()
95126
assert linear.bias._is_initialized()
127+
elif self._placements_type == "MP":
128+
assert linear.weight._is_initialized()
129+
assert linear.bias._is_initialized()
130+
assert linear.weight._local_shape == [10, 5]
131+
assert linear.bias._local_shape == [5]
132+
133+
def test_unbalance_mp(self):
134+
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
135+
with LazyGuard():
136+
linear = paddle.nn.Linear(11, 11)
137+
linear.weight = dist.shard_tensor(
138+
linear.weight, self._mesh_weight, self._placements_weight
139+
)
140+
linear.bias = dist.shard_tensor(
141+
linear.bias, self._mesh_bias, self._placements_bias
142+
)
143+
for param in linear.parameters():
144+
assert not param._is_initialized()
145+
param.initialize()
146+
assert param._is_initialized()
147+
148+
if dist.get_rank() == 0:
149+
assert linear.weight._local_shape == [11, 6]
150+
assert linear.bias._local_shape == [6]
151+
else:
152+
assert linear.weight._local_shape == [11, 5]
153+
assert linear.bias._local_shape == [5]
96154

97155
def run_test_case(self):
98156
self.test_placements()
99157
self.test_different_xavier()
158+
self.test_constant()
159+
if self._placements_type == "MP":
160+
self.test_unbalance_mp()
100161

101162

102163
if __name__ == '__main__':

test/auto_parallel/test_semi_auto_parallel_lazy_init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def setUp(self):
2929
}
3030
self._changeable_envs = {
3131
"backend": ["cpu", "gpu"],
32-
"_placements_type": ["DP", "PP"],
32+
"_placements_type": ["DP", "PP", "MP"],
3333
}
3434

3535
def test_lazy_init(self):

0 commit comments

Comments
 (0)