Skip to content

Commit c6f2050

Browse files
【Fix PIR Unittest No.43 BUAA】Fix test_tensor_array_to_tensor (#66337)
1 parent 8aadae3 commit c6f2050

File tree

3 files changed

+12
-11
lines changed

3 files changed

+12
-11
lines changed

test/deprecated/legacy_test/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,6 @@ set(STATIC_BUILD_TESTS
709709
test_nce_deprecated
710710
test_sparse_conv_op
711711
test_sparse_norm_op
712-
test_tensor_array_to_tensor
713712
test_tensor_array_to_tensor_deprecated
714713
test_unique
715714
test_one_hot_v2_op)

test/legacy_test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,8 @@ set(STATIC_BUILD_TESTS
11071107
test_shuffle_batch_op
11081108
test_update_loss_scaling_op
11091109
test_segment_ops
1110-
test_while_op)
1110+
test_while_op
1111+
test_tensor_array_to_tensor)
11111112

11121113
if(NOT WITH_GPU)
11131114
list(REMOVE_ITEM STATIC_BUILD_TESTS test_fused_attention_op)

test/deprecated/legacy_test/test_tensor_array_to_tensor.py renamed to test/legacy_test/test_tensor_array_to_tensor.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,17 @@ def set_program(self):
7979
self.output_vars = [output]
8080

8181
def run_check(self, executor, scope):
82-
executor.run(self.program, scope=scope)
82+
result = executor.run(
83+
self.program, fetch_list=self.output_vars, scope=scope
84+
)
8385
for i, output in enumerate(self.outputs):
84-
np.allclose(
85-
np.array(scope.var(self.output_vars[i].name).get_tensor()),
86-
output,
87-
atol=0,
88-
)
89-
tensor_array_grad = scope.var(self.array.name).get_lod_tensor_array()
90-
for i, input_grad in enumerate(self.input_grads):
91-
np.allclose(np.array(tensor_array_grad[i]), input_grad, atol=0)
86+
np.allclose(result[i], output, atol=0)
87+
if not paddle.framework.use_pir_api():
88+
tensor_array_grad = scope.var(
89+
self.array.name
90+
).get_lod_tensor_array()
91+
for i, input_grad in enumerate(self.input_grads):
92+
np.allclose(np.array(tensor_array_grad[i]), input_grad, atol=0)
9293

9394
def test_cpu(self):
9495
scope = core.Scope()

0 commit comments

Comments
 (0)