Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,26 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
// uninitialized tensor only with dist_tensor_meta_.
if (IsCurRankInMesh(process_mesh)) {
if (!dist_attr_.is_replicated()) {
value_ = std::make_shared<DenseTensor>();
// 1. create replicated global tensor
TensorDistAttr replicated_dist_attr(
common::vectorize(global_value->dims()));
replicated_dist_attr.set_process_mesh(process_mesh);
DistTensor replicated_tensor(global_value, replicated_dist_attr);

// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr_);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value->place());
func->Eval(dev_ctx, replicated_tensor, dist_attr_, this);
if (global_value->initialized()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面那个dist_attr的从global_value初始化的需要支持lazy init吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另一个方法调用不到吧

value_ = std::make_shared<DenseTensor>();
// 1. create replicated global tensor
TensorDistAttr replicated_dist_attr(
common::vectorize(global_value->dims()));
replicated_dist_attr.set_process_mesh(process_mesh);
DistTensor replicated_tensor(global_value, replicated_dist_attr);

// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr_);
auto* dev_ctx =
DeviceContextPool::Instance().Get(global_value->place());
func->Eval(dev_ctx, replicated_tensor, dist_attr_, this);
} else {
// For lazy init, the global value is an uninitialized tensor.
// Just infer the local shape of the dist tensor.
value_ = global_value;
value_->Resize(
InferShapeForReshardFromReplicate(global_value, dist_attr_));
}
} else {
value_ = global_value;
}
Expand Down
23 changes: 23 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,5 +180,28 @@ phi::DeviceContext* GetDistTensorDeviceContext(
return phi::DeviceContextPool::Instance().Get(place);
}

phi::DDim InferShapeForReshardFromReplicate(
const std::shared_ptr<phi::DenseTensor>& global_value,
const TensorDistAttr& dist_attr) {
phi::DDim out_dim = global_value->dims();
auto coord_id = GetCurRankCoordInMesh(dist_attr.process_mesh());
for (int tensor_axis = 0; tensor_axis < global_value->dims().size();
++tensor_axis) {
if (dist_attr.is_shard(-1, tensor_axis)) {
for (int mesh_axis = 0; mesh_axis < dist_attr.process_mesh().ndim();
++mesh_axis) {
if (dist_attr.is_shard(mesh_axis, tensor_axis)) {
// handle the shard axis
int64_t global_shape = out_dim[tensor_axis];
int64_t mesh_size = dist_attr.process_mesh().dim_size(mesh_axis);
auto balance_shard = BalancedSplit(global_shape, mesh_size);
out_dim[tensor_axis] = balance_shard[coord_id[mesh_axis]];
}
}
}
}
return out_dim;
}

} // namespace distributed
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ std::vector<int64_t> BalancedSplit(int64_t total_nums, int64_t num_of_pieces);
CommContext* CreateOrGetCommContext(const DeviceContext& dev_ctx,
const std::vector<int64_t>& process_ids);

phi::DDim InferShapeForReshardFromReplicate(
const std::shared_ptr<phi::DenseTensor>& global_value,
const TensorDistAttr& dist_attr);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#define RESHARD_FUNCTOR_IMPL(dev_ctx, fn_name, dtype, ...) \
do { \
Expand Down
17 changes: 14 additions & 3 deletions python/paddle/nn/initializer/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,20 @@ def forward(self, var, block=None):
if self._force_cpu:
place = core.CPUPlace()
if in_dygraph_mode():
_C_ops.full_(
var, var.shape, float(self._value), var.dtype, place
)
if isinstance(var, framework.EagerParamBase) and var.is_dist():
out_var = _C_ops.full(
var._local_shape, float(self._value), var.dtype, place
)
out_var = (
paddle.distributed.auto_parallel.api.dtensor_from_local(
out_var, var.process_mesh, var.placements
)
)
out_var._share_underline_tensor_to(var)
else:
_C_ops.full_(
var, var.shape, float(self._value), var.dtype, place
)
return None
else:
return _C_ops.full(
Expand Down
63 changes: 62 additions & 1 deletion test/auto_parallel/semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import paddle
Expand All @@ -34,6 +34,11 @@ def __init__(self):
self._mesh_bias = dist.ProcessMesh([1], dim_names=["x"])
self._placements_weight = [dist.Replicate()]
self._placements_bias = [dist.Replicate()]
elif self._placements_type == "MP":
self._mesh_weight = dist.ProcessMesh([0, 1], dim_names=["x"])
self._mesh_bias = dist.ProcessMesh([0, 1], dim_names=["x"])
self._placements_weight = [dist.Shard(1)]
self._placements_bias = [dist.Shard(0)]

def test_different_xavier(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
Expand All @@ -53,6 +58,31 @@ def test_different_xavier(self):
linear.bias = dist.shard_tensor(
linear.bias, self._mesh_bias, self._placements_bias
)
for param in linear.parameters():
param.initialize()
logging.info(param)

def test_constant(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
weight_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(2.0)
)
bias_attr = paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(1.0)
)
with LazyGuard():
linear = paddle.nn.Linear(
10, 10, weight_attr=weight_attr, bias_attr=bias_attr
)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh_weight, self._placements_weight
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh_bias, self._placements_bias
)
for param in linear.parameters():
param.initialize()
logging.info(param)

def test_placements(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
Expand All @@ -67,6 +97,7 @@ def test_placements(self):
for param in linear.parameters():
assert not param._is_initialized()
param.initialize()
logging.info(param)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的logging是为了对齐什么呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以触发一个神奇的bug,lazy init下找不到对应reshard function的bug


if self._placements_type == "DP":
assert linear.weight._is_initialized()
Expand All @@ -93,10 +124,40 @@ def test_placements(self):
else:
assert not linear.weight._is_initialized()
assert linear.bias._is_initialized()
elif self._placements_type == "MP":
assert linear.weight._is_initialized()
assert linear.bias._is_initialized()
assert linear.weight._local_shape == [10, 5]
assert linear.bias._local_shape == [5]

def test_unbalance_mp(self):
paddle.distributed.auto_parallel.parallel_manual_seed(self._seed)
with LazyGuard():
linear = paddle.nn.Linear(11, 11)
linear.weight = dist.shard_tensor(
linear.weight, self._mesh_weight, self._placements_weight
)
linear.bias = dist.shard_tensor(
linear.bias, self._mesh_bias, self._placements_bias
)
for param in linear.parameters():
assert not param._is_initialized()
param.initialize()
assert param._is_initialized()

if dist.get_rank() == 0:
assert linear.weight._local_shape == [11, 6]
assert linear.bias._local_shape == [6]
else:
assert linear.weight._local_shape == [11, 5]
assert linear.bias._local_shape == [5]

def run_test_case(self):
self.test_placements()
self.test_different_xavier()
self.test_constant()
if self._placements_type == "MP":
self.test_unbalance_mp()


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion test/auto_parallel/test_semi_auto_parallel_lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def setUp(self):
}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
"_placements_type": ["DP", "PP"],
"_placements_type": ["DP", "PP", "MP"],
}

def test_lazy_init(self):
Expand Down