Skip to content

Commit 29399ed

Browse files
Try lightning when pytorch_lightning is not found (#10404)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 16486ed commit 29399ed

File tree

3 files changed

+29
-13
lines changed

3 files changed

+29
-13
lines changed

CHANGELOG.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Fixed
99

10-
- Fix `detach()` warnings in example scripts involving tensor conversions. ([#10357](https://github.com/pyg-team/pytorch_geometric/pull/10357))
11-
- Fix non-tuple indexing to resolve PyTorch deprecation warning. ([#10389](https://github.com/pyg-team/pytorch_geometric/pull/10389))
10+
- Fixed importing PyTorch Lightning in `torch_geometric.graphgym` and `torch_geometric.data.lightning` when using `lightning` instead of `pytorch-lightning` ([#10404](https://github.com/pyg-team/pytorch_geometric/pull/10404))
11+
- Fixed `detach()` warnings in example scripts involving tensor conversions ([#10357](https://github.com/pyg-team/pytorch_geometric/pull/10357))
12+
- Fixed non-tuple indexing to resolve PyTorch deprecation warning ([#10389](https://github.com/pyg-team/pytorch_geometric/pull/10389))
1213

1314
### Added
1415

torch_geometric/data/lightning/datamodule.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,27 @@
1111
from torch_geometric.typing import InputEdges, InputNodes, OptTensor
1212

1313
try:
14-
from pytorch_lightning import LightningDataModule as PLLightningDataModule
15-
no_pytorch_lightning = False
14+
from pytorch_lightning import LightningDataModule as _LightningDataModule
15+
_pl_is_available = True
1616
except ImportError:
17-
PLLightningDataModule = object # type: ignore
18-
no_pytorch_lightning = True
17+
try:
18+
from lightning.pytorch import \
19+
LightningDataModule as _LightningDataModule
20+
_pl_is_available = True
21+
except ImportError:
22+
_pl_is_available = False
23+
_LightningDataModule = object
1924

2025

21-
class LightningDataModule(PLLightningDataModule):
26+
class LightningDataModule(_LightningDataModule):
2227
def __init__(self, has_val: bool, has_test: bool, **kwargs: Any) -> None:
2328
super().__init__()
2429

25-
if no_pytorch_lightning:
30+
if not _pl_is_available:
2631
raise ModuleNotFoundError(
27-
"No module named 'pytorch_lightning' found on this machine. "
28-
"Run 'pip install pytorch_lightning' to install the library.")
32+
"No module named 'pytorch_lightning' (or 'lightning') found "
33+
"in your Python environment. Run 'pip install "
34+
"pytorch_lightning' or 'pip install lightning'")
2935

3036
if not has_val:
3137
self.val_dataloader = None # type: ignore

torch_geometric/graphgym/imports.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,23 @@
44

55
try:
66
import pytorch_lightning as pl
7+
_pl_is_available = True
8+
except ImportError:
9+
try:
10+
import lightning.pytorch as pl
11+
_pl_is_available = True
12+
except ImportError:
13+
_pl_is_available = False
14+
15+
if _pl_is_available:
716
LightningModule = pl.LightningModule
817
Callback = pl.Callback
9-
except ImportError:
18+
else:
1019
pl = object
1120
LightningModule = torch.nn.Module
1221
Callback = object
1322

1423
warnings.warn(
15-
"Please install 'pytorch_lightning' via "
16-
"'pip install pytorch_lightning' in order to use GraphGym",
24+
"To use GraphGym, install 'pytorch_lightning' or 'lightning' via "
25+
"'pip install pytorch_lightning' or 'pip install lightning'",
1726
stacklevel=2)

0 commit comments

Comments
 (0)