Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions swanlab/cli/commands/ sync/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
@author: cunyue
@file: __init__.py
@time: 2025/6/5 14:03
@description: 同步本地数据到云端
"""
2 changes: 1 addition & 1 deletion swanlab/data/callbacker/backup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
@author: cunyue
@file: backup.py
@file: crypto.py
@time: 2025/6/2 15:07
@description: 日志备份回调
"""
Expand Down
2 changes: 1 addition & 1 deletion swanlab/log/backup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@description: 日志备份处理模块
"""

from .backup_handler import BackupHandler, backup
from .handler import BackupHandler, backup


__all__ = ["BackupHandler", "backup"]
229 changes: 229 additions & 0 deletions swanlab/log/backup/datastore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""
@author: cunyue
@file: datastore.py
@time: 2025/6/5 16:32
@description: 记录的数据遵循LevelDB格式:https://github.com/google/leveldb/blob/main/doc/log_format.md
我们使用 crc32 计算数据校验和,crc32 相对轻量,且计算速度较快
字符编码使用 utf-8, 确保数据兼容性 (这与 LevelDB 的规范有所冲突,如果有必要,未来可以升级版本并通过LEVELDBLOG_HEADER_VERSION兼容)
这为后续引入 protobuf 或其他序列化格式打下基础
DataStore 大致代码借鉴自 W&B
"""

import os
import struct
import zlib
from typing import Optional, Any, IO

LEVELDBLOG_HEADER_LEN = 7
LEVELDBLOG_BLOCK_LEN = 32768
LEVELDBLOG_DATA_LEN = LEVELDBLOG_BLOCK_LEN - LEVELDBLOG_HEADER_LEN

LEVELDBLOG_FULL = 1
LEVELDBLOG_FIRST = 2
LEVELDBLOG_MIDDLE = 3
LEVELDBLOG_LAST = 4


LEVELDBLOG_HEADER_IDENT = ":SWL"
LEVELDBLOG_HEADER_MAGIC = 0xE1D6 # zlib.crc32(bytes("SwanLab", 'utf-8')) & 0xffff
LEVELDBLOG_HEADER_VERSION = 0


def strtobytes(x):
"""
文件转字符串
"""
return bytes(x, "utf-8")


def bytestostr(x):
return str(x, 'utf-8')


class DataStore:

def __init__(self):
self._filename: Optional[str] = None
self._fp: Optional[IO[Any]] = None
# 当前文件的偏移量
self._index: int = 0
# 当前文件的已刷写偏移量
self._flush_offset = 0
# 日志系统预计算并缓存CRC32校验值,缓存每一个数据类型的CRC32值,分别存在各自的索引位置
self._crc = [0] * (LEVELDBLOG_LAST + 1)
for x in range(1, LEVELDBLOG_LAST + 1):
self._crc[x] = zlib.crc32(strtobytes(chr(x))) & 0xFFFFFFFF

# 是否为扫描模式打开文件
self._opened_for_scan = False
# 当前文件大小(仅在扫描模式下有效)
self._size_bytes: int = 0

# ---------------------------------- 读取 ----------------------------------

def open_for_scan(self, filename: str):
self._filename = filename
self._fp = open(filename, "r+b")
self._index = 0
self._size_bytes = os.stat(filename).st_size
self._opened_for_scan = True
self._read_header()

def _read_header(self):
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
assert (
len(header) == LEVELDBLOG_HEADER_LEN
), f"header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
ident, magic, version = struct.unpack("<4sHB", header)
if ident != strtobytes(LEVELDBLOG_HEADER_IDENT):
raise Exception("Invalid header")
if magic != LEVELDBLOG_HEADER_MAGIC:
raise Exception("Invalid header")
if version != LEVELDBLOG_HEADER_VERSION:
raise Exception("Invalid header")
self._index += len(header)

def _scan_record(self) -> Optional[tuple[int, bytes]]:
"""
扫描一条记录
"""
assert self._opened_for_scan, "file not open for scanning"
# 1. 读取数据头
header = self._fp.read(LEVELDBLOG_HEADER_LEN)
if len(header) == 0:
return None
assert (
len(header) == LEVELDBLOG_HEADER_LEN
), f"record header is {len(header)} bytes instead of the expected {LEVELDBLOG_HEADER_LEN}"
# 2. 解析数据头并校验数据完整性
checksum, data_length, data_type = struct.unpack("<IHB", header)
self._index += LEVELDBLOG_HEADER_LEN
data = self._fp.read(data_length)
checksum_computed = zlib.crc32(data, self._crc[data_type]) & 0xFFFFFFFF
assert checksum == checksum_computed, "record checksum is invalid, data may be corrupt"
self._index += data_length
# 3. 返回数据
return data_type, data

def scan(self) -> Optional[str]:
"""
扫描日志文件,返回一条记录
"""
# 1. 一次读取一条记录,如果剩余空间不足存储数据头,校验并跳过,此为写入的逆操作
offset = self._index % LEVELDBLOG_BLOCK_LEN
space_left = LEVELDBLOG_BLOCK_LEN - offset
if space_left < LEVELDBLOG_HEADER_LEN:
pad_check = strtobytes("\x00" * space_left)
pad = self._fp.read(space_left)
# 校验必须为0
assert pad == pad_check, "invalid padding"
self._index += space_left
# 2. 扫描一条记录
record = self._scan_record()
if record is None: # eof
return None
dtype, data = record
if dtype == LEVELDBLOG_FULL:
return bytestostr(data)
# 3. 如果是第一条记录,则继续扫描直到找到最后一条记录
assert dtype == LEVELDBLOG_FIRST, f"expected record to be type {LEVELDBLOG_FIRST} but found {dtype}"
while True:
record = self._scan_record()
if record is None: # eof
return None
dtype, new_data = record
if dtype == LEVELDBLOG_LAST:
data += new_data
break
assert dtype == LEVELDBLOG_MIDDLE, f"expected record to be type {LEVELDBLOG_MIDDLE} but found {dtype}"
data += new_data
return bytestostr(data)

# ---------------------------------- 写入 ----------------------------------

def open_for_write(self, filename: str):
self._filename = filename
self._fp = open(filename, "xb")
# 写入文件头, 长度等于 LEVELDBLOG_HEADER_LEN
data = struct.pack(
"<4sHB",
strtobytes(LEVELDBLOG_HEADER_IDENT),
LEVELDBLOG_HEADER_MAGIC,
LEVELDBLOG_HEADER_VERSION,
)
assert len(data) == LEVELDBLOG_HEADER_LEN, f"header size is {len(data)} bytes, expected {LEVELDBLOG_HEADER_LEN}"
self._fp.write(data)
self._index += len(data)

def _write_record(self, data: bytes, data_type: int = LEVELDBLOG_FULL):
"""
写入记录到日志文件
"""
assert len(data) + LEVELDBLOG_HEADER_LEN <= (
LEVELDBLOG_BLOCK_LEN - self._index % LEVELDBLOG_BLOCK_LEN
), "not enough space to write new records"
data_length = len(data)
# 计算校验值,校验值为对数据和数据类型的 CRC32 校验和
checksum = zlib.crc32(data, self._crc[data_type]) & 0xFFFFFFFF
# 写入数据头,格式为:<IHB>,分别表示校验和、数据长度和数据类型
# I: unsigned int (4 bytes), H: unsigned short (2 bytes), B: unsigned char (1 byte)
self._fp.write(struct.pack("<IHB", checksum, data_length, data_type))
if data_length:
self._fp.write(data)
self._index += LEVELDBLOG_HEADER_LEN + len(data)

def write(self, s: str):
"""
写入数据到日志文件,遵循 LevelDB 规范
:param s: 要写入的数据,必须是字符串形式
:return: 返回写入的起始偏移量、当前偏移量和已刷写偏移量
"""
data = strtobytes(s)
# 1. 计算偏移量
start_offset = self._index
offset = self._index % LEVELDBLOG_BLOCK_LEN
space_left = LEVELDBLOG_BLOCK_LEN - offset
data_used = 0
data_left = len(data)
# 2. 剩余长度小于数据头长度则填充0,归位到下一个块
if space_left < LEVELDBLOG_HEADER_LEN:
pad = "\x00" * space_left
self._fp.write(strtobytes(pad))
self._index += space_left
space_left = LEVELDBLOG_BLOCK_LEN
# 3. 如果剩余长度大于等于数据长度,则直接写入
if data_left + LEVELDBLOG_HEADER_LEN <= space_left:
self._write_record(data)
# 4. 否则需要分块写入(注意此时我们可能在一个块的中间)
else:
# 4.1 写入第一个数据块,确保接下来数据独占一个块
data_room = space_left - LEVELDBLOG_HEADER_LEN
self._write_record(data[:data_room], LEVELDBLOG_FIRST)
data_used += data_room
data_left -= data_room
assert data_left, "data_left should be non-zero"
# 4.2 写入中间数据
while data_left > LEVELDBLOG_DATA_LEN:
self._write_record(
data[data_used : data_used + LEVELDBLOG_DATA_LEN],
LEVELDBLOG_MIDDLE,
)
data_used += LEVELDBLOG_DATA_LEN
data_left -= LEVELDBLOG_DATA_LEN
# 4.3 写入最后一个数据块
self._write_record(data[data_used:], LEVELDBLOG_LAST)
# 刷写完整数据
self._fp.flush()
os.fsync(self._fp.fileno())
self._flush_offset = self._index

return start_offset, self._index, self._flush_offset

# ---------------------------------- 辅助函数 ----------------------------------

def ensure_flushed(self) -> None:
self._fp.flush()

def close(self):
# 关闭文件句柄
self._fp.close()
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
"""
@author: cunyue
@file: backup_handler.py
@file: handler.py
@time: 2025/6/4 12:43
@description: 备份处理器,负责出入日志备份写入操作
@description: 备份处理器,负责日志编解码和写入操作
"""

import os.path
import os
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, TextIO
from typing import List, Optional

import wrapt
from swankit.callback import ColumnInfo, MetricInfo, RuntimeInfo
from swankit.env import create_time

from swanlab.log.backup.datastore import DataStore
from swanlab.log.backup.models import Experiment, Log, Project, Column, Runtime, Metric, Header
from swanlab.log.backup.writer import write_media_buffer, write_runtime_info
from swanlab.log.type import LogData
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(self, enable: bool = True, backup_type: str = "DEFAULT", save_media
# 线程执行器
self.executor: Optional[ThreadPoolExecutor] = None
# 日志文件写入句柄
self.f: Optional[TextIO] = None
self.f = DataStore()
# 运行时文件备份目录
self.files_dir: Optional[str] = None
self.save_media: bool = save_media
Expand All @@ -98,16 +99,15 @@ def start(self, run_dir: str, files_dir: str, exp_name: str, description: str, t
# 2. 避免多线程写入同一文件导致数据混乱
# 3. 部分用户会将 swanlog 文件夹挂载在 NAS 等对写入并发有限制的存储设备上
self.executor = ThreadPoolExecutor(max_workers=1)
self.f = open(os.path.join(run_dir, "backup.swanlab"), "a", encoding="utf-8")
self.f.open_for_write(os.path.join(run_dir, "backup.swanlab"))
self.f.write(
Header.model_validate(
{
"create_time": create_time(),
"version": get_package_version(),
"backup_type": self.backup_type,
}
).to_backup()
+ "\n"
).to_record()
)
self.backup_proj()
self.backup_exp(exp_name, description, tags)
Expand All @@ -124,8 +124,9 @@ def stop(self, epoch: int, error: str = None):
# 如果有错误信息则在日志中记录
if error is not None:
log = Log.model_validate({"level": "ERROR", "message": error, "create_time": create_time(), "epoch": epoch})
self.f.write(log.to_backup() + "\n")
self.f.write(log.to_record())
# 关闭日志文件句柄
self.f.ensure_flushed()
self.f.close()
self.f = None

Expand All @@ -136,7 +137,7 @@ def backup_terminal(self, log_data: LogData):
"""
logs = Log.from_log_data(log_data)
for log in logs:
self.f.write(log.to_backup() + "\n")
self.f.write(log.to_record())

@async_io()
def backup_proj(self):
Expand All @@ -150,7 +151,7 @@ def backup_proj(self):
"public": self.cache_public,
}
)
self.f.write(project.to_backup() + "\n")
self.f.write(project.to_record())

@async_io()
def backup_exp(self, exp_name: str, description: str, tags: List[str]):
Expand All @@ -164,23 +165,23 @@ def backup_exp(self, exp_name: str, description: str, tags: List[str]):
"tags": tags,
}
)
self.f.write(experiment.to_backup() + "\n")
self.f.write(experiment.to_record())

@async_io()
def backup_column(self, column_info: ColumnInfo):
"""
备份指标列信息
"""
column = Column.from_column_info(column_info)
self.f.write(column.to_backup() + "\n")
self.f.write(column.to_record())

@async_io()
def backup_runtime(self, runtime_info: RuntimeInfo):
"""
备份运行时信息
"""
runtime = Runtime.from_runtime_info(runtime_info)
self.f.write(runtime.to_backup() + "\n")
self.f.write(runtime.to_record())
write_runtime_info(self.files_dir, runtime_info)

@async_io()
Expand All @@ -189,7 +190,7 @@ def backup_metric(self, metric_info: MetricInfo):
备份指标信息
"""
metric = Metric.from_metric_info(metric_info)
self.f.write(metric.to_backup() + "\n")
self.f.write(metric.to_record())
if self.save_media:
write_media_buffer(metric_info)

Expand Down
Loading
Loading