99
1010//! The dirichlet distribution. 
1111#![ cfg( feature = "alloc" ) ]  
12- use  num_traits:: Float ; 
13- use  crate :: { Distribution ,  Exp1 ,  Gamma ,  Open01 ,  StandardNormal } ; 
12+ use  num_traits:: { Float ,   NumCast } ; 
13+ use  crate :: { Beta ,   Distribution ,  Exp1 ,  Gamma ,  Open01 ,  StandardNormal } ; 
1414use  rand:: Rng ; 
1515use  core:: fmt; 
1616use  alloc:: { boxed:: Box ,  vec,  vec:: Vec } ; 
@@ -123,16 +123,56 @@ where
123123    fn  sample < R :  Rng  + ?Sized > ( & self ,  rng :  & mut  R )  -> Vec < F >  { 
124124        let  n = self . alpha . len ( ) ; 
125125        let  mut  samples = vec ! [ F :: zero( ) ;  n] ; 
126-         let  mut  sum = F :: zero ( ) ; 
127126
128-         for  ( s,  & a)  in  samples. iter_mut ( ) . zip ( self . alpha . iter ( ) )  { 
129-             let  g = Gamma :: new ( a,  F :: one ( ) ) . unwrap ( ) ; 
130-             * s = g. sample ( rng) ; 
131-             sum =  sum + ( * s) ; 
132-         } 
133-         let  invacc = F :: one ( )  / sum; 
134-         for  s in  samples. iter_mut ( )  { 
135-             * s = ( * s) * invacc; 
127+         if  self . alpha . iter ( ) . all ( |x| * x <= NumCast :: from ( 0.1 ) . unwrap ( ) )  { 
128+             // All the values in alpha are less than 0.1. 
129+             // 
130+             // When all the alpha parameters are sufficiently small, there 
131+             // is a nontrivial probability that the samples from the gamma 
132+             // distributions used in the other method will all be 0, which 
133+             // results in the dirichlet samples being nan.  So instead of 
134+             // use that method, use the "stick breaking" method based on the 
135+             // marginal beta distributions. 
136+             // 
137+             // Form the right-to-left cumulative sum of alpha, exluding the 
138+             // first element of alpha.  E.g. if alpha = [a0, a1, a2, a3], then 
139+             // after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl 
140+             // will hold [a1+a2+a3, a2+a3, a3]. 
141+             let  mut  alpha_sum_rl:  Vec < F >  = self 
142+                 . alpha 
143+                 . iter ( ) 
144+                 . skip ( 1 ) 
145+                 . rev ( ) 
146+                 // scan does the cumulative sum 
147+                 . scan ( F :: zero ( ) ,  |sum,  x| { 
148+                     * sum = * sum + * x; 
149+                     Some ( * sum) 
150+                 } ) 
151+                 . collect ( ) ; 
152+             alpha_sum_rl. reverse ( ) ; 
153+             let  mut  acc = F :: one ( ) ; 
154+             for  ( ( s,  & a) ,  & b)  in  samples
155+                 . iter_mut ( ) 
156+                 . zip ( self . alpha . iter ( ) ) 
157+                 . zip ( alpha_sum_rl. iter ( ) ) 
158+             { 
159+                 let  beta = Beta :: new ( a,  b) . unwrap ( ) ; 
160+                 let  beta_sample = beta. sample ( rng) ; 
161+                 * s = acc *  beta_sample; 
162+                 acc = acc *  ( F :: one ( )  - beta_sample) ; 
163+             } 
164+             samples[ n - 1 ]  = acc; 
165+         }  else  { 
166+             let  mut  sum = F :: zero ( ) ; 
167+             for  ( s,  & a)  in  samples. iter_mut ( ) . zip ( self . alpha . iter ( ) )  { 
168+                 let  g = Gamma :: new ( a,  F :: one ( ) ) . unwrap ( ) ; 
169+                 * s = g. sample ( rng) ; 
170+                 sum = sum + ( * s) ; 
171+             } 
172+             let  invacc = F :: one ( )  / sum; 
173+             for  s in  samples. iter_mut ( )  { 
174+                 * s = ( * s)  *  invacc; 
175+             } 
136176        } 
137177        samples
138178    } 
@@ -142,6 +182,33 @@ where
142182mod  test { 
143183    use  super :: * ; 
144184
185+     // 
186+     // Check that the means of the components of n samples from 
187+     // the Dirichlet distribution agree with the expected means 
188+     // with a relative tolerance of rtol. 
189+     // 
190+     // This is a crude statistical test, but it will catch egregious 
191+     // mistakes.  It will also also fail if any samples contain nan. 
192+     // 
193+     fn  check_dirichlet_means ( alpha :  & Vec < f64 > ,  n :  i32 ,  rtol :  f64 ,  seed :  u64 )  { 
194+         let  d = Dirichlet :: new ( & alpha) . unwrap ( ) ; 
195+         let  alpha_len = d. alpha . len ( ) ; 
196+         let  mut  rng = crate :: test:: rng ( seed) ; 
197+         let  mut  sums = vec ! [ 0.0 ;  alpha_len] ; 
198+         for  _ in  0 ..n { 
199+             let  samples = d. sample ( & mut  rng) ; 
200+             for  i in  0 ..alpha_len { 
201+                 sums[ i]  += samples[ i] ; 
202+             } 
203+         } 
204+         let  sample_mean:  Vec < f64 >  = sums. iter ( ) . map ( |x| x / n as  f64 ) . collect ( ) ; 
205+         let  alpha_sum:  f64  = d. alpha . iter ( ) . sum ( ) ; 
206+         let  expected_mean:  Vec < f64 >  = d. alpha . iter ( ) . map ( |x| x / alpha_sum) . collect ( ) ; 
207+         for  i in  0 ..alpha_len { 
208+             assert_almost_eq ! ( sample_mean[ i] ,  expected_mean[ i] ,  rtol) ; 
209+         } 
210+     } 
211+ 
145212    #[ test]  
146213    fn  test_dirichlet ( )  { 
147214        let  d = Dirichlet :: new ( & [ 1.0 ,  2.0 ,  3.0 ] ) . unwrap ( ) ; 
@@ -172,6 +239,48 @@ mod test {
172239            . collect ( ) ; 
173240    } 
174241
242+     #[ test]  
243+     fn  test_dirichlet_means ( )  { 
244+         // Check the means of 20000 samples for several different alphas. 
245+         let  alpha_set = vec ! [ 
246+             vec![ 0.5 ,  0.25 ] , 
247+             vec![ 123.0 ,  75.0 ] , 
248+             vec![ 2.0 ,  2.5 ,  5.0 ,  7.0 ] , 
249+             vec![ 0.1 ,  8.0 ,  1.0 ,  2.0 ,  2.0 ,  0.85 ,  0.05 ,  12.5 ] , 
250+         ] ; 
251+         let  n = 20000 ; 
252+         let  rtol = 2e-2 ; 
253+         let  seed = 1317624576693539401 ; 
254+         for  alpha in  alpha_set { 
255+             check_dirichlet_means ( & alpha,  n,  rtol,  seed) ; 
256+         } 
257+     } 
258+ 
259+     #[ test]  
260+     fn  test_dirichlet_means_very_small_alpha ( )  { 
261+         // With values of alpha that are all 0.001, check that the means of the 
262+         // components of 10000 samples are within 1% of the expected means. 
263+         // With the sampling method based on gamma variates, this test would 
264+         // fail, with about 10% of the samples containing nan. 
265+         let  alpha = vec ! [ 0.001 ,  0.001 ,  0.001 ] ; 
266+         let  n = 10000 ; 
267+         let  rtol = 1e-2 ; 
268+         let  seed = 1317624576693539401 ; 
269+         check_dirichlet_means ( & alpha,  n,  rtol,  seed) ; 
270+     } 
271+ 
272+     #[ test]  
273+     fn  test_dirichlet_means_small_alpha ( )  { 
274+         // With values of alpha that are all less than 0.1, check that the 
275+         // means of the components of 150000 samples are within 0.1% of the 
276+         // expected means. 
277+         let  alpha = vec ! [ 0.05 ,  0.025 ,  0.075 ,  0.05 ] ; 
278+         let  n = 150000 ; 
279+         let  rtol = 1e-3 ; 
280+         let  seed = 1317624576693539401 ; 
281+         check_dirichlet_means ( & alpha,  n,  rtol,  seed) ; 
282+     } 
283+ 
175284    #[ test]  
176285    #[ should_panic]  
177286    fn  test_dirichlet_invalid_length ( )  { 
0 commit comments