Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 3 additions & 87 deletions swanlab/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,91 +8,7 @@
API模块,封装api请求接口
"""
from .auth import LoginInfo
import requests
from typing import Optional
from .auth.login import login_request, terminal_login, code_login, login_by_key
from swanlab.error import ValidationError
from datetime import datetime
import asyncio
from .auth.login import terminal_login, code_login
from .http import create_http


class HTTP:
"""
封装请求函数,添加get、post、put、delete方法
"""
REFRESH_TIME = 60 * 60 * 24 * 7 # 7天
"""
刷新时间,单位秒,如果sid过期时间减去当前时间小于这个时间,就刷新sid
"""

def __init__(self, login_info: LoginInfo):
"""
初始化会话
"""
self.__login_info = login_info
self.__session = self.__create_session()

def expired_at(self):
"""
获取sid的过期时间,字符串格式转时间
"""
return datetime.strptime(self.__login_info.expired_at, '%Y-%m-%dT%H:%M:%S.%fZ')

def __create_session(self) -> requests.Session:
"""
创建会话,这将在HTTP类实例化时调用
"""
req = requests.Session()
req.cookies.update({'sid': self.__login_info.sid})
return req

def __before_request(self):
# 判断是否已经达到了过期时间
if (self.expired_at() - datetime.utcnow()).total_seconds() < self.REFRESH_TIME:
# 刷新sid
self.__login_info = asyncio.run(login_by_key(self.__login_info.api_key))
self.__session = self.__create_session()

async def get(self, url: str, **kwargs) -> requests.Response:
"""
get请求
"""
self.__before_request()
return self.__session.get(url, **kwargs)

async def post(self, url: str, **kwargs) -> requests.Response:
"""
post请求
"""
self.__before_request()
return self.__session.post(url, **kwargs)

async def put(self, url: str, **kwargs) -> requests.Response:
"""
put请求
"""
self.__before_request()
return self.__session.put(url, **kwargs)

async def delete(self, url: str, **kwargs) -> requests.Response:
"""
delete请求
"""
self.__before_request()
return self.__session.delete(url, **kwargs)


http: Optional["HTTP"] = None


def create_http(login_info: LoginInfo) -> HTTP:
"""
创建http请求对象
:return: http请求对象
"""
global http
http = HTTP(login_info)
return http


__all__ = ["LoginInfo", "code_login", 'terminal_login', "HTTP", "create_http"]
__all__ = ["LoginInfo", "code_login", 'terminal_login', 'create_http']
134 changes: 134 additions & 0 deletions swanlab/api/http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/4/7 16:51
@File: http.py
@IDE: pycharm
@Description:
http会话对象
"""
import requests
from requests.exceptions import RequestException
from typing import Optional
from datetime import datetime
from .auth import LoginInfo
from .auth.login import login_by_key
from swanlab.error import NetworkError
from swanlab.package import get_host_api
import asyncio


class HTTP:
"""
封装请求函数,添加get、post、put、delete方法
"""
REFRESH_TIME = 60 * 60 * 24 * 7 # 7天
"""
刷新时间,单位秒,如果sid过期时间减去当前时间小于这个时间,就刷新sid
"""

def __init__(self, login_info: LoginInfo):
"""
初始化会话
"""
self.__login_info = login_info
self.__session = self.__create_session()
self.base_url = get_host_api()

def expired_at(self):
"""
获取sid的过期时间,字符串格式转时间
"""
return datetime.strptime(self.__login_info.expired_at, '%Y-%m-%dT%H:%M:%S.%fZ')

def __create_session(self) -> requests.Session:
"""
创建会话,这将在HTTP类实例化时调用
"""
req = requests.Session()
req.cookies.update({'sid': self.__login_info.sid})
return req

def __before_request(self):
# 判断是否已经达到了过期时间
if (self.expired_at() - datetime.utcnow()).total_seconds() < self.REFRESH_TIME:
# 刷新sid
self.__login_info = asyncio.run(login_by_key(self.__login_info.api_key))
self.__session = self.__create_session()

async def get(self, url: str, **kwargs) -> requests.Response:
"""
get请求
"""
self.__before_request()
url = self.base_url + url
return self.__session.get(url, **kwargs)

async def post(self, url: str, data: dict = None) -> requests.Response:
"""
post请求
"""
self.__before_request()
url = self.base_url + url
return self.__session.post(url, json=data)

async def put(self, url: str, **kwargs) -> requests.Response:
"""
put请求
"""
self.__before_request()
url = self.base_url + url
return self.__session.put(url, **kwargs)

async def delete(self, url: str, **kwargs) -> requests.Response:
"""
delete请求
"""
self.__before_request()
url = self.base_url + url
return self.__session.delete(url, **kwargs)


http: Optional["HTTP"] = None
"""
一个进程只有一个http请求对象
"""


def create_http(login_info: LoginInfo) -> HTTP:
"""
创建http请求对象
"""
global http
if http is None:
http = HTTP(login_info)
return http


def get_http() -> HTTP:
"""
创建http请求对象
:return: http请求对象
"""
global http
if http is None:
raise ValueError("http object is not initialized")
return http


def async_error_handler(func):
"""
用于进行统一的错误捕获
"""

async def wrapper(*args, **kwargs):
try:
# 在装饰器中调用被装饰的异步函数
result = await func(*args, **kwargs)
return result
except RequestException:
return NetworkError()
except Exception as e:
return e

return wrapper
70 changes: 70 additions & 0 deletions swanlab/api/upload/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/4/7 16:56
@File: __init__.py
@IDE: pycharm
@Description:
上传相关接口
"""
from ..http import get_http, async_error_handler
from typing import List, Tuple
import asyncio
import os

url = '/house/metrics'


def mock_data(metrics: List[dict], metrics_type: str) -> dict:
"""
模拟一下,上传日志和实验指标信息
"""
return {
"projectId": "1",
"experimentId": "1",
"type": metrics_type,
"metrics": metrics
}


@async_error_handler
async def upload_logs(logs: List[str], level: str = "INFO"):
"""
模拟一下,上传日志和实验指标信息
:param logs: 日志列表
:param level: 日志级别,'INFO', 'ERROR',默认INFO
"""
http = get_http()
# 将logs解析为json格式
metrics = [{"level": level, "message": x} for x in logs]
data = mock_data(metrics, "log")
resp = await http.post(url, data)


@async_error_handler
async def upload_media_metrics(media_metrics: List[Tuple[dict, List[str]]]):
"""
上传指标的媒体数据
"""
print("上传媒体指标信息: ", media_metrics)


@async_error_handler
async def upload_scalar_metrics(scalar_metrics: List[dict]):
"""
上传指标的标量数据
"""
print("上传标量指标信息: ", scalar_metrics)


@async_error_handler
async def upload_files(files: List[str]):
"""
模拟一下,上传日志和实验指标信息
:param files: 文件列表,内部为文件绝对路径
"""
# 去重list
files = list(set(files))
files = [os.path.basename(x) for x in files]
await asyncio.sleep(1)
# print("上传文件信息: ", files)
16 changes: 4 additions & 12 deletions swanlab/cloud/_log_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import asyncio
from swanlab.log import swanlog
from .utils import LogQueue
from swanlab.error import UpLoadError
from swanlab.error import ApiError
from .files_types import FileType


Expand All @@ -39,7 +39,7 @@ def __init__(self):
self.__now_task = None

@staticmethod
def report_known_error(errors: List[UpLoadError]):
def report_known_error(errors: List[ApiError]):
"""
上报错误信息
:param errors: 错误信息列表
Expand All @@ -61,25 +61,17 @@ async def upload(self):
for msg in self.container:
tasks_dict[msg[0]].extend(msg[1])
# 此时应该只剩下最多 FileType内部枚举个数 个任务
# 媒体类型需要首先上传,后面几个可以一起上传
# 检查每一个上传结果
success_tasks_type = []
# 已知错误列表
known_errors = []
if len(tasks_dict[FileType.MEDIA]):
media_result = await FileType.MEDIA.value['upload'](tasks_dict[FileType.MEDIA])
if isinstance(media_result, UpLoadError):
known_errors.append(media_result)
elif isinstance(media_result, Exception):
swanlog.error(f"upload logs error: {media_result}, it might be a swanlab bug, data will be lost!")
success_tasks_type.append(FileType.MEDIA)
# 上传任务
tasks_key_list = [key for key in tasks_dict if len(tasks_dict[key]) > 0 and key != FileType.MEDIA]
tasks_key_list = [key for key in tasks_dict if len(tasks_dict[key]) > 0]
tasks = [x.value['upload'](tasks_dict[x]) for x in tasks_key_list]
results = await asyncio.gather(*tasks)
for index, result in enumerate(results):
# 如果出现已知问题
if isinstance(result, UpLoadError):
if isinstance(result, ApiError):
known_errors.append(result)
continue
# 如果出现其他问题,没有办法处理,就直接跳过,但是会有警告
Expand Down
Loading