Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions swanlab/data/run/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,28 @@
@Description:
在此处定义SwanLabRun类并导出
"""
import random
from datetime import datetime
from enum import Enum
from typing import Callable, Optional, Dict

from swankit.core import SwanLabSharedSettings
from swanlab.log import swanlog

from swanlab.data.modules import MediaType, DataWrapper, FloatConvertible, Line
from .system import get_system_info, get_requirements
from swanlab.env import get_mode, get_swanlog_dir
from swanlab.log import swanlog
from swanlab.package import get_package_version
from . import namer as N
from .config import SwanLabConfig
from enum import Enum
from .exp import SwanLabExp
from datetime import datetime
from typing import Callable, Optional, Dict
from .helper import SwanLabRunOperator, RuntimeInfo
from .public import SwanLabPublicConfig
from .system import get_system_info, get_requirements
from ..formater import check_key_format, check_exp_name_format, check_desc_format
from swanlab.env import get_mode, get_swanlog_dir
from . import namer as N
import random

MAX_LIST_LENGTH = 108


class SwanLabRunState(Enum):
"""SwanLabRunState is an enumeration class that represents the state of the experiment.
We Recommend that you use this enumeration class to represent the state of the experiment.
Expand Down Expand Up @@ -94,6 +98,7 @@ def __init__(
should_save=not self.__operator.disabled,
version=get_package_version(),
)
self.__public = SwanLabPublicConfig(self.__project_name, self.__settings)
self.__operator.before_run(self.__settings)
# ---------------------------------- 初始化日志记录器 ----------------------------------
swanlog.level = self.__check_log_level(log_level)
Expand Down Expand Up @@ -129,17 +134,13 @@ def _(state: SwanLabRunState):
self.__operator.on_runtime_info_update(
RuntimeInfo(
requirements=get_requirements(),
metadata=get_system_info(get_package_version(), self.settings.log_dir),
metadata=get_system_info(get_package_version(), self.__settings.log_dir),
)
)

@property
def operator(self) -> SwanLabRunOperator:
return self.__operator

@property
def project_name(self) -> str:
return self.__project_name
def public(self):
return self.__public

@property
def mode(self) -> str:
Expand Down Expand Up @@ -210,7 +211,7 @@ def finish(state: SwanLabRunState = SwanLabRunState.SUCCESS, error=None):
_set_run_state(state)
error = error if state == SwanLabRunState.CRASHED else None
# 退出回调
run.operator.on_stop(error)
getattr(run, "_SwanLabRun__operator").on_stop(error)
try:
swanlog.uninstall()
except RuntimeError:
Expand All @@ -225,14 +226,6 @@ def finish(state: SwanLabRunState = SwanLabRunState.SUCCESS, error=None):

return _run

@property
def settings(self) -> SwanLabSharedSettings:
"""
This property allows you to access the 'settings' content passed through `init`,
and runtime settings can not be modified.
"""
return self.__settings

@property
def config(self) -> SwanLabConfig:
"""
Expand Down Expand Up @@ -353,10 +346,10 @@ def __register_exp(
description = "" if description is None else description
colors = N.generate_colors(num)
self.__operator.before_init_experiment(self.__run_id, experiment_name, description, num, colors)
self.settings.exp_name = experiment_name
self.settings.exp_colors = colors
self.settings.description = description
return SwanLabExp(self.settings, operator=self.__operator)
self.__settings.exp_name = experiment_name
self.__settings.exp_colors = colors
self.__settings.description = description
return SwanLabExp(self.__settings, operator=self.__operator)

@staticmethod
def __check_log_level(log_level: str) -> str:
Expand Down
148 changes: 148 additions & 0 deletions swanlab/data/run/public.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from swankit.core import SwanLabSharedSettings

from swanlab.api import get_http
from swanlab.package import get_project_url, get_experiment_url


class SwanlabCloudConfig:
"""
public data for the SwanLab project when running in cloud mode.
"""

def __init__(self):
self.__http = None

def __get_property_from_http(self, name: str):
"""
Get the property from the http object.
if the http object is None, it will be initialized.
if initialization fails, it will return None.
"""
if self.available:
return getattr(self.__http, name)
return None

@property
def available(self):
"""
Whether the SwanLab is running in cloud mode.
"""
try:
if self.__http is None:
self.__http = get_http()
return True
except ValueError:
return False

@property
def project_name(self):
"""
The name of the project. Equal to `run.public.project_name`.
If swanlab is not running in cloud mode, it will return None.
"""
return self.__get_property_from_http("projname")

@property
def project_url(self):
"""
The url of the project. It is the url of the project page on the SwanLab.
If swanlab is not running in cloud mode, it will return None.
"""
if not self.available:
return None
groupname = self.__get_property_from_http("groupname")
projname = self.__get_property_from_http("projname")
return get_project_url(groupname, projname)

@property
def experiment_name(self):
"""
The name of the experiment. It may be different from the name of swanboard.
"""
return self.__get_property_from_http("expname")

@property
def experiment_url(self):
"""
The url of the experiment. It is the url of the experiment page on the SwanLab.
"""
if not self.available:
return None
groupname = self.__get_property_from_http("groupname")
projname = self.__get_property_from_http("projname")
exp_id = self.__get_property_from_http("exp_id")
return get_experiment_url(groupname, projname, exp_id)


class SwanLabPublicConfig:
"""
Public data for the SwanLab project.
"""

def __init__(self, project_name: str, settings: SwanLabSharedSettings):
self.__project_name = project_name
self.__cloud = SwanlabCloudConfig()
self.__settings = settings

def json(self):
"""
Return a dict of the public config.
This method is used to serialize the public config to json.
"""
return {
"project_name": self.project_name,
"version": self.version,
"run_id": self.run_id,
"swanlog_dir": self.swanlog_dir,
"run_dir": self.run_dir,
"cloud": {
"project_name": self.cloud.project_name,
"project_url": self.cloud.project_url,
"experiment_name": self.cloud.experiment_name,
"experiment_url": self.cloud.experiment_url,
},
}

@property
def cloud(self):
"""
The cloud configuration.
"""
return self.__cloud

@property
def project_name(self):
"""
The name of the project. Equal to `run.public.project_name`.
"""
return self.__project_name

# ---------------------------------- 继承settings的属性 ----------------------------------

@property
def version(self) -> str:
"""
The version of the SwanLab.
"""
return self.__settings.version

@property
def run_id(self) -> str:
"""
The id of the run.
"""
return self.__settings.run_id

@property
def swanlog_dir(self) -> str:
"""
The directory of the SwanLab log.
"""
return self.__settings.swanlog_dir

@property
def run_dir(self) -> str:
"""
The directory of the run.
"""
return self.__settings.run_dir
43 changes: 24 additions & 19 deletions test/unit/data/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
@Description:
测试sdk的一些api
"""
import tutils as T
import os

import pytest
from nanoid import generate

import swanlab.data.sdk as S
import swanlab.error as Err
from swanlab.log import swanlog
from swanlab.env import SwanLabEnv, get_save_dir
import tutils as T
from swanlab.data.run import get_run
from nanoid import generate
import pytest
import os
from swanlab.env import SwanLabEnv, get_save_dir
from swanlab.log import swanlog


@pytest.fixture(scope="function", autouse=True)
Expand Down Expand Up @@ -76,7 +78,7 @@ def test_init_disabled(self):
assert not os.path.exists(logdir)
assert os.environ[MODE] == "disabled"
run.log({"TestInitMode": 1}) # 不会报错
a = run.settings.run_dir
a = run.public.run_dir
assert not os.path.exists(a)
assert get_run() is not None

Expand All @@ -85,6 +87,7 @@ def test_init_local(self):
assert os.environ[MODE] == "local"
run.log({"TestInitMode": 1}) # 不会报错
assert get_run() is not None
assert run.public.cloud.project_name is None

@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test")
def test_init_cloud(self):
Expand All @@ -93,6 +96,8 @@ def test_init_cloud(self):
assert os.environ[MODE] == "cloud"
run.log({"TestInitMode": 1}) # 不会报错
assert get_run() is not None
for key in run.public.json()['cloud']:
assert run.public.json()['cloud'][key] is not None

def test_init_error(self):
with pytest.raises(ValueError):
Expand All @@ -106,7 +111,7 @@ def test_init_disabled_env(self):
run = S.init()
assert os.environ[MODE] == "disabled"
run.log({"TestInitMode": 1})
a = run.settings.run_dir
a = run.public.run_dir
assert not os.path.exists(a)
assert get_run() is not None

Expand All @@ -131,7 +136,7 @@ def test_init_disabled_env_mode(self):
run = S.init(mode="disabled")
assert os.environ[MODE] == "disabled"
run.log({"TestInitMode": 1})
a = run.settings.run_dir
a = run.public.run_dir
assert not os.path.exists(a)
assert get_run() is not None

Expand All @@ -146,15 +151,15 @@ def test_init_project_none(self):
设置project为None
"""
run = S.init(project=None, mode="disabled")
assert run.project_name == os.path.basename(os.getcwd())
assert run.public.project_name == os.path.basename(os.getcwd())

def test_init_project(self):
"""
设置project为字符串
"""
project = "test_project"
run = S.init(project=project, mode="disabled")
assert run.project_name == project
assert run.public.project_name == project


LOG_DIR = SwanLabEnv.SWANLOG_FOLDER.value
Expand All @@ -171,26 +176,26 @@ def test_init_logdir_disabled(self):
"""
logdir = generate()
run = S.init(logdir=logdir, mode="disabled")
assert run.settings.swanlog_dir != logdir
assert run.settings.swanlog_dir == os.environ[LOG_DIR]
assert run.public.swanlog_dir != logdir
assert run.public.swanlog_dir == os.environ[LOG_DIR]
run.finish()
del os.environ[LOG_DIR]
run = S.init(logdir=logdir, mode="disabled")
assert run.settings.swanlog_dir != logdir
assert run.settings.swanlog_dir == os.path.join(os.getcwd(), "swanlog")
assert run.public.swanlog_dir != logdir
assert run.public.swanlog_dir == os.path.join(os.getcwd(), "swanlog")

def test_init_logdir_enabled(self):
"""
其他模式下设置logdir生效
"""
logdir = os.path.join(T.TEMP_PATH, generate()).__str__()
run = S.init(logdir=logdir, mode="local")
assert run.settings.swanlog_dir == logdir
assert run.public.swanlog_dir == logdir
run.finish()
del os.environ[LOG_DIR]
logdir = os.path.join(T.TEMP_PATH, generate()).__str__()
run = S.init(logdir=logdir, mode="local")
assert run.settings.swanlog_dir == logdir
assert run.public.swanlog_dir == logdir

def test_init_logdir_env(self):
"""
Expand All @@ -199,13 +204,13 @@ def test_init_logdir_env(self):
logdir = os.path.join(T.TEMP_PATH, generate()).__str__()
os.environ[LOG_DIR] = logdir
run = S.init(mode="local")
assert run.settings.swanlog_dir == logdir
assert run.public.swanlog_dir == logdir
run.finish()
del os.environ[LOG_DIR]
logdir = os.path.join(T.TEMP_PATH, generate()).__str__()
os.environ[LOG_DIR] = logdir
run = S.init(mode="local")
assert run.settings.swanlog_dir == logdir
assert run.public.swanlog_dir == logdir


@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test")
Expand Down