-
Couldn't load subscription status.
- Fork 158
Feat/object3d chart #382
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Feat/object3d chart #382
Changes from 3 commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| Image, | ||
| Text, | ||
| Video, | ||
| Object3D, | ||
| ) | ||
| from .sdk import ( | ||
| init, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| Author: nexisato | ||
| Date: 2024-02-21 17:52:22 | ||
| FilePath: /SwanLab/swanlab/data/modules/object_3d.py | ||
| Description: | ||
| 3D Point Cloud data parsing | ||
| """ | ||
|
|
||
| from .base import BaseType | ||
| import numpy as np | ||
| from typing import Union, ClassVar, Set, List, Optional | ||
| from ..utils.file import get_file_hash_numpy_array, get_file_hash_path | ||
| import os | ||
| import json | ||
| import shutil | ||
| from io import BytesIO | ||
|
|
||
| # 格式化输出 json | ||
| import codecs | ||
|
|
||
|
|
||
| class Object3D(BaseType): | ||
| """Object 3D class constructor | ||
|
|
||
| Parameters | ||
| ---------- | ||
| data_or_path: numpy.array, string, io | ||
| Path to an object3d format file or numpy array of object3d or Bytes.IO. | ||
|
|
||
| numpy.array: 3D point cloud data, shape (N, 3), (N, 4) or (N, 6). | ||
| (N, 3) : N * (x, y, z) coordinates | ||
| (N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14] | ||
| (N, 6) : N * (x, y, z, r, g, b) coordinates | ||
| caption: str | ||
| caption associated with the object3d for display | ||
| """ | ||
|
|
||
| SUPPORTED_TYPES: ClassVar[Set[str]] = { | ||
| "obj", | ||
| "gltf", | ||
| "glb", | ||
| "babylon", | ||
| "stl", | ||
| "pts.json", | ||
| } | ||
|
|
||
| def __init__( | ||
| self, | ||
| data_or_path: Union[str, "np.ndarray", "BytesIO", List["Object3D"]], | ||
| caption: Optional[str] = None, | ||
| ): | ||
| super().__init__(data_or_path) | ||
| self.object3d_data = None | ||
| self.caption = self.__convert_caption(caption) | ||
| self.extension = None | ||
|
|
||
| def get_data(self): | ||
| # 如果传入的是Object3D类列表 | ||
| if isinstance(self.value, list): | ||
| return self.get_data_list() | ||
|
|
||
| self.object3d_data = self.__preprocess(self.value) | ||
|
|
||
| # 根据不同的输入类型进行不同的哈希校验 | ||
| hash_name = ( | ||
| get_file_hash_numpy_array(self.object3d_data)[:16] | ||
| if isinstance(self.object3d_data, np.ndarray) | ||
| else get_file_hash_path(self.object3d_data)[:16] | ||
| ) | ||
|
|
||
| save_dir = os.path.join(self.settings.static_dir, self.tag) | ||
| save_name = ( | ||
| f"{self.caption}-step{self.step}-{hash_name}.{self.extension}" | ||
| if self.caption is not None | ||
| else f"object3d-step{self.step}-{hash_name}.{self.extension}" | ||
| ) | ||
| # 如果不存在目录则创建 | ||
| if os.path.exists(save_dir) is False: | ||
| os.makedirs(save_dir) | ||
| save_path = os.path.join(save_dir, save_name) | ||
|
|
||
| self.__save(save_path) | ||
| return save_name | ||
|
|
||
| def __preprocess(self, data_or_path): | ||
| """根据输入不同的输入类型进行不同处理""" | ||
| # 如果类型为 str,进行文件后缀格式检查 | ||
| if isinstance(data_or_path, str): | ||
| extension = None | ||
| for SUPPORTED_TYPE in Object3D.SUPPORTED_TYPES: | ||
| if data_or_path.endswith(SUPPORTED_TYPE): | ||
| extension = SUPPORTED_TYPE | ||
| break | ||
| if not extension: | ||
| raise TypeError( | ||
| "File '" | ||
| + data_or_path | ||
| + "' is not compatible with Object3D: supported types are: " | ||
| + ", ".join(Object3D.SUPPORTED_TYPES) | ||
| ) | ||
| self.extension = extension | ||
| return data_or_path | ||
| # 如果类型为 io.BytesIO 二进制流,直接返回 | ||
| elif isinstance(data_or_path, BytesIO): | ||
| self.extension = "pts.json" | ||
| return data_or_path | ||
|
|
||
| # 如果类型为 numpy.array,进行numpy格式检查 | ||
| elif isinstance(data_or_path, np.ndarray): | ||
| if len(data_or_path.shape) != 2 or data_or_path.shape[1] not in {3, 4, 6}: | ||
| raise TypeError( | ||
| """ | ||
| The shape of the numpy array must be one of either: | ||
| (N, 3) : N * (x, y, z) coordinates | ||
| (N, 4) : N * (x, y, z, c) coordinates, where c in range [1, 14] | ||
| (N, 6) : N * (x, y, z, r, g, b) coordinates | ||
| """ | ||
| ) | ||
| self.extension = "pts.json" | ||
| return data_or_path | ||
| else: | ||
| raise TypeError("swanlab.Object3D accepts a file path or numpy like data as input") | ||
|
|
||
| def __convert_caption(self, caption): | ||
| """将caption转换为字符串""" | ||
| # 如果类型是字符串,则不做转换 | ||
| if isinstance(caption, str): | ||
| caption = caption | ||
| # 如果类型是数字,则转换为字符串 | ||
| elif isinstance(caption, (int, float)): | ||
| caption = str(caption) | ||
| # 如果类型是None,则转换为默认字符串 | ||
| elif caption is None: | ||
| caption = None | ||
| else: | ||
| raise TypeError("caption must be a string, int or float.") | ||
| return caption | ||
|
|
||
| def __save_numpy(self, save_path): | ||
| """保存 numpy.array 格式的 3D点云资源文件 .pts.json 到指定路径""" | ||
| try: | ||
| list_data = self.object3d_data.tolist() | ||
| with codecs.open(save_path, "w", encoding="utf-8") as fp: | ||
| json.dump( | ||
| list_data, | ||
| fp, | ||
| separators=(",", ":"), | ||
| sort_keys=True, | ||
| indent=4, | ||
| ) | ||
| except Exception as e: | ||
| raise TypeError(f"Could not save the 3D point cloud data to the path: {save_path}") from e | ||
|
|
||
| def __save(self, save_path): | ||
| """ | ||
| 保存 3D点云资源文件到指定路径 | ||
| """ | ||
| if isinstance(self.object3d_data, str): | ||
| shutil.copy(self.object3d_data, save_path) | ||
| elif isinstance(self.object3d_data, BytesIO): | ||
| with open(save_path, "wb") as f: | ||
| f.write(self.object3d_data.read()) | ||
| elif isinstance(self.object3d_data, np.ndarray): | ||
| self.__save_numpy(save_path) | ||
|
|
||
| def get_more(self, *args, **kwargs) -> dict: | ||
| """返回config数据""" | ||
| # 如果传入的是Objet3d类列表 | ||
| if isinstance(self.value, list): | ||
| return self.get_more_list() | ||
| else: | ||
| return ( | ||
| { | ||
| "caption": self.caption, | ||
| } | ||
| if self.caption is not None | ||
| else None | ||
| ) | ||
|
|
||
| def expect_types(self, *args, **kwargs) -> list: | ||
| """返回支持的文件类型""" | ||
| return ["str", "numpy.array", "io"] | ||
|
|
||
| def get_namespace(self, *args, **kwargs) -> str: | ||
| """设定分组名""" | ||
| return "Object3D" | ||
|
|
||
| def get_chart_type(self) -> str: | ||
| """设定图表类型""" | ||
| return self.chart.object3d |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,19 +1,61 @@ | ||
| import swanlab | ||
| import random | ||
| import numpy as np | ||
|
|
||
| epochs = 50 | ||
| lr = 0.01 | ||
| offset = random.random() / 5 | ||
|
|
||
| run = swanlab.init( | ||
| experiment_name="Example", | ||
| description="这是一个机器学习模拟实验", | ||
| config={ | ||
| "learning_rate": 0.01, | ||
| "epochs": 20, | ||
| "epochs": epochs, | ||
| "learning_rate": lr, | ||
| "test": 1, | ||
| "debug": "这是一串" + "很长" * 100 + "的字符串", | ||
| "verbose": 1, | ||
| }, | ||
| logggings=True, | ||
| ) | ||
|
|
||
| # 模拟机器学习训练过程 | ||
| for epoch in range(2, run.config.epochs): | ||
|
|
||
| def generate_random_nx3(n): | ||
| """生成形状为nx3的随机数组""" | ||
| return np.random.rand(n, 3) | ||
|
|
||
|
|
||
| def generate_random_nx4(n): | ||
| """生成形状为nx4的随机数组,最后一列是[1,14]范围内的整数分类""" | ||
| xyz = np.random.rand(n, 3) | ||
| c = np.random.randint(1, 15, size=(n, 1)) | ||
| return np.hstack((xyz, c)) | ||
|
|
||
|
|
||
| def generate_random_nx6(n): | ||
| """生成形状为nx6的随机数组,包含RGB颜色""" | ||
| xyz = np.random.rand(n, 3) | ||
| rgb = np.random.rand(n, 3) # RGB颜色值也可以是[0,1]之间的随机数 | ||
| rgb = (rgb * 255).astype(np.uint8) # 转换为[0,255]之间的整数 | ||
| return np.hstack((xyz, rgb)) | ||
|
|
||
|
|
||
| for epoch in range(2, epochs): | ||
| if epoch % 10 == 0: | ||
|
|
||
| # swanlab.log( | ||
| # { | ||
| # "test/object3d1": | ||
| # }, | ||
| # step=epoch, | ||
| # ) | ||
| swanlab.log( | ||
| { | ||
| "test-object3d1": swanlab.Object3D("./assets/bunny.obj", caption="bunny-obj"), | ||
| "test-object3d2": swanlab.Object3D("./assets/test1.pts.json", caption="test1-pts"), | ||
| }, | ||
| step=epoch, | ||
| ) | ||
| acc = 1 - 2**-epoch - random.random() / epoch - offset | ||
| loss = 2**-epoch + random.random() / epoch + offset | ||
| swanlab.log({"loss": loss, "accuracy": acc}) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.