Skip to content

Commit c8f2b2b

Browse files
authored
Feat/image enhance (#393)
* cancel resize * update size, format * Update image.py
1 parent 9e33e0e commit c8f2b2b

File tree

1 file changed

+114
-15
lines changed

1 file changed

+114
-15
lines changed

swanlab/data/modules/image.py

Lines changed: 114 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,27 @@
33
from .base import BaseType
44
from .utils_modules import BoundingBoxes, ImageMask
55
from ..utils.file import get_file_hash_pil
6-
from typing import Union, List, Dict
6+
from typing import Union, List, Dict, Any
77
from io import BytesIO
88
import os
99

1010

11+
def is_pytorch_tensor_typename(typename: str) -> bool:
12+
return typename.startswith("torch.") and ("Tensor" in typename or "Variable" in typename)
13+
14+
15+
def get_full_typename(o: Any) -> Any:
16+
"""Determine types based on type names.
17+
18+
Avoids needing to to import (and therefore depend on) PyTorch, TensorFlow, etc.
19+
"""
20+
instance_name = o.__class__.__module__ + "." + o.__class__.__name__
21+
if instance_name in ["builtins.module", "__builtin__.module"]:
22+
return o.__name__
23+
else:
24+
return instance_name
25+
26+
1127
class Image(BaseType):
1228
"""Image class constructor
1329
@@ -20,30 +36,43 @@ class Image(BaseType):
2036
More information about the mode can be found at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
2137
caption: (str)
2238
Caption for the image.
39+
file_type: (str)
40+
File type for the image. It is used to save the image in the specified format. The default is 'png'.
41+
size: (int or list or tuple)
42+
The size of the image can be controlled in four ways:
43+
1. If int type, it represents the maximum side length of the image, that is, the width and height cannot exceed this maximum side length. The image will be scaled proportionally to ensure that the maximum side length does not exceed MAX_DIMENSION.
44+
2. If list or tuple type with both specified values, e.g. (500, 500), then the image will be scaled to the specified width and height.
45+
3. If list or tuple type with only one specified value and another value as None, e.g. (500, None), it means resize the image to the specified width, and the height is scaled proportionally.
46+
4. If it is None, it means no scaling for the image.
47+
qualityu: (int)
48+
Quality of the image.
2349
"""
2450

2551
def __init__(
2652
self,
2753
data_or_path: Union[str, np.ndarray, PILImage.Image, List["Image"]],
2854
mode: str = "RGB",
2955
caption: str = None,
56+
file_type: str = None,
57+
size: Union[int, list, tuple] = None,
3058
# boxes: dict = None,
3159
# masks: dict = None,
3260
):
3361
super().__init__(data_or_path)
3462
self.image_data = None
3563
self.mode = mode
3664
self.caption = self.__convert_caption(caption)
65+
self.format = self.__convert_file_type(file_type)
66+
self.size = self.__convert_size(size)
67+
68+
# TODO: 等前端支持Boxes和Masks后再开启
3769

3870
# self.boxes = None
3971
# self.boxes_total_classes = None
4072
# self.masks = None
4173
# self.masks_total_classes = None
42-
43-
# TODO: 等前端支持Boxes和Masks后再开启
4474
# if boxes:
4575
# self.boxes, self.boxes_total_classes = self.__convert_boxes(boxes)
46-
4776
# if masks:
4877
# self.masks, self.masks_total_classes = self.__convert_masks(masks)
4978

@@ -58,9 +87,9 @@ def get_data(self):
5887
# 设置保存路径, 保存文件名
5988
save_dir = os.path.join(self.settings.static_dir, self.tag)
6089
save_name = (
61-
f"{self.caption}-step{self.step}-{hash_name}.png"
90+
f"{self.caption}-step{self.step}-{hash_name}.{self.format}"
6291
if self.caption is not None
63-
else f"image-step{self.step}-{hash_name}.png"
92+
else f"image-step{self.step}-{hash_name}.{self.format}"
6493
)
6594
# 如果不存在目录则创建
6695
if os.path.exists(save_dir) is False:
@@ -128,6 +157,42 @@ def __convert_masks(self, masks):
128157

129158
return masks_final, total_classes
130159

160+
def __convert_file_type(self, file_type):
161+
"""转换file_type,并检测file_type是否正确"""
162+
accepted_formats = ["png", "jpg", "jpeg", "bmp"]
163+
if file_type is None:
164+
format = "png"
165+
else:
166+
format = file_type
167+
168+
if format not in accepted_formats:
169+
raise ValueError(f"file_type must be one of {accepted_formats}")
170+
171+
return format
172+
173+
def __convert_size(self, size):
174+
"""将size转换为PIL图像的size"""
175+
if size is None:
176+
return None
177+
if isinstance(size, int):
178+
return size
179+
if isinstance(size, (list, tuple)):
180+
if len(size) == 2:
181+
if size[0] is None and size[1] is None:
182+
return None
183+
elif size[0] is None:
184+
return (None, int(size[1]))
185+
elif size[1] is None:
186+
return (int(size[0]), None)
187+
else:
188+
return (int(size[0]), int(size[1]))
189+
if len(size) == 1:
190+
if size[0] is None:
191+
return None
192+
else:
193+
return int(size[0])
194+
raise ValueError("size must be an int, list or tuple with 2 or 1 elements")
195+
131196
def __preprocess(self, data):
132197
"""将不同类型的输入转换为PIL图像"""
133198
if isinstance(data, str):
@@ -142,13 +207,21 @@ def __preprocess(self, data):
142207
elif hasattr(self.value, "savefig"):
143208
# 如果输入为matplotlib图像
144209
image = self.__convert_plt_to_image(data)
210+
elif is_pytorch_tensor_typename(get_full_typename(data)):
211+
# 如果输入为pytorch tensor
212+
import torchvision
213+
214+
if hasattr(data, "requires_grad") and data.requires_grad:
215+
data = data.detach()
216+
if hasattr(data, "detype") and str(data.type) == "torch.uint8":
217+
data = data.to(float)
218+
data = torchvision.utils.make_grid(data, normalize=True)
219+
image = PILImage.fromarray(data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy())
145220
else:
146221
# 以上都不是,则报错
147222
raise TypeError("Unsupported image type. Please provide a valid path, numpy array, or PIL.Image.")
148-
# 缩放大小
149-
image = self.__resize(image)
150223

151-
self.image_data = image
224+
self.image_data = self.__resize(image, self.size)
152225

153226
def __load_image_from_path(self, path):
154227
"""判断字符串是否为正确的图像路径,如果是则返回np.ndarray类型对象,如果不是则报错"""
@@ -172,18 +245,40 @@ def __convert_plt_to_image(self, plt_obj):
172245
""" """
173246
try:
174247
buf = BytesIO()
175-
plt_obj.savefig(buf, format="png") # 将图像保存到BytesIO对象
248+
plt_obj.savefig(buf, format=self.format) # 将图像保存到BytesIO对象
176249
buf.seek(0) # 移动到缓冲区的开始位置
177250
image = PILImage.open(buf).convert(self.mode) # 使用PIL打开图像
178251
buf.close() # 关闭缓冲区
179252
return image
180253
except Exception as e:
181254
raise TypeError("Invalid matplotlib figure for the image") from e
182255

183-
def __resize(self, image, MAX_DIMENSION=1280):
184-
"""将图像调整大小, 保证最大边长不超过MAX_DIMENSION"""
185-
if max(image.size) > MAX_DIMENSION:
186-
image.thumbnail((MAX_DIMENSION, MAX_DIMENSION))
256+
def __resize(self, image, size=None):
257+
"""将图像调整大小"""
258+
# 如果size是None, 则返回原图
259+
if size is None:
260+
return image
261+
# 如果size是int类型,且图像的最大边长超过了size,则进行缩放
262+
if isinstance(size, int):
263+
MAX_DIMENSION = size
264+
if max(image.size) > MAX_DIMENSION:
265+
image.thumbnail((MAX_DIMENSION, MAX_DIMENSION))
266+
# 如果size是list或tuple类型
267+
elif isinstance(size, (list, tuple)):
268+
# 如果size是两个值的list或tuple,如(500, 500),则进行缩放
269+
if None not in size:
270+
image = image.resize(size)
271+
else:
272+
# 如果size中有一个值为None,且图像对应的边长超过了size中的另一个值,则进行缩放
273+
if size[0] is not None:
274+
wpercent = size[0] / float(image.size[0])
275+
hsize = int(float(image.size[1]) * float(wpercent))
276+
image = image.resize((size[0], hsize), PILImage.ANTIALIAS)
277+
elif size[1] is not None:
278+
hpercent = size[1] / float(image.size[1])
279+
wsize = int(float(image.size[0]) * float(hpercent))
280+
image = image.resize((wsize, size[1]), PILImage.ANTIALIAS)
281+
187282
return image
188283

189284
def __save(self, save_path):
@@ -192,7 +287,11 @@ def __save(self, save_path):
192287
if not isinstance(pil_image, PILImage.Image):
193288
raise TypeError("Invalid image data for the image")
194289
try:
195-
pil_image.save(save_path, format="png")
290+
if self.format == "jpg":
291+
pil_image.save(save_path, format="JPEG")
292+
else:
293+
pil_image.save(save_path, format=self.format)
294+
196295
except Exception as e:
197296
raise TypeError(f"Could not save the image to the path: {save_path}") from e
198297

0 commit comments

Comments
 (0)