99
1010//! The dirichlet distribution.
1111#![ cfg( feature = "alloc" ) ]
12- use num_traits:: Float ;
13- use crate :: { Distribution , Exp1 , Gamma , Open01 , StandardNormal } ;
12+ use num_traits:: { Float , NumCast } ;
13+ use crate :: { Beta , Distribution , Exp1 , Gamma , Open01 , StandardNormal } ;
1414use rand:: Rng ;
1515use core:: fmt;
1616use alloc:: { boxed:: Box , vec, vec:: Vec } ;
@@ -123,16 +123,56 @@ where
123123 fn sample < R : Rng + ?Sized > ( & self , rng : & mut R ) -> Vec < F > {
124124 let n = self . alpha . len ( ) ;
125125 let mut samples = vec ! [ F :: zero( ) ; n] ;
126- let mut sum = F :: zero ( ) ;
127126
128- for ( s, & a) in samples. iter_mut ( ) . zip ( self . alpha . iter ( ) ) {
129- let g = Gamma :: new ( a, F :: one ( ) ) . unwrap ( ) ;
130- * s = g. sample ( rng) ;
131- sum = sum + ( * s) ;
132- }
133- let invacc = F :: one ( ) / sum;
134- for s in samples. iter_mut ( ) {
135- * s = ( * s) * invacc;
127+ if self . alpha . iter ( ) . all ( |x| * x <= NumCast :: from ( 0.1 ) . unwrap ( ) ) {
128+ // All the values in alpha are less than 0.1.
129+ //
130+ // When all the alpha parameters are sufficiently small, there
131+ // is a nontrivial probability that the samples from the gamma
132+ // distributions used in the other method will all be 0, which
133+ // results in the dirichlet samples being nan. So instead of
134+ // use that method, use the "stick breaking" method based on the
135+ // marginal beta distributions.
136+ //
137+ // Form the right-to-left cumulative sum of alpha, exluding the
138+ // first element of alpha. E.g. if alpha = [a0, a1, a2, a3], then
139+ // after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl
140+ // will hold [a1+a2+a3, a2+a3, a3].
141+ let mut alpha_sum_rl: Vec < F > = self
142+ . alpha
143+ . iter ( )
144+ . skip ( 1 )
145+ . rev ( )
146+ // scan does the cumulative sum
147+ . scan ( F :: zero ( ) , |sum, x| {
148+ * sum = * sum + * x;
149+ Some ( * sum)
150+ } )
151+ . collect ( ) ;
152+ alpha_sum_rl. reverse ( ) ;
153+ let mut acc = F :: one ( ) ;
154+ for ( ( s, & a) , & b) in samples
155+ . iter_mut ( )
156+ . zip ( self . alpha . iter ( ) )
157+ . zip ( alpha_sum_rl. iter ( ) )
158+ {
159+ let beta = Beta :: new ( a, b) . unwrap ( ) ;
160+ let beta_sample = beta. sample ( rng) ;
161+ * s = acc * beta_sample;
162+ acc = acc * ( F :: one ( ) - beta_sample) ;
163+ }
164+ samples[ n - 1 ] = acc;
165+ } else {
166+ let mut sum = F :: zero ( ) ;
167+ for ( s, & a) in samples. iter_mut ( ) . zip ( self . alpha . iter ( ) ) {
168+ let g = Gamma :: new ( a, F :: one ( ) ) . unwrap ( ) ;
169+ * s = g. sample ( rng) ;
170+ sum = sum + ( * s) ;
171+ }
172+ let invacc = F :: one ( ) / sum;
173+ for s in samples. iter_mut ( ) {
174+ * s = ( * s) * invacc;
175+ }
136176 }
137177 samples
138178 }
@@ -142,6 +182,33 @@ where
142182mod test {
143183 use super :: * ;
144184
185+ //
186+ // Check that the means of the components of n samples from
187+ // the Dirichlet distribution agree with the expected means
188+ // with a relative tolerance of rtol.
189+ //
190+ // This is a crude statistical test, but it will catch egregious
191+ // mistakes. It will also also fail if any samples contain nan.
192+ //
193+ fn check_dirichlet_means ( alpha : & Vec < f64 > , n : i32 , rtol : f64 , seed : u64 ) {
194+ let d = Dirichlet :: new ( & alpha) . unwrap ( ) ;
195+ let alpha_len = d. alpha . len ( ) ;
196+ let mut rng = crate :: test:: rng ( seed) ;
197+ let mut sums = vec ! [ 0.0 ; alpha_len] ;
198+ for _ in 0 ..n {
199+ let samples = d. sample ( & mut rng) ;
200+ for i in 0 ..alpha_len {
201+ sums[ i] += samples[ i] ;
202+ }
203+ }
204+ let sample_mean: Vec < f64 > = sums. iter ( ) . map ( |x| x / n as f64 ) . collect ( ) ;
205+ let alpha_sum: f64 = d. alpha . iter ( ) . sum ( ) ;
206+ let expected_mean: Vec < f64 > = d. alpha . iter ( ) . map ( |x| x / alpha_sum) . collect ( ) ;
207+ for i in 0 ..alpha_len {
208+ assert_almost_eq ! ( sample_mean[ i] , expected_mean[ i] , rtol) ;
209+ }
210+ }
211+
145212 #[ test]
146213 fn test_dirichlet ( ) {
147214 let d = Dirichlet :: new ( & [ 1.0 , 2.0 , 3.0 ] ) . unwrap ( ) ;
@@ -172,6 +239,48 @@ mod test {
172239 . collect ( ) ;
173240 }
174241
242+ #[ test]
243+ fn test_dirichlet_means ( ) {
244+ // Check the means of 20000 samples for several different alphas.
245+ let alpha_set = vec ! [
246+ vec![ 0.5 , 0.25 ] ,
247+ vec![ 123.0 , 75.0 ] ,
248+ vec![ 2.0 , 2.5 , 5.0 , 7.0 ] ,
249+ vec![ 0.1 , 8.0 , 1.0 , 2.0 , 2.0 , 0.85 , 0.05 , 12.5 ] ,
250+ ] ;
251+ let n = 20000 ;
252+ let rtol = 2e-2 ;
253+ let seed = 1317624576693539401 ;
254+ for alpha in alpha_set {
255+ check_dirichlet_means ( & alpha, n, rtol, seed) ;
256+ }
257+ }
258+
259+ #[ test]
260+ fn test_dirichlet_means_very_small_alpha ( ) {
261+ // With values of alpha that are all 0.001, check that the means of the
262+ // components of 10000 samples are within 1% of the expected means.
263+ // With the sampling method based on gamma variates, this test would
264+ // fail, with about 10% of the samples containing nan.
265+ let alpha = vec ! [ 0.001 , 0.001 , 0.001 ] ;
266+ let n = 10000 ;
267+ let rtol = 1e-2 ;
268+ let seed = 1317624576693539401 ;
269+ check_dirichlet_means ( & alpha, n, rtol, seed) ;
270+ }
271+
272+ #[ test]
273+ fn test_dirichlet_means_small_alpha ( ) {
274+ // With values of alpha that are all less than 0.1, check that the
275+ // means of the components of 150000 samples are within 0.1% of the
276+ // expected means.
277+ let alpha = vec ! [ 0.05 , 0.025 , 0.075 , 0.05 ] ;
278+ let n = 150000 ;
279+ let rtol = 1e-3 ;
280+ let seed = 1317624576693539401 ;
281+ check_dirichlet_means ( & alpha, n, rtol, seed) ;
282+ }
283+
175284 #[ test]
176285 #[ should_panic]
177286 fn test_dirichlet_invalid_length ( ) {
0 commit comments