Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 114 additions & 15 deletions swanlab/data/modules/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,27 @@
from .base import BaseType
from .utils_modules import BoundingBoxes, ImageMask
from ..utils.file import get_file_hash_pil
from typing import Union, List, Dict
from typing import Union, List, Dict, Any
from io import BytesIO
import os


def is_pytorch_tensor_typename(typename: str) -> bool:
return typename.startswith("torch.") and ("Tensor" in typename or "Variable" in typename)


def get_full_typename(o: Any) -> Any:
"""Determine types based on type names.

Avoids needing to to import (and therefore depend on) PyTorch, TensorFlow, etc.
"""
instance_name = o.__class__.__module__ + "." + o.__class__.__name__
if instance_name in ["builtins.module", "__builtin__.module"]:
return o.__name__
else:
return instance_name


class Image(BaseType):
"""Image class constructor

Expand All @@ -20,30 +36,43 @@ class Image(BaseType):
More information about the mode can be found at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
caption: (str)
Caption for the image.
file_type: (str)
File type for the image. It is used to save the image in the specified format. The default is 'png'.
size: (int or list or tuple)
The size of the image can be controlled in four ways:
1. If int type, it represents the maximum side length of the image, that is, the width and height cannot exceed this maximum side length. The image will be scaled proportionally to ensure that the maximum side length does not exceed MAX_DIMENSION.
2. If list or tuple type with both specified values, e.g. (500, 500), then the image will be scaled to the specified width and height.
3. If list or tuple type with only one specified value and another value as None, e.g. (500, None), it means resize the image to the specified width, and the height is scaled proportionally.
4. If it is None, it means no scaling for the image.
qualityu: (int)
Quality of the image.
"""

def __init__(
self,
data_or_path: Union[str, np.ndarray, PILImage.Image, List["Image"]],
mode: str = "RGB",
caption: str = None,
file_type: str = None,
size: Union[int, list, tuple] = None,
# boxes: dict = None,
# masks: dict = None,
):
super().__init__(data_or_path)
self.image_data = None
self.mode = mode
self.caption = self.__convert_caption(caption)
self.format = self.__convert_file_type(file_type)
self.size = self.__convert_size(size)

# TODO: 等前端支持Boxes和Masks后再开启

# self.boxes = None
# self.boxes_total_classes = None
# self.masks = None
# self.masks_total_classes = None

# TODO: 等前端支持Boxes和Masks后再开启
# if boxes:
# self.boxes, self.boxes_total_classes = self.__convert_boxes(boxes)

# if masks:
# self.masks, self.masks_total_classes = self.__convert_masks(masks)

Expand All @@ -58,9 +87,9 @@ def get_data(self):
# 设置保存路径, 保存文件名
save_dir = os.path.join(self.settings.static_dir, self.tag)
save_name = (
f"{self.caption}-step{self.step}-{hash_name}.png"
f"{self.caption}-step{self.step}-{hash_name}.{self.format}"
if self.caption is not None
else f"image-step{self.step}-{hash_name}.png"
else f"image-step{self.step}-{hash_name}.{self.format}"
)
# 如果不存在目录则创建
if os.path.exists(save_dir) is False:
Expand Down Expand Up @@ -128,6 +157,42 @@ def __convert_masks(self, masks):

return masks_final, total_classes

def __convert_file_type(self, file_type):
"""转换file_type,并检测file_type是否正确"""
accepted_formats = ["png", "jpg", "jpeg", "bmp"]
if file_type is None:
format = "png"
else:
format = file_type

if format not in accepted_formats:
raise ValueError(f"file_type must be one of {accepted_formats}")

return format

def __convert_size(self, size):
"""将size转换为PIL图像的size"""
if size is None:
return None
if isinstance(size, int):
return size
if isinstance(size, (list, tuple)):
if len(size) == 2:
if size[0] is None and size[1] is None:
return None
elif size[0] is None:
return (None, int(size[1]))
elif size[1] is None:
return (int(size[0]), None)
else:
return (int(size[0]), int(size[1]))
if len(size) == 1:
if size[0] is None:
return None
else:
return int(size[0])
raise ValueError("size must be an int, list or tuple with 2 or 1 elements")

def __preprocess(self, data):
"""将不同类型的输入转换为PIL图像"""
if isinstance(data, str):
Expand All @@ -142,13 +207,21 @@ def __preprocess(self, data):
elif hasattr(self.value, "savefig"):
# 如果输入为matplotlib图像
image = self.__convert_plt_to_image(data)
elif is_pytorch_tensor_typename(get_full_typename(data)):
# 如果输入为pytorch tensor
import torchvision

if hasattr(data, "requires_grad") and data.requires_grad:
data = data.detach()
if hasattr(data, "detype") and str(data.type) == "torch.uint8":
data = data.to(float)
data = torchvision.utils.make_grid(data, normalize=True)
image = PILImage.fromarray(data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy())
else:
# 以上都不是,则报错
raise TypeError("Unsupported image type. Please provide a valid path, numpy array, or PIL.Image.")
# 缩放大小
image = self.__resize(image)

self.image_data = image
self.image_data = self.__resize(image, self.size)

def __load_image_from_path(self, path):
"""判断字符串是否为正确的图像路径,如果是则返回np.ndarray类型对象,如果不是则报错"""
Expand All @@ -172,18 +245,40 @@ def __convert_plt_to_image(self, plt_obj):
""" """
try:
buf = BytesIO()
plt_obj.savefig(buf, format="png") # 将图像保存到BytesIO对象
plt_obj.savefig(buf, format=self.format) # 将图像保存到BytesIO对象
buf.seek(0) # 移动到缓冲区的开始位置
image = PILImage.open(buf).convert(self.mode) # 使用PIL打开图像
buf.close() # 关闭缓冲区
return image
except Exception as e:
raise TypeError("Invalid matplotlib figure for the image") from e

def __resize(self, image, MAX_DIMENSION=1280):
"""将图像调整大小, 保证最大边长不超过MAX_DIMENSION"""
if max(image.size) > MAX_DIMENSION:
image.thumbnail((MAX_DIMENSION, MAX_DIMENSION))
def __resize(self, image, size=None):
"""将图像调整大小"""
# 如果size是None, 则返回原图
if size is None:
return image
# 如果size是int类型,且图像的最大边长超过了size,则进行缩放
if isinstance(size, int):
MAX_DIMENSION = size
if max(image.size) > MAX_DIMENSION:
image.thumbnail((MAX_DIMENSION, MAX_DIMENSION))
# 如果size是list或tuple类型
elif isinstance(size, (list, tuple)):
# 如果size是两个值的list或tuple,如(500, 500),则进行缩放
if None not in size:
image = image.resize(size)
else:
# 如果size中有一个值为None,且图像对应的边长超过了size中的另一个值,则进行缩放
if size[0] is not None:
wpercent = size[0] / float(image.size[0])
hsize = int(float(image.size[1]) * float(wpercent))
image = image.resize((size[0], hsize), PILImage.ANTIALIAS)
elif size[1] is not None:
hpercent = size[1] / float(image.size[1])
wsize = int(float(image.size[0]) * float(hpercent))
image = image.resize((wsize, size[1]), PILImage.ANTIALIAS)

return image

def __save(self, save_path):
Expand All @@ -192,7 +287,11 @@ def __save(self, save_path):
if not isinstance(pil_image, PILImage.Image):
raise TypeError("Invalid image data for the image")
try:
pil_image.save(save_path, format="png")
if self.format == "jpg":
pil_image.save(save_path, format="JPEG")
else:
pil_image.save(save_path, format=self.format)

except Exception as e:
raise TypeError(f"Could not save the image to the path: {save_path}") from e

Expand Down