Skip to content

Conversation

Zeyi-Lin
Copy link
Member

@Zeyi-Lin Zeyi-Lin commented Apr 18, 2024

Description

与PyTorch Lightning库的集成,测试代码为:

from swanlab.integration.pytorch_lightning import SwanLabLogger

import importlib.util
import os

if importlib.util.find_spec("lightning"):
    import lightning.pytorch as pl
elif importlib.util.find_spec("pytorch_lightning"):  # noqa F401
    import pytorch_lightning as pl
else:
    raise RuntimeError(
        "This contrib module requires PyTorch Lightning to be installed. "
        "Please install it with command: \n pip install pytorch-lightning \n"
        "or \n pip install lightning"
    )
from torch import nn, optim, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))


# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        # test_step defines the test loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("test_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)


# setup data
dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
train_dataset, val_dataset = utils.data.random_split(dataset, [55000, 5000])
test_dataset = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())

train_loader = utils.data.DataLoader(train_dataset)
val_loader = utils.data.DataLoader(val_dataset)
test_loader = utils.data.DataLoader(test_dataset)

swanlab_logger = SwanLabLogger(
    project="swanlab_example",
    experiment_name="example_experiment",
    cloud=False,
)

trainer = pl.Trainer(limit_train_batches=100, max_epochs=5, logger=swanlab_logger)


trainer.fit(model=autoencoder, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(dataloaders=test_loader)

Closes: #478

@Zeyi-Lin Zeyi-Lin requested a review from SAKURA-CAT April 18, 2024 20:11
@Zeyi-Lin Zeyi-Lin self-assigned this Apr 18, 2024
@Zeyi-Lin Zeyi-Lin added this to the Integration milestone Apr 18, 2024
@Zeyi-Lin Zeyi-Lin added the 💪 enhancement New feature or request label Apr 18, 2024
@SAKURA-CAT SAKURA-CAT merged commit 1584f95 into main Apr 28, 2024
@SAKURA-CAT SAKURA-CAT deleted the feat/integration-lightning branch April 28, 2024 06:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

💪 enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[REQUEST] 集成PyTorch Lightning

2 participants