Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 35 additions & 12 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..db import Project, connect, Experiment
from ..env import init_env, ROOT, get_swanlab_folder
from ..log import swanlog
from ..utils import FONT, check_load_json_yaml
from ..utils import FONT, check_load_json_yaml, check_proj_name_format
from ..utils.key import get_key
from ..utils.judgment import in_jupyter, show_button_html
from swanlab.api import create_http, get_http, code_login, LoginInfo, terminal_login
Expand All @@ -51,6 +51,31 @@
"""


def _check_proj_name(name: str) -> str:
"""检查项目名称是否合法,如果不合法则抛出ValueError异常
项目名称必须是一个非空字符串,长度不能超过255个字符

Parameters
----------
name : str
待检查的项目名称

Returns
-------
str
返回项目名称

Raises
------
ValueError
项目名称不合法
"""
_name = check_proj_name_format(name)
if len(name) != len(_name):
swanlog.warning(f"project name is too long, auto cut to {_name}")
return _name


def _create_metric_callback(pool: ThreadPool):
"""
创建指标回调函数
Expand Down Expand Up @@ -89,7 +114,7 @@ def _is_inited():
return get_run() is not None


def login(api_key: str):
def login(api_key: str = None):
"""
Login to SwanLab Cloud. If you already have logged in, you can use this function to relogin.
Every time you call this function, the previous login information will be overwritten.
Expand All @@ -98,12 +123,12 @@ def login(api_key: str):
Parameters
----------
api_key : str
authentication key.
authentication key, if not provided, the key will be read from the key file.
"""
if _is_inited():
raise RuntimeError("You must call swanlab.login() before using init()")
global login_info
login_info = code_login(api_key)
login_info = code_login(api_key) if api_key else _login_in_init()


def init(
Expand All @@ -129,7 +154,6 @@ def init(
project : str, optional
The project name of the current experiment, the default is None,
which means the current project name is the same as the current working directory.
If you are using cloud mode, you must provide the project name.
workspace : str, optional
Where the current project is located, it can be an organization or a user (currently only supports yourself).
The default is None, which means the current entity is the same as the current user.
Expand Down Expand Up @@ -167,7 +191,7 @@ def init(
cloud : bool, optional
Whether to use the cloud mode, the default is True.
If you use the cloud mode, the log file will be stored in the cloud, which will still be saved locally.
If you are not using cloud mode, the `project` and `entity` fields are invalid.
If you are not using cloud mode, the `workspace` fields are invalid.
load : str, optional
If you pass this parameter,SwanLab will search for the configuration file you specified
(which must be in JSON or YAML format)
Expand All @@ -183,8 +207,6 @@ def init(
swanlog.warning("You have already initialized a run, the init function will be ignored")
return run
# ---------------------------------- 一些变量、格式检查 ----------------------------------
# 默认实验名称为当前目录名
project = (project or os.path.basename(os.getcwd())) if cloud else None
# 如果传入了load,则加载load文件,如果load文件不存在,报错
if load:
load_data = check_load_json_yaml(load, load)
Expand All @@ -197,6 +219,8 @@ def init(
cloud = _load_data(load_data, "cloud", cloud)
project = _load_data(load_data, "project", project)
workspace = _load_data(load_data, "workspace", workspace)
# 默认实验名称为当前目录名
project = _check_proj_name(project if project else os.path.basename(os.getcwd()))
# 初始化logdir参数,接下来logdir被设置为绝对路径且当前程序有写权限
logdir = _init_logdir(logdir)
# 初始化confi参数
Expand All @@ -218,11 +242,10 @@ def init(
exp_num = http.mount_project(project, workspace).history_exp_count
# 初始化、挂载线程池
pool = ThreadPool()

# 连接本地数据库,要求路径必须存在,但是如果数据库文件不存在,会自动创建
connect(autocreate=True)
# 初始化项目数据库
Project.init(os.path.basename(os.getcwd()))
Project.init(project)
# ---------------------------------- 实例化实验 ----------------------------------
# 如果是云端环境,设置回调函数
callbacks = (
Expand Down Expand Up @@ -267,8 +290,8 @@ def _write_call_call(message):
# 注册清理函数
atexit.register(clean_handler)
# ---------------------------------- 终端输出 ----------------------------------
if not cloud and not (project is None and workspace is None):
swanlog.warning("The `project` or `workspace` parameters are invalid in non-cloud mode")
if not cloud and workspace is not None:
swanlog.warning("The `workspace` field is invalid in local mode")
swanlog.debug("SwanLab Runtime has initialized")
swanlog.debug("SwanLab will take over all the print information of the terminal from now on")
swanlog.info("Tracking run with swanlab version " + get_package_version())
Expand Down
9 changes: 0 additions & 9 deletions swanlab/db/models/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@
class Project(SwanModel):
"""项目表
目前,在一个工程中,只有一个项目

Attributes
----------
experiments: list of Experiment
由 Experiment 表中外键反链接生成的实验列表
charts: list of Chart
由 Chart 表中外键反链接生成的图表列表
namespaces: list of Namespace
由 Namespace 表中外键反链接生成的命名空间数据列表
"""

# 默认的项目id应该是1
Expand Down
2 changes: 1 addition & 1 deletion swanlab/log/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@File: swanlab/log/__init__.py
@IDE: vscode
@Description:
日志记录模块
日志记录模块,在设计上swanlog作为一个独立的模块被使用,你可以在除了utils的任何地方使用它
"""
from typing import Optional
from .log import SwanLog
Expand Down
2 changes: 1 addition & 1 deletion swanlab/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
"""
from .font import generate_color, FONT, COLOR_LIST
from .time import create_time
from .file import check_load_json_yaml
from .file import check_load_json_yaml, check_proj_name_format, check_exp_name_format
Loading