Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions swanlab/server/api/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""
from datetime import datetime
from fastapi import APIRouter, Request
from ..module.resp import SUCCESS_200, NOT_FOUND_404
from ..module.resp import SUCCESS_200, NOT_FOUND_404, PARAMS_ERROR_422
import os
import ujson
from urllib.parse import quote, unquote # 转码路径参数
Expand Down Expand Up @@ -377,10 +377,16 @@ async def update_experiment_config(experiment_id: int, request: Request):
experiment_index = experiment["index"] - 1
# 修改实验名称
if not experiment["name"] == body["name"]:
# 修改实验名称
if not experiment["name"] == body["name"]:
# 检测实验名是否重复
for expr in project["experiments"]:
if expr["name"] == body["name"]:
return PARAMS_ERROR_422("Experiment's target name already exists")
project["experiments"][experiment_index]["name"] = body["name"]
# 修改实验目录名
old_path = os.path.join(PROJECT_PATH, experiment["name"])
new_path = os.path.join(PROJECT_PATH, body["name"])
old_path = os.path.join(SWANLOG_DIR, experiment["name"])
new_path = os.path.join(SWANLOG_DIR, body["name"])
os.rename(old_path, new_path)
# 修改实验描述
if not experiment["description"] == body["description"]:
Expand Down
6 changes: 3 additions & 3 deletions swanlab/server/api/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from swanlab.env import get_swanlog_dir
import ujson
from urllib.parse import unquote
from ..settings import PROJECT_PATH

router = APIRouter()

Expand Down Expand Up @@ -75,8 +76,7 @@ async def summaries(experiment_names: str):
@router.patch("/update")
async def update(request: Request):
body = await request.json()
file_path = os.path.join(swc.root, "project.json")
with open(file_path, "r") as f:
with open(PROJECT_PATH, "r") as f:
project = ujson.load(f)
# 检查名字
if "name" in project and project["name"] == body["name"]:
Expand All @@ -90,6 +90,6 @@ async def update(request: Request):
project.update({"description": body["description"]})
# project["update_time"] = create_time()
# 写入文件
with get_a_lock(file_path, "w") as f:
with get_a_lock(PROJECT_PATH, "w") as f:
ujson.dump(project, f, indent=4, ensure_ascii=False)
return SUCCESS_200({"project": project})