Skip to content

Commit c5a96ee

Browse files
authored
paddle.distribution.StudentT improve input data type and fix returned dimension 易用性提升 (#68895)
1 parent fbfb46d commit c5a96ee

File tree

3 files changed

+110
-18
lines changed

3 files changed

+110
-18
lines changed

python/paddle/distribution/distribution.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,44 @@ def _logits_to_probs(
306306
if is_binary
307307
else paddle.nn.functional.softmax(logits, axis=-1)
308308
)
309+
310+
def _broadcast_all(
311+
self, *args: TensorLike | NestedNumbericSequence
312+
) -> tuple[Tensor, ...]:
313+
r"""
314+
Returns a list where each arg is broadcasted. Scalar args are upcast to tensors
315+
having the same data type as the first Tensor passed to `args`. If all the
316+
args are scalars, then they are upcasted to Tensors with paddle default data type.
317+
318+
Args:
319+
value (float, list, numpy.ndarray, Tensor)
320+
321+
Returns:
322+
Broadcasted Tensor of args.
323+
"""
324+
for arg in args:
325+
if not isinstance(
326+
arg,
327+
(float, list, tuple, np.ndarray, Variable, paddle.pir.Value),
328+
):
329+
raise TypeError(
330+
f"Type of input args must be float, list, tuple, numpy.ndarray or Tensor, but received type {type(arg)}"
331+
)
332+
if not all(
333+
isinstance(arg, (Variable, paddle.pir.Value)) for arg in args
334+
):
335+
dtype = paddle.get_default_dtype()
336+
for arg in args:
337+
if isinstance(arg, (Variable, paddle.pir.Value)):
338+
dtype = arg.dtype
339+
break
340+
new_args = [
341+
(
342+
arg
343+
if isinstance(arg, (Variable, paddle.pir.Value))
344+
else paddle.to_tensor(arg, dtype=dtype)
345+
)
346+
for arg in args
347+
]
348+
return paddle.broadcast_tensors(new_args)
349+
return paddle.broadcast_tensors(args)

python/paddle/distribution/student_t.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from typing import TYPE_CHECKING
1919

2020
import paddle
21-
from paddle.base.data_feeder import check_type, convert_dtype
21+
from paddle.base.data_feeder import check_type
2222
from paddle.base.framework import Variable
2323
from paddle.distribution import Gamma, distribution
2424
from paddle.framework import in_dynamic_mode
@@ -135,18 +135,7 @@ def __init__(
135135
)
136136

137137
self.name = name if name is not None else 'StudentT'
138-
self.dtype = paddle.get_default_dtype()
139-
140-
if self._validate_args(df, loc, scale):
141-
self.df = df
142-
self.loc = loc
143-
self.scale = scale
144-
self.df, self.loc, self.scale = paddle.broadcast_tensors(
145-
[self.df, self.loc, self.scale]
146-
)
147-
self.dtype = convert_dtype(df.dtype)
148-
else:
149-
self.df, self.loc, self.scale = self._to_tensor(df, loc, scale)
138+
self.df, self.loc, self.scale = self._broadcast_all(df, loc, scale)
150139

151140
if not self._check_nonnegative(self.df):
152141
raise ValueError(
@@ -157,10 +146,6 @@ def __init__(
157146
'Every element of input parameter `scale` should be nonnegative.'
158147
)
159148

160-
if self.df.shape == []:
161-
self.df = self.df.reshape([1])
162-
self.loc = self.loc.reshape([1])
163-
self.scale = self.scale.reshape([1])
164149
batch_shape = self.df.shape
165150
super().__init__(batch_shape)
166151
self._chi2 = Gamma(0.5 * self.df, paddle.full_like(self.df, 0.5))
@@ -222,7 +207,7 @@ def sample(self, shape: Sequence[int] = ()) -> Tensor:
222207
raise TypeError('sample shape must be Sequence object.')
223208

224209
output_shape = self._extend_shape(shape)
225-
z = paddle.cast(paddle.normal(shape=output_shape), self.dtype)
210+
z = paddle.normal(shape=output_shape)
226211
chi2 = self._chi2.sample(shape)
227212
x = z * paddle.rsqrt(chi2 / self.df)
228213
return self.loc + self.scale * x

test/distribution/test_distribution_student_t.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,47 @@ def _np_entropy(self):
187187
return scipy.stats.t.entropy(df, loc, scale)
188188

189189

190+
@parameterize.place(config.DEVICES)
191+
@parameterize.parameterize_cls(
192+
(parameterize.TEST_CASE_NAME, 'df', 'loc', 'scale'),
193+
[
194+
(
195+
'float-tensor',
196+
10.0,
197+
paddle.to_tensor(1.0),
198+
2.0,
199+
),
200+
(
201+
'float-tensor1',
202+
10.0,
203+
parameterize.xrand((2, 3), dtype='float32', min=1, max=10),
204+
2.0,
205+
),
206+
(
207+
'float-tensor2',
208+
parameterize.xrand((2, 1), dtype='float64', min=4, max=30),
209+
parameterize.xrand((2, 3), dtype='float64', min=1, max=10),
210+
2.0,
211+
),
212+
(
213+
'float-tensor3',
214+
parameterize.xrand((2, 1), dtype='float64', min=4, max=30),
215+
1.0,
216+
parameterize.xrand((2, 1), dtype='float64', min=0.1, max=3),
217+
),
218+
(
219+
'float-tensor4',
220+
5.0,
221+
parameterize.xrand((2, 1), dtype='float32', min=-1, max=-10),
222+
parameterize.xrand((2, 3), dtype='float32', min=0.1, max=3),
223+
),
224+
],
225+
)
226+
class TestStudentT2(TestStudentT):
227+
def setUp(self):
228+
self._dist = StudentT(self.df, self.loc, self.scale)
229+
230+
190231
@parameterize.place(config.DEVICES)
191232
@parameterize.parameterize_cls(
192233
(parameterize.TEST_CASE_NAME, 'df', 'loc', 'scale', 'value'),
@@ -247,6 +288,31 @@ def test_log_prob(self):
247288
)
248289

249290

291+
@parameterize.place(config.DEVICES)
292+
@parameterize.parameterize_cls(
293+
(parameterize.TEST_CASE_NAME, 'df', 'loc', 'scale', 'value'),
294+
[
295+
(
296+
'float-tensor1',
297+
10.0,
298+
parameterize.xrand((2, 1), dtype='float32', min=-10, max=10),
299+
1.0,
300+
np.array(3.3).astype("float32"),
301+
),
302+
(
303+
'float-tensor2',
304+
parameterize.xrand((2, 1), dtype='float64', min=4, max=30),
305+
1.0,
306+
parameterize.xrand((2, 1), dtype='float64', min=0.1, max=5),
307+
parameterize.xrand((2, 4), dtype='float64', min=-10, max=10),
308+
),
309+
],
310+
)
311+
class TestStudentTProbs2(TestStudentTProbs):
312+
def setUp(self):
313+
self._dist = StudentT(self.df, self.loc, self.scale)
314+
315+
250316
@parameterize.place(config.DEVICES)
251317
@parameterize_cls([TEST_CASE_NAME], ['StudentTTestError'])
252318
class StudentTTestError(unittest.TestCase):

0 commit comments

Comments
 (0)