Skip to content

Commit 85cbada

Browse files
authored
num_samples 向下去整,防止prefrech预取时候超过数据集最大长度... (#8690)
1 parent cf57f86 commit 85cbada

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

paddlenlp/utils/batch_sampler.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
from __future__ import division, print_function
1616

17-
import math
18-
1917
import paddle
2018

2119
__all__ = ["DistributedBatchSampler"]
@@ -110,7 +108,7 @@ def __init__(
110108
# In pre-training mode when using distributed dataloader, the input dataset can be None. We should handle this situation.
111109
self.num_samples = 0
112110
else:
113-
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
111+
self.num_samples = int(len(self.dataset) * 1.0 / self.nranks)
114112
self.total_size = self.num_samples * self.nranks
115113

116114
def get_start_end_idx(self):
@@ -125,7 +123,7 @@ def __iter__(self):
125123
self.consumed_samples,
126124
self.nranks,
127125
)
128-
self.remain_num_samples = int(math.ceil((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks))
126+
self.remain_num_samples = int((len(self.dataset) - self.consumed_samples) * 1.0 / self.nranks)
129127
self.remain_total_size = self.remain_num_samples * self.nranks
130128
self.batch_size_times_rank_size = self.batch_size * self.nranks
131129

0 commit comments

Comments
 (0)