Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1088,8 +1088,8 @@ bool HistogramOpInferSymbolicShape(
const symbol::ShapeOrDataDimExprs &input_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
int64_t bins = op->attribute<pir::Int64Attribute>("bins").data();
int min = op->attribute<pir::Int32Attribute>("min").data();
int max = op->attribute<pir::Int32Attribute>("max").data();
float min = op->attribute<pir::FloatAttribute>("min").data();
float max = op->attribute<pir::FloatAttribute>("max").data();
PADDLE_ENFORCE_GE(bins,
1,
common::errors::InvalidArgument(
Expand All @@ -1100,7 +1100,7 @@ bool HistogramOpInferSymbolicShape(
max,
min,
common::errors::InvalidArgument("max must be larger or equal to min."
"But received max is %d, min is %d",
"But received max is %f, min is %f",
max,
min));
if (op->operand_source(1)) {
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/infermeta/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2257,8 +2257,8 @@ void HingeLossInferMeta(const MetaTensor& logits,
void HistogramInferMeta(const MetaTensor& input,
const MetaTensor& weight,
int64_t bins,
int min,
int max,
float min,
float max,
bool density,
MetaTensor* out) {
PADDLE_ENFORCE_GE(bins,
Expand All @@ -2271,7 +2271,7 @@ void HistogramInferMeta(const MetaTensor& input,
max,
min,
common::errors::InvalidArgument("max must be larger or equal to min."
"But received max is %d, min is %d",
"But received max is %f, min is %f",
max,
min));
if (weight) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/infermeta/binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,8 +418,8 @@ void HingeLossInferMeta(const MetaTensor& logits,
void HistogramInferMeta(const MetaTensor& input,
const MetaTensor& weight,
int64_t bins,
int min,
int max,
float min,
float max,
bool density,
MetaTensor* out);

Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/cpu/histogram_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ void HistogramKernel(const Context& dev_ctx,
const DenseTensor& input,
const paddle::optional<DenseTensor>& weight,
int64_t bins,
int min,
int max,
float min,
float max,
bool density,
DenseTensor* output) {
auto& nbins = bins;
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/gpu/histogram_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ void HistogramKernel(const Context& dev_ctx,
const DenseTensor& input,
const paddle::optional<DenseTensor>& weight,
int64_t bins,
int min,
int max,
float min,
float max,
bool density,
DenseTensor* output) {
auto& nbins = bins;
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/histogram_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ void HistogramKernel(const Context& dev_ctx,
const DenseTensor& input,
const paddle::optional<DenseTensor>& weight,
int64_t bins,
int min,
int max,
float min,
float max,
bool density,
DenseTensor* output);

Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2526,7 +2526,7 @@
backward: hinge_loss_grad

- op : histogram
args : (Tensor input, Tensor weight, int64_t bins = 100, int min = 0, int max = 0, bool density = false)
args : (Tensor input, Tensor weight, int64_t bins = 100, float min = 0.0, float max = 0.0, bool density = false)
output : Tensor(out)
infer_meta :
func : HistogramInferMeta
Expand Down
30 changes: 20 additions & 10 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,8 +2483,8 @@ def bmm(x: Tensor, y: Tensor, name: str | None = None) -> Tensor:
def histogram(
input: Tensor,
bins: int = 100,
min: int = 0,
max: int = 0,
min: float = 0.0,
max: float = 0.0,
weight: Tensor | None = None,
density: bool = False,
name: str | None = None,
Expand All @@ -2497,8 +2497,8 @@ def histogram(
input (Tensor): A Tensor with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
should be float32, float64, int32, int64.
bins (int, optional): number of histogram bins. Default: 100.
min (int, optional): lower end of the range (inclusive). Default: 0.
max (int, optional): upper end of the range (inclusive). Default: 0.
min (float, optional): lower end of the range (inclusive). Default: 0.0.
max (float, optional): upper end of the range (inclusive). Default: 0.0.
weight (Tensor, optional): If provided, it must have the same shape as input. Each value in input contributes its associated
weight towards the bin count (instead of 1). Default: None.
density (bool, optional): If False, the result will contain the count (or total weight) in each bin. If True, the result is the
Expand All @@ -2519,6 +2519,11 @@ def histogram(
Tensor(shape=[4], dtype=int64, place=Place(cpu), stop_gradient=True,
[0, 2, 1, 0])
"""
if isinstance(min, int):
min = float(min)
if isinstance(max, int):
max = float(max)

if in_dynamic_or_pir_mode():
return _C_ops.histogram(input, weight, bins, min, max, density)
else:
Expand Down Expand Up @@ -2560,8 +2565,8 @@ def histogram(
def histogram_bin_edges(
input: Tensor,
bins: int = 100,
min: int = 0,
max: int = 0,
min: float = 0.0,
max: float = 0.0,
name: str | None = None,
) -> Tensor:
"""
Expand All @@ -2571,8 +2576,8 @@ def histogram_bin_edges(
Args:
input (Tensor): The data type of the input Tensor should be float32, float64, int32, int64.
bins (int, optional): number of histogram bins.
min (int, optional): lower end of the range (inclusive). Default: 0.
max (int, optional): upper end of the range (inclusive). Default: 0.
min (float, optional): lower end of the range (inclusive). Default: 0.0.
max (float, optional): upper end of the range (inclusive). Default: 0.0.
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Expand All @@ -2589,6 +2594,11 @@ def histogram_bin_edges(
Tensor(shape=[5], dtype=float32, place=Place(cpu), stop_gradient=True,
[0. , 0.75000000, 1.50000000, 2.25000000, 3. ])
"""
if isinstance(min, int):
min = float(min)
if isinstance(max, int):
max = float(max)

check_type(input, 'input', (Variable), 'histogram_bin_edges')
check_dtype(
input.dtype,
Expand All @@ -2597,13 +2607,13 @@ def histogram_bin_edges(
'histogram_bin_edges',
)
check_type(bins, 'bins', int, 'histogram_bin_edges')
if max == 0 and min == 0:
if max == 0.0 and min == 0.0:
min = paddle.min(input)
max = paddle.max(input)
else:
if max < min:
raise ValueError("max must be larger than min in range parameter")
if (min - max) == 0:
if (min - max) == 0.0:
max = max + 0.5
min = min - 0.5
return paddle.linspace(min, max, bins + 1, name=name)
Expand Down
10 changes: 10 additions & 0 deletions test/legacy_test/test_histogram_bin_edges_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,15 @@ def setUp(self):
)


class TestHistogramBinEdgesOpTest2(TestHistogramBinEdgesOp):
def setUp(self):
self.x = np.random.randn(5, 20).astype('float32')
self.bin = 10
self.range = (0.2, 0.9)
self.out = np.histogram_bin_edges(
self.x, bins=self.bin, range=self.range
)


if __name__ == "__main__":
unittest.main()
12 changes: 11 additions & 1 deletion test/legacy_test/test_histogram_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def net_func():
)
paddle.histogram(input=input_value, bins=1, min=-np.inf, max=5)

with self.assertRaises(TypeError):
with self.assertRaises(ValueError):
self.run_network(net_func)

def test_input_range_error(self):
Expand Down Expand Up @@ -302,6 +302,16 @@ def init_test_case(self):
self.is_weight = False


class TestHistogramOpAPIWithFloatminMax(TestHistogram):
def init_test_case(self):
self.in_shape = (10, 12)
self.bins = 4
self.min = 2.2
self.max = 4.5
self.density = False
self.is_weight = False


if __name__ == "__main__":
paddle.enable_static()
unittest.main()