Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/test-pr-to-main.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: Test PR to Main

# 还在测试中
#on:
# pull_request:
# branches:
# - main

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ 3.7, 3.8, 3.9, 3.10, 3.11, 3.12 ]

steps:
- uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements-dev.txt

- name: Run tests
run: |
python scripts/test.py
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ swanlog.bak/
.eslintrc-auto-import.json
dist/

# test
test/temp
tutils/config.json
tutils/package.mock.json

# playground
playground/

Expand Down
62 changes: 31 additions & 31 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# 在此处完成python项目的配置以及依赖安装
[build-system] # 指定构建系统与一些依赖
requires = [
"hatchling",
"hatch-requirements-txt",
"hatch-fancy-pypi-readme>=22.5.0",
"hatchling",
"hatch-requirements-txt",
"hatch-fancy-pypi-readme>=22.5.0",
]
build-backend = "hatchling.build"

Expand All @@ -14,18 +14,18 @@ dynamic = ["version", "dependencies", "optional-dependencies", "readme"] # 动
description = "Python library for streamlined tracking and management of AI training processes." # 项目描述
license = "Apache-2.0" # 项目许可证
requires-python = ">=3.8" # python版本要求,我们只维护python3.8以上版本
authors = [ # 项目作者
{ name = "Cunyue", email = "[email protected]" },
{ name = "Feudalman", email = "[email protected]" },
{name="ZeYi Lin", email = "[email protected]"},
{name="KashiwaByte", email="[email protected]"}
authors = [# 项目作者
{ name = "Cunyue", email = "[email protected]" },
{ name = "Feudalman", email = "[email protected]" },
{ name = "ZeYi Lin", email = "[email protected]" },
{ name = "KashiwaByte", email = "[email protected]" }
]
keywords = [ # 项目关键词
keywords = [# 项目关键词
"machine learning",
"reproducibility",
"visualization"
]
classifiers = [ # 项目分类
classifiers = [# 项目分类
'Development Status :: 3 - Alpha',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
Expand Down Expand Up @@ -78,10 +78,10 @@ replacement = "https://github.com/SwanHubX/SwanLab/blob/main/README.md"


[tool.hatch.build]
artifacts = [ # 将一些非python文件打包到项目中
"/swanlab/template",
"*.json",
"*.pyi", # 类型提示文件
artifacts = [# 将一些非python文件打包到项目中
"/swanlab/template",
"*.json",
"*.pyi", # 类型提示文件
]

# [tool.hatch.build.targets.wheel.hooks.custom] # 构建wheel时,用python执行一些脚本
Expand All @@ -90,13 +90,13 @@ artifacts = [ # 将一些非python文件打包到项目中

[tool.hatch.build.targets.sdist] # 在执行构建之前,将一些必要文件拷贝到虚拟环境中,以便构建,此时已经完成了vue的编译
include = [
"/swanlab",
"/test", # 包含一些测试脚本,确保测试成功
"/README.md", # 包含readme
"/requirements.txt", # 包含依赖
"/requirements-swan.txt", # 包含可选依赖
"/.config/copy_frontend.py", # 包含拷贝前端文件的脚本
"/package.json", # 包含前端的package.json文件
"/swanlab",
"/test", # 包含一些测试脚本,确保测试成功
"/README.md", # 包含readme
"/requirements.txt", # 包含依赖
"/requirements-swan.txt", # 包含可选依赖
"/.config/copy_frontend.py", # 包含拷贝前端文件的脚本
"/package.json", # 包含前端的package.json文件
]

[tool.hatch.build.targets.wheel]
Expand All @@ -110,16 +110,16 @@ exclude = [] # 排除的文件
target-version = "py37"
extend-select = ["B", "C", "I", "N", "SIM", "UP"]
ignore = [
"C901", # function is too complex (TODO: un-ignore this)
"B023", # function definition in loop (TODO: un-ignore this)
"B008", # function call in argument defaults
"B017", # pytest.raises considered evil
"B028", # explicit stacklevel for warnings
"E501", # from scripts/lint_backend.sh
"SIM105", # contextlib.suppress (has a performance cost)
"SIM117", # multiple nested with blocks (doesn't look good with gr.Row etc)
"UP007", # use X | Y for type annotations (TODO: can be enabled once Pydantic plays nice with them)
"UP006", # use `list` instead of `List` for type annotations (fails for 3.8)
"C901", # function is too complex (TODO: un-ignore this)
"B023", # function definition in loop (TODO: un-ignore this)
"B008", # function call in argument defaults
"B017", # pytest.raises considered evil
"B028", # explicit stacklevel for warnings
"E501", # from scripts/lint_backend.sh
"SIM105", # contextlib.suppress (has a performance cost)
"SIM117", # multiple nested with blocks (doesn't look good with gr.Row etc)
"UP007", # use X | Y for type annotations (TODO: can be enabled once Pydantic plays nice with them)
"UP006", # use `list` instead of `List` for type annotations (fails for 3.8)
]
exclude = []

Expand Down
7 changes: 7 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# pytest

pytest

pytest-asyncio

nanoid
14 changes: 14 additions & 0 deletions script/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024/4/3 16:41
@File: test.py.py
@IDE: pycharm
@Description:
运行单元测试,运行时应该在当前项目根目录下运行
"""
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
os.system("pytest test/unit")
2 changes: 1 addition & 1 deletion swanlab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
Run,
)

from .utils import get_package_version
from .package import get_package_version


__version__ = get_package_version()
1 change: 0 additions & 1 deletion swanlab/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@
这一块因为可能swanlog模块没有初始化,所以需要自己单独打印一下
"""
from .login import terminal_login, code_login
from .experiment import get_exp_token
59 changes: 0 additions & 59 deletions swanlab/auth/experiment.py

This file was deleted.

31 changes: 22 additions & 9 deletions swanlab/auth/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
@Description:
定义认证数据格式
"""
from ..utils.token import save_token
from ..env import get_api_key_file_path
from ..utils.package import get_host_api
import os.path
from ..utils.key import save_key
from ..env import get_swanlab_folder
from swanlab.package import get_host_api
import requests


class LoginInfo:
Expand All @@ -18,25 +20,36 @@ class LoginInfo:
无论接口请求成功还是失败,都会初始化一个LoginInfo对象
"""

def __init__(self, api_key: str, **kwargs):
self.api_key = api_key
def __init__(self, resp: requests.Response, api_key: str):
self.__resp = resp
self.__api_key = api_key

@property
def is_fail(self):
"""
判断登录是否失败
"""
# TODO 作为测试,api_key如果为123456时返回None
return self.api_key == "123456"
return self.__resp.status_code != 200

@property
def api_key(self):
"""
获取api_key
"""
if self.is_fail:
return None
return self.__api_key

def __str__(self) -> str:
return f"LoginInfo"
"""错误时会返回错误信息"""
return self.__resp.reason

def save(self):
"""
保存登录信息
"""
return save_token(get_api_key_file_path(), get_host_api(), "user", self.api_key)
path = os.path.join(get_swanlab_folder(), '.netrc')
return save_key(path, get_host_api(), "user", self.api_key)


class ExpInfo:
Expand Down
29 changes: 17 additions & 12 deletions swanlab/auth/login.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@
@IDE: vscode
@Description:
用户登录接口,输入用户的apikey,保存用户token到本地
进行一些交互定义和数据请求
"""
import asyncio
from ..error import ValidationError
from ..utils import FONT
from ..utils.package import USER_SETTING_PATH
from swanlab.package import get_user_setting_path, get_host_api
import sys
from .info import LoginInfo
import getpass
import requests


async def _login(api_key: str, timeout: int = 20) -> LoginInfo:
async def login_by_key(api_key: str, timeout: int = 20, save: bool = True) -> LoginInfo:
"""用户登录,异步调用接口完成验证
返回后端内容(dict),如果后端请求失败,返回None

Expand All @@ -26,11 +28,17 @@ async def _login(api_key: str, timeout: int = 20) -> LoginInfo:
用户api_key
timeout : int, optional
请求认证的超时时间,单位秒
save : bool, optional
是否保存到本地token文件
"""
await asyncio.sleep(5)
try:
resp = requests.post(f"{get_host_api()}/login/api_key", headers={'authorization': api_key}, timeout=timeout)
except requests.exceptions.RequestException:
# 请求超时等网络错误
raise ValidationError("Network error, please try again.")
# api key写入token文件
login_info = LoginInfo(api_key)
not login_info.is_fail and login_info.save()
login_info = LoginInfo(resp, api_key)
save and not login_info.is_fail and login_info.save()
return login_info


Expand All @@ -43,17 +51,16 @@ def input_api_key(

Parameters
----------
str : str
用户api_key
tip : str
提示信息
again : bool, optional
是否是重新输入api_key,如果是,不显示额外的提示信息

"""
_t = sys.excepthook
sys.excepthook = _abort_tip
if not again:
print(FONT.swanlab("Logging into swanlab cloud."))
print(FONT.swanlab("You can find your API key at: " + USER_SETTING_PATH))
print(FONT.swanlab("You can find your API key at: " + get_user_setting_path()))
key = getpass.getpass(FONT.swanlab(tip))
sys.excepthook = _t
return key
Expand All @@ -67,10 +74,8 @@ async def code_login(api_key: str):
----------
api_key : str
用户api_key
save : bool, optional
是否保存api_key到本地token文件
"""
login_task = asyncio.create_task(_login(api_key))
login_task = asyncio.create_task(login_by_key(api_key))
prefix = FONT.bold(FONT.blue("swanlab: "))
tip = "Waiting for the swanlab cloud response."
loading_task = asyncio.create_task(FONT.loading(tip, interval=0.5, prefix=prefix))
Expand Down
8 changes: 4 additions & 4 deletions swanlab/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

import click
from .utils import is_valid_ip, is_valid_port, is_valid_root_dir, URL
from ..utils import FONT, version_limit
from ..env import get_server_host, get_server_port, get_swanlog_dir, is_login
from ..utils import FONT
from swanlab.package import version_limit, get_package_version
from ..env import get_server_host, get_server_port, get_swanlog_dir
import time
from ..db import connect
from ..utils import get_package_version
from ..error import TokenFileError
from ..error import KeyFileError
from ..auth import terminal_login


Expand Down
Loading