@@ -1447,24 +1447,41 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
14471447                    process_index = self .args .dataset_rank ,
14481448                )
14491449
1450-             return  _DataLoader (
1451-                 eval_dataset ,
1452-                 batch_size = self .args .per_device_eval_batch_size ,
1453-                 collate_fn = self .data_collator ,
1454-                 num_workers = self .args .dataloader_num_workers ,
1455-             )
1450+             if  self .args .distributed_dataloader :
1451+                 return  _DataLoader (
1452+                     eval_dataset ,
1453+                     batch_size = self .args .per_device_eval_batch_size ,
1454+                     collate_fn = self .data_collator ,
1455+                     num_workers = self .args .dataloader_num_workers ,
1456+                     eval = True ,
1457+                 )
1458+             else :
1459+                 return  _DataLoader (
1460+                     eval_dataset ,
1461+                     batch_size = self .args .per_device_eval_batch_size ,
1462+                     collate_fn = self .data_collator ,
1463+                     num_workers = self .args .dataloader_num_workers ,
1464+                 )
14561465
14571466        eval_sampler  =  self ._get_eval_sampler (eval_dataset )
14581467
14591468        if  self .args .distributed_dataloader :
14601469            logger .info ("Eval using DistDataLoader." )
14611470
1462-         return  _DataLoader (
1463-             eval_dataset ,
1464-             batch_sampler = eval_sampler ,
1465-             collate_fn = self .data_collator ,
1466-             num_workers = self .args .dataloader_num_workers ,
1467-         )
1471+             return  _DataLoader (
1472+                 eval_dataset ,
1473+                 batch_sampler = eval_sampler ,
1474+                 collate_fn = self .data_collator ,
1475+                 num_workers = self .args .dataloader_num_workers ,
1476+                 eval = True ,
1477+             )
1478+         else :
1479+             return  _DataLoader (
1480+                 eval_dataset ,
1481+                 batch_sampler = eval_sampler ,
1482+                 collate_fn = self .data_collator ,
1483+                 num_workers = self .args .dataloader_num_workers ,
1484+             )
14681485
14691486    def  get_test_dataloader (self , test_dataset : Dataset ) ->  DataLoader :
14701487        """ 
@@ -1497,25 +1514,42 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
14971514                    process_index = self .args .dataset_rank ,
14981515                )
14991516
1500-             return  _DataLoader (
1501-                 test_dataset ,
1502-                 batch_size = self .args .per_device_eval_batch_size  *  self .world_size ,
1503-                 collate_fn = self .data_collator ,  # _get_collator_with_removed_columns 
1504-                 num_workers = self .args .dataloader_num_workers ,
1505-             )
1517+             if  self .args .distributed_dataloader :
1518+                 return  _DataLoader (
1519+                     test_dataset ,
1520+                     batch_size = self .args .per_device_eval_batch_size  *  self .world_size ,
1521+                     collate_fn = self .data_collator ,  # _get_collator_with_removed_columns 
1522+                     num_workers = self .args .dataloader_num_workers ,
1523+                     eval = True ,
1524+                 )
1525+             else :
1526+                 return  _DataLoader (
1527+                     test_dataset ,
1528+                     batch_size = self .args .per_device_eval_batch_size  *  self .world_size ,
1529+                     collate_fn = self .data_collator ,  # _get_collator_with_removed_columns 
1530+                     num_workers = self .args .dataloader_num_workers ,
1531+                 )
15061532
15071533        test_sampler  =  self ._get_eval_sampler (test_dataset )
15081534
15091535        if  self .args .distributed_dataloader :
15101536            logger .info ("Test using DistDataLoader." )
15111537
1512-         # We use the same batch_size as for eval. 
1513-         return  _DataLoader (
1514-             test_dataset ,
1515-             batch_sampler = test_sampler ,
1516-             collate_fn = self .data_collator ,
1517-             drop_last = self .args .dataloader_drop_last ,
1518-         )
1538+             # We use the same batch_size as for eval. 
1539+             return  _DataLoader (
1540+                 test_dataset ,
1541+                 batch_sampler = test_sampler ,
1542+                 collate_fn = self .data_collator ,
1543+                 drop_last = self .args .dataloader_drop_last ,
1544+                 eval = True ,
1545+             )
1546+         else :
1547+             return  _DataLoader (
1548+                 test_dataset ,
1549+                 batch_sampler = test_sampler ,
1550+                 collate_fn = self .data_collator ,
1551+                 drop_last = self .args .dataloader_drop_last ,
1552+             )
15191553
15201554    def  create_optimizer_and_scheduler (self , num_training_steps : int ):
15211555        """ 
0 commit comments