Skip to content

Commit 36c5a65

Browse files
authored
add WeightedRandomSampler. test=develop (#2872)
1 parent c93028b commit 36c5a65

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
.. _cn_api_io_cn_WeightedRandomSampler:
2+
3+
WeightedRandomSampler
4+
-------------------------------
5+
6+
.. py:class:: paddle.io.WeightedRandomSampler(weights, num_samples, replacement=True)
7+
8+
通过制定的权重随机采样,采样下标范围在 ``[0, len(weights) - 1]`` , 如果 ``replacement`` 为 ``True`` ,则下标可被采样多次
9+
10+
参数:
11+
- **weights** (numpy.ndarray|paddle.Tensor|tuple|list) - 权重序列,需要是numpy数组,paddle.Tensor,list或者tuple类型。
12+
- **num_samples** (int) - 采样样本数。
13+
- **replacement** (bool) - 是否采用有放回的采样,默认值为True
14+
15+
返回: 返回根据权重随机采样下标的采样器
16+
17+
返回类型: WeightedRandomSampler
18+
19+
**代码示例**
20+
21+
.. code-block:: python
22+
23+
from paddle.io import WeightedRandomSampler
24+
25+
sampler = WeightedRandomSampler(weights=[0.1, 0.3, 0.5, 0.7, 0.2],
26+
num_samples=5,
27+
replacement=True)
28+
29+
for index in sampler:
30+
print(index)
31+

0 commit comments

Comments
 (0)