1515class TestExperimentalPjrt (parameterized .TestCase ):
1616
1717 def setUp (self ):
18- global xr
19- reload (xr )
2018 xr .set_device_type ('CPU' )
2119
2220 @parameterized .parameters (('CPU' , 'CPU' ), ('CUDA' , 'CUDA' ), ('TPU' , 'TPU' ))
@@ -44,12 +42,6 @@ def test_set_device_type_same_device(self):
4442 torch_xla ._XLAC , '_xla_runtime_is_initialized' , return_value = True ):
4543 xr .set_device_type ('CPU' )
4644
47- def test_requires_pjrt (self ):
48- with mock .patch .dict (
49- os .environ , {'PJRT_SELECT_DEFAULT_DEVICE' : '0' }, clear = True ):
50- with self .assertRaises (NotImplementedError ):
51- xr .xla_device ()
52-
5345 def test_default_ordinals (self ):
5446 global_ordinal = xr .global_ordinal ()
5547 self .assertEqual (global_ordinal , 0 )
@@ -65,9 +57,6 @@ def test_num_global_devices(self):
6557 self .assertLen (torch_xla ._XLAC ._xla_get_all_devices (),
6658 xr .global_device_count ())
6759
68- def test_world_size (self ):
69- self .assertEqual (xr .world_size (), xr .world_size ())
70-
7160 def test_xla_device_error (self ):
7261 with self .assertRaises (IndexError ):
7362 xm .xla_device (10 )
@@ -87,21 +76,19 @@ def test_xla_device_error(self):
8776 'GPU_NUM_DEVICES' : '4'
8877 }, True ))
8978 def test_pjrt_default_device (self , env_vars , expect_using_pjrt ):
90- with mock .patch .dict (os .environ , env_vars , clear = True ):
91- # Print a warningif we had to select a default runtime
92- if 'PJRT_DEVICE' not in os .environ and expect_using_pjrt :
93- logs_context = self .assertLogs (level = logging .WARNING )
94- else :
79+ # Prevent flag checking during reinitialization of PJRT backend.
80+ # Without the patch, the test will be impacted by other tests when torch_xla reloads.
81+ with mock .patch (
82+ 'torch_xla._XLAC._xla_runtime_is_initialized' , return_value = False ):
83+ with mock .patch .dict (os .environ , env_vars , clear = True ):
84+ # We need to reload the torch_xla module because clear=True will clear all os.environ.
85+ global torch_xla
86+ reload (torch_xla )
9587 logs_context = contextlib .nullcontext ()
96-
97- with logs_context :
98- # Configure default device
99- xr .using_pjrt ()
100-
101- if expect_using_pjrt :
102- self .assertIn (xr .device_type (), ['CPU' , 'CUDA' , 'TPU' ])
103- else :
104- self .assertIsNone (xr .device_type ())
88+ if expect_using_pjrt :
89+ self .assertIn (xr .device_type (), ['CPU' , 'CUDA' , 'TPU' ])
90+ else :
91+ self .assertIsNone (xr .device_type ())
10592
10693 def test_host_index (self ):
10794 self .assertEqual (xr .host_index (), 0 )
0 commit comments