Skip to content

Commit f263820

Browse files
authored
Poisson: split Knuth/Rejection methods (#1493)
1 parent ef052ec commit f263820

File tree

3 files changed

+191
-183
lines changed

3 files changed

+191
-183
lines changed

benches/benches/distr.rs

Lines changed: 70 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -11,95 +11,46 @@
1111
// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable
1212
#![rustfmt::skip]
1313

14-
const RAND_BENCH_N: u64 = 1000;
15-
1614
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
1715
use criterion_cycles_per_byte::CyclesPerByte;
1816

19-
use core::mem::size_of;
20-
2117
use rand::prelude::*;
2218
use rand_distr::*;
2319

2420
// At this time, distributions are optimised for 64-bit platforms.
2521
use rand_pcg::Pcg64Mcg;
2622

23+
const ITER_ELTS: u64 = 100;
24+
2725
macro_rules! distr_int {
2826
($group:ident, $fnn:expr, $ty:ty, $distr:expr) => {
29-
$group.throughput(Throughput::Bytes(
30-
size_of::<$ty>() as u64 * RAND_BENCH_N));
3127
$group.bench_function($fnn, |c| {
3228
let mut rng = Pcg64Mcg::from_os_rng();
3329
let distr = $distr;
3430

35-
c.iter(|| {
36-
let mut accum: $ty = 0;
37-
for _ in 0..RAND_BENCH_N {
38-
let x: $ty = distr.sample(&mut rng);
39-
accum = accum.wrapping_add(x);
40-
}
41-
accum
42-
});
31+
c.iter(|| distr.sample(&mut rng));
4332
});
4433
};
4534
}
4635

4736
macro_rules! distr_float {
4837
($group:ident, $fnn:expr, $ty:ty, $distr:expr) => {
49-
$group.throughput(Throughput::Bytes(
50-
size_of::<$ty>() as u64 * RAND_BENCH_N));
5138
$group.bench_function($fnn, |c| {
5239
let mut rng = Pcg64Mcg::from_os_rng();
5340
let distr = $distr;
5441

55-
c.iter(|| {
56-
let mut accum = 0.;
57-
for _ in 0..RAND_BENCH_N {
58-
let x: $ty = distr.sample(&mut rng);
59-
accum += x;
60-
}
61-
accum
62-
});
63-
});
64-
};
65-
}
66-
67-
macro_rules! distr {
68-
($group:ident, $fnn:expr, $ty:ty, $distr:expr) => {
69-
$group.throughput(Throughput::Bytes(
70-
size_of::<$ty>() as u64 * RAND_BENCH_N));
71-
$group.bench_function($fnn, |c| {
72-
let mut rng = Pcg64Mcg::from_os_rng();
73-
let distr = $distr;
74-
75-
c.iter(|| {
76-
let mut accum: u32 = 0;
77-
for _ in 0..RAND_BENCH_N {
78-
let x: $ty = distr.sample(&mut rng);
79-
accum = accum.wrapping_add(x as u32);
80-
}
81-
accum
82-
});
42+
c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng));
8343
});
8444
};
8545
}
8646

8747
macro_rules! distr_arr {
8848
($group:ident, $fnn:expr, $ty:ty, $distr:expr) => {
89-
$group.throughput(Throughput::Bytes(
90-
size_of::<$ty>() as u64 * RAND_BENCH_N));
9149
$group.bench_function($fnn, |c| {
9250
let mut rng = Pcg64Mcg::from_os_rng();
9351
let distr = $distr;
9452

95-
c.iter(|| {
96-
let mut accum: u32 = 0;
97-
for _ in 0..RAND_BENCH_N {
98-
let x: $ty = distr.sample(&mut rng);
99-
accum = accum.wrapping_add(x[0] as u32);
100-
}
101-
accum
102-
});
53+
c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng));
10354
});
10455
};
10556
}
@@ -111,122 +62,126 @@ macro_rules! sample_binomial {
11162
}
11263

11364
fn bench(c: &mut Criterion<CyclesPerByte>) {
114-
{
11565
let mut g = c.benchmark_group("exp");
11666
distr_float!(g, "exp", f64, Exp::new(1.23 * 4.56).unwrap());
11767
distr_float!(g, "exp1_specialized", f64, Exp1);
11868
distr_float!(g, "exp1_general", f64, Exp::new(1.).unwrap());
119-
}
69+
g.finish();
12070

121-
{
12271
let mut g = c.benchmark_group("normal");
12372
distr_float!(g, "normal", f64, Normal::new(-1.23, 4.56).unwrap());
12473
distr_float!(g, "standardnormal_specialized", f64, StandardNormal);
12574
distr_float!(g, "standardnormal_general", f64, Normal::new(0., 1.).unwrap());
12675
distr_float!(g, "log_normal", f64, LogNormal::new(-1.23, 4.56).unwrap());
127-
g.throughput(Throughput::Bytes(size_of::<f64>() as u64 * RAND_BENCH_N));
76+
g.throughput(Throughput::Elements(ITER_ELTS));
12877
g.bench_function("iter", |c| {
12978
use core::f64::consts::{E, PI};
13079
let mut rng = Pcg64Mcg::from_os_rng();
13180
let distr = Normal::new(-E, PI).unwrap();
132-
let mut iter = distr.sample_iter(&mut rng);
13381

13482
c.iter(|| {
135-
let mut accum = 0.0;
136-
for _ in 0..RAND_BENCH_N {
137-
accum += iter.next().unwrap();
138-
}
139-
accum
83+
distr.sample_iter(&mut rng)
84+
.take(ITER_ELTS as usize)
85+
.fold(0.0, |a, r| a + r)
14086
});
14187
});
142-
}
88+
g.finish();
14389

144-
{
14590
let mut g = c.benchmark_group("skew_normal");
14691
distr_float!(g, "shape_zero", f64, SkewNormal::new(0.0, 1.0, 0.0).unwrap());
14792
distr_float!(g, "shape_positive", f64, SkewNormal::new(0.0, 1.0, 100.0).unwrap());
14893
distr_float!(g, "shape_negative", f64, SkewNormal::new(0.0, 1.0, -100.0).unwrap());
149-
}
94+
g.finish();
15095

151-
{
15296
let mut g = c.benchmark_group("gamma");
153-
distr_float!(g, "gamma_large_shape", f64, Gamma::new(10., 1.0).unwrap());
154-
distr_float!(g, "gamma_small_shape", f64, Gamma::new(0.1, 1.0).unwrap());
155-
distr_float!(g, "beta_small_param", f64, Beta::new(0.1, 0.1).unwrap());
156-
distr_float!(g, "beta_large_param_similar", f64, Beta::new(101., 95.).unwrap());
157-
distr_float!(g, "beta_large_param_different", f64, Beta::new(10., 1000.).unwrap());
158-
distr_float!(g, "beta_mixed_param", f64, Beta::new(0.5, 100.).unwrap());
159-
}
97+
distr_float!(g, "large_shape", f64, Gamma::new(10., 1.0).unwrap());
98+
distr_float!(g, "small_shape", f64, Gamma::new(0.1, 1.0).unwrap());
99+
g.finish();
100+
101+
let mut g = c.benchmark_group("beta");
102+
distr_float!(g, "small_param", f64, Beta::new(0.1, 0.1).unwrap());
103+
distr_float!(g, "large_param_similar", f64, Beta::new(101., 95.).unwrap());
104+
distr_float!(g, "large_param_different", f64, Beta::new(10., 1000.).unwrap());
105+
distr_float!(g, "mixed_param", f64, Beta::new(0.5, 100.).unwrap());
106+
g.finish();
160107

161-
{
162108
let mut g = c.benchmark_group("cauchy");
163109
distr_float!(g, "cauchy", f64, Cauchy::new(4.2, 6.9).unwrap());
164-
}
110+
g.finish();
165111

166-
{
167112
let mut g = c.benchmark_group("triangular");
168113
distr_float!(g, "triangular", f64, Triangular::new(0., 1., 0.9).unwrap());
169-
}
114+
g.finish();
170115

171-
{
172116
let mut g = c.benchmark_group("geometric");
173117
distr_int!(g, "geometric", u64, Geometric::new(0.5).unwrap());
174118
distr_int!(g, "standard_geometric", u64, StandardGeometric);
175-
}
119+
g.finish();
176120

177-
{
178121
let mut g = c.benchmark_group("weighted");
179-
distr_int!(g, "weighted_i8", usize, WeightedIndex::new([1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
180-
distr_int!(g, "weighted_u32", usize, WeightedIndex::new([1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
181-
distr_int!(g, "weighted_f64", usize, WeightedIndex::new([1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
182-
distr_int!(g, "weighted_large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());
183-
distr_int!(g, "weighted_alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
184-
distr_int!(g, "weighted_alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
185-
distr_int!(g, "weighted_alias_method_f64", usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
186-
distr_int!(g, "weighted_alias_method_large_set", usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap());
187-
}
122+
distr_int!(g, "i8", usize, WeightedIndex::new([1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
123+
distr_int!(g, "u32", usize, WeightedIndex::new([1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
124+
distr_int!(g, "f64", usize, WeightedIndex::new([1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
125+
distr_int!(g, "large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());
126+
distr_int!(g, "alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
127+
distr_int!(g, "alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
128+
distr_int!(g, "alias_method_f64", usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
129+
distr_int!(g, "alias_method_large_set", usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap());
130+
g.finish();
188131

189-
{
190132
let mut g = c.benchmark_group("binomial");
191-
sample_binomial!(g, "binomial", 20, 0.7);
192-
sample_binomial!(g, "binomial_small", 1_000_000, 1e-30);
193-
sample_binomial!(g, "binomial_1", 1, 0.9);
194-
sample_binomial!(g, "binomial_10", 10, 0.9);
195-
sample_binomial!(g, "binomial_100", 100, 0.99);
196-
sample_binomial!(g, "binomial_1000", 1000, 0.01);
197-
sample_binomial!(g, "binomial_1e12", 1_000_000_000_000, 0.2);
198-
}
133+
sample_binomial!(g, "small", 1_000_000, 1e-30);
134+
sample_binomial!(g, "1", 1, 0.9);
135+
sample_binomial!(g, "10", 10, 0.9);
136+
sample_binomial!(g, "100", 100, 0.99);
137+
sample_binomial!(g, "1000", 1000, 0.01);
138+
sample_binomial!(g, "1e12", 1_000_000_000_000, 0.2);
139+
g.finish();
199140

200-
{
201141
let mut g = c.benchmark_group("poisson");
202-
distr_float!(g, "poisson", f64, Poisson::new(4.0).unwrap());
142+
for lambda in [1f64, 4.0, 10.0, 100.0].into_iter() {
143+
let name = format!("{lambda}");
144+
distr_float!(g, name, f64, Poisson::new(lambda).unwrap());
203145
}
146+
g.throughput(Throughput::Elements(ITER_ELTS));
147+
g.bench_function("variable", |c| {
148+
let mut rng = Pcg64Mcg::from_os_rng();
149+
let ldistr = Uniform::new(0.1, 10.0).unwrap();
150+
151+
c.iter(|| {
152+
let l = rng.sample(ldistr);
153+
let distr = Poisson::new(l * l).unwrap();
154+
Distribution::<f64>::sample_iter(&distr, &mut rng)
155+
.take(ITER_ELTS as usize)
156+
.fold(0.0, |a, r| a + r)
157+
})
158+
});
159+
g.finish();
204160

205-
{
206161
let mut g = c.benchmark_group("zipf");
207162
distr_float!(g, "zipf", f64, Zipf::new(10, 1.5).unwrap());
208163
distr_float!(g, "zeta", f64, Zeta::new(1.5).unwrap());
209-
}
164+
g.finish();
210165

211-
{
212166
let mut g = c.benchmark_group("bernoulli");
213-
distr!(g, "bernoulli", bool, Bernoulli::new(0.18).unwrap());
214-
}
167+
g.bench_function("bernoulli", |c| {
168+
let mut rng = Pcg64Mcg::from_os_rng();
169+
let distr = Bernoulli::new(0.18).unwrap();
170+
c.iter(|| distr.sample(&mut rng))
171+
});
172+
g.finish();
215173

216-
{
217-
let mut g = c.benchmark_group("circle");
174+
let mut g = c.benchmark_group("unit");
218175
distr_arr!(g, "circle", [f64; 2], UnitCircle);
219-
}
220-
221-
{
222-
let mut g = c.benchmark_group("sphere");
223176
distr_arr!(g, "sphere", [f64; 3], UnitSphere);
224-
}
177+
g.finish();
225178
}
226179

227180
criterion_group!(
228181
name = benches;
229-
config = Criterion::default().with_measurement(CyclesPerByte);
182+
config = Criterion::default().with_measurement(CyclesPerByte)
183+
.warm_up_time(core::time::Duration::from_secs(1))
184+
.measurement_time(core::time::Duration::from_secs(2));
230185
targets = bench
231186
);
232187
criterion_main!(benches);

rand_distr/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ mod normal;
211211
mod normal_inverse_gaussian;
212212
mod pareto;
213213
mod pert;
214-
mod poisson;
214+
pub(crate) mod poisson;
215215
mod skew_normal;
216216
mod student_t;
217217
mod triangular;

0 commit comments

Comments
 (0)