Skip to content

Commit 2ba4e05

Browse files
authored
Improve device auto-detection (#7787)
1 parent eef7bb4 commit 2ba4e05

File tree

13 files changed

+19
-119
lines changed

13 files changed

+19
-119
lines changed

test/cpp/cpp_test_util.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -437,12 +437,6 @@ torch::lazy::NodePtr CreateNonZeroNode2d(int64_t num_non_zero_element,
437437
return nonzero_node;
438438
}
439439

440-
bool UsingPjRt() {
441-
static bool using_pjrt =
442-
!torch_xla::runtime::sys_util::GetEnvString("PJRT_DEVICE", "").empty();
443-
return using_pjrt;
444-
}
445-
446440
bool UsingTpu() {
447441
static bool using_tpu =
448442
absl::StartsWith(

test/cpp/cpp_test_util.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ void TestBackward(
116116
torch::lazy::NodePtr CreateNonZeroNode2d(int64_t num_non_zero_element,
117117
int64_t num_row, int64_t num_col);
118118

119-
bool UsingPjRt();
120-
121119
bool UsingTpu();
122120

123121
} // namespace cpp_test

test/pjrt/test_profiler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ def _profile(logdir: str, port: int = 9012):
3131
class TestPjRtProfiler(absltest.TestCase):
3232

3333
def setUp(self):
34-
assert xr.using_pjrt()
35-
3634
# HACK: ensure libtpu is loaded if using TPU
3735
xm.xla_device()
3836

test/pjrt/test_runtime.py

Lines changed: 12 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
class TestExperimentalPjrt(parameterized.TestCase):
1616

1717
def setUp(self):
18-
global xr
19-
reload(xr)
2018
xr.set_device_type('CPU')
2119

2220
@parameterized.parameters(('CPU', 'CPU'), ('CUDA', 'CUDA'), ('TPU', 'TPU'))
@@ -44,12 +42,6 @@ def test_set_device_type_same_device(self):
4442
torch_xla._XLAC, '_xla_runtime_is_initialized', return_value=True):
4543
xr.set_device_type('CPU')
4644

47-
def test_requires_pjrt(self):
48-
with mock.patch.dict(
49-
os.environ, {'PJRT_SELECT_DEFAULT_DEVICE': '0'}, clear=True):
50-
with self.assertRaises(NotImplementedError):
51-
xr.xla_device()
52-
5345
def test_default_ordinals(self):
5446
global_ordinal = xr.global_ordinal()
5547
self.assertEqual(global_ordinal, 0)
@@ -65,9 +57,6 @@ def test_num_global_devices(self):
6557
self.assertLen(torch_xla._XLAC._xla_get_all_devices(),
6658
xr.global_device_count())
6759

68-
def test_world_size(self):
69-
self.assertEqual(xr.world_size(), xr.world_size())
70-
7160
def test_xla_device_error(self):
7261
with self.assertRaises(IndexError):
7362
xm.xla_device(10)
@@ -87,21 +76,19 @@ def test_xla_device_error(self):
8776
'GPU_NUM_DEVICES': '4'
8877
}, True))
8978
def test_pjrt_default_device(self, env_vars, expect_using_pjrt):
90-
with mock.patch.dict(os.environ, env_vars, clear=True):
91-
# Print a warningif we had to select a default runtime
92-
if 'PJRT_DEVICE' not in os.environ and expect_using_pjrt:
93-
logs_context = self.assertLogs(level=logging.WARNING)
94-
else:
79+
# Prevent flag checking during reinitialization of PJRT backend.
80+
# Without the patch, the test will be impacted by other tests when torch_xla reloads.
81+
with mock.patch(
82+
'torch_xla._XLAC._xla_runtime_is_initialized', return_value=False):
83+
with mock.patch.dict(os.environ, env_vars, clear=True):
84+
# We need to reload the torch_xla module because clear=True will clear all os.environ.
85+
global torch_xla
86+
reload(torch_xla)
9587
logs_context = contextlib.nullcontext()
96-
97-
with logs_context:
98-
# Configure default device
99-
xr.using_pjrt()
100-
101-
if expect_using_pjrt:
102-
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU'])
103-
else:
104-
self.assertIsNone(xr.device_type())
88+
if expect_using_pjrt:
89+
self.assertIn(xr.device_type(), ['CPU', 'CUDA', 'TPU'])
90+
else:
91+
self.assertIsNone(xr.device_type())
10592

10693
def test_host_index(self):
10794
self.assertEqual(xr.host_index(), 0)

test/test_operations.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1659,9 +1659,7 @@ def test_fn(t):
16591659
t += torch.tensor(i, dtype=torch.float, device=t.device)
16601660
return t
16611661

1662-
# This test is for PjRT only
1663-
if xr.using_pjrt():
1664-
self.runAtenTest([torch.tensor(20.0)], test_fn)
1662+
self.runAtenTest([torch.tensor(20.0)], test_fn)
16651663

16661664
def test_view_and_copy_(self):
16671665
xla_device = xm.xla_device()

test/test_train_mp_imagenet.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,7 @@ def train_imagenet():
255255

256256
# Initialization is nondeterministic with multiple threads in PjRt.
257257
# Synchronize model parameters across replicas manually.
258-
if xr.using_pjrt():
259-
xm.broadcast_master_param(model)
258+
xm.broadcast_master_param(model)
260259

261260
if FLAGS.ddp:
262261
model = DDP(model, gradient_as_bucket_view=True, broadcast_buffers=False)

test/test_train_mp_mnist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,7 @@ def train_mnist(flags, **kwargs):
135135

136136
# Initialization is nondeterministic with multiple threads in PjRt.
137137
# Synchronize model parameters across replicas manually.
138-
if xr.using_pjrt():
139-
xm.broadcast_master_param(model)
138+
xm.broadcast_master_param(model)
140139

141140
if flags.ddp:
142141
model = DDP(model, gradient_as_bucket_view=True)

torch_xla/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,6 @@ def _init_xla_lazy_backend():
251251

252252
# register all custom kenels and decomp by default
253253
from ._internal import custom_kernel, decomp_registration, c10d_registration
254+
255+
# select default PJRT_DEVICE before any execution
256+
runtime._maybe_select_default_device()

torch_xla/_internal/pjrt.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def _merge_replica_results(
4141
return dict(replica_results)
4242

4343

44-
@runtime.requires_pjrt
4544
def _run_thread_per_device(
4645
local_rank: int, local_world_size: int, fn: Callable[[], R],
4746
initializer_fn: Callable[[int, int], None]) -> Dict[int, R]:
@@ -81,7 +80,6 @@ def _thread_fn(device: torch.device):
8180
return _merge_replica_results(replica_results)
8281

8382

84-
@runtime.requires_pjrt
8583
def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
8684
"""Runs `fn` on a single device core.
8785
@@ -99,7 +97,6 @@ def _run_singleprocess(fn: Callable[..., R], *args, **kwargs) -> Dict[int, R]:
9997
return fn(*args, **kwargs)
10098

10199

102-
@runtime.requires_pjrt
103100
def initialize_singleprocess():
104101
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, '1')
105102

@@ -110,7 +107,6 @@ def initialize_singleprocess():
110107
xm.set_replication(xm.xla_device(), [])
111108

112109

113-
@runtime.requires_pjrt
114110
def initialize_multiprocess(local_rank: int, local_world_size: int):
115111
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank))
116112
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, str(local_world_size))
@@ -126,7 +122,6 @@ def initialize_multiprocess(local_rank: int, local_world_size: int):
126122
xm.set_replication(xm.xla_device(), devices)
127123

128124

129-
@runtime.requires_pjrt
130125
def run_multiprocess(fn: Callable[..., R],
131126
*args,
132127
start_method: str = 'spawn',
@@ -214,7 +209,6 @@ def spawn(fn: Callable,
214209
run_multiprocess(spawn_fn, start_method=start_method)
215210

216211

217-
@runtime.requires_pjrt
218212
def _initialize_single_process(local_rank: int, local_world_size: int):
219213
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_RANK, str(local_rank))
220214
os.environ.setdefault(xenv.PJRT_LOCAL_PROCESS_COUNT, str(local_world_size))

torch_xla/_internal/utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,7 @@
1-
import functools
21
import re
32

43

54
def parse_xla_device(device: str):
65
m = re.match(r'([A-Z]+):(\d+)$', device)
76
if m:
87
return (m.group(1), int(m.group(2)))
9-
10-
11-
def run_once(func):
12-
result = None
13-
14-
@functools.wraps(func)
15-
def wrapper(*args, **kwargs):
16-
nonlocal result
17-
if result is None:
18-
result = func(*args, **kwargs)
19-
return result
20-
21-
return wrapper

0 commit comments

Comments
 (0)