@@ -283,7 +283,7 @@ def sample(self, shape=()):
283283            shape (Sequence[int], optional): Prepended shape of the generated samples. 
284284
285285        Returns: 
286-             Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the global default dtype . 
286+             Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the same as `self.loc` . 
287287        """ 
288288        with  paddle .no_grad ():
289289            return  self .rsample (shape )
@@ -295,7 +295,7 @@ def rsample(self, shape=()):
295295            shape (Sequence[int], optional): Prepended shape of the generated samples. 
296296
297297        Returns: 
298-             Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the global default dtype . 
298+             Tensor, Sampled data with shape `sample_shape` + `batch_shape` + `event_shape`. The data type is the same as `self.loc` . 
299299        """ 
300300        if  not  isinstance (shape , Sequence ):
301301            raise  TypeError ('sample shape must be Sequence object.' )
@@ -312,7 +312,7 @@ def log_prob(self, value):
312312          value (Tensor): The input tensor. 
313313
314314        Returns: 
315-           Tensor: log probability. The data type is same with :attr:`value`  . 
315+           Tensor: log probability. The data type is the  same as `self.loc` . 
316316        """ 
317317        value  =  paddle .cast (value , dtype = self .dtype )
318318
@@ -335,7 +335,7 @@ def prob(self, value):
335335            value (Tensor): The input tensor. 
336336
337337        Returns: 
338-             Tensor: probability. The data type is same with :attr:`value`  . 
338+             Tensor: probability. The data type is the  same as `self.loc` . 
339339        """ 
340340        return  paddle .exp (self .log_prob (value ))
341341
@@ -353,7 +353,7 @@ def entropy(self):
353353        * :math:\Omega: is the support of the distribution. 
354354
355355        Returns: 
356-             Tensor, Shannon entropy of Multivariate Normal distribution. The data type is the global default dtype . 
356+             Tensor, Shannon entropy of Multivariate Normal distribution. The data type is the same as `self.loc` . 
357357        """ 
358358        half_log_det  =  (
359359            self ._unbroadcasted_scale_tril .diagonal (axis1 = - 2 , axis2 = - 1 )
@@ -382,7 +382,7 @@ def kl_divergence(self, other):
382382            other (MultivariateNormal): instance of Multivariate Normal. 
383383
384384        Returns: 
385-             Tensor, kl-divergence between two Multivariate Normal distributions. The data type is the global default dtype . 
385+             Tensor, kl-divergence between two Multivariate Normal distributions. The data type is the same as `self.loc` . 
386386
387387        """ 
388388        if  (
0 commit comments