Skip to content

Commit 401062e

Browse files
committed
fix test
1 parent 74b18ea commit 401062e

File tree

4 files changed

+37
-7
lines changed

4 files changed

+37
-7
lines changed

paddle/fluid/operators/bincount_op.cu

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,29 @@ void BincountCUDAInner(const framework::ExecutionContext& context) {
6363
}
6464
auto input_x = framework::EigenVector<InputT>::Flatten(*input);
6565

66-
framework::Tensor input_max_t;
66+
framework::Tensor input_min_t, input_max_t;
6767
auto* input_max_data =
6868
input_max_t.mutable_data<InputT>({1}, context.GetPlace());
69+
auto* input_min_data =
70+
input_min_t.mutable_data<InputT>({1}, context.GetPlace());
71+
6972
auto input_max_scala = framework::EigenScalar<InputT>::From(input_max_t);
73+
auto input_min_scala = framework::EigenScalar<InputT>::From(input_min_t);
7074

7175
auto* place = context.template device_context<DeviceContext>().eigen_device();
7276
input_max_scala.device(*place) = input_x.maximum();
77+
input_min_scala.device(*place) = input_x.minimum();
7378

74-
Tensor input_max_cpu;
79+
Tensor input_min_cpu, input_max_cpu;
7580
TensorCopySync(input_max_t, platform::CPUPlace(), &input_max_cpu);
81+
TensorCopySync(input_min_t, platform::CPUPlace(), &input_min_cpu);
82+
83+
InputT input_min = input_min_cpu.data<InputT>()[0];
84+
85+
PADDLE_ENFORCE_GE(
86+
input_min, static_cast<InputT>(0),
87+
platform::errors::InvalidArgument(
88+
"The elements in input tensor must be non-negative ints"));
7689

7790
int64_t output_size =
7891
static_cast<int64_t>(input_max_cpu.data<InputT>()[0]) + 1L;

paddle/fluid/operators/bincount_op.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ void BincountInner(const framework::ExecutionContext& context) {
4343
return;
4444
}
4545

46+
PADDLE_ENFORCE_GE(
47+
*std::min_element(input_data, input_data + input_numel),
48+
static_cast<InputT>(0),
49+
platform::errors::InvalidArgument(
50+
"The elements in input tensor must be non-negative ints"));
51+
4652
int64_t output_size = static_cast<int64_t>(*std::max_element(
4753
input_data, input_data + input_numel)) +
4854
1L;

python/paddle/fluid/tests/unittests/test_bincount_op.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_static_graph(self):
4747
'weights': w},
4848
fetch_list=[output])
4949
actual = np.array(res[0])
50-
expected = np.bincount(inputs, weights=weights)
50+
expected = np.bincount(img, weights=w)
5151
self.assertTrue(
5252
(actual == expected).all(),
5353
msg='bincount output is wrong, out =' + str(actual))
@@ -70,6 +70,16 @@ def run_network(self, net_func):
7070
with fluid.dygraph.guard():
7171
net_func()
7272

73+
def test_input_value_error(self):
74+
"""Test input tensor should be non-negative."""
75+
76+
def net_func():
77+
input_value = paddle.to_tensor([1, 2, 3, 4, -5])
78+
paddle.bincount(input_value)
79+
80+
with self.assertRaises(ValueError):
81+
self.run_network(net_func)
82+
7383
def test_input_shape_error(self):
7484
"""Test input tensor should be 1-D tansor."""
7585

@@ -97,7 +107,7 @@ def net_func():
97107
input_value = paddle.to_tensor([1., 2., 3., 4., 5.])
98108
paddle.bincount(input_value)
99109

100-
with self.assertRaises(ValueError):
110+
with self.assertRaises(TypeError):
101111
self.run_network(net_func)
102112

103113
def test_weights_shape_error(self):

python/paddle/tensor/linalg.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,15 +1318,16 @@ def bincount(x, weights=None, minlength=0, name=None):
13181318
result2 = paddle.bincount(x, weights=w)
13191319
print(result2) # [0., 2.19999981, 0.40000001, 0., 0.50000000, 0.50000000]
13201320
"""
1321-
check_variable_and_dtype(x, 'X', ['int32', 'int64'], 'bincount')
1321+
if x.dtype not in [paddle.int32, paddle.int64]:
1322+
raise TypeError("Elements in Input(x) should all be integers")
13221323

1323-
if paddle.min(x) < 0:
1324-
raise ValueError("Elements in Input(x) should all be non-negative")
13251324
if in_dygraph_mode():
13261325
return _C_ops.bincount(x, weights, "minlength", minlength)
13271326

13281327
helper = LayerHelper('bincount', **locals())
13291328

1329+
check_variable_and_dtype(x, 'X', ['int32', 'int64'], 'bincount')
1330+
13301331
if weights is not None:
13311332
check_variable_and_dtype(weights, 'Weights',
13321333
['int32', 'int64', 'float32', 'float64'],

0 commit comments

Comments
 (0)