-
Couldn't load subscription status.
- Fork 5.9k
[auto parallel] Lazy init for MP. Add reshard infer shape. #60563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,7 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import logging | ||
| import os | ||
|
|
||
| import paddle | ||
|
|
@@ -34,6 +34,11 @@ def __init__(self): | |
| self._mesh_bias = dist.ProcessMesh([1], dim_names=["x"]) | ||
| self._placements_weight = [dist.Replicate()] | ||
| self._placements_bias = [dist.Replicate()] | ||
| elif self._placements_type == "MP": | ||
| self._mesh_weight = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||
| self._mesh_bias = dist.ProcessMesh([0, 1], dim_names=["x"]) | ||
| self._placements_weight = [dist.Shard(1)] | ||
| self._placements_bias = [dist.Shard(0)] | ||
|
|
||
| def test_different_xavier(self): | ||
| paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) | ||
|
|
@@ -53,6 +58,31 @@ def test_different_xavier(self): | |
| linear.bias = dist.shard_tensor( | ||
| linear.bias, self._mesh_bias, self._placements_bias | ||
| ) | ||
| for param in linear.parameters(): | ||
| param.initialize() | ||
| logging.info(param) | ||
|
|
||
| def test_constant(self): | ||
| paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) | ||
| weight_attr = paddle.framework.ParamAttr( | ||
| initializer=paddle.nn.initializer.Constant(2.0) | ||
| ) | ||
| bias_attr = paddle.framework.ParamAttr( | ||
| initializer=paddle.nn.initializer.Constant(1.0) | ||
| ) | ||
| with LazyGuard(): | ||
| linear = paddle.nn.Linear( | ||
| 10, 10, weight_attr=weight_attr, bias_attr=bias_attr | ||
| ) | ||
| linear.weight = dist.shard_tensor( | ||
| linear.weight, self._mesh_weight, self._placements_weight | ||
| ) | ||
| linear.bias = dist.shard_tensor( | ||
| linear.bias, self._mesh_bias, self._placements_bias | ||
| ) | ||
| for param in linear.parameters(): | ||
| param.initialize() | ||
| logging.info(param) | ||
|
|
||
| def test_placements(self): | ||
| paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) | ||
|
|
@@ -67,6 +97,7 @@ def test_placements(self): | |
| for param in linear.parameters(): | ||
| assert not param._is_initialized() | ||
| param.initialize() | ||
| logging.info(param) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的logging是为了对齐什么呢 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 可以触发一个神奇的bug,lazy init下找不到对应reshard function的bug |
||
|
|
||
| if self._placements_type == "DP": | ||
| assert linear.weight._is_initialized() | ||
|
|
@@ -93,10 +124,40 @@ def test_placements(self): | |
| else: | ||
| assert not linear.weight._is_initialized() | ||
| assert linear.bias._is_initialized() | ||
| elif self._placements_type == "MP": | ||
| assert linear.weight._is_initialized() | ||
| assert linear.bias._is_initialized() | ||
| assert linear.weight._local_shape == [10, 5] | ||
| assert linear.bias._local_shape == [5] | ||
|
|
||
| def test_unbalance_mp(self): | ||
| paddle.distributed.auto_parallel.parallel_manual_seed(self._seed) | ||
| with LazyGuard(): | ||
| linear = paddle.nn.Linear(11, 11) | ||
| linear.weight = dist.shard_tensor( | ||
| linear.weight, self._mesh_weight, self._placements_weight | ||
| ) | ||
| linear.bias = dist.shard_tensor( | ||
| linear.bias, self._mesh_bias, self._placements_bias | ||
| ) | ||
| for param in linear.parameters(): | ||
| assert not param._is_initialized() | ||
| param.initialize() | ||
| assert param._is_initialized() | ||
|
|
||
| if dist.get_rank() == 0: | ||
| assert linear.weight._local_shape == [11, 6] | ||
| assert linear.bias._local_shape == [6] | ||
| else: | ||
| assert linear.weight._local_shape == [11, 5] | ||
| assert linear.bias._local_shape == [5] | ||
|
|
||
| def run_test_case(self): | ||
| self.test_placements() | ||
| self.test_different_xavier() | ||
| self.test_constant() | ||
| if self._placements_type == "MP": | ||
| self.test_unbalance_mp() | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面那个dist_attr的从global_value初始化的需要支持lazy init吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
另一个方法调用不到吧