Skip to content

Commit 0187c10

Browse files
committed
add doc and example codes
1 parent 8296d9e commit 0187c10

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

python/paddle/distributed/auto_parallel/local_layer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,58 @@
2626

2727

2828
class LocalLayer(Layer):
29+
"""
30+
The `LocalLayer` class is a specialized `Layer` for managing distributed tensors during
31+
forward and backward passes in a parallelized training environment. It converts distributed tensors
32+
to local tensors for computation and then back to distributed tensors as output, ensuring seamless
33+
integration with distributed parallelism frameworks.
34+
35+
Args:
36+
out_dist_attrs (list[tuple[ProcessMesh, list[Placement]]]):
37+
A list where each entry is a tuple containing the `ProcessMesh` and the list of `Placement`
38+
attributes for the corresponding output tensors. These attributes define the distribution
39+
strategy for the outputs.
40+
41+
Examples:
42+
.. code-block:: python
43+
44+
import paddle
45+
import paddle.distributed as dist
46+
from paddle import nn
47+
48+
class CustomLayer(LocalLayer):
49+
def __init__(self, mesh):
50+
super().__init__(
51+
out_dist_attrs=[(mesh, [dist.Partial(dist.ReduceType.kRedSum)])]
52+
)
53+
self.fc = nn.Linear(16, 8)
54+
55+
def forward(self, x):
56+
return self.fc(x)
57+
58+
# doctest: +REQUIRES(env:DISTRIBUTED)
59+
mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
60+
custom_layer = CustomLayer(mesh)
61+
input_tensor = dist.auto_parallel.api.dtensor_from_local(
62+
paddle.randn([4, 16]), mesh, [dist.Replicate()]
63+
)
64+
65+
output_tensor = custom_layer(input_tensor)
66+
print(output_tensor)
67+
"""
68+
2969
def __init__(
3070
self, out_dist_attrs: list[tuple[ProcessMesh, list[Placement]]]
3171
):
3272
super().__init__()
3373
self.out_dist_attrs = out_dist_attrs
3474

3575
def __call__(self, *inputs: Any, **kwargs: Any) -> Any:
76+
"""
77+
Overrides the base `Layer`'s `__call__` method. Transforms distributed tensors to local tensors
78+
before computation, invokes the parent class's `__call__` method, and then transforms the
79+
outputs back to distributed tensors based on the specified distribution attributes.
80+
"""
3681
inputs = list(inputs)
3782
for idx in range(len(inputs)):
3883
if inputs[idx].is_dist():

0 commit comments

Comments
 (0)