Skip to content

Commit 4922fc4

Browse files
LiyulingyueHermitSun
authored andcommitted
【PIR API adaptor No.158】nanmedian (PaddlePaddle#58889)
1 parent 183cbfb commit 4922fc4

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

python/paddle/tensor/stat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None):
329329
elif isinstance(axis, int):
330330
axis = [axis]
331331

332-
if in_dynamic_mode():
332+
if in_dynamic_or_pir_mode():
333333
return _C_ops.nanmedian(x, axis, keepdim)
334334
else:
335335
check_variable_and_dtype(

test/legacy_test/test_nanmedian.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import paddle
2121
from paddle.base import core
22+
from paddle.pir_utils import test_with_pir_api
2223

2324
np.random.seed(102)
2425

@@ -79,6 +80,7 @@ def setUp(self):
7980
[0, 2, 1, 3],
8081
]
8182

83+
@test_with_pir_api
8284
def test_api_static(self):
8385
data = self.fake_data["col_nan_odd"]
8486
paddle.enable_static()
@@ -257,10 +259,10 @@ def setUp(self):
257259
self.outputs = {'Out': Out}
258260

259261
def test_check_output(self):
260-
self.check_output()
262+
self.check_output(check_pir=True)
261263

262264
def test_check_grad(self):
263-
self.check_grad(['X'], 'Out')
265+
self.check_grad(['X'], 'Out', check_pir=True)
264266

265267

266268
@unittest.skipIf(
@@ -282,11 +284,11 @@ def setUp(self):
282284

283285
def test_check_output(self):
284286
place = core.CUDAPlace(0)
285-
self.check_output_with_place(place)
287+
self.check_output_with_place(place, check_pir=True)
286288

287289
def test_check_grad(self):
288290
place = core.CUDAPlace(0)
289-
self.check_grad_with_place(place, ['X'], 'Out')
291+
self.check_grad_with_place(place, ['X'], 'Out', check_pir=True)
290292

291293

292294
if __name__ == "__main__":

0 commit comments

Comments
 (0)