Skip to content

Commit 66adcaf

Browse files
authored
Refactor/data-code (#1126)
* refactor: backup and proto * refactor: transfer * refactor: swanlab settings * Refactor backup and transfer modules structure Moved backup-related modules from swanlab.data.backup to swanlab.log.backup and transfer modules from swanlab.data.transfers to swanlab.transfers. Updated all relevant imports and usages throughout the codebase. Refactored Transfer and ProtoV0Transfer to accept media_dir and file_dir as constructor arguments. Improved error log uploading in CloudPyCallback and unified backup handler usage. * Refactor backup handler and callbackers for improved modularity Refactored BackupHandler to accept run parameters directly instead of accessing run_store internally, improving modularity and testability. Updated cloud, local, and offline callbackers to pass run parameters explicitly and to use run_store consistently. Added access control to get_run_store to restrict usage to swanlab.data module. Moved uploader imports to explicit usage and centralized error log uploading in ProtoV0Transfer. * Refactor callbacker module and unify callback logic Moved and refactored callback logic from swanlab/data/run/callback.py into swanlab/data/callbacker/callback.py, and updated cloud, local, and offline callback implementations to inherit from the new base class. Centralized utility functions in swanlab/data/callbacker/utils.py and updated all callbackers to use these shared utilities for printing and path formatting. Removed redundant code and improved maintainability by reducing duplication and clarifying callback registration and cleanup logic. * Refactor data transfer and backup architecture Removed legacy transfer and backup modules, consolidating data transfer logic into a new ProtoTransfer singleton in swanlab/data/transfer.py. Updated callbackers to use ProtoTransfer for logging, metric, and runtime info handling. Migrated DataStore to swanlab/data, removed async_io utility, and cleaned up related imports and usages. Adjusted tests and internal references to reflect new module structure and APIs. * Refactor DataPorter for unified data sync and trace modes Replaces the previous ProtoTransfer and ModelsParser logic with a new DataPorter class that supports both experiment trace and sync upload modes, centralizing backup file parsing, data publishing, and resource management. Updates all callbackers and sync logic to use DataPorter, removes ModelsParser from proto/v0.py, and adds platformdirs to requirements. This refactor improves maintainability and consistency for experiment data handling and synchronization. * Add synced decorator and improve backup handling Applied the @synced decorator to DataPorter methods to ensure thread safety. Improved backup handling in SwanLabRun by deleting the run directory when backup is disabled. Updated SwanLabInitializer to use a system runtime directory for logs when backup is off, ensuring proper log storage. * Update __init__.py * Refactor test setup and store access control Introduced an @inside decorator to restrict access to RunStore functions to the swanlab.data module or test runtime. Refactored test utilities to use a new UseMockRunState context manager for consistent client and store state management in tests. Updated imports and test code to align with these changes, improving test isolation and reliability. * Refactor local callback and improve config error handling Refactored LocalRunCallback to set logdir via environment variable and updated tests to use UseMockRunState for better isolation. Changed SwanLabRun to default operator to None. Cleared callbacks after validation in SwanLabInitializer. Improved and relocated error handling for invalid config parameters from test_config.py to test_sdk.py, adding more comprehensive tests for config input validation. * Add disabled callback mode and refactor logdir handling Introduced DisabledCallback for 'disabled' mode, ensuring proper handling when runs are not to be saved or uploaded. Refactored logdir initialization logic in SwanLabInitializer to unify and simplify directory setup for all modes, and moved run_dir cleanup logic from main.py to the cloud callback. Updated operator creation to use DisabledCallback in disabled mode. * chore: fix test * fix: test * chore: fix typing * chore: change cache dir * chore: some misc * fix: code
1 parent dadba39 commit 66adcaf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+1768
-1697
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ botocore
1111
pydantic>=2.9.0
1212
pyecharts>=2.0.0
1313
wrapt>=1.17.0
14+
platformdirs>=4.2.0
1415
typing_extensions; python_version < '3.9'
1516
protobuf>=3.12.0,!=4.21.0,!=5.28.0,<7; python_version < '3.9' and sys_platform == 'linux'
1617
protobuf>=3.15.0,!=4.21.0,!=5.28.0,<7; python_version == '3.9' and sys_platform == 'linux'

swanlab/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# 导出初始化函数和log函数,以及一些数据处理模块
22
from .data import *
3+
from .data.modules import (
4+
Audio,
5+
Image,
6+
Object3D,
7+
Molecule,
8+
Text,
9+
echarts,
10+
)
311
from .env import SwanLabEnv
412
from .package import get_package_version
513
from .swanlab_settings import Settings

swanlab/cli/commands/sync/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def sync(path, api_key, workspace, project, host):
6767
api_key = get_key() if api_key is None else api_key
6868
except KeyFileError:
6969
pass
70-
for path in path:
70+
for p in path:
7171
# 1.3 登录,创建 http 对象
7272
log_info = auth.terminal_login(api_key=api_key, save_key=False)
7373
create_client(log_info)
7474
# 2. 同步日志
75-
sync_logs(path, workspace=workspace, project_name=project, login_required=False, raise_error=len(path) == 1)
75+
sync_logs(p, workspace=workspace, project_name=project, login_required=False, raise_error=len(path) == 1)

swanlab/core_python/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,3 @@
88
"""
99

1010
from .client import *
11-
from .uploader import *
12-
from .uploader import thread

swanlab/core_python/auth/providers/api_key.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from rich.status import Status
1414
from rich.text import Text
1515

16-
from swanlab.env import is_windows
17-
from swanlab.error import ValidationError, APIKeyFormatError
16+
from swanlab.core_python import auth
17+
from swanlab.env import is_windows, is_interactive
18+
from swanlab.error import ValidationError, APIKeyFormatError, KeyFileError
1819
from swanlab.log import swanlog
19-
from swanlab.package import get_setting_url, get_host_api, get_host_web, fmt_web_host, save_key as sk
20+
from swanlab.package import get_setting_url, get_host_api, get_host_web, fmt_web_host, save_key as sk, get_key
2021

2122

2223
class LoginInfo:
@@ -206,6 +207,22 @@ def login_again(error: Exception):
206207
api_key = login_again(e)
207208

208209

210+
def create_login_info(save: bool = True):
211+
"""
212+
在代码运行时发起登录,获取登录信息,执行此方法会覆盖原有的login_info
213+
"""
214+
key = None
215+
try:
216+
key = get_key()
217+
except KeyFileError:
218+
pass
219+
if key is None and not is_interactive():
220+
raise KeyFileError(
221+
"api key not configured (no-tty), call `swanlab.login(api_key=[your_api_key])` or set `swanlab.init(mode=\"local\")`."
222+
)
223+
return auth.terminal_login(key, save)
224+
225+
209226
def _abort_tip(tp, _, __):
210227
"""处理用户在input_api_key输入时按下CTRL+C的情况"""
211228
if tp == KeyboardInterrupt:
@@ -219,4 +236,5 @@ def _abort_tip(tp, _, __):
219236
"terminal_login",
220237
"code_login",
221238
"LoginInfo",
239+
"create_login_info",
222240
]

swanlab/core_python/client/__init__.py

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
import requests
1313
from requests.adapters import HTTPAdapter
14-
from rich.status import Status
1514
from urllib3.util.retry import Retry
1615

1716
from swanlab.error import NetworkError, ApiError
@@ -255,46 +254,44 @@ def mount_project(self, name: str, username: str = None, public: bool = None):
255254
:param public: 项目是否公开
256255
:return: 项目信息
257256
"""
258-
with Status("Getting project...", spinner="dots"):
259-
try:
260-
data = {"name": name}
261-
if username is not None:
262-
data["username"] = username
263-
if public is not None:
264-
data["visibility"] = "PUBLIC" if public else "PRIVATE"
265-
resp = self.post(f"/project", data=data)
266-
except ApiError as e:
267-
if e.resp.status_code == 409:
268-
# 项目已经存在,从对象中解析信息
269-
resp = decode_response(e.resp)
270-
elif e.resp.status_code == 404 and e.resp.reason == "Not Found":
271-
# WARNING: 早期 (私有化) swanlab 后端没有 /project 接口,需要使用 /project/{username} 接口,此时没有默认空间的特性
272-
self.__groupname = self.__groupname if username is None else username
273-
try:
274-
visibility = "PUBLIC" if public else "PRIVATE"
275-
resp = self.post(f"/project/{self.groupname}", data={"name": name, "visibility": visibility})
276-
except ApiError as e:
277-
# 如果为409,表示已经存在,获取项目信息
278-
if e.resp.status_code == 409:
279-
resp = self.get(f"/project/{self.groupname}/{name}")
280-
elif e.resp.status_code == 404:
281-
# 组织/用户不存在
282-
raise ValueError(f"Space `{self.groupname}` not found")
283-
elif e.resp.status_code == 403:
284-
# 权限不足
285-
raise ValueError(f"Space permission denied: " + self.groupname)
286-
else:
287-
raise e
288-
return ProjectInfo(resp)
289-
else:
290-
# 此接口为后端处理,sdk 在理论上不会出现其他错误,因此不需要处理其他错误
291-
raise e
292-
# 设置当前项目所属的用户名
293-
self.__groupname = resp['username']
294-
# 获取详细信息
295-
resp = self.get(f"/project/{self.groupname}/{name}")
296-
project = ProjectInfo(resp)
297-
self.__proj = project
257+
try:
258+
data = {"name": name}
259+
if username is not None:
260+
data["username"] = username
261+
if public is not None:
262+
data["visibility"] = "PUBLIC" if public else "PRIVATE"
263+
resp = self.post(f"/project", data=data)
264+
except ApiError as e:
265+
if e.resp.status_code == 409:
266+
# 项目已经存在,从对象中解析信息
267+
resp = decode_response(e.resp)
268+
elif e.resp.status_code == 404 and e.resp.reason == "Not Found":
269+
# WARNING: 早期 (私有化) swanlab 后端没有 /project 接口,需要使用 /project/{username} 接口,此时没有默认空间的特性
270+
self.__groupname = self.__groupname if username is None else username
271+
try:
272+
visibility = "PUBLIC" if public else "PRIVATE"
273+
resp = self.post(f"/project/{self.groupname}", data={"name": name, "visibility": visibility})
274+
except ApiError as e:
275+
# 如果为409,表示已经存在,获取项目信息
276+
if e.resp.status_code == 409:
277+
resp = self.get(f"/project/{self.groupname}/{name}")
278+
elif e.resp.status_code == 404:
279+
# 组织/用户不存在
280+
raise ValueError(f"Space `{self.groupname}` not found")
281+
elif e.resp.status_code == 403:
282+
# 权限不足
283+
raise ValueError(f"Space permission denied: " + self.groupname)
284+
else:
285+
raise e
286+
return ProjectInfo(resp)
287+
else:
288+
# 此接口为后端处理,sdk 在理论上不会出现其他错误,因此不需要处理其他错误
289+
raise e
290+
# 设置当前项目所属的用户名
291+
self.__groupname = resp['username']
292+
# 获取详细信息
293+
resp = self.get(f"/project/{self.groupname}/{name}")
294+
self.__proj = ProjectInfo(resp)
298295

299296
def mount_exp(self, exp_name, colors: Tuple[str, str], description: str = None, tags: List[str] = None):
300297
"""
@@ -304,20 +301,19 @@ def mount_exp(self, exp_name, colors: Tuple[str, str], description: str = None,
304301
:param description: 实验描述
305302
:param tags: 实验标签
306303
"""
307-
with Status("Getting experiment...", spinner="dots"):
308-
post_data = {
309-
"name": exp_name,
310-
"colors": list(colors),
311-
}
312-
if description is not None:
313-
post_data["description"] = description
314-
if tags is not None:
315-
post_data["labels"] = [{"name": tag} for tag in tags]
316-
317-
data = self.post(f"/project/{self.groupname}/{self.__proj.name}/runs", post_data)
318-
self.__exp = ExperimentInfo(data)
319-
# 获取cos信息
320-
self.__get_cos()
304+
post_data = {
305+
"name": exp_name,
306+
"colors": list(colors),
307+
}
308+
if description is not None:
309+
post_data["description"] = description
310+
if tags is not None:
311+
post_data["labels"] = [{"name": tag} for tag in tags]
312+
313+
data = self.post(f"/project/{self.groupname}/{self.__proj.name}/runs", post_data)
314+
self.__exp = ExperimentInfo(data)
315+
# 获取cos信息
316+
self.__get_cos()
321317

322318
def update_state(self, success: bool):
323319
"""

swanlab/data/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
@Description:
88
在此处完成回调注册、swanlog注册,并为外界提供api,提供运行时生成的配置
99
"""
10-
from .modules import Audio, Image, Text, Object3D, Molecule, echarts
1110
from .run import (
1211
SwanLabRun as Run,
1312
SwanLabRunState as State,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
@author: cunyue
3+
@file: __init__.py
4+
@time: 2025/6/21 17:10
5+
@description: 回调器模块,四大回调器对应 swanlab 的四种运行模式。
6+
1. local: 本地模式回调器
7+
2. cloud: 云端模式回调器
8+
3. offline: 离线模式回调器
9+
4. disabled: 禁用回调器
10+
"""
11+
12+
from .cloud import CloudPyCallback
13+
from .disabled import DisabledCallback
14+
from .local import LocalRunCallback
15+
from .offline import OfflineCallback
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
r"""
4+
@DATE: 2024/6/19 16:46
5+
@File: callback.py
6+
@IDE: pycharm
7+
@Description:
8+
回调函数注册抽象模块
9+
"""
10+
import atexit
11+
import sys
12+
import traceback
13+
14+
from swanlab.data.run import SwanLabRunState, get_run
15+
from swanlab.log import swanlog
16+
from swanlab.swanlab_settings import get_settings
17+
from swanlab.toolkit import SwanKitCallback
18+
from . import utils
19+
from ..porter import DataPorter
20+
from ..store import get_run_store
21+
22+
23+
class SwanLabRunCallback(SwanKitCallback):
24+
"""
25+
SwanLabRunCallback,回调函数注册类,所有以`on_`和`before_`开头的函数都会在对应的时机被调用
26+
为了方便管理:
27+
1. `_`开头的函数为内部函数,不会被调用,且写在最开头
28+
2. 所有回调按照逻辑上的触发顺序排列
29+
3. 带有from_*后缀的回调函数代表调用者来自其他地方,比如config、operator等,这将通过settings对象传递
30+
4. 所有回调不要求全部实现,只需实现需要的回调即可
31+
"""
32+
33+
def __init__(self):
34+
self.run_store = get_run_store()
35+
self.porter = DataPorter()
36+
self.user_settings = get_settings()
37+
38+
def _register_sys_callback(self):
39+
"""
40+
注册系统回调,内部使用
41+
"""
42+
sys.excepthook = self._except_handler
43+
atexit.register(self._clean_handler)
44+
45+
def _unregister_sys_callback(self):
46+
"""
47+
注销系统回调,内部使用
48+
"""
49+
sys.excepthook = sys.__excepthook__
50+
atexit.unregister(self._clean_handler)
51+
52+
def _clean_handler(self):
53+
"""
54+
正常退出清理函数,此函数调用`run.finish`
55+
"""
56+
run = get_run()
57+
if run is None:
58+
return swanlog.debug("SwanLab Runtime has been cleaned manually.")
59+
# 打印训练结束信息
60+
utils.print_train_finish(self.run_store.run_name)
61+
# 如果正在运行
62+
run.finish() if run.running else swanlog.debug("Duplicate finish, ignore it.")
63+
64+
@staticmethod
65+
def _except_handler(tp, val, tb):
66+
"""
67+
异常退出清理函数
68+
"""
69+
# 1. 如果是KeyboardInterrupt异常,特殊显示
70+
if tp == KeyboardInterrupt:
71+
swanlog.info("KeyboardInterrupt by user")
72+
else:
73+
swanlog.info("Error happened while training")
74+
# 2. 生成错误堆栈
75+
trace_list = traceback.format_tb(tb)
76+
error = ""
77+
for line in trace_list:
78+
error += line
79+
error += str(val)
80+
# 3. 结束运行,注意此时终端错误还没打印
81+
get_run().finish(SwanLabRunState.CRASHED, error=error)
82+
assert swanlog.proxied is False, "except_handler should be called after swanlog.stop_proxy()"
83+
# 4. 打印终端错误,此时终端代理已经停止,不必担心此副作用
84+
print(error, file=sys.stderr)
85+
86+
def _start_terminal_proxy(self, handler=None):
87+
"""
88+
启动终端代理
89+
"""
90+
if handler is None:
91+
handler = lambda data: self.porter.trace_log(data)
92+
swanlog.start_proxy(
93+
proxy_type=self.user_settings.log_proxy_type,
94+
max_log_length=self.user_settings.max_log_length,
95+
handler=handler,
96+
)
97+
98+
def __str__(self):
99+
raise NotImplementedError("Please implement this method")

0 commit comments

Comments
 (0)