Skip to content

Commit dd59f8b

Browse files
authored
update histogram histogram_bin_edge (#69750)
1 parent e7f09be commit dd59f8b

File tree

10 files changed

+56
-26
lines changed

10 files changed

+56
-26
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,8 +1088,8 @@ bool HistogramOpInferSymbolicShape(
10881088
const symbol::ShapeOrDataDimExprs &input_shape_or_data =
10891089
infer_context->GetShapeOrDataForValue(op->operand_source(0));
10901090
int64_t bins = op->attribute<pir::Int64Attribute>("bins").data();
1091-
int min = op->attribute<pir::Int32Attribute>("min").data();
1092-
int max = op->attribute<pir::Int32Attribute>("max").data();
1091+
float min = op->attribute<pir::FloatAttribute>("min").data();
1092+
float max = op->attribute<pir::FloatAttribute>("max").data();
10931093
PADDLE_ENFORCE_GE(bins,
10941094
1,
10951095
common::errors::InvalidArgument(
@@ -1100,7 +1100,7 @@ bool HistogramOpInferSymbolicShape(
11001100
max,
11011101
min,
11021102
common::errors::InvalidArgument("max must be larger or equal to min."
1103-
"But received max is %d, min is %d",
1103+
"But received max is %f, min is %f",
11041104
max,
11051105
min));
11061106
if (op->operand_source(1)) {

paddle/phi/infermeta/binary.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2257,8 +2257,8 @@ void HingeLossInferMeta(const MetaTensor& logits,
22572257
void HistogramInferMeta(const MetaTensor& input,
22582258
const MetaTensor& weight,
22592259
int64_t bins,
2260-
int min,
2261-
int max,
2260+
float min,
2261+
float max,
22622262
bool density,
22632263
MetaTensor* out) {
22642264
PADDLE_ENFORCE_GE(bins,
@@ -2271,7 +2271,7 @@ void HistogramInferMeta(const MetaTensor& input,
22712271
max,
22722272
min,
22732273
common::errors::InvalidArgument("max must be larger or equal to min."
2274-
"But received max is %d, min is %d",
2274+
"But received max is %f, min is %f",
22752275
max,
22762276
min));
22772277
if (weight) {

paddle/phi/infermeta/binary.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,8 @@ void HingeLossInferMeta(const MetaTensor& logits,
418418
void HistogramInferMeta(const MetaTensor& input,
419419
const MetaTensor& weight,
420420
int64_t bins,
421-
int min,
422-
int max,
421+
float min,
422+
float max,
423423
bool density,
424424
MetaTensor* out);
425425

paddle/phi/kernels/cpu/histogram_kernel.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ void HistogramKernel(const Context& dev_ctx,
2626
const DenseTensor& input,
2727
const paddle::optional<DenseTensor>& weight,
2828
int64_t bins,
29-
int min,
30-
int max,
29+
float min,
30+
float max,
3131
bool density,
3232
DenseTensor* output) {
3333
auto& nbins = bins;

paddle/phi/kernels/gpu/histogram_kernel.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ void HistogramKernel(const Context& dev_ctx,
140140
const DenseTensor& input,
141141
const paddle::optional<DenseTensor>& weight,
142142
int64_t bins,
143-
int min,
144-
int max,
143+
float min,
144+
float max,
145145
bool density,
146146
DenseTensor* output) {
147147
auto& nbins = bins;

paddle/phi/kernels/histogram_kernel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ void HistogramKernel(const Context& dev_ctx,
2222
const DenseTensor& input,
2323
const paddle::optional<DenseTensor>& weight,
2424
int64_t bins,
25-
int min,
26-
int max,
25+
float min,
26+
float max,
2727
bool density,
2828
DenseTensor* output);
2929

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2526,7 +2526,7 @@
25262526
backward: hinge_loss_grad
25272527

25282528
- op : histogram
2529-
args : (Tensor input, Tensor weight, int64_t bins = 100, int min = 0, int max = 0, bool density = false)
2529+
args : (Tensor input, Tensor weight, int64_t bins = 100, float min = 0.0, float max = 0.0, bool density = false)
25302530
output : Tensor(out)
25312531
infer_meta :
25322532
func : HistogramInferMeta

python/paddle/tensor/linalg.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,8 +2519,8 @@ def bmm(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
25192519
def histogram(
25202520
input: Tensor,
25212521
bins: int = 100,
2522-
min: int = 0,
2523-
max: int = 0,
2522+
min: float = 0.0,
2523+
max: float = 0.0,
25242524
weight: Tensor | None = None,
25252525
density: bool = False,
25262526
name: str | None = None,
@@ -2533,8 +2533,8 @@ def histogram(
25332533
input (Tensor): A Tensor with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
25342534
should be float32, float64, int32, int64.
25352535
bins (int, optional): number of histogram bins. Default: 100.
2536-
min (int, optional): lower end of the range (inclusive). Default: 0.
2537-
max (int, optional): upper end of the range (inclusive). Default: 0.
2536+
min (float, optional): lower end of the range (inclusive). Default: 0.0.
2537+
max (float, optional): upper end of the range (inclusive). Default: 0.0.
25382538
weight (Tensor, optional): If provided, it must have the same shape as input. Each value in input contributes its associated
25392539
weight towards the bin count (instead of 1). Default: None.
25402540
density (bool, optional): If False, the result will contain the count (or total weight) in each bin. If True, the result is the
@@ -2555,6 +2555,11 @@ def histogram(
25552555
Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True,
25562556
[0, 2, 1, 0])
25572557
"""
2558+
if isinstance(min, int):
2559+
min = float(min)
2560+
if isinstance(max, int):
2561+
max = float(max)
2562+
25582563
if in_dynamic_or_pir_mode():
25592564
return _C_ops.histogram(input, weight, bins, min, max, density)
25602565
else:
@@ -2596,8 +2601,8 @@ def histogram(
25962601
def histogram_bin_edges(
25972602
input: Tensor,
25982603
bins: int = 100,
2599-
min: int = 0,
2600-
max: int = 0,
2604+
min: float = 0.0,
2605+
max: float = 0.0,
26012606
name: str | None = None,
26022607
) -> Tensor:
26032608
"""
@@ -2607,8 +2612,8 @@ def histogram_bin_edges(
26072612
Args:
26082613
input (Tensor): The data type of the input Tensor should be float32, float64, int32, int64.
26092614
bins (int, optional): number of histogram bins.
2610-
min (int, optional): lower end of the range (inclusive). Default: 0.
2611-
max (int, optional): upper end of the range (inclusive). Default: 0.
2615+
min (float, optional): lower end of the range (inclusive). Default: 0.0.
2616+
max (float, optional): upper end of the range (inclusive). Default: 0.0.
26122617
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
26132618
26142619
Returns:
@@ -2625,6 +2630,11 @@ def histogram_bin_edges(
26252630
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
26262631
[0. , 0.75000000, 1.50000000, 2.25000000, 3. ])
26272632
"""
2633+
if isinstance(min, int):
2634+
min = float(min)
2635+
if isinstance(max, int):
2636+
max = float(max)
2637+
26282638
check_type(input, 'input', (Variable), 'histogram_bin_edges')
26292639
check_dtype(
26302640
input.dtype,
@@ -2633,13 +2643,13 @@ def histogram_bin_edges(
26332643
'histogram_bin_edges',
26342644
)
26352645
check_type(bins, 'bins', int, 'histogram_bin_edges')
2636-
if max == 0 and min == 0:
2646+
if max == 0.0 and min == 0.0:
26372647
min = paddle.min(input)
26382648
max = paddle.max(input)
26392649
else:
26402650
if max < min:
26412651
raise ValueError("max must be larger than min in range parameter")
2642-
if (min - max) == 0:
2652+
if (min - max) == 0.0:
26432653
max = max + 0.5
26442654
min = min - 0.5
26452655
return paddle.linspace(min, max, bins + 1, name=name)

test/legacy_test/test_histogram_bin_edges_op.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,15 @@ def setUp(self):
6262
)
6363

6464

65+
class TestHistogramBinEdgesOpTest2(TestHistogramBinEdgesOp):
66+
def setUp(self):
67+
self.x = np.random.randn(5, 20).astype('float32')
68+
self.bin = 10
69+
self.range = (0.2, 0.9)
70+
self.out = np.histogram_bin_edges(
71+
self.x, bins=self.bin, range=self.range
72+
)
73+
74+
6575
if __name__ == "__main__":
6676
unittest.main()

test/legacy_test/test_histogram_op.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def net_func():
108108
)
109109
paddle.histogram(input=input_value, bins=1, min=-np.inf, max=5)
110110

111-
with self.assertRaises(TypeError):
111+
with self.assertRaises(ValueError):
112112
self.run_network(net_func)
113113

114114
def test_input_range_error(self):
@@ -302,6 +302,16 @@ def init_test_case(self):
302302
self.is_weight = False
303303

304304

305+
class TestHistogramOpAPIWithFloatminMax(TestHistogram):
306+
def init_test_case(self):
307+
self.in_shape = (10, 12)
308+
self.bins = 4
309+
self.min = 2.2
310+
self.max = 4.5
311+
self.density = False
312+
self.is_weight = False
313+
314+
305315
if __name__ == "__main__":
306316
paddle.enable_static()
307317
unittest.main()

0 commit comments

Comments
 (0)