Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion swanlab/db/db_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..env import get_db_path
import os
from peewee import SqliteDatabase
from .table_config import tables, Tag, Experiment, Namespace
from .table_config import tables, Tag, Experiment, Namespace, Chart
from .migrate import *

# 判断是否已经binded了
Expand Down Expand Up @@ -57,6 +57,10 @@ def connect(autocreate=False) -> SqliteDatabase:
swandb.bind(tables)
swandb.create_tables(tables)
swandb.close()
# 完成数据迁移,如果chart表中没有status字段,则添加
if not Chart.field_exists("status"):
# 不启用外键约束
add_status(SqliteDatabase(path))
# 完成数据迁移,如果namespace表中没有opened字段,则添加
if not Namespace.field_exists("opened"):
# 不启用外键约束
Expand Down
7 changes: 4 additions & 3 deletions swanlab/db/migrate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
@Description:
数据库迁移模块
"""
from .tag_sort import add_sort
from .exp_finish_time import add_finish_time
from .namespace_opened import add_opened
from .tag import add_sort
from .experiment import add_finish_time
from .namespace import add_opened
from .chart import add_status
41 changes: 41 additions & 0 deletions swanlab/db/migrate/chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024-03-09 17:39:12
@File: swanlab/db/migrate/chart.py
@IDE: vscode
@Description:
chart相关迁移操作
"""
from peewee import SqliteDatabase, IntegerField
from playhouse.migrate import migrate, SqliteMigrator
from .experiment import (
add_hidden_opened as add_hidden_opened_experiment,
add_pinned_opened as add_pinned_opened_experiment,
)
from .project import add_hidden_opened as add_hidden_opened_project, add_pinned_opened as add_pinned_opened_project


def add_status(db):
"""
为数据库的Chart表添加status字段,默认为0
"""
migrator = SqliteMigrator(db)
status = IntegerField(default=0)
migrate(migrator.add_column("chart", "status", status))
# 同时,添加sort字段
add_sort(db)
# 同时为experiment和project添加hidden_opened、pinned_opened字段
add_hidden_opened_experiment(db)
add_pinned_opened_experiment(db)
add_hidden_opened_project(db)
add_pinned_opened_project(db)


def add_sort(db):
"""
为数据库的Chart表添加sort字段,默认为NULL
"""
migrator = SqliteMigrator(db)
sort = IntegerField(default=None, null=True)
migrate(migrator.add_column("chart", "sort", sort))
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
实验结束时间是指实验被标记为完成/崩溃的时间
"""

from peewee import SqliteDatabase, CharField
from peewee import SqliteDatabase, CharField, IntegerField
from playhouse.migrate import migrate, SqliteMigrator


Expand All @@ -24,3 +24,23 @@ def add_finish_time(db):
# 原子化操作,遍历所有status不为0的实验,将其finish_time设置为update_time
for exp in db.execute_sql("select id, update_time from experiment where status != 0"):
db.execute_sql(f"update experiment set finish_time = '{exp[1]}' where id = {exp[0]}")


def add_pinned_opened(db):
"""
为数据库的Experiment表添加pinned_opened字段,用于记录实验置顶部分是否打开
默认值为1,表示打开
"""
migrator = SqliteMigrator(db)
pinned_opened = IntegerField(default=1, choices=[0, 1])
migrate(migrator.add_column("experiment", "pinned_opened", pinned_opened))


def add_hidden_opened(db):
"""
为数据库的Experiment表添加hidden_opened字段,用于记录实验隐藏部分是否打开
默认值为0,表示关闭
"""
migrator = SqliteMigrator(db)
hidden_opened = IntegerField(default=0, choices=[0, 1])
migrate(migrator.add_column("experiment", "hidden_opened", hidden_opened))
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def add_opened(db):
默认值为False
"""
migrator = SqliteMigrator(db)
opened = IntegerField(default=1)
opened = IntegerField(default=1, choices=[0, 1])
migrate(migrator.add_column("namespace", "opened", opened))
31 changes: 31 additions & 0 deletions swanlab/db/migrate/project.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024-03-09 17:43:27
@File: swanlab/db/migrate/project.py
@IDE: vscode
@Description:
为数据库的Project表添加迁移操作
"""
from peewee import IntegerField
from playhouse.migrate import migrate, SqliteMigrator


def add_pinned_opened(db):
"""
为数据库的Project表添加pinned_opened字段,用于记录项目置顶部分是否打开
默认值为1,表示打开
"""
migrator = SqliteMigrator(db)
pinned_opened = IntegerField(default=1, choices=[0, 1])
migrate(migrator.add_column("project", "pinned_opened", pinned_opened))


def add_hidden_opened(db):
"""
为数据库的Project表添加hidden_opened字段,用于记录项目隐藏部分是否打开
默认值为0,表示关闭
"""
migrator = SqliteMigrator(db)
hidden_opened = IntegerField(default=0, choices=[0, 1])
migrate(migrator.add_column("project", "hidden_opened", hidden_opened))
File renamed without changes.
105 changes: 102 additions & 3 deletions swanlab/db/models/charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
实验图标表
"""
from ..model import SwanModel
from peewee import CharField, IntegerField, ForeignKeyField, TextField, IntegrityError, Check, DatabaseProxy
from peewee import CharField, IntegerField, ForeignKeyField, TextField, IntegrityError, Check, DatabaseProxy, fn
from ..error import ExistedError, ForeignExpNotExistedError, ForeignProNotExistedError
from .experiments import Experiment
from .projects import Project
Expand Down Expand Up @@ -47,6 +47,10 @@ class Meta:
"""图表类型,由创建者决定,这与数据库本身无关"""
reference = CharField(max_length=10, default="step")
"""图表数据参考,由创建者决定,这与数据库本身无关"""
status = IntegerField(null=False, default=0)
"""图表状态 0代表默认状态 1代表被pinned -1代表被hidden"""
sort = IntegerField(null=True, default=None)
"""图表在被pinned或hidden时的排序,越小越靠前,越大越靠后,如果为NULL则代表未被pinned或hidden"""
config = TextField(null=True)
"""图表的其他配置,这实际上是一个json字符串"""
more = TextField(null=True)
Expand All @@ -57,16 +61,20 @@ class Meta:
"""更新时间"""

def __dict__(self):
experiment_id = None if self.experiment_id is None else self.experiment_id.__dict__()
project_id = None if self.project_id is None else self.project_id.__dict__()
return {
"id": self.id,
"experiment_id": self.experiment_id,
"project_id": self.project_id,
"experiment_id": experiment_id,
"project_id": project_id,
"name": self.name,
"description": self.description,
"system": self.system,
"type": self.type,
"reference": self.reference,
"config": self.config,
"status": self.status,
"sort": self.sort,
"more": self.more,
"create_time": self.create_time,
"update_time": self.update_time,
Expand Down Expand Up @@ -143,3 +151,94 @@ def create(
)
except IntegrityError:
raise ExistedError("图表已存在")

@classmethod
def pin(cls, id: int, sort: int = None) -> "Chart":
"""
把某个chart置顶

Parameters
----------
id : int
图表id
sort : int
图表在被pinned或hidden时的排序,越小越靠前,越大越靠后,如果为NULL则代表未被pinned或hidden, 默认为None
默认排在最后,也就是传入None的时候

Returns
-------
Chart : Chart
置顶的图表
"""
chart: Chart = cls.get(cls.id == id)
chart.status = 1
if sort is None:
# 判断当前实验/项目下有多少个hidden的chart
project_id = None if chart.project_id is None else chart.project_id.id
experiment_id = None if chart.experiment_id is None else chart.experiment_id.id
# 获取当前namespace下的最大索引,如果没有,则为0
sort = (
cls.select(fn.Max(cls.sort))
.where(cls.project_id == project_id, cls.experiment_id == experiment_id, cls.status == 1)
.scalar()
)
sort = 0 if sort is None else sort + 1
chart.sort = sort
chart.save()
return chart

@classmethod
def restore(cls, id: int) -> "Chart":
"""
把某个chart恢复正常状态

Parameters
----------
id : int
图表id

Returns
-------
Chart : Chart
恢复的图表
"""
chart: Chart = cls.get(cls.id == id)
chart.status = 0
chart.sort = None
chart.save()
return chart

@classmethod
def hide(cls, id: int, sort: int = None) -> "Chart":
"""
把某个chart隐藏

Parameters
----------
id : int
图表id
sort : int
图表在被pinned或hidden时的排序,越小越靠前,越大越靠后,如果为NULL则代表未被pinned或hidden, 默认为None
默认排在最后,也就是传入None的时候

Returns
-------
Chart : Chart
隐藏的图表
"""
chart: Chart = cls.get(cls.id == id)
chart.status = -1
if sort is None:
# 判断当前实验/项目下有多少个hidden的chart
project_id = None if chart.project_id is None else chart.project_id.id
experiment_id = None if chart.experiment_id is None else chart.experiment_id.id
# 获取当前namespace下的最大索引,如果没有,则为0
sort = (
cls.select(fn.Max(cls.sort))
.where(cls.project_id == project_id, cls.experiment_id == experiment_id, cls.status == -1)
.scalar()
)
sort = 0 if sort is None else sort + 1
chart.sort = sort
chart.save()
return chart
6 changes: 6 additions & 0 deletions swanlab/db/models/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ class Meta:
dark = CharField(max_length=20, null=True)
"""暗色主题颜色"""

pinned_opened = IntegerField(default=1, choices=[0, 1])
"""实验图表置顶部分是否打开,默认值为1,表示打开"""

hidden_opened = IntegerField(default=0, choices=[0, 1])
"""实验图表隐藏部分是否打开,默认值为0,表示关闭"""

more = TextField(null=True)
"""更多信息配置,json格式,将在表函数中检查并解析"""

Expand Down
2 changes: 1 addition & 1 deletion swanlab/db/models/namespaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Meta:
"""命名空间描述,可为空"""
sort = IntegerField()
"""命名空间索引,用于排序,同一个项目/实验下,命名空间索引不能重复,索引越小,排序越靠前,索引>=0"""
opened = IntegerField(default=1)
opened = IntegerField(default=1, choices=[0, 1])
"""命名空间是否已经被打开,默认为1,表示已经被打开,0表示未被打开"""
more = TextField(default=None, null=True)
"""更多信息配置,json格式,将在表函数中检查并解析"""
Expand Down
7 changes: 7 additions & 0 deletions swanlab/db/models/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ class Meta:
"""项目下实验数量,包括已删除的实验,这是一个只增不减的值"""
charts = IntegerField(default=0, choices=[0, 1], null=False)
"""是否已经生成项目级别图表,0 未生成,1 已生成"""
pinned_opened = IntegerField(default=1, choices=[0, 1])
"""多实验图表置顶部分是否打开,默认值为1,表示打开"""
hidden_opened = IntegerField(default=0, choices=[0, 1])
"""多实验图表隐藏部分是否打开,默认值为0,表示关闭"""
more = CharField(null=True)
"""更多信息配置,json格式"""
version = CharField(max_length=30, null=False)
Expand All @@ -59,6 +63,9 @@ def __dict__(self):
"sum": self.sum,
"charts": self.charts,
"more": self.more,
"pinned_opened": self.pinned_opened,
"hidden_opened": self.hidden_opened,
"version": self.version,
"create_time": self.create_time,
"update_time": self.update_time,
}
Expand Down
2 changes: 2 additions & 0 deletions swanlab/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def _(*args, **kwargs):
from .router.experiment import router as experiment
from .router.project import router as project
from .router.namespace import router as namespace
from .router.chart import router as chart

# 媒体文件路由,允许前端获取其他产生的媒体文件
from .router.media import router as media
Expand All @@ -96,3 +97,4 @@ async def _(*args, **kwargs):
app.include_router(experiment, prefix=prefix + "/experiment")
app.include_router(media, prefix=prefix + "/media")
app.include_router(namespace, prefix=prefix + "/namespace")
app.include_router(chart, prefix=prefix + "/chart")
44 changes: 44 additions & 0 deletions swanlab/server/controller/chart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024-03-10 13:53:16
@File: swanlab/server/controller/chart.py
@IDE: vscode
@Description:
图表相关操作api
"""
from .db import Chart
from .utils import get_exp_charts, get_proj_charts
from ..module import SUCCESS_200, PARAMS_ERROR_422


def update_charts_status(chart_id: int, status: int):
"""
更新图表状态,status=1表示置顶,status=0表示正常状态,status=-1表示隐藏

Parameters
----------
chart_id : int
图表id
status : int
状态,1表示置顶,0表示正常状态,-1表示隐藏
"""
chart: Chart = Chart.get_by_id(chart_id)

if not chart:
return PARAMS_ERROR_422("chart_id not exist")
if status == 1:
Chart.pin(id=chart_id)
elif status == -1:
Chart.hide(id=chart_id)
else:
Chart.restore(id=chart_id)
if chart.project_id:
chart_list, namespace_list = get_proj_charts(chart.project_id.id)
else:
chart_list, namespace_list = get_exp_charts(chart.experiment_id.id)
for namespace in namespace_list:
for j, chart_id in enumerate(namespace["charts"]):
# 在charts中找到对应的chart_id
namespace["charts"][j] = next((x for x in chart_list if x["id"] == chart_id), None)
return SUCCESS_200({"groups": namespace_list})
10 changes: 10 additions & 0 deletions swanlab/server/controller/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
r"""
@DATE: 2024-03-09 21:44:40
@File: swanlab/server/controller/db.py
@IDE: vscode
@Description:
数据库,从数据库模块中导入数据库相关函数,供其他模块调用,不用每次都通过相对路径导入了
"""
from ...db import *
Loading