Skip to content

Commit 510f91c

Browse files
committed
fix
1 parent c21114f commit 510f91c

File tree

7 files changed

+183
-28
lines changed

7 files changed

+183
-28
lines changed

python/paddle/static/nn/metric.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ def auc(
250250
ins_tag_weight = paddle.full(
251251
shape=[1, 1], dtype="float32", fill_value=1.0
252252
)
253+
check_variable_and_dtype(input, 'input', ['float32', 'float64'], 'auc')
254+
check_variable_and_dtype(label, 'label', ['int32', 'int64'], 'auc')
255+
check_variable_and_dtype(
256+
ins_tag_weight, 'ins_tag_weight', ['float32', 'float64'], 'auc'
257+
)
253258
stat_pos = paddle.zeros(shape=[1, num_thresholds + 1], dtype="int64")
254259
stat_neg = paddle.zeros(shape=[1, num_thresholds + 1], dtype="int64")
255260
auc_out, batch_stat_pos, batch_stat_neg = _C_ops.auc(
@@ -260,7 +265,7 @@ def auc(
260265
ins_tag_weight,
261266
curve,
262267
num_thresholds,
263-
slide_steps,
268+
0,
264269
)
265270
return (
266271
auc_out,

test/deprecated/legacy_test/test_auc_op.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from op_test import OpTest
1919

2020
import paddle
21-
from paddle import base
2221

2322

2423
class TestAucOp(OpTest):
@@ -139,7 +138,9 @@ def test_static(self):
139138

140139
class TestAucOpError(unittest.TestCase):
141140
def test_errors(self):
142-
with base.program_guard(base.Program(), base.Program()):
141+
with paddle.static.program_guard(
142+
paddle.static.Program(), paddle.static.Program()
143+
):
143144

144145
def test_type1():
145146
data1 = paddle.static.data(

test/legacy_test/test_auc_op.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from op_test import OpTest
19+
20+
import paddle
21+
22+
23+
class TestAucOp(OpTest):
24+
def setUp(self):
25+
self.op_type = "auc"
26+
pred = np.random.random((128, 2)).astype("float32")
27+
labels = np.random.randint(0, 2, (128, 1)).astype("int64")
28+
num_thresholds = 200
29+
slide_steps = 1
30+
31+
stat_pos = np.zeros(
32+
(1 + slide_steps) * (num_thresholds + 1) + 1,
33+
).astype("int64")
34+
stat_neg = np.zeros(
35+
(1 + slide_steps) * (num_thresholds + 1) + 1,
36+
).astype("int64")
37+
38+
self.inputs = {
39+
'Predict': pred,
40+
'Label': labels,
41+
"StatPos": stat_pos,
42+
"StatNeg": stat_neg,
43+
}
44+
self.attrs = {
45+
'curve': 'ROC',
46+
'num_thresholds': num_thresholds,
47+
"slide_steps": slide_steps,
48+
}
49+
50+
python_auc = paddle.metric.Auc(
51+
name="auc", curve='ROC', num_thresholds=num_thresholds
52+
)
53+
python_auc.update(pred, labels)
54+
55+
pos = python_auc._stat_pos.tolist() * 2
56+
pos.append(1)
57+
neg = python_auc._stat_neg.tolist() * 2
58+
neg.append(1)
59+
self.outputs = {
60+
'AUC': np.array(python_auc.accumulate()),
61+
'StatPosOut': np.array(pos),
62+
'StatNegOut': np.array(neg),
63+
}
64+
65+
def test_check_output(self):
66+
self.check_output(check_dygraph=False)
67+
68+
69+
class TestGlobalAucOp(OpTest):
70+
def setUp(self):
71+
self.op_type = "auc"
72+
pred = np.random.random((128, 2)).astype("float32")
73+
labels = np.random.randint(0, 2, (128, 1)).astype("int64")
74+
num_thresholds = 200
75+
slide_steps = 0
76+
77+
stat_pos = np.zeros((1, (num_thresholds + 1))).astype("int64")
78+
stat_neg = np.zeros((1, (num_thresholds + 1))).astype("int64")
79+
80+
self.inputs = {
81+
'Predict': pred,
82+
'Label': labels,
83+
"StatPos": stat_pos,
84+
"StatNeg": stat_neg,
85+
}
86+
self.attrs = {
87+
'curve': 'ROC',
88+
'num_thresholds': num_thresholds,
89+
"slide_steps": slide_steps,
90+
}
91+
92+
python_auc = paddle.metric.Auc(
93+
name="auc", curve='ROC', num_thresholds=num_thresholds
94+
)
95+
python_auc.update(pred, labels)
96+
97+
pos = python_auc._stat_pos
98+
neg = python_auc._stat_neg
99+
self.outputs = {
100+
'AUC': np.array(python_auc.accumulate()),
101+
'StatPosOut': np.array([pos]),
102+
'StatNegOut': np.array([neg]),
103+
}
104+
105+
def test_check_output(self):
106+
self.check_output(check_dygraph=False)
107+
108+
109+
class TestAucAPI(unittest.TestCase):
110+
def test_static(self):
111+
paddle.enable_static()
112+
data = paddle.static.data(name="input", shape=[-1, 1], dtype="float32")
113+
label = paddle.static.data(name="label", shape=[4], dtype="int64")
114+
ins_tag_weight = paddle.static.data(
115+
name="ins_tag_weight", shape=[4], dtype="float32"
116+
)
117+
result = paddle.static.auc(
118+
input=data, label=label, ins_tag_weight=ins_tag_weight
119+
)
120+
121+
place = paddle.CPUPlace()
122+
exe = paddle.static.Executor(place)
123+
124+
exe.run(paddle.static.default_startup_program())
125+
126+
x = np.array([[0.0474], [0.5987], [0.7109], [0.9997]]).astype("float32")
127+
128+
y = np.array([0, 0, 1, 0]).astype('int64')
129+
z = np.array([1, 1, 1, 1]).astype('float32')
130+
(output,) = exe.run(
131+
feed={"input": x, "label": y, "ins_tag_weight": z},
132+
fetch_list=[result[0]],
133+
)
134+
auc_np = np.array(0.66666667).astype("float32")
135+
np.testing.assert_allclose(output, auc_np, rtol=1e-05)
136+
assert auc_np.shape == auc_np.shape
137+
138+
139+
class TestAucOpError(unittest.TestCase):
140+
def test_errors(self):
141+
with paddle.static.program_guard(
142+
paddle.static.Program(), paddle.static.Program()
143+
):
144+
145+
def test_type1():
146+
data1 = paddle.static.data(
147+
name="input1", shape=[-1, 2], dtype="int"
148+
)
149+
label1 = paddle.static.data(
150+
name="label1", shape=[-1], dtype="int"
151+
)
152+
ins_tag_w1 = paddle.static.data(
153+
name="label1", shape=[-1], dtype="int"
154+
)
155+
result1 = paddle.static.auc(
156+
input=data1, label=label1, ins_tag_weight=ins_tag_w1
157+
)
158+
159+
self.assertRaises(TypeError, test_type1)
160+
161+
def test_type2():
162+
data2 = paddle.static.data(
163+
name="input2", shape=[-1, 2], dtype="float32"
164+
)
165+
label2 = paddle.static.data(
166+
name="label2", shape=[-1], dtype="float32"
167+
)
168+
result2 = paddle.static.auc(input=data2, label=label2)
169+
170+
self.assertRaises(TypeError, test_type2)
171+
172+
173+
if __name__ == '__main__':
174+
unittest.main()

test/deprecated/legacy_test/test_sparse_conv_op.py renamed to test/legacy_test/test_sparse_conv_op.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,10 @@
1616
import unittest
1717

1818
import numpy as np
19-
from utils import compare_legacy_with_pt
2019

2120
import paddle
2221
from paddle import sparse
2322
from paddle.base import core
24-
from paddle.pir_utils import test_with_pir_api
2523

2624
logging.basicConfig(
2725
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO
@@ -316,8 +314,6 @@ def test_Conv3D_bias(self):
316314

317315

318316
class TestStatic(unittest.TestCase):
319-
@compare_legacy_with_pt
320-
@test_with_pir_api
321317
def test(self):
322318
paddle.enable_static()
323319
main = paddle.static.Program()
@@ -386,8 +382,6 @@ def test(self):
386382
self.assertTrue(out_indices.dtype == paddle.int32)
387383
paddle.disable_static()
388384

389-
@compare_legacy_with_pt
390-
@test_with_pir_api
391385
def test_cpu(self):
392386
paddle.enable_static()
393387
main = paddle.static.Program()
@@ -457,8 +451,6 @@ def test_cpu(self):
457451
self.assertTrue(out_indices.dtype == paddle.int32)
458452
paddle.disable_static()
459453

460-
@compare_legacy_with_pt
461-
@test_with_pir_api
462454
def test2D(self):
463455
paddle.enable_static()
464456
main = paddle.static.Program()
@@ -522,8 +514,6 @@ def test2D(self):
522514
self.assertTrue(out_indices.dtype == paddle.int32)
523515
paddle.disable_static()
524516

525-
@compare_legacy_with_pt
526-
@test_with_pir_api
527517
def test2D_cpu(self):
528518
paddle.enable_static()
529519
main = paddle.static.Program()

test/deprecated/legacy_test/test_sparse_isnan_op.py renamed to test/legacy_test/test_sparse_isnan_op.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
import unittest
1616

1717
import numpy as np
18-
from utils import compare_legacy_with_pt
1918

2019
import paddle
21-
from paddle.pir_utils import test_with_pir_api
2220

2321

2422
class TestSparseIsnan(unittest.TestCase):
@@ -65,8 +63,6 @@ def test_isnan_dtype(self):
6563

6664

6765
class TestStatic(unittest.TestCase):
68-
@compare_legacy_with_pt
69-
@test_with_pir_api
7066
def test(self):
7167
paddle.enable_static()
7268
main_program = paddle.static.Program()

test/deprecated/legacy_test/test_sparse_slice_op.py renamed to test/legacy_test/test_sparse_slice_op.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import unittest
1616

1717
import numpy as np
18-
from utils import compare_legacy_with_pt
1918

2019
import paddle
2120

@@ -207,32 +206,26 @@ def check_result_with_list(self, x, axes, starts, ends, format='coo'):
207206
if format == 'coo':
208207
self._check_result_coo(np_x, axes, starts, ends)
209208

210-
@compare_legacy_with_pt
211209
def test_coo_5d(self):
212210
for item in data_5d:
213211
self.check_result_with_shape(*item, format='coo')
214212

215-
@compare_legacy_with_pt
216213
def test_coo_4d(self):
217214
for item in data_4d:
218215
self.check_result_with_shape(*item, format='coo')
219216

220-
@compare_legacy_with_pt
221217
def test_coo_3d(self):
222218
for item in data_3d:
223219
self.check_result_with_shape(*item, format='coo')
224220

225-
@compare_legacy_with_pt
226221
def test_coo_2d(self):
227222
for item in data_2d:
228223
self.check_result_with_shape(*item, format='coo')
229224

230-
@compare_legacy_with_pt
231225
def test_coo_1d(self):
232226
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
233227
self.check_result_with_list(x, [0], [3], [5], format='coo')
234228

235-
@compare_legacy_with_pt
236229
def test_coo_1d_zero(self):
237230
x = [-49, 55, -5, 0, 3, 0, 0, -60, -21, 0, 0, 0]
238231
self.check_result_with_list(x, [0], [-3], [-1], format='coo')

test/deprecated/legacy_test/test_sparse_sum_op.py renamed to test/legacy_test/test_sparse_sum_op.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
import unittest
1616

1717
import numpy as np
18-
from utils import compare_legacy_with_pt
1918

2019
import paddle
21-
from paddle.pir_utils import test_with_dygraph_pir
2220

2321
devices = ['cpu']
2422
if paddle.device.get_device() != "cpu":
@@ -178,8 +176,6 @@ def check_result_coo(self, x_shape, dims, keepdim, dtype=None):
178176
)
179177
paddle.disable_static()
180178

181-
@compare_legacy_with_pt
182-
@test_with_dygraph_pir
183179
def test_sum(self):
184180
# 1d
185181
self.check_result_coo([5], None, False)

0 commit comments

Comments
 (0)