Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions swanlab/integration/mmengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class SwanlabVisBackend(BaseVisBackend):

def __init__(
self,
save_dir: str,
save_dir: str = None,
init_kwargs: Optional[dict] = None,
):
self._save_dir = save_dir
Expand All @@ -90,8 +90,9 @@ def experiment(self) -> Any:

def _init_env(self) -> Any:
"""Setup env for swanlab."""
if not os.path.exists(self._save_dir):
os.makedirs(self._save_dir, exist_ok=True) # type: ignore
if self._save_dir is not None:
if not os.path.exists(self._save_dir):
os.makedirs(self._save_dir, exist_ok=True) # type: ignore
if self._init_kwargs is None:
self._init_kwargs = {"logdir": self._save_dir}
else:
Expand Down
40 changes: 40 additions & 0 deletions test/integration/mmengine/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# SwanLab integrate MMEngine test docs

## User`s guidance

There are two scripts to test intgration module

* `mmengine_train.py` for full test. It is a implement for cifar10 classification mission. Use resnet50 as our classifier and use mmengine with our training farmwork. Ref [mmengine docs](https://mmengine.readthedocs.io/en/latest/get_started/15_minutes.html).

* `mmengine_visualizer_import.py` is a simple nad efficient test scripts. It use `mmengine.registry` initialize mmengine 'visualizer' with swanlab backend. Ref [mmegine docs](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/visualization.html#customize-storage-backends-and-visualizers)

## Install

Refer to [mmengine official document](https://mmengine.readthedocs.io/en/latest/get_started/installation.html).

Install the environment with following command.

```sh
# with cuda12.1 or you can find torch version you want at pytorch.org
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121

pip install -U openmim
mim install mmengine
pip install swanlab
```

Or use `pip install -r requirements.txt` to install env. (hope it can work)

## Start Test

Simple test with init visualizer

```sh
python mmengine_visualizer_import.py
```

Full test with training mission

```sh
python mmengine_train.py
```
94 changes: 94 additions & 0 deletions test/integration/mmengine/mmengine_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmengine.evaluator import BaseMetric
from mmengine.model import BaseModel
from mmengine.runner import Runner
from mmengine.visualization import Visualizer


from tutils import open_dev_mode
import swanlab
from swanlab.integration.mmengine import SwanlabVisBackend

swanlab.login(open_dev_mode())


class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == "loss":
return {"loss": F.cross_entropy(x, labels)}
elif mode == "predict":
return x, labels


class Accuracy(BaseMetric):
def process(self, data_batch, data_samples):
score, gt = data_samples
self.results.append(
{
"batch_size": len(gt),
"correct": (score.argmax(dim=1) == gt).sum().cpu(),
}
)

def compute_metrics(self, results):
total_correct = sum(item["correct"] for item in results)
total_size = sum(item["batch_size"] for item in results)
return dict(accuracy=100 * total_correct / total_size)


norm_cfg = dict(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
train_dataloader = DataLoader(
batch_size=32,
shuffle=True,
dataset=torchvision.datasets.CIFAR10(
"data/cifar10",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(**norm_cfg),
]
),
),
)

val_dataloader = DataLoader(
batch_size=32,
shuffle=False,
dataset=torchvision.datasets.CIFAR10(
"data/cifar10",
train=False,
download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(**norm_cfg)]),
),
)

visualizer = Visualizer(
vis_backends=SwanlabVisBackend(init_kwargs={})
) # init args can be found in https://docs.swanlab.cn/zh/guide_cloud/integration/integration-mmengine.html

runner = Runner(
model=MMResNet50(),
work_dir="./work_dir",
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
visualizer=visualizer,
)
runner.train()
37 changes: 37 additions & 0 deletions test/integration/mmengine/mmengine_visualizer_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from mmengine.config import Config
from mmengine.registry import VISUALIZERS
import mmengine
import swanlab

print(f"MMEngine Version: {mmengine.__version__}")
print(f"SwanLab Version: {swanlab.__version__}")

cfg_text = """
# swanlab visualizer
custom_imports = dict( # 引入SwanLab作为日志记录器,对于部分不支持custom_imports的项目可以直接初始化SwanlabVisBackend并加入vis_backends
imports=["swanlab.integration.mmengine"], allow_failed_imports=False
)

vis_backends = [
dict(
type="SwanlabVisBackend",
init_kwargs={ # swanlab.init 参数
"project": "swanlab-mmengine",
"experiment_name": "Your exp", # 实验名称
"description": "Note whatever you want", # 实验的描述信息
},
),
]

visualizer = dict(
type="Visualizer",
vis_backends=vis_backends,
name="visualizer",
)

"""

cfg = Config.fromstring(cfg_text, ".py")

custom_vis = VISUALIZERS.build(cfg.visualizer)
print(custom_vis)
4 changes: 4 additions & 0 deletions test/integration/mmengine/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch
torchvision
mmengine
swanlab