Skip to content

Commit 32baca9

Browse files
authored
Case7:paddle.distribution.Beta:fix beta(true stack) (#51847)
1 parent 65c6d2e commit 32baca9

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

paddle/phi/kernels/cpu/stack_kernel.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ void StackKernel(const Context& dev_ctx,
2525
int axis,
2626
DenseTensor* out) {
2727
if (axis < 0) axis += (x[0]->dims().size() + 1);
28+
29+
auto x_dims = x[0]->dims();
30+
for (int i = 0; i < x_dims.size(); i++) {
31+
PADDLE_ENFORCE_GT(x_dims[i],
32+
0,
33+
phi::errors::InvalidArgument(
34+
"The dims of Input(X) should be greater than 0"));
35+
}
36+
2837
int n = static_cast<int>(x.size());
2938
T* y_data = dev_ctx.template Alloc<T>(out);
3039
std::vector<const T*> x_datas(n);

paddle/phi/kernels/funcs/stack_and_unstack.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,12 @@ void StackRawKernel(const Context& ctx,
7777

7878
// Split x dim from axis to matrix of shape [x_row, x_col], and the output
7979
// tensor's shape is [x_row, out_col].
80-
int64_t x_row = 1;
80+
int64_t x_row = 1, x_row_bak = 1;
8181
for (int i = 0; i < axis; ++i) {
8282
x_row *= x[0]->dims()[i];
8383
}
84-
int64_t x_col = x[0]->numel() / x_row;
84+
x_row_bak = x_row == 0 ? 1 : x_row;
85+
int64_t x_col = x[0]->numel() / x_row_bak;
8586
int64_t out_col = x_col * num;
8687

8788
if (out->numel() < std::numeric_limits<int32_t>::max()) {

python/paddle/fluid/tests/unittests/distribution/test_distribution_beta.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def test_sample_shape(self):
113113
== case.get('expect')
114114
)
115115

116+
def test_errors(self):
117+
with self.assertRaises(ValueError):
118+
array = np.array([], dtype=np.float32)
119+
x = paddle.to_tensor(np.reshape(array, [0]), dtype='int32')
120+
paddle.distribution.Beta(alpha=x, beta=x)
121+
116122

117123
if __name__ == '__main__':
118124
unittest.main()

0 commit comments

Comments
 (0)