Skip to content

Commit dddee4a

Browse files
authored
[PIR]Fix value numel method (#67754)
1 parent 27df5e8 commit dddee4a

File tree

5 files changed

+10
-4
lines changed

5 files changed

+10
-4
lines changed

paddle/fluid/pybind/pir.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,6 @@ void BindValue(py::module *m) {
13221322
"get_defining_op",
13231323
[](Value self) -> pir::Operation * { return self.defining_op(); },
13241324
return_value_policy::reference)
1325-
.def("numel", [](Value self) { return phi::product(GetValueDims(self)); })
13261325
.def("type", &Value::type)
13271326
.def("index",
13281327
[](Value self) -> uint32_t {

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3476,6 +3476,7 @@
34763476
data_transform:
34773477
skip_transform : x
34783478
no_need_buffer : x
3479+
traits : paddle::dialect::ForwardOnlyTrait
34793480
interfaces : paddle::dialect::InferSymbolicShapeInterface
34803481

34813482
- op : one_hot

python/paddle/decomposition/recompute.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import functools
1617
import math
1718
import os
1819
from typing import TYPE_CHECKING
@@ -797,7 +798,10 @@ def cal_value_node_size(value_node):
797798
# todo(wanghao107) hack for dynamic shape
798799
if is_dynamic_value_node(value_node):
799800
return 1
800-
return value_node.numel() * _PADDLE_DTYPE_2_NBYTES[value_node.dtype]
801+
return (
802+
functools.reduce(lambda x, y: x * y, value_node.shape, 1)
803+
* _PADDLE_DTYPE_2_NBYTES[value_node.dtype]
804+
)
801805

802806

803807
def cal_value_nodes_dist_to_backward(all_ops, required_fw_value_nodes):

python/paddle/static/nn/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3492,7 +3492,7 @@ def spectral_norm(weight, dim=0, power_iters=1, eps=1e-12, name=None):
34923492

34933493
# create input and parameters
34943494
input_shape = weight.shape
3495-
assert weight.numel() > 0, "Any dimension of input cannot be equal to 0."
3495+
assert 0 not in input_shape, "Any dimension of input cannot be equal to 0."
34963496

34973497
if dim not in [0, 1]:
34983498
raise ValueError(

python/paddle/tensor/manipulation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import functools
1718
import math
1819
from typing import TYPE_CHECKING, Any, Literal, overload
1920

@@ -4459,8 +4460,9 @@ def check_input(x, repeat_times):
44594460
else:
44604461
for elem in repeat_times:
44614462
if isinstance(elem, (Variable, paddle.pir.Value)):
4463+
numel = functools.reduce(lambda x, y: x * y, elem.shape, 1)
44624464
assert (
4463-
elem.numel() == 1
4465+
numel == 1
44644466
), 'Elements in repeat_times must be Tensor with one element or integers.'
44654467
else:
44664468
type_tuple = (int, np.int32, np.int64)

0 commit comments

Comments
 (0)