Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion python/paddle/fluid/dataloader/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from __future__ import division

import numpy as np
from .. import core

__all__ = ["Sampler", "SequenceSampler", "RandomSampler"]
__all__ = [
"Sampler", "SequenceSampler", "RandomSampler", "WeightedRandomSampler"
]


class Sampler(object):
Expand Down Expand Up @@ -234,3 +237,85 @@ def __iter__(self):

def __len__(self):
return self.num_samples


def _weighted_sample(weights, num_samples, replacement=True):
if isinstance(weights, core.LoDTensor):
weights = weights.numpy()
if isinstance(weights, (list, tuple)):
weights = np.array(weights)
assert isinstance(weights, np.ndarray), \
"weights should be paddle.Tensor, numpy.ndarray, list or tuple"
assert len(weights.shape) <= 2, \
"weights should be a 1-D or 2-D array"
weights = weights.reshape((-1, weights.shape[-1]))
assert np.all(weights >= 0.), \
"weights should be positive value"
assert not np.any(weights == np.inf), \
"weights shoule not be INF"
assert not np.any(weights == np.nan), \
"weights shoule not be NaN"

non_zeros = np.sum(weights > 0., axis=1)
assert np.all(non_zeros > 0), \
"weights should have positive values"
if not replacement:
assert np.all(non_zeros >= num_samples), \
"weights positive value number should not " \
"less than num_samples when replacement=False"

weights = weights / weights.sum(axis=1)
rets = []
for i in range(weights.shape[0]):
ret = np.random.choice(weights.shape[1], num_samples, replacement,
weights[i])
rets.append(ret)
return np.array(rets)


class WeightedRandomSampler(Sampler):
"""
Random sample with given weights (probabilities), sampe index will be in range
[0, len(weights) - 1], if :attr:`replacement` is True, index can be sampled
multiple times.

Args:
weights(numpy.ndarray|paddle.Tensor|list|tuple): sequence of weights,
should be numpy array, paddle.Tensor, list or tuple
num_samples(int): set sample number to draw from sampler.
replacement(bool): Whether to draw sample with replacements, default True

Returns:
Sampler: a Sampler yield sample index randomly by given weights

Examples:

.. code-block:: python

from paddle.io import WeightedRandomSampler

sampler = WeightedRandomSampler(weights=[0.1, 0.3, 0.5, 0.7, 0.2],
num_samples=5,
replacement=True)

for index in sampler:
print(index)
"""

def __init__(self, weights, num_samples, replacement=True):
if not isinstance(num_samples, int) or num_samples <= 0:
raise ValueError("num_samples should be a positive integer")
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value")
self.weights = weights
self.num_samples = num_samples
self.replacement = replacement

def __iter__(self):
idxs = _weighted_sample(self.weights, self.num_samples,
self.replacement)
return iter(idxs.reshape((-1)).tolist())

def __len__(self):
mul = np.prod(self.weights.shape) // self.weights.shape[-1]
return self.num_samples * mul
92 changes: 83 additions & 9 deletions python/paddle/fluid/tests/unittests/test_batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

import unittest

import numpy as np
import paddle.fluid as fluid
from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, RandomSampler
from paddle.io import BatchSampler, Dataset, Sampler, SequenceSampler, \
RandomSampler, WeightedRandomSampler
from paddle.io import DistributedBatchSampler


Expand Down Expand Up @@ -195,14 +197,86 @@ def test_main(self):
pass


class TestDistributedBatchSamplerWithSampler(TestBatchSampler):
def init_batch_sampler(self):
dataset = RandomDataset(1000, 10)
bs = DistributedBatchSampler(
dataset=dataset,
batch_size=self.batch_size,
drop_last=self.drop_last)
return bs
class TestWeightedRandomSampler(unittest.TestCase):
def init_probs(self, total, pos):
pos_probs = np.random.random((pos, )).astype('float32')
probs = np.zeros((total, )).astype('float32')
probs[:pos] = pos_probs
np.random.shuffle(probs)
return probs

def test_replacement(self):
probs = self.init_probs(20, 10)
sampler = WeightedRandomSampler(probs, 30, True)
assert len(sampler) == 30
for idx in iter(sampler):
assert probs[idx] > 0.

def test_no_replacement(self):
probs = self.init_probs(20, 10)
sampler = WeightedRandomSampler(probs, 10, False)
assert len(sampler) == 10
idxs = []
for idx in iter(sampler):
assert probs[idx] > 0.
idxs.append(idx)
assert len(set(idxs)) == len(idxs)

def test_assert(self):
# all zeros
probs = np.zeros((10, )).astype('float32')
sampler = WeightedRandomSampler(probs, 10, True)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)

# not enough pos
probs = self.init_probs(10, 5)
sampler = WeightedRandomSampler(probs, 10, False)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)

# neg probs
probs = -1.0 * np.ones((10, )).astype('float32')
sampler = WeightedRandomSampler(probs, 10, True)
try:
for idx in iter(sampler):
pass
self.assertTrue(False)
except AssertionError:
self.assertTrue(True)

def test_raise(self):
# float num_samples
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, 2.3, True)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)

# neg num_samples
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, -1, True)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)

# no-bool replacement
probs = self.init_probs(10, 5)
try:
sampler = WeightedRandomSampler(probs, 5, 5)
self.assertTrue(False)
except ValueError:
self.assertTrue(True)


if __name__ == '__main__':
Expand Down
3 changes: 2 additions & 1 deletion python/paddle/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@
'Sampler',
'SequenceSampler',
'RandomSampler',
'WeightedRandomSampler',
]

from ..fluid.io import DataLoader
from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \
TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler, \
ComposeDataset, ChainDataset
ComposeDataset, ChainDataset, WeightedRandomSampler