Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
swankit==0.2.0
swankit==0.2.2
urllib3>=1.26.0
requests>=2.25.0
setuptools
Expand All @@ -10,4 +10,5 @@ boto3>=1.35.49
botocore
pydantic>=2.9.0
pyecharts>=2.0.0
wrapt>=1.17.0
typing_extensions; python_version < '3.9'
2 changes: 1 addition & 1 deletion swanlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .env import SwanLabEnv
from .package import get_package_version
from .swanlab_settings import Settings
from .sync import sync_wandb, sync_tensorboardX, sync_tensorboard_torch, sync_mlflow
from .sync import sync_wandb, sync_tensorboardX, sync_tensorboard_torch, sync_mlflow, sync

# 设置默认环境变量
SwanLabEnv.set_default()
Expand Down
9 changes: 5 additions & 4 deletions swanlab/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def proj_id(self):
def projname(self):
return self.__proj.name

@property
def history_exp_count(self):
return self.__proj.history_exp_count

@property
def exp_id(self):
return self.__exp.cuid
Expand Down Expand Up @@ -244,7 +248,7 @@ def upload_files(self, buffers: List[MediaBuffer]) -> Dict[str, Union[bool, List

# ---------------------------------- 接入后端api ----------------------------------

def mount_project(self, name: str, username: str = None, public: bool = None) -> ProjectInfo:
def mount_project(self, name: str, username: str = None, public: bool = None):
"""
创建项目,如果项目已存在,则获取项目信息
:param name: 项目名称
Expand Down Expand Up @@ -295,7 +299,6 @@ def _():

project: ProjectInfo = FONT.loading("Getting project...", _)
self.__proj = project
return project

def mount_exp(self, exp_name, colors: Tuple[str, str], description: str = None, tags: List[str] = None):
"""
Expand All @@ -309,13 +312,11 @@ def mount_exp(self, exp_name, colors: Tuple[str, str], description: str = None,
def _():
"""
先创建实验,后生成cos凭证
:return:
"""
post_data = {
"name": exp_name,
"colors": list(colors),
}

if description is not None:
post_data["description"] = description
if tags is not None:
Expand Down
1 change: 1 addition & 0 deletions swanlab/api/upload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,5 @@ def upload_columns(columns: List[ColumnModel], per_request_len: int = 3000):
"MediaModel",
"ColumnModel",
"FileModel",
"LogModel",
]
10 changes: 10 additions & 0 deletions swanlab/api/upload/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@Description:
上传请求模型
"""
import json
from datetime import datetime
from enum import Enum
from typing import List, Optional, TypedDict, Literal
Expand Down Expand Up @@ -123,6 +124,15 @@ def __init__(
epoch: int,
buffers: List[MediaBuffer] = None,
):

# -------------------------- 🤡这里是一点小小的💩 --------------------------
# 要求上传时的文件路径必须带key_encoded前缀
if buffers is not None:
metric = json.loads(json.dumps(metric))
for i, d in enumerate(metric["data"]):
metric["data"][i] = "{}/{}".format(key_encoded, d)
# ------------------------------------------------------------------------

self.metric = metric
self.step = step
self.epoch = epoch
Expand Down
1 change: 1 addition & 0 deletions swanlab/cli/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .auth import login, logout
from .converter import convert
from .dashboard import watch
from .sync import sync
74 changes: 74 additions & 0 deletions swanlab/cli/commands/sync/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
@author: cunyue
@file: __init__.py
@time: 2025/6/5 14:03
@description: 同步本地数据到云端
"""

import click

from swanlab.api import terminal_login, create_http
from swanlab.error import KeyFileError
from swanlab.package import get_key, HostFormatter
from swanlab.sync import sync as sync_logs


@click.command()
@click.argument(
"path",
type=click.Path(
exists=True,
dir_okay=True,
file_okay=False,
resolve_path=True,
readable=True,
),
nargs=1,
required=True,
)
@click.option(
"--api-key",
"-k",
default=None,
type=str,
help="The API key to use for authentication. If not specified, it will use the default API key from the environment."
"If specified, it will log in using this API key but will not save the key.",
)
@click.option(
"--host",
"-h",
default=None,
type=str,
help="The host to sync the logs to. If not specified, it will use the default host.",
)
@click.option(
"--workspace",
"-w",
default=None,
type=str,
help="The workspace to sync the logs to. If not specified, it will use the default workspace.",
)
@click.option(
"--project",
"-p",
default=None,
type=str,
help="The project to sync the logs to. If not specified, it will use the default project.",
)
def sync(path, api_key, workspace, project, host):
"""
Synchronize local logs to the cloud.
"""
# 1. 创建 http 对象
# 1.1 检查host是否合法,并格式化,注入到环境变量中
HostFormatter(host)()
# 1.2 如果输入了 api-key, 使用此 api-key 登录但不保存数据
try:
api_key = get_key() if api_key is None else api_key
except KeyFileError:
pass
# 1.3 登录,创建 http 对象
log_info = terminal_login(api_key=api_key, save_key=False)
create_http(log_info)
# 2. 同步日志
sync_logs(path, workspace=workspace, project_name=project, login_required=False)
3 changes: 3 additions & 0 deletions swanlab/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def cli():
# noinspection PyTypeChecker
cli.add_command(C.convert) # 转换命令,用于转换其他实验跟踪工具

# noinspection PyTypeChecker
cli.add_command(C.sync) # 同步命令,用于同步本地数据到云端


if __name__ == "__main__":
cli()
84 changes: 33 additions & 51 deletions swanlab/data/callbacker/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
云端回调
"""

import json
import sys
from typing import Literal

from swankit.callback.models import RuntimeInfo, MetricInfo, ColumnInfo
from swankit.core import SwanLabSharedSettings
from swankit.env import create_time
from swankit.log import FONT

Expand All @@ -27,23 +25,22 @@
get_package_latest_version,
get_key,
)
from .utils import error_print, traceback_error
from ..run import get_run, SwanLabRunState
from ..run.callback import SwanLabRunCallback
from ...log.backup import backup
from ...log.type import LogData
from ...swanlab_settings import get_settings


class CloudRunCallback(SwanLabRunCallback):

def __init__(self, public: bool):
super().__init__()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pool = ThreadPool()
self.exiting = False
"""
标记是否正在退出云端环境
"""
self.public = public

@classmethod
def create_login_info(cls, save: bool = True):
Expand Down Expand Up @@ -80,28 +77,20 @@ def _view_web_print():
return exp_url

def _clean_handler(self):
run = get_run()
if run is None:
return swanlog.debug("SwanLab Runtime has been cleaned manually.")
if self.exiting:
return swanlog.debug("SwanLab is exiting, please wait.")
self._train_finish_print()
# 如果正在运行
run.finish() if run.running else swanlog.debug("Duplicate finish, ignore it.")
super()._clean_handler()

def _except_handler(self, tp, val, tb):
if self.exiting:
print("")
swanlog.error("Aborted uploading by user")
sys.exit(1)
error_print(tp)
# 结束运行
get_run().finish(SwanLabRunState.CRASHED, error=traceback_error(tb, tp(val)))
if tp != KeyboardInterrupt:
print(traceback_error(tb, tp(val)), file=sys.stderr)
super()._except_handler(tp, val, tb)

def _write_handler(self, log_data: LogData):
level: Literal['INFO', 'WARN', 'ERROR']
@backup("terminal")
def _terminal_handler(self, log_data: LogData):
level: Literal['INFO', 'WARN']
if log_data['type'] == 'stdout':
level = "INFO"
elif log_data['type'] == 'stderr':
Expand All @@ -113,18 +102,19 @@ def _write_handler(self, log_data: LogData):
def __str__(self):
return "SwanLabCloudRunCallback"

def on_init(self, project: str, workspace: str, logdir: str = None, *args, **kwargs) -> int:
def on_init(self, proj_name: str, workspace: str, public: bool = None, logdir: str = None, *args, **kwargs):
try:
http = get_http()
except ValueError:
swanlog.debug("Login info is None, get login info.")
http = create_http(self.create_login_info())
# 检测是否有最新的版本
self._get_package_latest_version()
return http.mount_project(project, workspace, self.public).history_exp_count

def before_run(self, settings: SwanLabSharedSettings, *args, **kwargs):
self.settings = settings
http.mount_project(proj_name, workspace, public)
# 设置项目缓存
self.backup.cache_proj_name = proj_name
self.backup.cache_workspace = workspace
self.backup.cache_public = public

def on_run(self, *args, **kwargs):
http = get_http()
Expand All @@ -135,25 +125,19 @@ def on_run(self, *args, **kwargs):
description=self.settings.description,
tags=self.settings.tags,
)
# 注册终端输出流代理
settings = get_settings()
if settings.log_proxy_type != "none":
swanlog.start_proxy(
proxy_type=settings.log_proxy_type,
max_log_length=settings.max_log_length,
handler=self._write_handler,
)
# 注册系统回调
self._register_sys_callback()
# 打印信息
self._train_begin_print()
# 注册运行状态
self.handle_run()
# 打印实验开始信息,在 cloud 模式下如果没有开启 backup 的话不打印“数据保存在 xxx”的信息
swanlab_settings = get_settings()
self._train_begin_print(save_dir=self.settings.run_dir if swanlab_settings.backup else None)
swanlog.info("👋 Hi " + FONT.bold(FONT.default(get_http().username)) + ", welcome to swanlab!")
swanlog.info("Syncing run " + FONT.yellow(self.settings.exp_name) + " to the cloud")
experiment_url = self._view_web_print()
# 在Jupyter Notebook环境下,显示按钮
if in_jupyter():
show_button_html(experiment_url)

@backup("runtime")
def on_runtime_info_update(self, r: RuntimeInfo, *args, **kwargs):
# 添加上传任务到线程池
rc = r.config.to_dict() if r.config is not None else None
Expand All @@ -164,6 +148,7 @@ def on_runtime_info_update(self, r: RuntimeInfo, *args, **kwargs):
f = FileModel(requirements=rr, config=rc, metadata=rm, conda=ro)
self.pool.queue.put((UploadType.FILE, [f]))

@backup("column")
def on_column_create(self, column_info: ColumnInfo, *args, **kwargs):
error = None
if column_info.error is not None:
Expand Down Expand Up @@ -192,6 +177,7 @@ def on_column_create(self, column_info: ColumnInfo, *args, **kwargs):
)
self.pool.queue.put((UploadType.COLUMN, [column]))

@backup("metric")
def on_metric_create(self, metric_info: MetricInfo, *args, **kwargs):
# 有错误就不上传
if metric_info.error:
Expand All @@ -206,29 +192,24 @@ def on_metric_create(self, metric_info: MetricInfo, *args, **kwargs):
scalar = ScalarModel(metric, key, step, epoch)
return self.pool.queue.put((UploadType.SCALAR_METRIC, [scalar]))
# 媒体指标数据

# -------------------------- 🤡这里是一点小小的💩 --------------------------
# 要求上传时的文件路径必须带key_encoded前缀
if metric_info.metric_buffers is not None:
metric = json.loads(json.dumps(metric))
for i, d in enumerate(metric["data"]):
metric["data"][i] = "{}/{}".format(key_encoded, d)
# ------------------------------------------------------------------------

media = MediaModel(metric, key, key_encoded, step, epoch, metric_info.metric_buffers)
self.pool.queue.put((UploadType.MEDIA_METRIC, [media]))

def on_stop(self, error: str = None, *args, **kwargs):
# 打印信息
self._view_web_print()
run = get_run()
# 如果正在退出或者run对象为None或者不在云端环境下
# 如果正在退出或者run对象为None或者不在云端环境下,则不执行任何操作
# 原因是在云端环境下退出时会新建一个线程完成上传日志等操作,此时回调会重复执行
# 必须要有个标志表明正在退出
if self.exiting or run is None:
return swanlog.debug("SwanLab is exiting or run is None, ignore it.")
state = run.state
# 标志正在退出(需要在下面的逻辑之前标志)
# ---------------------------------- 正在退出 ----------------------------------
self.exiting = True
# 打印信息
self._view_web_print()
state = run.state
sys.excepthook = self._except_handler
swanlog_epoch = run.swanlog_epoch
self.backup.stop(error=error, epoch=swanlog_epoch + 1)

def _():
# 关闭线程池,等待上传线程完成
Expand All @@ -237,7 +218,7 @@ def _():
if error is not None:
logs = LogModel(
level="ERROR",
contents=[{"message": error, "create_time": create_time(), "epoch": swanlog.epoch + 1}],
contents=[{"message": error, "create_time": create_time(), "epoch": swanlog_epoch + 1}],
)
upload_logs([logs])

Expand All @@ -247,6 +228,7 @@ def _():
# 取消注册系统回调
self._unregister_sys_callback()
self.exiting = False
# -------------------------------------------------------------------------


def show_button_html(experiment_url):
Expand Down
Loading