Skip to content

Commit bf6ddb8

Browse files
committed
revert shard_dataloader
1 parent 40c1df6 commit bf6ddb8

File tree

6 files changed

+4
-88
lines changed

6 files changed

+4
-88
lines changed

paddle/fluid/pybind/dist_api.cc

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,6 @@ void BindTensorDistAttribute(py::module *m) {
8181
[](TensorDistAttribute &self) {
8282
return self.process_mesh_attr().process_mesh();
8383
})
84-
.def_property_readonly(
85-
"process_mesh_attr",
86-
[](TensorDistAttribute &self) { return self.process_mesh_attr(); })
87-
.def_property_readonly("process_mesh_name",
88-
[](TensorDistAttribute &self) {
89-
return self.process_mesh_attr().dim_names();
90-
})
9184
.def_property_readonly(
9285
"dims_mapping",
9386
[](TensorDistAttribute &self) { return self.dims_mapping(); })

paddle/fluid/pybind/pir.cc

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,10 +1532,6 @@ void BindAttribute(py::module *m) {
15321532
}
15331533
return py::cast<py::none>(Py_None);
15341534
})
1535-
.def("as_int64",
1536-
[](Attribute &self) {
1537-
return reinterpret_cast<int64_t>(static_cast<const void *>(self));
1538-
})
15391535
.def("as_array_attr", [](Attribute &self) -> py::object {
15401536
if (auto array_attr = self.dyn_cast<ArrayAttribute>()) {
15411537
return py::cast(array_attr);

python/paddle/distributed/auto_parallel/api.py

Lines changed: 2 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from __future__ import annotations
1515

1616
import copy
17-
import typing
1817
from types import MethodType
1918
from typing import TYPE_CHECKING, Any, Literal, TypedDict
2019

@@ -340,11 +339,9 @@ def forward(
340339
if local_tensor.is_dist():
341340
local_mesh = local_tensor.process_mesh
342341
local_val = local_tensor._local_value()
343-
# local_placement = local_tensor.placements[0]
344342
else:
345343
local_val = local_tensor
346344
local_mesh = None
347-
# local_placement = dist.Replicate()
348345

349346
ctx.global_mesh = copy.deepcopy(mesh)
350347
ctx.placements = placements
@@ -2766,15 +2763,7 @@ def __init__(
27662763
dataloader: paddle.io.DataLoader,
27672764
meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh],
27682765
input_keys: list[str] | tuple[str] | None = None,
2769-
shard_dims: (
2770-
list
2771-
| tuple
2772-
| str
2773-
| int
2774-
| list[dist.Placement]
2775-
| list[list[dist.Placement]]
2776-
| None
2777-
) = None,
2766+
shard_dims: list | tuple | str | int | None = None,
27782767
is_dataset_splitted: bool = False,
27792768
):
27802769
# do some check
@@ -2850,7 +2839,6 @@ def __init__(
28502839
self._dataloader.pin_memory = False
28512840

28522841
def _process_shard_dims(self, shard_dims):
2853-
shard_dims = self._convert_shard_dim_type(shard_dims)
28542842
if isinstance(shard_dims, (int, str)) or shard_dims is None:
28552843
res = []
28562844
for i in range(len(self._meshes)):
@@ -2866,52 +2854,6 @@ def _process_shard_dims(self, shard_dims):
28662854
)
28672855
return shard_dims
28682856

2869-
def _convert_placements_to_mesh_dim(self, placements):
2870-
mesh_dim = None
2871-
for i, placement in enumerate(placements):
2872-
if placement.is_shard():
2873-
shard_dim = typing.cast(dist.Shard, placement).get_dim()
2874-
assert (
2875-
shard_dim == 0
2876-
), "Only the 0th dim of the input can be sharded."
2877-
assert (
2878-
mesh_dim is None
2879-
), "The input placements can only contain one Shard(0)."
2880-
mesh_dim = i
2881-
else:
2882-
assert (
2883-
placement.is_replicate()
2884-
), "The input placement must be Replicate or Shard(0)."
2885-
assert (
2886-
mesh_dim is not None
2887-
), "Failed to convert placements to a mesh_dim."
2888-
return mesh_dim
2889-
2890-
def _convert_shard_dim_type(self, shard_dims):
2891-
if not isinstance(shard_dims, list) or not isinstance(
2892-
shard_dims[0], dist.Placement
2893-
):
2894-
# if the input shard_dims is not Placement type,
2895-
# no need to convert it
2896-
return shard_dims
2897-
if isinstance(shard_dims[0], dist.Placement):
2898-
# if the input shard_dims is a list of Placement,
2899-
# convert it to a mesh_dim value
2900-
mesh_dim = self._convert_placements_to_mesh_dim(shard_dims)
2901-
return mesh_dim
2902-
elif isinstance(shard_dims[0], list):
2903-
# if the input shard_dims is a list of List(Placement),
2904-
# convert each placements to a mesh_dim value
2905-
res = []
2906-
for shard_dim in shard_dims:
2907-
mesh_dim = self._convert_placements_to_mesh_dim(shard_dim)
2908-
res.append(mesh_dim)
2909-
return res
2910-
else:
2911-
raise TypeError(
2912-
f"shard_dims must be Placements or list/tuple of Placements, but got {type(shard_dims)}"
2913-
)
2914-
29152857
def _get_mesh_and_shard_dim(self, process_id):
29162858
for i in range(len(self._meshes)):
29172859
if isinstance(self._meshes[i], (list, tuple)):
@@ -3075,15 +3017,7 @@ def shard_dataloader(
30753017
dataloader: paddle.io.DataLoader,
30763018
meshes: ProcessMesh | list[ProcessMesh] | tuple[ProcessMesh],
30773019
input_keys: list[str] | tuple[str] | None = None,
3078-
shard_dims: (
3079-
list
3080-
| tuple
3081-
| str
3082-
| int
3083-
| list[dist.Placement]
3084-
| list[list[dist.Placement]]
3085-
| None
3086-
) = None,
3020+
shard_dims: list | tuple | str | int | None = None,
30873021
is_dataset_splitted: bool = False,
30883022
) -> ShardDataloader:
30893023
"""

python/paddle/distributed/auto_parallel/static/pir_pass.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,6 @@ def reshard_combine_value(program, op, operand, attr):
9191

9292
def apply_partition_pass(program):
9393
for op in program.global_block().ops:
94-
# if op.name() == "pd_op.matmul_grad":
95-
# breakpoint()
9694
if op.name() in partition_skip_op_list:
9795
continue
9896

test/auto_parallel/pir/semi_auto_parallel_simple_net_ep.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,8 @@ def __init__(self):
3131
self.hidden_size = 16
3232
self.class_num = 10
3333
self.run_ep = False
34-
# self.mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
3534
self.mesh = dist.ProcessMesh([0, 1])
3635
self.expert_mesh_list = []
37-
# self.expert_mesh_list.append(dist.ProcessMesh([0], dim_names=["x"]))
38-
# self.expert_mesh_list.append(dist.ProcessMesh([1], dim_names=["x"]))
3936
self.expert_mesh_list.append(dist.ProcessMesh([0]))
4037
self.expert_mesh_list.append(dist.ProcessMesh([1]))
4138

@@ -204,7 +201,7 @@ def run_ep(self):
204201
model, train_dataloader, criterion, optimizer = self.build(config)
205202

206203
dist_dataloader = dist.shard_dataloader(
207-
train_dataloader, config.mesh, shard_dims=[dist.Shard(0)]
204+
train_dataloader, config.mesh, shard_dims=0
208205
)
209206
loss = self.train(config, model, dist_dataloader, criterion, optimizer)
210207

@@ -226,7 +223,7 @@ def run_dy2st(self):
226223
model, train_dataloader, criterion, optimizer = self.build(config)
227224

228225
dist_dataloader = dist.shard_dataloader(
229-
train_dataloader, config.mesh, shard_dims="d0"
226+
train_dataloader, config.mesh, shard_dims=0
230227
)
231228

232229
mode = "train"

test/auto_parallel/pir/test_semi_auto_parallel_simple_net_ep.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import sys
1514
import tempfile
1615
import unittest
1716

18-
sys.path.append("..")
1917
import collective.test_communication_api_base as test_base
2018

2119

0 commit comments

Comments
 (0)