Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions swanlab/data/callback_cloud.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/5/5 20:22
@File: callback_cloud.py
@IDE: pycharm
@Description:
云端回调
"""
from .run.callback import NewKeyInfo
from swanlab.cloud import UploadType
from typing import Optional, Dict
from swanlab.error import ApiError
from swanlab.api.upload.model import ColumnModel
from urllib.parse import quote
from swanlab.api import LoginInfo, create_http, terminal_login
from swanlab.api.upload import upload_logs
from swanlab.log import swanlog
from swanlab.utils.font import FONT
from swanlab.api import get_http
from swanlab.utils.key import get_key
from swanlab.utils.judgment import in_jupyter, show_button_html
from swanlab.package import get_host_web, get_host_api
from swanlab.error import KeyFileError
from swanlab.env import get_swanlab_folder
from .callback_local import LocalRunCallback, get_run, SwanLabRunState
from swanlab.cloud import LogSnifferTask, ThreadPool
from swanlab.db import Experiment
from swanlab.utils import create_time
import sys
import os


class CloudRunCallback(LocalRunCallback):
login_info: LoginInfo = None
"""
用户登录信息
"""

def __init__(self):
super(CloudRunCallback, self).__init__()
self.pool = ThreadPool()
self.exiting = False
"""
标记是否正在退出云端环境
"""

def before_init_project(self, project: str, workspace: str, *args, **kwargs) -> int:
if self.login_info is None:
swanlog.debug("Login info is None, get login info.")
self.login_info = self.get_login_info()
http = create_http(self.login_info)
return http.mount_project(project, workspace).history_exp_count

def _clean_handler(self):
run = get_run()
if run is None:
return swanlog.debug("SwanLab Runtime has been cleaned manually.")
if self.exiting:
return swanlog.debug("SwanLab is exiting, please wait.")
self._train_finish_print()
# 如果正在运行
run.finish() if run.is_running else swanlog.debug("Duplicate finish, ignore it.")

def _except_handler(self, tp, val, tb):
if self.exiting:
# FIXME not a good way to fix '\n' problem
print("")
swanlog.error("Aborted uploading by user")
sys.exit(1)
self._error_print(tp)
# 结束运行
get_run().finish(SwanLabRunState.CRASHED, error=self._traceback_error(tb))
if tp != KeyboardInterrupt:
raise tp(val)

def _view_web_print(self):
self._watch_tip_print()
http = get_http()
project_url = get_host_web() + f"/@{http.groupname}/{http.projname}"
experiment_url = project_url + f"/runs/{http.exp_id}"
swanlog.info("🏠 View project at " + FONT.blue(FONT.underline(project_url)))
swanlog.info("🚀 View run at " + FONT.blue(FONT.underline(experiment_url)))
return experiment_url

def on_train_begin(self):
# 注册实验信息
try:
get_http().mount_exp(
exp_name=self.settings.exp_name,
colors=self.settings.exp_colors,
description=self.settings.description,
)
except ApiError as e:
if e.resp.status_code == 409:
FONT.brush("", 50)
swanlog.error("The experiment name already exists, please change the experiment name")
Experiment.purely_delete(run_id=self.settings.run_id)
sys.exit(409)

# 资源嗅探器
sniffer = LogSnifferTask(self.settings.files_dir)
self.pool.create_thread(sniffer.task, name="sniffer", callback=sniffer.callback)

# 向swanlog注册输出流回调
def _write_call_call(message):
self.pool.queue.put((UploadType.LOG, [message]))

swanlog.set_write_callback(_write_call_call)

# 注册系统回调
self._register_sys_callback()
# 打印信息
self._train_begin_print()
swanlog.info("👋 Hi " + FONT.bold(FONT.default(self.login_info.username)) + ", welcome to swanlab!")
swanlog.info("Syncing run " + FONT.yellow(self.settings.exp_name) + " to the cloud")
experiment_url = self._view_web_print()

# 在Jupyter Notebook环境下,显示按钮
if in_jupyter():
show_button_html(experiment_url)

def on_train_end(self, error: str = None):
# 打印信息
self._view_web_print()
run = get_run()
# 如果正在退出或者run对象为None或者不在云端环境下
if self.exiting or run is None:
return swanlog.debug("SwanLab is exiting or run is None, ignore it.")
state = run.state
# 标志正在退出(需要在下面的逻辑之前标志)
self.exiting = True
sys.excepthook = self._except_handler

def _():
# 关闭线程池,等待上传线程完成
self.pool.finish()
# 上传错误日志
if error is not None:
msg = [{"message": error, "create_time": create_time(), "epoch": swanlog.epoch + 1}]
upload_logs(msg, level="ERROR")

FONT.loading("Waiting for uploading complete", _)
get_http().update_state(state == SwanLabRunState.SUCCESS)
# 取消注册系统回调
self._unregister_sys_callback()
self.exiting = False

def on_metric_create(self, key: str, key_info: NewKeyInfo, static_dir: str):
"""
指标创建回调函数,新增指标信息时调用
:param key: 指标key名称
:param key_info: 指标信息
:param static_dir: 媒体文件目录
"""
if key_info is None:
return
new_data, data_type, step, epoch = key_info
new_data['key'] = key
new_data['index'] = step
new_data['epoch'] = epoch
if data_type == "default":
return self.pool.queue.put((UploadType.SCALAR_METRIC, [new_data]))
key = quote(key, safe="")
data = (new_data, key, data_type, static_dir)
self.pool.queue.put((UploadType.MEDIA_METRIC, [data]))

def on_column_create(self, key, data_type: str, error: Optional[Dict] = None):
self.pool.queue.put((UploadType.COLUMN, [ColumnModel(key, data_type.upper(), error)]))

@classmethod
def get_login_info(cls):
"""
发起登录,获取登录信息,执行此方法会覆盖原有的login_info
"""
key = None
try:
key = get_key(os.path.join(get_swanlab_folder(), ".netrc"), get_host_api())[2]
except KeyFileError:
fd = sys.stdin.fileno()
# 不是标准终端,且非jupyter环境,无法控制其回显
if not os.isatty(fd) and not in_jupyter():
raise KeyFileError("The key file is not found, call `swanlab.login()` or use `swanlab login` ")
return terminal_login(key)
90 changes: 90 additions & 0 deletions swanlab/data/callback_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/5/5 20:23
@File: callback_local.py
@IDE: pycharm
@Description:
基本回调函数注册表,此时不考虑云端情况
"""
from swanlab.log import swanlog
from swanlab.utils.font import FONT
from swanlab.data.run.main import get_run, SwanLabRunState
from swanlab.data.run.callback import SwanLabRunCallback
import traceback


class LocalRunCallback(SwanLabRunCallback):

def __init__(self):
super(LocalRunCallback, self).__init__()

@staticmethod
def _traceback_error(tb):
"""
获取traceback信息
"""
trace_list = traceback.format_tb(tb)
html = ""
for line in trace_list:
html += line + "\n"
return html

@staticmethod
def _error_print(tp):
"""
错误打印
"""
# 如果是KeyboardInterrupt异常
if tp == KeyboardInterrupt:
swanlog.error("KeyboardInterrupt by user")
else:
swanlog.error("Error happened while training")

def before_init_project(self, *args, **kwargs):
pass

def _except_handler(self, tp, val, tb):
"""
异常处理
"""
self._error_print(tp)
# 结束运行
get_run().finish(SwanLabRunState.CRASHED, error=self._traceback_error(tb))
if tp != KeyboardInterrupt:
raise tp(val)

def _clean_handler(self):
run = get_run()
if run is None:
return swanlog.debug("SwanLab Runtime has been cleaned manually.")
self._train_finish_print()
# 如果正在运行
run.finish() if run.is_running else swanlog.debug("Duplicate finish, ignore it.")

def on_train_begin(self, *args, **kwargs):
"""
训练开始,注册系统回调
"""
# 注入系统回调
self._register_sys_callback()
# 打印信息
self._train_begin_print()
swanlog.info("Experiment_name: " + FONT.yellow(self.settings.exp_name))
self._watch_tip_print()

def on_train_end(self, error: str = None):
"""
训练结束,取消系统回调
此函数被`run.finish`调用
"""
# 打印信息
self._watch_tip_print()
# 取消注册系统回调
self._unregister_sys_callback()

def on_metric_create(self, *args, **kwargs):
pass

def on_column_create(self, *args, **kwargs):
pass
41 changes: 41 additions & 0 deletions swanlab/data/modules/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/5/5 22:47
@File: _utils.py
@IDE: pycharm
@Description:
工具函数
"""
import io
import hashlib


def get_file_hash_path(file_path: str) -> str:
"""计算并返回给定文件的SHA-256哈希值。"""

hash_sha256 = hashlib.sha256()
with open(file_path, "rb") as f: # 以二进制读取模式打开文件
while chunk := f.read(8192): # 读取文件的小块进行处理
hash_sha256.update(chunk)
return hash_sha256.hexdigest()


def get_file_hash_numpy_array(array) -> str:
"""计算并返回给定NumPy数组的SHA-256哈希值。"""

hash_sha256 = hashlib.sha256()
# 将NumPy数组转换为字节串,然后更新哈希值
hash_sha256.update(array.tobytes())
return hash_sha256.hexdigest()


def get_file_hash_pil(image) -> str:
"""计算并返回给定PIL.Image对象的SHA-256哈希值。"""

hash_sha256 = hashlib.sha256()
# 将图像转换为字节数据
with io.BytesIO() as buffer:
image.save(buffer, format="PNG") # 可以选择其他格式,如'JPEG'
hash_sha256.update(buffer.getvalue())
return hash_sha256.hexdigest()
5 changes: 1 addition & 4 deletions swanlab/data/modules/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@
音频数据解析
"""
from .base import BaseType
from ..utils.file import get_file_hash_numpy_array, get_file_hash_path
import os
from ._utils import get_file_hash_numpy_array, get_file_hash_path
from typing import Union, List

### 以下为音频数据解析的依赖库
import soundfile as sf
import numpy as np
import json
Expand Down
2 changes: 1 addition & 1 deletion swanlab/data/modules/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from PIL import Image as PILImage
from .base import BaseType
from .utils_modules import BoundingBoxes, ImageMask
from ..utils.file import get_file_hash_pil
from ._utils import get_file_hash_pil
from typing import Union, List, Dict, Any
from io import BytesIO
import os
Expand Down
2 changes: 1 addition & 1 deletion swanlab/data/modules/object_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .base import BaseType
import numpy as np
from typing import Union, ClassVar, Set, List, Optional
from ..utils.file import get_file_hash_numpy_array, get_file_hash_path
from ._utils import get_file_hash_numpy_array, get_file_hash_path
import os
import json
import shutil
Expand Down
2 changes: 1 addition & 1 deletion swanlab/data/modules/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
视频数据解析
"""
from .base import BaseType
from ..utils.file import get_file_hash_numpy_array, get_file_hash_path
from ._utils import get_file_hash_numpy_array, get_file_hash_path
import os
from typing import Union, List
import logging
Expand Down
2 changes: 1 addition & 1 deletion swanlab/data/run/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
@Description:
在此处导出SwanLabRun类,一次实验运行应该只有一个SwanLabRun实例
"""
from .main import SwanLabRun, get_run, except_handler, clean_handler, SwanLabRunState
from .main import SwanLabRun, get_run, SwanLabRunState


def register(*args, **kwargs) -> SwanLabRun:
Expand Down
Loading