Skip to content
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ __pycache__/
test/*
!test/db/
!test/_server.py
!test/load.yaml
!test/config.json

# C extensions
*.so
Expand Down
54 changes: 5 additions & 49 deletions swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
check_desc_format,
get_a_lock,
json_serializable,
FONT,
)
from datetime import datetime
import time
Expand Down Expand Up @@ -326,7 +325,6 @@ def __init__(
experiment_name: str = None,
description: str = None,
config: dict = None,
config_file: str = None,
log_level: str = None,
suffix: str = None,
):
Expand Down Expand Up @@ -369,48 +367,6 @@ def __init__(
level = self.__check_log_level(log_level)
swanlog.setLevel(level)

# ---------------------------------- 检查config_file并与config合并 ----------------------------------
# 如果config_file不是None,说明用户提供了配置文件,需要读取配置文件
if config_file is not None:
# 检查config_file的后缀是否是json/yaml,否则报错
config_file_suffix = config_file.split(".")[-1]
if not config_file.endswith((".json", ".yaml", ".yml")):
raise ValueError(
"config_file must be a json or yaml file ('.json', '.yaml', '.yml'), but got {}, please check if the content of config_file is correct.".format(
config_file_suffix
)
)

# 读取配置文件
with open(config_file, "r") as f:
if config_file_suffix == "json":
# 读取配置文件的内容
config_from_file = ujson.load(f)
# 如果读取的内容不是字典类型,则报错
if not isinstance(config_from_file, dict):
raise TypeError(
"The configuration file must be a dictionary, but got {}".format(type(config_from_file))
)
elif config_file_suffix in ["yaml", "yml"]:
# 读取配置文件的内容
config_from_file = yaml.safe_load(f)
# 如果读取的内容不是字典类型,则报错
if not isinstance(config_from_file, dict):
raise TypeError(
"The configuration file must be a dictionary, but got {}".format(type(config_from_file))
)

# 如果config不是None,说明用户提供了配置,需要合并配置文件和配置
if config is not None:
# 如果config不是字典类型,则报错
if not isinstance(config, dict):
raise TypeError("The configuration must be a dictionary, but got {}".format(type(config)))
# 合并配置文件和配置
config = {**config, **config_from_file}
# 否则config就是配置文件的内容
else:
config = config_from_file

# ---------------------------------- 初始化配置 ----------------------------------
# 给外部1个config
self.__config = SwanLabConfig(config, self.__settings)
Expand Down Expand Up @@ -496,12 +452,12 @@ def log(self, data: dict, step: int = None):
# 数据类型的检查将在创建chart配置的时候完成,因为数据类型错误并不会影响实验进行
self.__exp.add(key=key, data=d, step=step)

def success(self):
"""标记实验成功"""
def _success(self):
"""Mark the experiment as success. Users should not use this function."""
self.__set_exp_status(1)

def fail(self):
"""标记实验失败"""
def _fail(self):
"""Mark the experiment as failure. Users should not use this function."""
self.__set_exp_status(-1)

def __str__(self) -> str:
Expand Down Expand Up @@ -553,7 +509,7 @@ def __get_exp_name(self, experiment_name: str = None, suffix: str = None) -> Tup
swanlog.warning(tip)

# 如果suffix为None, 则不添加后缀,直接返回
if suffix is None:
if suffix is None or suffix is False:
return experiment_name_checked, experiment_name

# suffix必须是字符串
Expand Down
1 change: 1 addition & 0 deletions swanlab/data/run/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
get_a_lock,
check_exp_name_format,
check_desc_format,
check_load_json_yaml,
)
from ...utils import get_package_version, create_time, generate_color, FONT
import datetime
Expand Down
142 changes: 90 additions & 52 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import atexit, sys, traceback, os
from datetime import datetime
from .run import SwanLabRun, SwanLabConfig, register
from typing import Optional
from typing import Optional, Union
from ..log import swanlog
from .modules import DataType
from typing import Dict
from ..env import init_env, ROOT, is_login, get_user_api_key
from .utils.file import check_dir_and_create, formate_abs_path
from ..db import Project, connect
from ..utils import version_limit, FONT, get_package_version
from ..utils import version_limit, FONT, get_package_version, check_load_json_yaml
from ..utils.package import get_host_web
from ..auth import get_exp_token, terminal_login, code_login
from ..error import NotLoginError, ValidationError
Expand Down Expand Up @@ -60,14 +60,14 @@ def login(api_key: str):
def init(
experiment_name: str = None,
description: str = None,
config: dict = None,
config_file: str = None,
config: Union[dict, str] = None,
logdir: str = None,
suffix: str = "default",
log_level: str = None,
# cloud: bool = False,
# project: str = None,
# organization: str = None,
load: str = None,
**kwargs,
) -> SwanLabRun:
"""
Start a new run to track and log.
Expand All @@ -82,17 +82,16 @@ def init(
description : str, optional
The experiment description you currently have open, used for a more detailed introduction or labeling of the current experiment.
If you do not provide this parameter, you can modify it later in the web interface.
config : dict, optional
Some experiment parameter configurations that can be displayed on the web interface, such as learning rate, batch size, etc.
config_file: str, optional
The path to the configuration file, the default is None.
If you provide this parameter, SwanLab will read the configuration from the file and update 'config' you provide.
The configuration file must be in the format of json or yaml.
log_level : str, optional
The log level of the current experiment, the default is 'info', you can choose from 'debug', 'info', 'warning', 'error', 'critical'.
config : Union[dict, str], optional
If you provide as a dict, it will be used as the configuration of the current experiment.
If you provide as a string, SwanLab will read the configuration from the file. And the configuration file must be in the format of `json` or `yaml`.
Anyway, you can modify the configuration later after this function is called.
logdir : str, optional
The directory where the log file is stored, the default is current working directory.
You can also specify a directory to store the log file, whether using an absolute path or a relative path, but you must ensure that the directory exists.
The folder will store all the log information generated during the execution of Swanlab.
If the parameter is None, Swanlab will generate a folder named "swanlog" in the same path as the code execution to store the data.
If you want to visualize the generated log files, simply run the command `swanlab watch` in the same path where the code is executed (without entering the "swanlog" folder).
You can also specify your own folder, but you must ensure that the folder exists and preferably does not contain anything other than data generated by Swanlab.
In this case, if you want to view the logs, you must use something like `swanlab watch -l ./your_specified_folder` to specify the folder path.
suffix : str, optional
The suffix of the experiment name, the default is 'default'.
If this parameter is 'default', suffix will be '%b%d-%h-%m-%s_<hostname>'(example:'Feb03_14-45-37_windowsX'), which represents the current time.
Expand All @@ -109,59 +108,52 @@ def init(
If you are using cloud mode, you must provide the project name.
organization : str, optional
The organization name of the current experiment, the default is None, which means the log file will be stored in your personal space.
load : str, optional
If you pass this parameter,swanlab will search for the configuration file you specified (which must be in JSON or YAML format) and automatically fill in some explicit parameters of this function for you (excluding parameters in `**kwargs` and the parameters if they are None).
In terms of priority, if the parameters passed to init are `None`, swanlab will attempt to replace them from the configuration file you provided;
otherwise, it will use the parameters you passed as the definitive ones.
log_level : str, optional
The default log level for logging to the terminal in Swanlab is "info", but it can be chosen as "debug", "info", "warning", "error", or "critical".
This is unrelated to the `swanlab.log` function and is typically used for development and debugging purposes.
Therefore, it is only passed implicitly through the `**kwargs` parameter without explicit prompting in this context.
"""
global run, inited
# ---------------------------------- 一些变量、格式检查 ----------------------------------
# 如果已经初始化过了,直接返回run
if inited:
swanlog.warning("You have already initialized a run, the init function will be ignored")
return run
# 如果传入了logdir,则将logdir设置为环境变量,代表日志文件存放的路径
if logdir is not None:
try:
logdir = check_dir_and_create(logdir)
except ValueError:
raise ValueError("logdir must be a str.")
except IOError:
raise IOError("logdir must be a path and have Write permission.")
os.environ[ROOT] = logdir
# 如果没有传入logdir,则使用默认的logdir, 即当前工作目录,但是需要保证目录存在
else:
logdir = os.path.abspath("swanlog")
try:
os.makedirs(logdir, exist_ok=True)
if not os.access(logdir, os.W_OK):
raise IOError
except:
raise IOError("logdir must have Write permission.")

# 如果传入了config_file,则检查config_file是否是一个字符串,以及转换为绝对路径
if config_file is not None:
if not isinstance(config_file, str):
raise ValueError("config_file must be a string")
if not os.path.isabs(config_file):
config_file = os.path.abspath(config_file)

# ---------------------------------- 一些变量、格式检查 ----------------------------------
# 如果传入了load,则加载load文件,如果load文件不存在,报错
if load:
load_data = check_load_json_yaml(load, "load")
# 尝试更改传入的参数为None的情况,如果传入的参数不是None,不做任何操作
experiment_name = _load_data(load_data, "experiment_name", experiment_name)
description = _load_data(load_data, "description", description)
config = _load_data(load_data, "config", config)
logdir = _load_data(load_data, "logdir", logdir)
suffix = _load_data(load_data, "suffix", suffix)
# 初始化logdir参数,接下来logdir被设置为绝对路径且当前程序有写权限
logdir = _init_logdir(logdir)
# 初始化confi参数
config = _init_config(config)
# 检查logdir内文件的版本,如果<=0.1.4则报错
version_limit(logdir, mode="init")
# 初始化环境变量
init_env()

# ---------------------------------- 用户登录、格式、权限校验 ----------------------------------
# 1. 如果没有登录,提示登录
# 2. 如果登录了,发起请求,如果请求失败,重新登录,返回步骤1
# token = _get_exp_token(cloud=cloud)
# 连接本地数据库,要求路径必须存在,但是如果数据库文件不存在,会自动创建
connect(autocreate=True)

# 初始化项目数据库
Project.init(os.path.basename(os.getcwd()))
# 注册实验
run = register(
experiment_name=experiment_name,
description=description,
config=config,
config_file=config_file,
log_level=log_level,
log_level=kwargs.get("log_level", "info"),
suffix=suffix,
)
# 如果使用云端模式,在此开启其他线程负责同步数据
Expand All @@ -180,7 +172,7 @@ def init(
# 云端版本有一些额外的信息展示
# cloud and swanlog.info("Syncing run " + FONT.yellow(run.settings.exp_name) + " to the cloud")
swanlog.info(
"🌟 [Offline Dashboard] Run `"
"🌟 Run `"
+ FONT.bold("swanlab watch -l {}".format(formate_abs_path(run.settings.swanlog_dir)))
+ "` to view SwanLab Experiment Dashboard locally"
)
Expand Down Expand Up @@ -223,17 +215,63 @@ def finish():
but you can also execute it manually and mark the experiment as 'completed'.
Once the experiment is marked as 'completed', no more data can be logged to the experiment by 'swanlab.log'.
"""
global run
global run, inited
if not inited:
raise RuntimeError("You must call swanlab.data.init() before using finish()")
if run is None:
return swanlog.error("After calling finish(), you can no longer close the current experiment")
run.success()
run._success()
swanlog.setSuccess()
swanlog.reset_console()
run = None


def _init_logdir(logdir: str) -> str:
"""
处理通过init传入的logdir存在的一些情况
"""
# 如果传入了logdir,则将logdir设置为环境变量,代表日志文件存放的路径
if logdir is not None:
try:
logdir = check_dir_and_create(logdir)
except ValueError:
raise ValueError("logdir must be a str.")
except IOError:
raise IOError("logdir must be a path and have Write permission.")
os.environ[ROOT] = logdir
# 如果没有传入logdir,则使用默认的logdir, 即当前工作目录下的swanlog文件夹,但是需要保证目录存在
else:
logdir = os.path.abspath("swanlog")
try:
os.makedirs(logdir, exist_ok=True)
if not os.access(logdir, os.W_OK):
raise IOError
except:
raise IOError("logdir must have Write permission.")
return logdir


def _init_config(config: Union[dict, str]):
"""初始化传入的config参数"""
if isinstance(config, dict) or config is None:
return config
print(FONT.swanlab("The parameter config is loaded from the configuration file: {}".format(config)))
return check_load_json_yaml(config, "config")


def _load_data(load_data: dict, key: str, value):
"""从load_data中加载数据,如果value不是None"""
if value is not None:
# tip = "The parameter {} is loaded from the configuration file: {}".format(FONT.bold(key), value)
# print(FONT.swanlab(tip))
return value
d = load_data.get(key, None)
# if d is not None:
# tip = "The parameter {} is loaded from the configuration file: {}".format(FONT.bold(key), d)
# print(FONT.swanlab(tip))
return d


def _get_exp_token(cloud: bool = False):
"""获取当前实验的相关信息
可能包含实验的token、实验的id、用户信息等信息
Expand Down Expand Up @@ -262,7 +300,7 @@ def __clean_handler():
run.settings.exp_name
)
)
run.success()
run._success()
swanlog.setSuccess()
swanlog.reset_console()

Expand All @@ -271,10 +309,10 @@ def __clean_handler():
def __except_handler(tp, val, tb):
"""定义异常处理函数"""
if run is None:
return swanlog.warning("SwanLab Runtime has been cleaned manually, the exception will be ignored")
return
swanlog.error("Error happended while training, SwanLab will throw it")
# 标记实验失败
run.fail()
run._fail()
swanlog.setError()
# 记录异常信息
# 追踪信息
Expand Down
1 change: 1 addition & 0 deletions swanlab/server/controller/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ async def update_experiment_info(experiment_id: int, request: Request):

db = connect()
body = await request.json()
body["name"] = body["name"].strip()
with db.atomic():
experiment = Experiment.get(experiment_id)
experiment.name = body.get("name")
Expand Down
2 changes: 1 addition & 1 deletion swanlab/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .font import generate_color, FONT, COLOR_LIST
from .time import create_time
from .file import lock_file, get_a_lock
from .file import lock_file, get_a_lock, check_load_json_yaml
from .package import get_package_version, version_limit
Loading