Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions swanlab/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Image,
Text,
Video,
Object3D,
)
from .sdk import (
init,
Expand Down
4 changes: 2 additions & 2 deletions swanlab/data/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
from .image import Image
from .text import Text
from .video import Video
from .object_3d import Object3D

# from .video import Video
from typing import Protocol, Union


class FloatConvertible(Protocol):
def __float__(self) -> float:
...
def __float__(self) -> float: ...


DataType = Union[float, FloatConvertible, int, BaseType]
2 changes: 2 additions & 0 deletions swanlab/data/modules/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ class Chart:
text = "text", [list, str]
# 视频类型 list代表一步多视频
video = "video", [list, str]
# 3D点云类型,list代表一步多3D点云
object3d = "object3d", [list, str]
191 changes: 191 additions & 0 deletions swanlab/data/modules/object_3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# -*- coding: utf-8 -*-
"""
Author: nexisato
Date: 2024-02-21 17:52:22
FilePath: /SwanLab/swanlab/data/modules/object_3d.py
Description:
3D Point Cloud data parsing
"""

from .base import BaseType
import numpy as np
from typing import Union, ClassVar, Set, List, Optional
from ..utils.file import get_file_hash_numpy_array, get_file_hash_path
import os
import json
import shutil
from io import BytesIO

# 格式化输出 json
import codecs


class Object3D(BaseType):
"""Object 3D class constructor

Parameters
----------
data_or_path: numpy.array, string, io
Path to an object3d format file or numpy array of object3d or Bytes.IO.

numpy.array: 3D point cloud data, shape (N, 3), (N, 4) or (N, 6).
(N, 3) : N * (x, y, z) coordinates
(N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14]
(N, 6) : N * (x, y, z, r, g, b) coordinates
caption: str
caption associated with the object3d for display
"""

SUPPORTED_TYPES: ClassVar[Set[str]] = {
"obj",
"gltf",
"glb",
"babylon",
"stl",
"pts.json",
}

def __init__(
self,
data_or_path: Union[str, "np.ndarray", "BytesIO", List["Object3D"]],
caption: Optional[str] = None,
):
super().__init__(data_or_path)
self.object3d_data = None
self.caption = self.__convert_caption(caption)
self.extension = None

def get_data(self):
# 如果传入的是Object3D类列表
if isinstance(self.value, list):
return self.get_data_list()

self.object3d_data = self.__preprocess(self.value)

# 根据不同的输入类型进行不同的哈希校验
hash_name = (
get_file_hash_numpy_array(self.object3d_data)[:16]
if isinstance(self.object3d_data, np.ndarray)
else get_file_hash_path(self.object3d_data)[:16]
)

save_dir = os.path.join(self.settings.static_dir, self.tag)
save_name = (
f"{self.caption}-step{self.step}-{hash_name}.{self.extension}"
if self.caption is not None
else f"object3d-step{self.step}-{hash_name}.{self.extension}"
)
# 如果不存在目录则创建
if os.path.exists(save_dir) is False:
os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name)

self.__save(save_path)
return save_name

def __preprocess(self, data_or_path):
"""根据输入不同的输入类型进行不同处理"""
# 如果类型为 str,进行文件后缀格式检查
if isinstance(data_or_path, str):
extension = None
for SUPPORTED_TYPE in Object3D.SUPPORTED_TYPES:
if data_or_path.endswith(SUPPORTED_TYPE):
extension = SUPPORTED_TYPE
break
if not extension:
raise TypeError(
"File '"
+ data_or_path
+ "' is not compatible with Object3D: supported types are: "
+ ", ".join(Object3D.SUPPORTED_TYPES)
)
self.extension = extension
return data_or_path
# 如果类型为 io.BytesIO 二进制流,直接返回
elif isinstance(data_or_path, BytesIO):
self.extension = "pts.json"
return data_or_path

# 如果类型为 numpy.array,进行numpy格式检查
elif isinstance(data_or_path, np.ndarray):
if len(data_or_path.shape) != 2 or data_or_path.shape[1] not in {3, 4, 6}:
raise TypeError(
"""
The shape of the numpy array must be one of either:
(N, 3) : N * (x, y, z) coordinates
(N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14]
(N, 6) : N * (x, y, z, r, g, b) coordinates
"""
)
self.extension = "pts.json"
return data_or_path
else:
raise TypeError("swanlab.Object3D accepts a file path or numpy like data as input")

def __convert_caption(self, caption):
"""将caption转换为字符串"""
# 如果类型是字符串,则不做转换
if isinstance(caption, str):
caption = caption
# 如果类型是数字,则转换为字符串
elif isinstance(caption, (int, float)):
caption = str(caption)
# 如果类型是None,则转换为默认字符串
elif caption is None:
caption = None
else:
raise TypeError("caption must be a string, int or float.")
return caption

def __save_numpy(self, save_path):
"""保存 numpy.array 格式的 3D点云资源文件 .pts.json 到指定路径"""
try:
list_data = self.object3d_data.tolist()
with codecs.open(save_path, "w", encoding="utf-8") as fp:
json.dump(
list_data,
fp,
separators=(",", ":"),
sort_keys=True,
indent=4,
)
except Exception as e:
raise TypeError(f"Could not save the 3D point cloud data to the path: {save_path}") from e

def __save(self, save_path):
"""
保存 3D点云资源文件到指定路径
"""
if isinstance(self.object3d_data, str):
shutil.copy(self.object3d_data, save_path)
elif isinstance(self.object3d_data, BytesIO):
with open(save_path, "wb") as f:
f.write(self.object3d_data.read())
elif isinstance(self.object3d_data, np.ndarray):
self.__save_numpy(save_path)

def get_more(self, *args, **kwargs) -> dict:
"""返回config数据"""
# 如果传入的是Objet3d类列表
if isinstance(self.value, list):
return self.get_more_list()
else:
return (
{
"caption": self.caption,
}
if self.caption is not None
else None
)

def expect_types(self, *args, **kwargs) -> list:
"""返回支持的文件类型"""
return ["str", "numpy.array", "io"]

def get_namespace(self, *args, **kwargs) -> str:
"""设定分组名"""
return "Object3D"

def get_chart_type(self) -> str:
"""设定图表类型"""
return self.chart.object3d
50 changes: 46 additions & 4 deletions test/create_experiment.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,61 @@
import swanlab
import random
import numpy as np

epochs = 50
lr = 0.01
offset = random.random() / 5

run = swanlab.init(
experiment_name="Example",
description="这是一个机器学习模拟实验",
config={
"learning_rate": 0.01,
"epochs": 20,
"epochs": epochs,
"learning_rate": lr,
"test": 1,
"debug": "这是一串" + "很长" * 100 + "的字符串",
"verbose": 1,
},
logggings=True,
)

# 模拟机器学习训练过程
for epoch in range(2, run.config.epochs):

def generate_random_nx3(n):
"""生成形状为nx3的随机数组"""
return np.random.rand(n, 3)


def generate_random_nx4(n):
"""生成形状为nx4的随机数组,最后一列是[1,14]范围内的整数分类"""
xyz = np.random.rand(n, 3)
c = np.random.randint(1, 15, size=(n, 1))
return np.hstack((xyz, c))


def generate_random_nx6(n):
"""生成形状为nx6的随机数组,包含RGB颜色"""
xyz = np.random.rand(n, 3)
rgb = np.random.rand(n, 3) # RGB颜色值也可以是[0,1]之间的随机数
rgb = (rgb * 255).astype(np.uint8) # 转换为[0,255]之间的整数
return np.hstack((xyz, rgb))


for epoch in range(2, epochs):
if epoch % 10 == 0:

# swanlab.log(
# {
# "test/object3d1":
# },
# step=epoch,
# )
swanlab.log(
{
"test-object3d1": swanlab.Object3D("./assets/bunny.obj", caption="bunny-obj"),
"test-object3d2": swanlab.Object3D("./assets/test1.pts.json", caption="test1-pts"),
},
step=epoch,
)
acc = 1 - 2**-epoch - random.random() / epoch - offset
loss = 2**-epoch + random.random() / epoch + offset
swanlab.log({"loss": loss, "accuracy": acc})