-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingcallback: early stoppinglogger: wandbWeights & BiasesWeights & Biases
Description
Bug description
When an EarlyStopping
callback would halt the training before min_epochs
has elapsed, EarlyStopping
is (correctly) overridden, and prints the warning message given below. However, at the exact step number when the warning was printed, WandbLogger
suddenly begins logging the train metrics for every single batch. This results in slowed training and strange output graphs.
What version are you seeing the problem on?
v2.2
How to reproduce the bug
import torch
from lightning.pytorch import LightningModule, Trainer, seed_everything
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("valid_loss", loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("test_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
def run():
seed_everything(42, workers=True)
wandb_logger = WandbLogger(project="bug-report",
entity="example-user", name="debug_logging")
early_stopping_callback = EarlyStopping(monitor="train_loss", patience=2)
callbacks = [early_stopping_callback]
kwargs = {
"log_every_n_steps": 8,
"logger": wandb_logger,
"num_sanity_val_steps": 0,
"callbacks": callbacks,
"val_check_interval": 0.1,
"max_epochs": 10,
"min_epochs": 2,
"deterministic": True
}
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(**kwargs)
trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
if __name__ == "__main__":
run()
Error messages and logs
Epoch 0: 38%|████████████████████████████████████████▉ | 12/32 [00:00<00:00, 53.14it/s, v_num=yqz5]
Trainer was signaled to stop but the required `min_epochs=2` or `min_steps=None` has not been met. Training will continue...
Environment
Current environment
- CUDA:
- GPU: None
- available: False
- version: None - Lightning:
- lightning: 2.2.5
- lightning-utilities: 0.11.2
- pytorch-lightning: 2.2.2
- torch: 2.3.0
- torchmetrics: 1.4.0.post0 - Packages:
- appdirs: 1.4.4
- appnope: 0.1.4
- asttokens: 2.4.1
- brotli: 1.1.0
- certifi: 2024.6.2
- chardet: 5.2.0
- charset-normalizer: 3.3.2
- click: 8.1.7
- colorama: 0.4.6
- comm: 0.2.2
- contourpy: 1.2.1
- cycler: 0.12.1
- debugpy: 1.8.1
- decorator: 5.1.1
- docker-pycreds: 0.4.0
- exceptiongroup: 1.2.0
- executing: 2.0.1
- filelock: 3.14.0
- fonttools: 4.53.0
- freetype-py: 2.3.0
- fsspec: 2024.6.0
- gitdb: 4.0.11
- gitpython: 3.1.43
- gmpy2: 2.1.5
- greenlet: 3.0.3
- idna: 3.7
- importlib-metadata: 7.1.0
- ipykernel: 6.29.3
- ipython: 8.25.0
- jedi: 0.19.1
- jinja2: 3.1.4
- joblib: 1.4.2
- jupyter-client: 8.6.2
- jupyter-core: 5.7.2
- kiwisolver: 1.4.5
- lightning: 2.2.5
- lightning-utilities: 0.11.2
- markupsafe: 2.1.5
- matplotlib: 3.8.4
- matplotlib-inline: 0.1.7
- mpmath: 1.3.0
- munkres: 1.1.4
- nest-asyncio: 1.6.0
- networkx: 3.3
- numexpr: 2.10.0
- numpy: 1.26.4
- packaging: 24.0
- pandas: 2.2.2
- parso: 0.8.4
- pathtools: 0.1.2
- pexpect: 4.9.0
- pickleshare: 0.7.5
- pillow: 10.3.0
- pip: 24.0
- platformdirs: 4.2.2
- prompt-toolkit: 3.0.46
- protobuf: 4.25.3
- psutil: 5.9.8
- ptyprocess: 0.7.0
- pure-eval: 0.2.2
- py-cpuinfo: 9.0.0
- pycairo: 1.26.0
- pygments: 2.18.0
- pyparsing: 3.1.2
- pysocks: 1.7.1
- python-dateutil: 2.9.0
- pytorch-lightning: 2.2.2
- pytz: 2024.1
- pyyaml: 6.0.1
- pyzmq: 26.0.3
- rdkit: 2024.3.3
- reportlab: 4.1.0
- requests: 2.32.3
- rlpycairo: 0.2.0
- scikit-learn: 1.5.0
- scipy: 1.13.1
- sentry-sdk: 2.4.0
- setproctitle: 1.3.3
- setuptools: 70.0.0
- six: 1.16.0
- smmap: 5.0.0
- sqlalchemy: 2.0.30
- stack-data: 0.6.2
- sympy: 1.12
- tables: 3.9.2
- threadpoolctl: 3.5.0
- torch: 2.3.0
- torchmetrics: 1.4.0.post0
- tornado: 6.4.1
- tqdm: 4.66.4
- traitlets: 5.14.3
- typing-extensions: 4.12.1
- tzdata: 2024.1
- urllib3: 2.2.1
- wandb: 0.16.5
- wcwidth: 0.2.13
- wheel: 0.43.0
- zipp: 3.17.0 - System:
- OS: Darwin
- architecture:
- 64bit
-
- processor: arm
- python: 3.11.9
- release: 23.5.0
- version: Darwin Kernel Version 23.5.0: Wed May 1 20:19:05 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T8112
More info
The symptoms of this bug are somewhat similar to those of #16821 and #13525, but based on those threads it seems like the causes may be different.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcallback: early stoppinglogger: wandbWeights & BiasesWeights & Biases