Skip to content

Commit da60579

Browse files
authored
Replace using_pjrt() xla runtime device_type() check with in xla.py
Fixes Lightning-AI#20419 `torch_xla.runtime.using_pjrt()` is removed in pytorch/xla#7787 This PR replaces references to that function with a check to [`device_type()`](https://github.com/pytorch/xla/blob/master/torch_xla/runtime.py#L83) to recreate the behavior of that function, minus the manual initialization
1 parent 8ce5287 commit da60579

File tree

1 file changed

+1
-1
lines changed
  • src/lightning/fabric/accelerators

1 file changed

+1
-1
lines changed

src/lightning/fabric/accelerators/xla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def _using_pjrt() -> bool:
109109
if _XLA_GREATER_EQUAL_2_1:
110110
from torch_xla import runtime as xr
111111

112-
return xr.using_pjrt()
112+
return xr.device_type() is not None
113113
from torch_xla.experimental import pjrt
114114

115115
return pjrt.using_pjrt()

0 commit comments

Comments
 (0)