Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,6 @@ void BindValue(py::module *m) {
"get_defining_op",
[](Value self) -> pir::Operation * { return self.defining_op(); },
return_value_policy::reference)
.def("numel", [](Value self) { return phi::product(GetValueDims(self)); })
.def("type", &Value::type)
.def("index",
[](Value self) -> uint32_t {
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3476,6 +3476,7 @@
data_transform:
skip_transform : x
no_need_buffer : x
traits : paddle::dialect::ForwardOnlyTrait
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : one_hot
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/decomposition/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from __future__ import annotations

import functools
import math
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -696,7 +697,10 @@ def cal_value_node_size(value_node):
# todo(wanghao107) hack for dynamic shape
if is_dynamic_value_node(value_node):
return 1
return value_node.numel() * _PADDLE_DTYPE_2_NBYTES[value_node.dtype]
return (
functools.reduce(lambda x, y: x * y, value_node.shape, 1)
* _PADDLE_DTYPE_2_NBYTES[value_node.dtype]
)


def cal_value_nodes_dist_to_backward(all_ops, required_fw_value_nodes):
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/static/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3492,7 +3492,7 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):

# create input and parameters
input_shape = weight.shape
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
assert 0 not in input_shape, "Any dimension of input cannot be equal to 0."

if dim not in [0, 1]:
raise ValueError(
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import functools
import math
from typing import TYPE_CHECKING, Any, Literal, overload

Expand Down Expand Up @@ -4459,8 +4460,9 @@ def check_input(x, repeat_times):
else:
for elem in repeat_times:
if isinstance(elem, (Variable, paddle.pir.Value)):
numel = functools.reduce(lambda x, y: x * y, elem.shape, 1)
assert (
elem.numel() == 1
numel == 1
), 'Elements in repeat_times must be Tensor with one element or integers.'
else:
type_tuple = (int, np.int32, np.int64)
Expand Down