Skip to content

Commit 9314156

Browse files
committed
fix mock test that clears os env
1 parent 17fe1b8 commit 9314156

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

test/pjrt/test_runtime.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
class TestExperimentalPjrt(parameterized.TestCase):
1616

1717
def setUp(self):
18-
global xr
19-
reload(xr)
2018
xr.set_device_type('CPU')
2119

2220
@parameterized.parameters(('CPU', 'CPU'), ('CUDA', 'CUDA'), ('TPU', 'TPU'))
@@ -82,12 +80,10 @@ def test_xla_device_error(self):
8280
}, True))
8381
def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
8482
with mock.patch.dict(os.environ, env_vars, clear=True):
85-
# Print a warningif we had to select a default runtime
86-
if 'PJRT_DEVICE' not in os.environ and expect_using_pjrt:
87-
logs_context = self.assertLogs(level=logging.WARNING)
88-
else:
89-
logs_context = contextlib.nullcontext()
90-
83+
# We need to reload the torch_xla module because clear=True will clear all os.environ.
84+
global torch_xla
85+
reload(torch_xla)
86+
logs_context = contextlib.nullcontext()
9187
if expect_using_pjrt:
9288
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU'])
9389
else:

torch_xla/_internal/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import functools
21
import re
32

43

0 commit comments

Comments
 (0)