Skip to content

Commit 2ecf7ef

Browse files
authored
[DistDataloader] fix eval new_group (#9438)
1 parent 0429149 commit 2ecf7ef

File tree

2 files changed

+30
-29
lines changed

2 files changed

+30
-29
lines changed

paddlenlp/data/dist_dataloader.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666

6767
eval = kwargs.pop("eval", False)
6868
is_iterable_dataset = kwargs.pop("is_iterable_dataset", False)
69+
self._pp_data_group = kwargs.pop("pp_data_group", None)
6970

7071
if dataset is None:
7172
dataset = DummyDataset() if not is_iterable_dataset else IterableDummyDataset()
@@ -78,10 +79,8 @@ def __init__(
7879

7980
# Init pp data comm group.
8081
if self._hcg.get_pipe_parallel_world_size() > 1:
81-
self._pp_data_group = self._init_dataloader_comm_group()
8282
self._pp_group = self._hcg.get_pipe_parallel_group()
8383
else:
84-
self._pp_data_group = None
8584
self._pp_group = None
8685

8786
self.mp_group = self._hcg.get_model_parallel_group()
@@ -132,18 +131,6 @@ def __len__(self):
132131
else:
133132
raise ValueError("raise error for `paddlenlp.trainer.trainer_utils.has_length`")
134133

135-
def _init_dataloader_comm_group(self):
136-
topo = self._hcg._topo
137-
parallel_comm_group = None
138-
parallel_groups = topo.get_comm_list("pipe")
139-
140-
for group in parallel_groups:
141-
ranks = [group[0], group[-1]]
142-
comm_group = paddle.distributed.new_group(ranks=ranks)
143-
if paddle.distributed.get_rank() in ranks:
144-
parallel_comm_group = comm_group
145-
return parallel_comm_group
146-
147134
def __iter__(self):
148135
return self
149136

@@ -212,3 +199,16 @@ def __next__(self):
212199
logger.debug(e)
213200
data = self._broadcast_data(data)
214201
return data
202+
203+
204+
def init_dataloader_comm_group():
205+
hcg = fleet.get_hybrid_communicate_group()
206+
topo = hcg._topo
207+
parallel_groups = topo.get_comm_list("pipe")
208+
209+
for group in parallel_groups:
210+
ranks = [group[0], group[-1]]
211+
comm_group = paddle.distributed.new_group(ranks=ranks)
212+
if paddle.distributed.get_rank() in ranks:
213+
parallel_comm_group = comm_group
214+
return parallel_comm_group

paddlenlp/trainer/trainer.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
DataCollatorWithPadding,
8080
DistDataLoader,
8181
default_data_collator,
82+
init_dataloader_comm_group,
8283
)
8384
from ..peft import LoRAModel, PrefixModelForCausalLM, VeRAModel
8485

@@ -440,6 +441,10 @@ def fn(layer):
440441

441442
model.apply(fn)
442443

444+
self._pp_data_group = None
445+
if self.args.pipeline_parallel_degree > 1 and self.args.distributed_dataloader:
446+
self._pp_data_group = init_dataloader_comm_group()
447+
443448
default_label_names = (
444449
["start_positions", "end_positions"]
445450
if "QusetionAnswering" in type(self.model).__name__ or "UIE" in type(self.model).__name__
@@ -1537,6 +1542,7 @@ def get_train_dataloader(self):
15371542
train_dataset = self._remove_unused_columns(train_dataset, description="training")
15381543
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader
15391544

1545+
additional_configs = {}
15401546
if is_iterable_dataset: # For iterable dataset
15411547
if self.args.dataset_world_size > 1 and train_dataset is not None:
15421548
train_dataset = IterableDatasetShard(
@@ -1549,9 +1555,7 @@ def get_train_dataloader(self):
15491555

15501556
if self.args.distributed_dataloader:
15511557
logger.info("Training using DistDataLoader.")
1552-
additional_configs = {"is_iterable_dataset": True}
1553-
else:
1554-
additional_configs = {}
1558+
additional_configs = {"is_iterable_dataset": True, "pp_data_group": self._pp_data_group}
15551559
return _DataLoader(
15561560
train_dataset,
15571561
batch_size=self.args.per_device_train_batch_size,
@@ -1563,11 +1567,13 @@ def get_train_dataloader(self):
15631567
train_sampler = self._get_train_sampler()
15641568
if self.args.distributed_dataloader:
15651569
logger.info("Training using DistDataLoader.")
1570+
additional_configs = {"pp_data_group": self._pp_data_group}
15661571
return _DataLoader(
15671572
train_dataset,
15681573
batch_sampler=train_sampler,
15691574
collate_fn=self.data_collator,
15701575
num_workers=self.args.dataloader_num_workers,
1576+
**additional_configs,
15711577
)
15721578

15731579
def _get_eval_sampler(self, eval_dataset: Dataset):
@@ -1623,6 +1629,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
16231629
eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")
16241630
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader
16251631

1632+
additional_configs = {}
16261633
if is_iterable_dataset:
16271634
if self.args.dataset_world_size > 1 and eval_dataset is not None:
16281635
eval_dataset = IterableDatasetShard(
@@ -1632,11 +1639,10 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
16321639
num_processes=self.args.dataset_world_size,
16331640
process_index=self.args.dataset_rank,
16341641
)
1642+
16351643
if self.args.distributed_dataloader:
16361644
logger.info("Eval using DistDataLoader.")
1637-
additional_configs = {"eval": True, "is_iterable_dataset": True}
1638-
else:
1639-
additional_configs = {}
1645+
additional_configs = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group}
16401646
return _DataLoader(
16411647
eval_dataset,
16421648
batch_size=self.args.per_device_eval_batch_size,
@@ -1648,9 +1654,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
16481654
eval_sampler = self._get_eval_sampler(eval_dataset)
16491655
if self.args.distributed_dataloader:
16501656
logger.info("Eval using DistDataLoader.")
1651-
additional_configs = {"eval": True}
1652-
else:
1653-
additional_configs = {}
1657+
additional_configs = {"eval": True, "pp_data_group": self._pp_data_group}
16541658
return _DataLoader(
16551659
eval_dataset,
16561660
batch_sampler=eval_sampler,
@@ -1683,6 +1687,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
16831687
test_dataset = self._remove_unused_columns(test_dataset, description="test")
16841688
_DataLoader = DistDataLoader if self.args.distributed_dataloader else DataLoader
16851689

1690+
additional_config = {}
16861691
if is_iterable_dataset:
16871692
if self.args.dataset_world_size > 1 and test_dataset is not None:
16881693
test_dataset = IterableDatasetShard(
@@ -1695,9 +1700,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
16951700

16961701
if self.args.distributed_dataloader:
16971702
logger.info("Test using DistDataLoader.")
1698-
additional_config = {"eval": True, "is_iterable_dataset": True}
1699-
else:
1700-
additional_config = {}
1703+
additional_config = {"eval": True, "is_iterable_dataset": True, "pp_data_group": self._pp_data_group}
17011704
return _DataLoader(
17021705
test_dataset,
17031706
batch_size=self.args.per_device_eval_batch_size * self.world_size,
@@ -1709,9 +1712,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
17091712
test_sampler = self._get_eval_sampler(test_dataset)
17101713
if self.args.distributed_dataloader:
17111714
logger.info("Test using DistDataLoader.")
1712-
additional_config = {"eval": True}
1713-
else:
1714-
additional_config = {}
1715+
additional_config = {"eval": True, "pp_data_group": self._pp_data_group}
17151716
# We use the same batch_size as for eval.
17161717
return _DataLoader(
17171718
test_dataset,

0 commit comments

Comments
 (0)