Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def __register_exp(
continue
# 其他情况下,说明是用户自定义的后缀,需要报错
else:
raise ExistedError(f"Experiment {exp_name} has existed, please try another name.")
Experiment.purely_delete(run_id=self.__run_id)
raise ExistedError(f"Experiment {exp_name} has existed in local, please try another name.")

# 实验创建成功,设置实验相关信息
self.__settings.exp_name = exp_name
Expand Down
21 changes: 14 additions & 7 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
)
from .config import SwanLabConfig
from .utils.file import check_dir_and_create, formate_abs_path
from ..db import Project, connect
from ..db import Project, connect, Experiment
from ..env import init_env, ROOT, get_swanlab_folder
from ..log import swanlog
from ..utils import FONT, check_load_json_yaml
from ..utils.key import get_key
from swanlab.api import create_http, get_http, code_login, LoginInfo, terminal_login
from swanlab.api.upload.model import ColumnModel
from swanlab.package import version_limit, get_package_version, get_host_api, get_host_web
from swanlab.error import KeyFileError
from swanlab.error import KeyFileError, ApiError
from swanlab.cloud import LogSnifferTask, ThreadPool
from swanlab.cloud import UploadType

Expand Down Expand Up @@ -243,11 +243,18 @@ def init(
# ---------------------------------- 注册实验,开启线程 ----------------------------------
if cloud:
# 注册实验信息
get_http().mount_exp(
exp_name=run.settings.exp_name,
colors=run.settings.exp_colors,
description=run.settings.description,
)
try:
get_http().mount_exp(
exp_name=run.settings.exp_name,
colors=run.settings.exp_colors,
description=run.settings.description,
)
except ApiError as e:
if e.resp.status_code == 409:
FONT.brush("", 50)
swanlog.error("The experiment name already exists, please change the experiment name")
Experiment.purely_delete(run_id=run.exp.db.run_id)
sys.exit(409)
sniffer = LogSnifferTask(run.settings.files_dir)
run.pool.create_thread(sniffer.task, name="sniffer", callback=sniffer.callback)

Expand Down
28 changes: 28 additions & 0 deletions swanlab/db/models/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
@Description:
实验数据表
"""
import os.path

from ..model import SwanModel
from peewee import ForeignKeyField, CharField, IntegerField, TextField, IntegrityError, Check, DatabaseProxy
from .projects import Project
from ..error import ExistedError, NotExistedError, ForeignProNotExistedError
from swanlab.package import get_package_version
from ...utils import generate_color
from ...utils import create_time
from swanlab.env import get_swanlog_dir
import shutil


# 定义模型类
Expand Down Expand Up @@ -197,3 +201,27 @@ def update_status(self, status: int):
self.status = status
self.finish_time = create_time()
self.save()

@classmethod
def purely_delete(cls, **kwargs):
"""
删除某一个实验,顺便删除实验的文件夹

Parameters
----------
id : int
实验id
"""
# 获取实验实例
try:
exp = cls.get(**kwargs)
# 删除实验
exp.delete_instance()
# 更新项目的实验数量
Project.decrease_sum()
# 删除实验文件夹
shutil.rmtree(os.path.join(get_swanlog_dir(), exp.run_id.__str__()))
except NotExistedError:
# 删除实验文件夹
if "run_id" in kwargs:
shutil.rmtree(os.path.join(get_swanlog_dir(), kwargs["run_id"]))
21 changes: 21 additions & 0 deletions swanlab/db/models/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,24 @@ def increase_sum(cls, id=DEFAULT_PROJECT_ID) -> int:
project.sum = project.sum + 1 if project.sum else 1
project.save()
return project.sum

@classmethod
def decrease_sum(cls, id=DEFAULT_PROJECT_ID) -> int:
"""
静态方法
更新实验统计数量,减少1
此方法通常在删除实验时被Experiment类调用
Parameters
----------
id : int
实验id, 默认为DEFAULT_PROJECT_ID

Returns
-------
int:
当前实验统计数量
"""
project: "Project" = cls.filter(cls.id == id)[0]
project.sum = project.sum - 1 if project.sum else 0
project.save()
return project.sum
2 changes: 1 addition & 1 deletion swanlab/log/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def format(self, record):

def concat_messages(func):
"""
装饰器,当传递打印信息有多个时,拼接为一个
装饰器,当传递打印信息有多个时,拼接为一个,并且拦截记录它们
"""

def wrapper(self, *args, **kwargs):
Expand Down
20 changes: 9 additions & 11 deletions swanlab/utils/font.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import re
from typing import Coroutine
import threading
import time
import asyncio

light_colors = [
Expand Down Expand Up @@ -81,40 +82,37 @@ def loading(s: str, func: Coroutine, interval: float = 0.4, prefix: str = None,
prefix = FONT.bold(FONT.blue("swanlab")) + ': ' if prefix is None else prefix
symbols = ["\\", "|", "/", "-"]

running = True
running, result, error = True, None, None

async def _():
def loading():
index = 0
while True:
sys.stdout.write("\r" + prefix + symbols[index % len(symbols)] + " " + s)
sys.stdout.flush()
index += 1
await asyncio.sleep(interval)
time.sleep(interval)
if not running:
break

result, error = None, None

# 再次封装传入的func,当func执行完毕以后,将running置为False
async def __():
nonlocal running, result, error
def task():
nonlocal result, error, running
try:
result = await func
result = asyncio.run(func)
except Exception as e:
error = e
finally:
running = False

# 开启新线程
t1 = threading.Thread(target=lambda: asyncio.run(_()))
t2 = threading.Thread(target=lambda: asyncio.run(__()))
t1 = threading.Thread(target=loading)
t2 = threading.Thread(target=task)
t1.start()
t2.start()
t2.join()
t1.join()
if error is not None:
raise error
FONT.brush("", length=brush_length)
return result

@staticmethod
Expand Down
7 changes: 3 additions & 4 deletions test/create_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,16 @@
log_level="debug",
config="test/config/config.json",
load="test/config/load.yaml",
suffix=None,
# cloud=True,
)
swanlab.config.epoches = epochs
swanlab.config.learning_rate = lr
swanlab.config.debug = "这是一串" + "很长" * 100 + "的字符串"
# 模拟训练
for epoch in range(2, swanlab.config.epoches):
acc = 1 - 2**-epoch - random.random() / epoch - offset
loss = 2**-epoch + random.random() / epoch + offset
loss2 = 3**-epoch + random.random() / epoch + offset * 3
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
loss = 2 ** -epoch + random.random() / epoch + offset
loss2 = 3 ** -epoch + random.random() / epoch + offset * 3
print(f"epoch={epoch}, accuracy={acc}, loss={loss}")
if epoch % 10 == 0:
# 测试audio
Expand Down