Skip to content

Commit 78ec942

Browse files
authored
Fix 空指针 (Null pointer) of case15: paddle.broadcast_tensors (#49980)
* fix incorrect output shape of broadcast * add unittest
1 parent 1048b16 commit 78ec942

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

paddle/phi/infermeta/multiary.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ void BroadcastTensorsInferMeta(const std::vector<const MetaTensor*>& x,
791791

792792
// We performed bcast semantics check at python level
793793
// So input tensors should all have legal shape
794-
target_dim_size = std::max(target_dim_size, dim_size);
794+
target_dim_size = dim_size == 1 ? target_dim_size : dim_size;
795795
}
796796
target_dims[target_rank - index - 1] = target_dim_size;
797797
}

python/paddle/fluid/tests/unittests/test_broadcast_tensors_op.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,12 @@ def find_output_shape(input_list):
3333
rank = len(x.shape)
3434
output_rank = max(output_rank, rank)
3535

36-
output_shape = [0 for i in range(output_rank)]
36+
output_shape = [1 for i in range(output_rank)]
3737
for i in range(output_rank):
3838
for x in input_list:
3939
shape = list(reversed(x.shape))
40-
size = 1
41-
if i < len(shape):
42-
size = shape[i]
43-
output_shape[i] = max(output_shape[i], size)
40+
if i < len(shape) and shape[i] != 1:
41+
output_shape[i] = shape[i]
4442

4543
return list(reversed(output_shape))
4644

@@ -80,6 +78,11 @@ def gen_mixed_tensors_test(dtype):
8078
return make_inputs_outputs(input_shapes, dtype)
8179

8280

81+
def gen_empty_tensors_test(dtype):
82+
input_shapes = [(0), (0), (0)]
83+
return make_inputs_outputs(input_shapes, dtype)
84+
85+
8386
class TestCPUBroadcastTensorsOp(OpTest):
8487
def set_place(self):
8588
self.place = core.CPUPlace()
@@ -95,6 +98,7 @@ def setUp(self):
9598
gen_rank_diff_test,
9699
gen_no_broadcast_test,
97100
gen_mixed_tensors_test,
101+
gen_empty_tensors_test,
98102
]
99103
self.set_place()
100104
self.set_dtypes()

0 commit comments

Comments
 (0)