Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions swanlab/data/run/metadata/hardware/gpu/nvidia.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,14 @@ def __init__(self, count: int):
self.per_gpu_configs[self.gpu_power_key].append(power_config.clone(metric_name=metric_name))
self.per_gpu_configs[self.gpu_util_key].append(util_config.clone(metric_name=metric_name))

@HardwareCollector.try_run()
def get_gpu_config(self, key: str, idx: int) -> HardwareConfig:
"""
获取 某个GPU的某个配置信息
"""
return self.per_gpu_configs[key][idx]

@HardwareCollector.try_run()
def get_gpu_util(self, idx: int) -> HardwareInfo:
"""
获取 GPU 利用率
Expand All @@ -118,6 +120,7 @@ def get_gpu_util(self, idx: int) -> HardwareInfo:
"config": self.get_gpu_config(self.gpu_util_key, idx),
}

@HardwareCollector.try_run()
def get_gpu_mem_pct(self, idx: int) -> HardwareInfo:
"""
获取 GPU 内存使用率
Expand All @@ -132,6 +135,7 @@ def get_gpu_mem_pct(self, idx: int) -> HardwareInfo:
"config": self.get_gpu_config(self.gpu_mem_pct_key, idx),
}

@HardwareCollector.try_run()
def get_gpu_temp(self, idx: int) -> HardwareInfo:
"""
获取 GPU 温度
Expand All @@ -145,6 +149,7 @@ def get_gpu_temp(self, idx: int) -> HardwareInfo:
"config": self.get_gpu_config(self.gpu_temp_key, idx),
}

@HardwareCollector.try_run()
def get_gpu_power(self, idx: int) -> HardwareInfo:
"""
获取 GPU 功耗
Expand Down
36 changes: 32 additions & 4 deletions swanlab/data/run/metadata/hardware/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from abc import ABC, abstractmethod
from typing import TypedDict, Tuple, Optional, Any, List, Union

from swankit.callback import YRange
from swankit.callback.models import ColumnConfig
from swankit.callback.models import ColumnConfig, YRange

from swanlab.data.run.namer import generate_colors
from swanlab.log import swanlog
Expand Down Expand Up @@ -121,16 +120,45 @@ def after_collect_impl(self):
class HardwareCollector(CollectGuard, ABC):
@abstractmethod
def collect(self) -> HardwareInfoList:
"""
采集硬件信息的主函数,子类覆写此方法实现具体的硬件信息采集业务逻辑
"""
pass

@staticmethod
def try_run():
"""
一个装饰器
可以选择对某个采集任务的单个函数进行异常捕获,避免因为某个采集任务失败导致整个采集任务失败
"""

def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception as e:
swanlog.debug(f"Atoms collection failed: {func.__name__}, {str(e)}")
return None

return wrapper

return decorator

def __call__(self) -> Optional[HardwareInfoList]:
"""
聚合执行采集任务,包括采集任务的前置和后置操作
"""
try:
self.before_collect()
return self.collect()
result = self.collect()
if result is None:
return None
# 过滤掉采集结果中的None
return [r for r in result if r is not None]
except NotImplementedError as n:
raise n
except Exception as e:
swanlog.debug(f"Hardware info collection failed: {self.__class__.__name__}, {str(e)}")
swanlog.debug(f"Collection failed: {self.__class__.__name__}, {str(e)}")
return None
finally:
self.after_collect()
Expand Down
23 changes: 15 additions & 8 deletions swanlab/data/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,24 @@
"""
import os
from typing import Optional, Union, Dict, Tuple, Literal

from swanboard import SwanBoardCallback
from swankit.env import SwanLabMode

from swanlab.api import code_login
from swanlab.env import SwanLabEnv
from swanlab.log import swanlog
from .callback_cloud import CloudRunCallback
from .callback_local import LocalRunCallback
from .formater import check_load_json_yaml, check_proj_name_format
from .modules import DataType
from .run import (
SwanLabRunState,
SwanLabRun,
register,
get_run,
)
from .formater import check_load_json_yaml, check_proj_name_format
from .callback_cloud import CloudRunCallback
from .callback_local import LocalRunCallback
from .run.helper import SwanLabRunOperator
from swanlab.log import swanlog
from swanlab.api import code_login
from swanlab.env import SwanLabEnv
from swankit.env import SwanLabMode
from swanboard import SwanBoardCallback


def _check_proj_name(name: str) -> str:
Expand Down Expand Up @@ -150,6 +152,11 @@ def init(
project = _load_data(load_data, "project", project)
workspace = _load_data(load_data, "workspace", workspace)
public = _load_data(load_data, "private", public)

# ---------------------------------- 模式选择 ----------------------------------
# for

# ---------------------------------- helper初始化 ----------------------------------
operator, c = _create_operator(mode, public)
project = _check_proj_name(project if project else os.path.basename(os.getcwd())) # 默认实验名称为当前目录名
exp_num = SwanLabRunOperator.parse_return(
Expand Down
51 changes: 48 additions & 3 deletions test/unit/data/run/metadata/hardware/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,42 @@
@description: 硬件信息采集工具测试
"""

from swanlab.data.run.metadata.hardware.type import HardwareCollector, CollectGuard
from swanlab.data.run.metadata.hardware.type import HardwareCollector, CollectGuard, HardwareInfoList
from swanlab.data.run.metadata.hardware.type import HardwareInfo


def test_try_run():
"""
测试try_run包装器
"""
data = {"key": "test", "value": 1, "name": "test", "config": None}

class TestTryRun(HardwareCollector):
@HardwareCollector.try_run()
def collect(self) -> HardwareInfo:
return data

t = TestTryRun()
assert t.collect() == data

class TestErrorTryRun(HardwareCollector):
@HardwareCollector.try_run()
def collect(self) -> HardwareInfo:
raise Exception("test")

t = TestErrorTryRun()
assert t() is None


def test_hardware():
data = {"key": "test", "value": 1, "name": "test", "config": None}

class TestCollector(HardwareCollector):
def collect(self) -> HardwareInfo:
return {"key": "test", "value": 1, "name": "test", "config": None}
return data

t = TestCollector()
assert t.collect() == {"key": "test", "value": 1, "name": "test", "config": None}
assert t.collect() == data

class TestErrorCollector(HardwareCollector):
def collect(self) -> HardwareInfo:
Expand All @@ -24,6 +49,26 @@ def collect(self) -> HardwareInfo:
t = TestErrorCollector()
assert t() is None

# 采集任务中有部分采集失败,但是不影响其他采集任务
class TestErrorCollector(HardwareCollector):
@staticmethod
def collect_first() -> HardwareInfo:
return data

@HardwareCollector.try_run()
def collect_second(self) -> HardwareInfo:
raise Exception("test")

def collect(self) -> HardwareInfoList:
return [
self.collect_first(),
self.collect_second(),
]

t = TestErrorCollector()
assert t.collect() == [data, None]
assert t() == [data]


def test_collect_guard():
class TestGuard(CollectGuard):
Expand Down
Loading