Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 42 additions & 104 deletions swanlab/integration/accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from swanlab.integration.accelerate import SwanLabTracker
...
tracker = SwanLabTracker("some_run_name")
accelerator = Accelerator(log_with=tracker)
accelerator = Accelerator(log_with=[tracker])
...
---------------------------------
These also can be mixed with existing trackers, including with "all":
Expand Down Expand Up @@ -46,107 +46,73 @@ class SwanLabTracker(GeneralTracker):

Args:
run_name (`str`):
The name of the experiment run
logging_dir (`str`, `os.PathLike`):
Location for swanlab logs to be stored.
kwargs:
The name of the experiment run.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `swanlab.init` method.
"""

name = "swanlab"
requires_logging_directory = False
main_process_only = False

@on_main_process
def __init__(
self,
project: Optional[str] = None,
workspace: Optional[str] = None,
experiment_name: Optional[str] = None,
description: Optional[str] = None,
logdir: Optional[str] = None,
mode: Optional[str] = None,
tags: Optional[List[str]] = None,
**kwargs,
):
def __init__(self, run_name: str, **kwargs):
super().__init__()

tags = tags or []
tags.append("accelerate") if "accelerate" not in tags else None

self._swanlab_init: Dict[str, Any] = {
"project": project,
"workspace": workspace,
"experiment_name": experiment_name,
"description": description,
"logdir": logdir,
"mode": mode,
"tags": tags,
}

self._swanlab_init.update(**kwargs)

self._project = self._swanlab_init.get("project")
self._workspace = self._swanlab_init.get("workspace")
self._experiment_name = self._swanlab_init.get("experiment_name")
self._description = self._swanlab_init.get("description")
self._logdir = self._swanlab_init.get("logdir")
self._mode = self._swanlab_init.get("mode")
self._tags = self._swanlab_init.get("tags")

self.logdir = os.path.join(logdir, self._project) if self._logdir is not None else None

swanlab.config["FRAMEWORK"] = "accelerate"
if swanlab.get_run() is None:
self.writer = swanlab.init(**self._swanlab_init)
else:
self.writer = swanlab.get_run()

logger.debug(f"Initialized swanlab project {self._project} logging to {self._logdir}")
self.run_name = run_name
self.init_kwargs = kwargs

@on_main_process
def start(self):
import swanlab

self.run = swanlab.init(project=self.run_name, **self.init_kwargs)
swanlab.config["FRAMEWORK"] = "accelerate" # add accelerate logo in config
logger.debug(f"Initialized SwanLab project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
)

def update_config(self, config: Dict[str, Any]):
swanlab.config.update(config)

@property
def tracker(self):
return self.writer
return self.run

@on_main_process
def store_init_configuration(self, values: dict):
"""
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
hyperparameters in a yaml file for future use.
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.

Args:
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
`str`, `float`, `int`, or `None`.
"""
swanlab.config.update(values)
logger.debug("Stored initial configuration hyperparameters to swanlab")
import swanlab

swanlab.config.update(values, allow_val_change=True)
logger.debug("Stored initial configuration hyperparameters to SwanLab")

@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.

Args:
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
`str` to `float`/`int`.
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
data : Dict[str, DataType]
Data must be a dict.
The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/".
The value must be a `float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
step : int, optional
The step number of the current data, if not provided, it will be automatically incremented.
If step is duplicated, the data will be ignored.
kwargs:
Additional key word arguments passed along to either `SummaryWriter.add_scaler`,
`SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`.
Additional key word arguments passed along to the `swanlab.log` method. Likes:
print_to_console : bool, optional
Whether to print the data to the console, the default is False.
"""
self.writer.log(values, step=step)
logger.debug("Successfully logged to swanlab")
self.run.log(values, step=step, **kwargs)
logger.debug("Successfully logged to SwanLab")

@on_main_process
def log_images(self, values: dict, step: Optional[int], **kwargs):
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `images` to the current run.

Expand All @@ -156,48 +122,20 @@ def log_images(self, values: dict, step: Optional[int], **kwargs):
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs:
Additional key word arguments passed along to the `SummaryWriter.add_image` method.
Additional key word arguments passed along to the `swanlab.log` method. Likes:
print_to_console : bool, optional
Whether to print the data to the console, the default is False.
"""
import swanlab

for k, v in values.items():
self.writer.log(k, swanlab.Image(v), step=step)

logger.debug("Successfully logged images to swanlab")

# @on_main_process
# def log_table(
# self,
# table_name: str,
# columns: List[str] = None,
# data: List[List[Any]] = None,
# dataframe: Any = None,
# step: Optional[int] = None,
# **kwargs,
# ):
# """
# Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either
# with `columns` and `data` or with `dataframe`.

# Args:
# table_name (`str`):
# The name to give to the logged table on the swanlab workspace
# columns (List of `str`'s *optional*):
# The name of the columns on the table
# data (List of List of Any data type *optional*):
# The data to be logged in the table
# dataframe (Any data type *optional*):
# The data to be logged in the table
# step (`int`, *optional*):
# The run step. If included, the log will be affiliated with this step.
# """

# values = {table_name: swanlab.Table(columns=columns, data=data, dataframe=dataframe)}
# self.log(values, step=step, **kwargs)
self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs)
logger.debug("Successfully logged images to SwanLab")

@on_main_process
def finish(self):
"""
Closes `swanlab` writer
"""
self.writer.finish()
logger.debug("swanlab run closed")
self.run.finish()
logger.debug("SwanLab run closed")
Loading