Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/source/package_reference/tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ rendered properly in your Markdown viewer.

[[autodoc]] tracking.ClearMLTracker
- __init__

## SwanLabTracker

[[autodoc]] tracking.SwanLabTracker
- __init__
2 changes: 1 addition & 1 deletion examples/by_feature/deepspeed_with_config_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def parse_args():
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
' `"wandb"`, `"comet_ml"`, `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/megatron_lm_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def parse_args():
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`, and `"swanlab"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
Expand Down
10 changes: 9 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@
extras["rich"] = ["rich"]

extras["test_fp8"] = ["torchao"] # note: TE for now needs to be done via pulling down the docker image directly
extras["test_trackers"] = ["wandb", "comet-ml", "tensorboard", "dvclive", "mlflow", "matplotlib"]
extras["test_trackers"] = [
"wandb",
"comet-ml",
"tensorboard",
"dvclive",
"mlflow",
"matplotlib",
"swanlab",
]
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]

extras["sagemaker"] = [
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class Accelerator:
- `"tensorboard"`
- `"wandb"`
- `"comet_ml"`
- `"swanlab"`
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
project_config ([`~utils.ProjectConfiguration`], *optional*):
Expand Down
10 changes: 9 additions & 1 deletion src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
is_pytest_available,
is_schedulefree_available,
is_sdaa_available,
is_swanlab_available,
is_tensorboard_available,
is_timm_available,
is_torch_version,
Expand Down Expand Up @@ -482,6 +483,13 @@ def require_dvclive(test_case):
return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)


def require_swanlab(test_case):
"""
Decorator marking a test that requires swanlab installed. These tests are skipped when swanlab isn't installed
"""
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)


def require_pandas(test_case):
"""
Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
Expand Down Expand Up @@ -536,7 +544,7 @@ def require_matplotlib(test_case):


_atleast_one_tracker_available = (
any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available()
any([is_wandb_available(), is_tensorboard_available(), is_swanlab_available()]) and not is_comet_ml_available()
)


Expand Down
106 changes: 106 additions & 0 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_comet_ml_available,
is_dvclive_available,
is_mlflow_available,
is_swanlab_available,
is_tensorboard_available,
is_wandb_available,
listify,
Expand Down Expand Up @@ -63,6 +64,9 @@
if is_dvclive_available():
_available_trackers.append(LoggerType.DVCLIVE)

if is_swanlab_available():
_available_trackers.append(LoggerType.SWANLAB)

logger = get_logger(__name__)


Expand Down Expand Up @@ -1061,6 +1065,106 @@ def finish(self):
self.live.end()


class SwanLabTracker(GeneralTracker):
"""
A `Tracker` class that supports `swanlab`. Should be initialized at the start of your script.

Args:
run_name (`str`):
The name of the experiment run.
**kwargs (additional keyword arguments, *optional*):
Additional key word arguments passed along to the `swanlab.init` method.
"""

name = "swanlab"
requires_logging_directory = False
main_process_only = False

def __init__(self, run_name: str, **kwargs):
super().__init__()
self.run_name = run_name
self.init_kwargs = kwargs

@on_main_process
def start(self):
import swanlab

self.run = swanlab.init(project=self.run_name, **self.init_kwargs)
swanlab.config["FRAMEWORK"] = "🤗Accelerate" # add accelerate logo in config
logger.debug(f"Initialized SwanLab project {self.run_name}")
logger.debug(
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
)

@property
def tracker(self):
return self.run

@on_main_process
def store_init_configuration(self, values: dict):
"""
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.

Args:
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
`str`, `float`, `int`, or `None`.
"""
import swanlab

swanlab.config.update(values, allow_val_change=True)
logger.debug("Stored initial configuration hyperparameters to SwanLab")

@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.

Args:
data : Dict[str, DataType]
Data must be a dict. The key must be a string with 0-9, a-z, A-Z, " ", "_", "-", "/". The value must be a
`float`, `float convertible object`, `int` or `swanlab.data.BaseType`.
step : int, optional
The step number of the current data, if not provided, it will be automatically incremented.
If step is duplicated, the data will be ignored.
kwargs:
Additional key word arguments passed along to the `swanlab.log` method. Likes:
print_to_console : bool, optional
Whether to print the data to the console, the default is False.
"""
self.run.log(values, step=step, **kwargs)
logger.debug("Successfully logged to SwanLab")

@on_main_process
def log_images(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `images` to the current run.

Args:
values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`):
Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs:
Additional key word arguments passed along to the `swanlab.log` method. Likes:
print_to_console : bool, optional
Whether to print the data to the console, the default is False.
"""
import swanlab

for k, v in values.items():
self.log({k: [swanlab.Image(image) for image in v]}, step=step, **kwargs)
logger.debug("Successfully logged images to SwanLab")

@on_main_process
def finish(self):
"""
Closes `swanlab` writer
"""
self.run.finish()
logger.debug("SwanLab run closed")


LOGGER_TYPE_TO_CLASS = {
"aim": AimTracker,
"comet_ml": CometMLTracker,
Expand All @@ -1069,6 +1173,7 @@ def finish(self):
"wandb": WandBTracker,
"clearml": ClearMLTracker,
"dvclive": DVCLiveTracker,
"swanlab": SwanLabTracker,
}


Expand All @@ -1093,6 +1198,7 @@ def filter_trackers(
- `"comet_ml"`
- `"mlflow"`
- `"dvclive"`
- `"swanlab"`
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
logging_dir (`str`, `os.PathLike`, *optional*):
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
is_sagemaker_available,
is_schedulefree_available,
is_sdaa_available,
is_swanlab_available,
is_tensorboard_available,
is_timm_available,
is_torch_xla_available,
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,7 @@ class LoggerType(BaseEnum):
- **WANDB** -- wandb as an experiment tracker
- **COMETML** -- comet_ml as an experiment tracker
- **DVCLIVE** -- dvclive as an experiment tracker
- **SWANLAB** -- swanlab as an experiment tracker
"""

ALL = "all"
Expand All @@ -711,6 +712,7 @@ class LoggerType(BaseEnum):
MLFLOW = "mlflow"
CLEARML = "clearml"
DVCLIVE = "dvclive"
SWANLAB = "swanlab"


class PrecisionType(str, BaseEnum):
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ def is_comet_ml_available():
return _is_package_available("comet_ml")


def is_swanlab_available():
return _is_package_available("swanlab")


def is_boto3_available():
return _is_package_available("boto3")

Expand Down
Loading