7979 DataCollatorWithPadding ,
8080 DistDataLoader ,
8181 default_data_collator ,
82+ init_dataloader_comm_group ,
8283)
8384from ..peft import LoRAModel , PrefixModelForCausalLM , VeRAModel
8485
@@ -440,6 +441,10 @@ def fn(layer):
440441
441442 model .apply (fn )
442443
444+ self ._pp_data_group = None
445+ if self .args .pipeline_parallel_degree > 1 and self .args .distributed_dataloader :
446+ self ._pp_data_group = init_dataloader_comm_group ()
447+
443448 default_label_names = (
444449 ["start_positions" , "end_positions" ]
445450 if "QusetionAnswering" in type (self .model ).__name__ or "UIE" in type (self .model ).__name__
@@ -1537,6 +1542,7 @@ def get_train_dataloader(self):
15371542 train_dataset = self ._remove_unused_columns (train_dataset , description = "training" )
15381543 _DataLoader = DistDataLoader if self .args .distributed_dataloader else DataLoader
15391544
1545+ additional_configs = {}
15401546 if is_iterable_dataset : # For iterable dataset
15411547 if self .args .dataset_world_size > 1 and train_dataset is not None :
15421548 train_dataset = IterableDatasetShard (
@@ -1549,9 +1555,7 @@ def get_train_dataloader(self):
15491555
15501556 if self .args .distributed_dataloader :
15511557 logger .info ("Training using DistDataLoader." )
1552- additional_configs = {"is_iterable_dataset" : True }
1553- else :
1554- additional_configs = {}
1558+ additional_configs = {"is_iterable_dataset" : True , "pp_data_group" : self ._pp_data_group }
15551559 return _DataLoader (
15561560 train_dataset ,
15571561 batch_size = self .args .per_device_train_batch_size ,
@@ -1563,11 +1567,13 @@ def get_train_dataloader(self):
15631567 train_sampler = self ._get_train_sampler ()
15641568 if self .args .distributed_dataloader :
15651569 logger .info ("Training using DistDataLoader." )
1570+ additional_configs = {"pp_data_group" : self ._pp_data_group }
15661571 return _DataLoader (
15671572 train_dataset ,
15681573 batch_sampler = train_sampler ,
15691574 collate_fn = self .data_collator ,
15701575 num_workers = self .args .dataloader_num_workers ,
1576+ ** additional_configs ,
15711577 )
15721578
15731579 def _get_eval_sampler (self , eval_dataset : Dataset ):
@@ -1623,6 +1629,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
16231629 eval_dataset = self ._remove_unused_columns (eval_dataset , description = "evaluation" )
16241630 _DataLoader = DistDataLoader if self .args .distributed_dataloader else DataLoader
16251631
1632+ additional_configs = {}
16261633 if is_iterable_dataset :
16271634 if self .args .dataset_world_size > 1 and eval_dataset is not None :
16281635 eval_dataset = IterableDatasetShard (
@@ -1632,11 +1639,10 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
16321639 num_processes = self .args .dataset_world_size ,
16331640 process_index = self .args .dataset_rank ,
16341641 )
1642+
16351643 if self .args .distributed_dataloader :
16361644 logger .info ("Eval using DistDataLoader." )
1637- additional_configs = {"eval" : True , "is_iterable_dataset" : True }
1638- else :
1639- additional_configs = {}
1645+ additional_configs = {"eval" : True , "is_iterable_dataset" : True , "pp_data_group" : self ._pp_data_group }
16401646 return _DataLoader (
16411647 eval_dataset ,
16421648 batch_size = self .args .per_device_eval_batch_size ,
@@ -1648,9 +1654,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
16481654 eval_sampler = self ._get_eval_sampler (eval_dataset )
16491655 if self .args .distributed_dataloader :
16501656 logger .info ("Eval using DistDataLoader." )
1651- additional_configs = {"eval" : True }
1652- else :
1653- additional_configs = {}
1657+ additional_configs = {"eval" : True , "pp_data_group" : self ._pp_data_group }
16541658 return _DataLoader (
16551659 eval_dataset ,
16561660 batch_sampler = eval_sampler ,
@@ -1683,6 +1687,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
16831687 test_dataset = self ._remove_unused_columns (test_dataset , description = "test" )
16841688 _DataLoader = DistDataLoader if self .args .distributed_dataloader else DataLoader
16851689
1690+ additional_config = {}
16861691 if is_iterable_dataset :
16871692 if self .args .dataset_world_size > 1 and test_dataset is not None :
16881693 test_dataset = IterableDatasetShard (
@@ -1695,9 +1700,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
16951700
16961701 if self .args .distributed_dataloader :
16971702 logger .info ("Test using DistDataLoader." )
1698- additional_config = {"eval" : True , "is_iterable_dataset" : True }
1699- else :
1700- additional_config = {}
1703+ additional_config = {"eval" : True , "is_iterable_dataset" : True , "pp_data_group" : self ._pp_data_group }
17011704 return _DataLoader (
17021705 test_dataset ,
17031706 batch_size = self .args .per_device_eval_batch_size * self .world_size ,
@@ -1709,9 +1712,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
17091712 test_sampler = self ._get_eval_sampler (test_dataset )
17101713 if self .args .distributed_dataloader :
17111714 logger .info ("Test using DistDataLoader." )
1712- additional_config = {"eval" : True }
1713- else :
1714- additional_config = {}
1715+ additional_config = {"eval" : True , "pp_data_group" : self ._pp_data_group }
17151716 # We use the same batch_size as for eval.
17161717 return _DataLoader (
17171718 test_dataset ,
0 commit comments