@@ -30,7 +30,7 @@ class Normal(distribution.Distribution):
3030
3131 Mathematical details
3232
33- The probability density function (pdf) is
33+ If 'loc' is real number, the probability density function (pdf) is
3434
3535 .. math::
3636
@@ -40,14 +40,24 @@ class Normal(distribution.Distribution):
4040
4141 Z = (2 \pi \sigma^2)^{0.5}
4242
43- In the above equation:
43+ If 'loc' is complex number, the probability density function (pdf) is
44+
45+ .. math::
46+
47+ pdf(x; \mu, \sigma) = \frac{1}{Z}e^{\frac {-(x - \mu)^2} {\sigma^2} }
48+
49+ .. math::
50+
51+ Z = \pi \sigma^2
52+
53+ In the above equations:
4454
4555 * :math:`loc = \mu`: is the mean.
4656 * :math:`scale = \sigma`: is the std.
4757 * :math:`Z`: is the normalization constant.
4858
4959 Args:
50- loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32 and float64 .
60+ loc(int|float|complex| list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is float32, float64, complex64 and complex128 .
5161 scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is float32 and float64.
5262 name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
5363
@@ -102,6 +112,7 @@ def __init__(self, loc, scale, name=None):
102112 (
103113 int ,
104114 float ,
115+ complex ,
105116 np .ndarray ,
106117 Variable ,
107118 paddle .pir .Value ,
@@ -128,33 +139,82 @@ def __init__(self, loc, scale, name=None):
128139 self .all_arg_is_float = False
129140 self .name = name if name is not None else 'Normal'
130141 self .dtype = 'float32'
142+ self ._complex_gaussian = False
131143
132144 if isinstance (loc , int ):
133145 loc = float (loc )
134146 if isinstance (scale , int ):
135147 scale = float (scale )
136148
137- if self ._validate_args (loc , scale ):
138- self .loc = loc
139- self .scale = scale
140- self .dtype = convert_dtype (loc .dtype )
141- else :
142- if isinstance (loc , float ) and isinstance (scale , float ):
149+ if isinstance (loc , (tuple , list )):
150+ loc = np .array (loc )
151+ if loc .dtype == np .float64 :
152+ loc = loc .astype ('float32' )
153+ if loc .dtype == np .complex128 :
154+ loc = loc .astype ('complex64' )
155+
156+ if isinstance (scale , (tuple , list )):
157+ scale = np .array (scale , dtype = np .float32 )
158+
159+ if (
160+ isinstance (loc , complex )
161+ or (
162+ isinstance (loc , np .ndarray )
163+ and loc .dtype in [np .complex64 , np .complex128 ]
164+ )
165+ or (self ._validate_args (loc ) and loc .is_complex ())
166+ ):
167+ self ._complex_gaussian = True
168+ if isinstance (loc , complex ) and isinstance (scale , float ):
143169 self .all_arg_is_float = True
144- if isinstance (loc , np .ndarray ) and str (loc .dtype ) in [
145- 'float32' ,
146- 'float64' ,
147- ]:
148- self .dtype = loc .dtype
149- elif isinstance (scale , np .ndarray ) and str (scale .dtype ) in [
150- 'float32' ,
151- 'float64' ,
152- ]:
153- self .dtype = scale .dtype
154- self .loc , self .scale = self ._to_tensor (loc , scale )
155- if self .dtype != convert_dtype (self .loc .dtype ):
156- self .loc = paddle .cast (self .loc , dtype = self .dtype )
157- self .scale = paddle .cast (self .scale , dtype = self .dtype )
170+
171+ if isinstance (loc , np .ndarray ):
172+ real_dtype = (
173+ 'float32' if loc .dtype == np .complex64 else 'float64'
174+ )
175+ imag_dtype = (
176+ 'float32' if loc .dtype == np .complex64 else 'float64'
177+ )
178+ real = paddle .to_tensor (loc .real , real_dtype )
179+ imag = paddle .to_tensor (loc .imag , imag_dtype )
180+ self .loc = paddle .complex (real , imag )
181+ elif isinstance (loc , complex ):
182+ real = paddle .to_tensor (loc .real , dtype = 'float32' )
183+ imag = paddle .to_tensor (loc .imag , dtype = 'float32' )
184+ self .loc = paddle .complex (real , imag )
185+ else :
186+ self .loc = loc
187+
188+ if isinstance (scale , np .ndarray ):
189+ self .scale = paddle .to_tensor (scale , dtype = scale .dtype )
190+ elif isinstance (scale , float ):
191+ self .scale = paddle .to_tensor (scale , dtype = 'float32' )
192+ else :
193+ self .scale = scale
194+
195+ self .dtype = convert_dtype (self .loc .dtype )
196+ else :
197+ if self ._validate_args (loc , scale ):
198+ self .loc = loc
199+ self .scale = scale
200+ self .dtype = convert_dtype (loc .dtype )
201+ else :
202+ if isinstance (loc , float ) and isinstance (scale , float ):
203+ self .all_arg_is_float = True
204+ if isinstance (loc , np .ndarray ) and str (loc .dtype ) in [
205+ 'float32' ,
206+ 'float64' ,
207+ ]:
208+ self .dtype = loc .dtype
209+ elif isinstance (scale , np .ndarray ) and str (scale .dtype ) in [
210+ 'float32' ,
211+ 'float64' ,
212+ ]:
213+ self .dtype = scale .dtype
214+ self .loc , self .scale = self ._to_tensor (loc , scale )
215+ if self .dtype != convert_dtype (self .loc .dtype ):
216+ self .loc = paddle .cast (self .loc , dtype = self .dtype )
217+ self .scale = paddle .cast (self .scale , dtype = self .dtype )
158218 super ().__init__ (self .loc .shape )
159219
160220 @property
@@ -204,15 +264,23 @@ def sample(self, shape=(), seed=0):
204264
205265 zero_tmp_shape = paddle .shape (zero_tmp_reshape )
206266 normal_random_tmp = random .gaussian (
207- zero_tmp_shape , mean = 0.0 , std = 1.0 , seed = seed , dtype = self .dtype
267+ zero_tmp_shape ,
268+ mean = (0.0 + 0.0j ) if self ._complex_gaussian else 0.0 ,
269+ std = 1.0 ,
270+ seed = seed ,
271+ dtype = self .dtype ,
208272 )
209273 output = normal_random_tmp * (zero_tmp_reshape + self .scale )
210274 output = paddle .add (output , self .loc , name = name )
211275 return output
212276 else :
213277 output_shape = shape + batch_shape
214278 output = random .gaussian (
215- output_shape , mean = 0.0 , std = 1.0 , seed = seed , dtype = self .dtype
279+ output_shape ,
280+ mean = (0.0 + 0.0j ) if self ._complex_gaussian else 0.0 ,
281+ std = 1.0 ,
282+ seed = seed ,
283+ dtype = self .dtype ,
216284 ) * (paddle .zeros (output_shape , dtype = self .dtype ) + self .scale )
217285 output = paddle .add (output , self .loc , name = name )
218286 if self .all_arg_is_float :
@@ -234,18 +302,26 @@ def rsample(self, shape=()):
234302 raise TypeError ('sample shape must be Iterable object.' )
235303
236304 shape = self ._extend_shape (tuple (shape ))
237- eps = paddle .normal (shape = shape )
305+ eps = paddle .normal (
306+ mean = (0.0 + 0.0j ) if self ._complex_gaussian else 0.0 , shape = shape
307+ )
238308 return self .loc + eps * self .scale
239309
240310 def entropy (self ):
241311 r"""Shannon entropy in nats.
242312
243- The entropy is
313+ If non-complex, the entropy is
244314
245315 .. math::
246316
247317 entropy(\sigma) = 0.5 \log (2 \pi e \sigma^2)
248318
319+ If complex gaussian, the entropy is
320+
321+ .. math::
322+
323+ entropy(\sigma) = \log (\pi e \sigma^2) + 1
324+
249325 In the above equation:
250326
251327 * :math:`scale = \sigma`: is the std.
@@ -256,18 +332,33 @@ def entropy(self):
256332 """
257333 name = self .name + '_entropy'
258334 batch_shape = list ((self .loc + self .scale ).shape )
259- if - 1 in batch_shape :
260- fill_shape = list (batch_shape )
261- fill_shape [0 ] = paddle .shape (self .loc + self .scale )[0 ].item ()
262- fill_dtype = (self .loc + self .scale ).dtype
263- zero_tmp = paddle .full (fill_shape , 0.0 , fill_dtype )
335+
336+ if self ._complex_gaussian :
337+ if - 1 in batch_shape :
338+ fill_shape = list (batch_shape )
339+ fill_shape [0 ] = paddle .shape (self .loc + self .scale )[0 ].item ()
340+ fill_dtype = self .scale .dtype
341+ zero_tmp = paddle .full (fill_shape , 0.0 , fill_dtype )
342+ else :
343+ zero_tmp = paddle .full (batch_shape , 0.0 , self .scale .dtype )
344+ return paddle .add (
345+ 1.0 + zero_tmp ,
346+ math .log (math .pi ) + 2.0 * paddle .log (self .scale + zero_tmp ),
347+ name = name ,
348+ )
264349 else :
265- zero_tmp = paddle .full (batch_shape , 0.0 , self .dtype )
266- return paddle .add (
267- 0.5 + zero_tmp ,
268- 0.5 * math .log (2 * math .pi ) + paddle .log (self .scale + zero_tmp ),
269- name = name ,
270- )
350+ if - 1 in batch_shape :
351+ fill_shape = list (batch_shape )
352+ fill_shape [0 ] = paddle .shape (self .loc + self .scale )[0 ].item ()
353+ fill_dtype = (self .loc + self .scale ).dtype
354+ zero_tmp = paddle .full (fill_shape , 0.0 , fill_dtype )
355+ else :
356+ zero_tmp = paddle .full (batch_shape , 0.0 , self .dtype )
357+ return paddle .add (
358+ 0.5 + zero_tmp ,
359+ 0.5 * math .log (2 * math .pi ) + paddle .log (self .scale + zero_tmp ),
360+ name = name ,
361+ )
271362
272363 def log_prob (self , value ):
273364 """Log probability density/mass function.
@@ -284,11 +375,18 @@ def log_prob(self, value):
284375
285376 var = self .scale * self .scale
286377 log_scale = paddle .log (self .scale )
287- return paddle .subtract (
288- - 1.0 * ((value - self .loc ) * (value - self .loc )) / (2.0 * var ),
289- log_scale + math .log (math .sqrt (2.0 * math .pi )),
290- name = name ,
291- )
378+ if self ._complex_gaussian :
379+ return paddle .subtract (
380+ - 1.0 * ((value - self .loc ).conj () * (value - self .loc )) / (var ),
381+ 2.0 * log_scale + math .log (math .pi ),
382+ name = name ,
383+ )
384+ else :
385+ return paddle .subtract (
386+ - 1.0 * ((value - self .loc ) * (value - self .loc )) / (2.0 * var ),
387+ log_scale + math .log (math .sqrt (2.0 * math .pi )),
388+ name = name ,
389+ )
292390
293391 def probs (self , value ):
294392 """Probability density/mass function.
@@ -304,23 +402,42 @@ def probs(self, value):
304402 value = self ._check_values_dtype_in_probs (self .loc , value )
305403
306404 var = self .scale * self .scale
307- return paddle .divide (
308- paddle .exp (
309- - 1.0 * ((value - self .loc ) * (value - self .loc )) / (2.0 * var )
310- ),
311- (math .sqrt (2 * math .pi ) * self .scale ),
312- name = name ,
313- )
405+ if self ._complex_gaussian :
406+ return paddle .divide (
407+ paddle .exp (
408+ - 1.0
409+ * ((value - self .loc ).conj () * (value - self .loc ))
410+ / (var )
411+ ),
412+ (math .pi * var ),
413+ name = name ,
414+ )
415+ else :
416+ return paddle .divide (
417+ paddle .exp (
418+ - 1.0
419+ * ((value - self .loc ) * (value - self .loc ))
420+ / (2.0 * var )
421+ ),
422+ (math .sqrt (2 * math .pi ) * self .scale ),
423+ name = name ,
424+ )
314425
315426 def kl_divergence (self , other ):
316427 r"""The KL-divergence between two normal distributions.
317428
318- The probability density function (pdf) is
429+ If non-complex, the KL-divergence is
319430
320431 .. math::
321432
322433 KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio})
323434
435+ If complex gaussian:
436+
437+ .. math::
438+
439+ KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio}
440+
324441 .. math::
325442
326443 ratio = \frac{\sigma_0}{\sigma_1}
@@ -348,11 +465,21 @@ def kl_divergence(self, other):
348465 if not in_dynamic_mode ():
349466 check_type (other , 'other' , Normal , 'kl_divergence' )
350467
468+ if self ._complex_gaussian != other ._complex_gaussian :
469+ raise ValueError (
470+ "The kl divergence must be computed between two distributions in the same number field."
471+ )
351472 name = self .name + '_kl_divergence'
352473 var_ratio = self .scale / other .scale
353474 var_ratio = var_ratio * var_ratio
354475 t1 = (self .loc - other .loc ) / other .scale
355- t1 = t1 * t1
356- return paddle .add (
357- 0.5 * var_ratio , 0.5 * (t1 - 1.0 - paddle .log (var_ratio )), name = name
358- )
476+ if self ._complex_gaussian :
477+ t1 = t1 .conj () * t1
478+ return var_ratio + t1 - 1.0 - paddle .log (var_ratio )
479+ else :
480+ t1 = t1 * t1
481+ return paddle .add (
482+ 0.5 * var_ratio ,
483+ 0.5 * (t1 - 1.0 - paddle .log (var_ratio )),
484+ name = name ,
485+ )
0 commit comments