Skip to content

Commit ece1e4c

Browse files
authored
Add weighted random sampler (#28545)
* add WeightedRandomSampler. test=develop
1 parent 2cb71c0 commit ece1e4c

File tree

3 files changed

+171
-11
lines changed

3 files changed

+171
-11
lines changed

python/paddle/fluid/dataloader/sampler.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@
1616
from __future__ import division
1717

1818
import numpy as np
19+
from .. import core
1920

20-
__all__ = ["Sampler", "SequenceSampler", "RandomSampler"]
21+
__all__ = [
22+
"Sampler", "SequenceSampler", "RandomSampler", "WeightedRandomSampler"
23+
]
2124

2225

2326
class Sampler(object):
@@ -234,3 +237,85 @@ def __iter__(self):
234237

235238
def __len__(self):
236239
return self.num_samples
240+
241+
242+
def _weighted_sample(weights, num_samples, replacement=True):
243+
if isinstance(weights, core.LoDTensor):
244+
weights = weights.numpy()
245+
if isinstance(weights, (list, tuple)):
246+
weights = np.array(weights)
247+
assert isinstance(weights, np.ndarray), \
248+
"weights should be paddle.Tensor, numpy.ndarray, list or tuple"
249+
assert len(weights.shape) <= 2, \
250+
"weights should be a 1-D or 2-D array"
251+
weights = weights.reshape((-1, weights.shape[-1]))
252+
assert np.all(weights >= 0.), \
253+
"weights should be positive value"
254+
assert not np.any(weights == np.inf), \
255+
"weights shoule not be INF"
256+
assert not np.any(weights == np.nan), \
257+
"weights shoule not be NaN"
258+
259+
non_zeros = np.sum(weights > 0., axis=1)
260+
assert np.all(non_zeros > 0), \
261+
"weights should have positive values"
262+
if not replacement:
263+
assert np.all(non_zeros >= num_samples), \
264+
"weights positive value number should not " \
265+
"less than num_samples when replacement=False"
266+
267+
weights = weights / weights.sum(axis=1)
268+
rets = []
269+
for i in range(weights.shape[0]):
270+
ret = np.random.choice(weights.shape[1], num_samples, replacement,
271+
weights[i])
272+
rets.append(ret)
273+
return np.array(rets)
274+
275+
276+
class WeightedRandomSampler(Sampler):
277+
"""
278+
Random sample with given weights (probabilities), sampe index will be in range
279+
[0, len(weights) - 1], if :attr:`replacement` is True, index can be sampled
280+
multiple times.
281+
282+
Args:
283+
weights(numpy.ndarray|paddle.Tensor|list|tuple): sequence of weights,
284+
should be numpy array, paddle.Tensor, list or tuple
285+
num_samples(int): set sample number to draw from sampler.
286+
replacement(bool): Whether to draw sample with replacements, default True
287+
288+
Returns:
289+
Sampler: a Sampler yield sample index randomly by given weights
290+
291+
Examples:
292+
293+
.. code-block:: python
294+
295+
from paddle.io import WeightedRandomSampler
296+
297+
sampler = WeightedRandomSampler(weights=[0.1, 0.3, 0.5, 0.7, 0.2],
298+
num_samples=5,
299+
replacement=True)
300+
301+
for index in sampler:
302+
print(index)
303+
"""
304+
305+
def __init__(self, weights, num_samples, replacement=True):
306+
if not isinstance(num_samples, int) or num_samples <= 0:
307+
raise ValueError("num_samples should be a positive integer")
308+
if not isinstance(replacement, bool):
309+
raise ValueError("replacement should be a boolean value")
310+
self.weights = weights
311+
self.num_samples = num_samples
312+
self.replacement = replacement
313+
314+
def __iter__(self):
315+
idxs = _weighted_sample(self.weights, self.num_samples,
316+
self.replacement)
317+
return iter(idxs.reshape((-1)).tolist())
318+
319+
def __len__(self):
320+
mul = np.prod(self.weights.shape) // self.weights.shape[-1]
321+
return self.num_samples * mul

python/paddle/fluid/tests/unittests/test_batch_sampler.py

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@
1616

1717
import unittest
1818

19+
import numpy as np
1920
import paddle.fluid as fluid
20-
from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, RandomSampler
21+
from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, \
22+
RandomSampler, WeightedRandomSampler
2123
from paddle.io import DistributedBatchSampler
2224

2325

@@ -195,14 +197,86 @@ def test_main(self):
195197
pass
196198

197199

198-
class TestDistributedBatchSamplerWithSampler(TestBatchSampler):
199-
def init_batch_sampler(self):
200-
dataset = RandomDataset(1000, 10)
201-
bs = DistributedBatchSampler(
202-
dataset=dataset,
203-
batch_size=self.batch_size,
204-
drop_last=self.drop_last)
205-
return bs
200+
class TestWeightedRandomSampler(unittest.TestCase):
201+
def init_probs(self, total, pos):
202+
pos_probs = np.random.random((pos, )).astype('float32')
203+
probs = np.zeros((total, )).astype('float32')
204+
probs[:pos] = pos_probs
205+
np.random.shuffle(probs)
206+
return probs
207+
208+
def test_replacement(self):
209+
probs = self.init_probs(20, 10)
210+
sampler = WeightedRandomSampler(probs, 30, True)
211+
assert len(sampler) == 30
212+
for idx in iter(sampler):
213+
assert probs[idx] > 0.
214+
215+
def test_no_replacement(self):
216+
probs = self.init_probs(20, 10)
217+
sampler = WeightedRandomSampler(probs, 10, False)
218+
assert len(sampler) == 10
219+
idxs = []
220+
for idx in iter(sampler):
221+
assert probs[idx] > 0.
222+
idxs.append(idx)
223+
assert len(set(idxs)) == len(idxs)
224+
225+
def test_assert(self):
226+
# all zeros
227+
probs = np.zeros((10, )).astype('float32')
228+
sampler = WeightedRandomSampler(probs, 10, True)
229+
try:
230+
for idx in iter(sampler):
231+
pass
232+
self.assertTrue(False)
233+
except AssertionError:
234+
self.assertTrue(True)
235+
236+
# not enough pos
237+
probs = self.init_probs(10, 5)
238+
sampler = WeightedRandomSampler(probs, 10, False)
239+
try:
240+
for idx in iter(sampler):
241+
pass
242+
self.assertTrue(False)
243+
except AssertionError:
244+
self.assertTrue(True)
245+
246+
# neg probs
247+
probs = -1.0 * np.ones((10, )).astype('float32')
248+
sampler = WeightedRandomSampler(probs, 10, True)
249+
try:
250+
for idx in iter(sampler):
251+
pass
252+
self.assertTrue(False)
253+
except AssertionError:
254+
self.assertTrue(True)
255+
256+
def test_raise(self):
257+
# float num_samples
258+
probs = self.init_probs(10, 5)
259+
try:
260+
sampler = WeightedRandomSampler(probs, 2.3, True)
261+
self.assertTrue(False)
262+
except ValueError:
263+
self.assertTrue(True)
264+
265+
# neg num_samples
266+
probs = self.init_probs(10, 5)
267+
try:
268+
sampler = WeightedRandomSampler(probs, -1, True)
269+
self.assertTrue(False)
270+
except ValueError:
271+
self.assertTrue(True)
272+
273+
# no-bool replacement
274+
probs = self.init_probs(10, 5)
275+
try:
276+
sampler = WeightedRandomSampler(probs, 5, 5)
277+
self.assertTrue(False)
278+
except ValueError:
279+
self.assertTrue(True)
206280

207281

208282
if __name__ == '__main__':

python/paddle/io/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,10 @@
2727
'Sampler',
2828
'SequenceSampler',
2929
'RandomSampler',
30+
'WeightedRandomSampler',
3031
]
3132

3233
from ..fluid.io import DataLoader
3334
from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \
3435
TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler, \
35-
ComposeDataset, ChainDataset
36+
ComposeDataset, ChainDataset, WeightedRandomSampler

0 commit comments

Comments
 (0)