3
3
from .base import BaseType
4
4
from .utils_modules import BoundingBoxes , ImageMask
5
5
from ..utils .file import get_file_hash_pil
6
- from typing import Union , List , Dict
6
+ from typing import Union , List , Dict , Any
7
7
from io import BytesIO
8
8
import os
9
9
10
10
11
+ def is_pytorch_tensor_typename (typename : str ) -> bool :
12
+ return typename .startswith ("torch." ) and ("Tensor" in typename or "Variable" in typename )
13
+
14
+
15
+ def get_full_typename (o : Any ) -> Any :
16
+ """Determine types based on type names.
17
+
18
+ Avoids needing to to import (and therefore depend on) PyTorch, TensorFlow, etc.
19
+ """
20
+ instance_name = o .__class__ .__module__ + "." + o .__class__ .__name__
21
+ if instance_name in ["builtins.module" , "__builtin__.module" ]:
22
+ return o .__name__
23
+ else :
24
+ return instance_name
25
+
26
+
11
27
class Image (BaseType ):
12
28
"""Image class constructor
13
29
@@ -20,30 +36,43 @@ class Image(BaseType):
20
36
More information about the mode can be found at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
21
37
caption: (str)
22
38
Caption for the image.
39
+ file_type: (str)
40
+ File type for the image. It is used to save the image in the specified format. The default is 'png'.
41
+ size: (int or list or tuple)
42
+ The size of the image can be controlled in four ways:
43
+ 1. If int type, it represents the maximum side length of the image, that is, the width and height cannot exceed this maximum side length. The image will be scaled proportionally to ensure that the maximum side length does not exceed MAX_DIMENSION.
44
+ 2. If list or tuple type with both specified values, e.g. (500, 500), then the image will be scaled to the specified width and height.
45
+ 3. If list or tuple type with only one specified value and another value as None, e.g. (500, None), it means resize the image to the specified width, and the height is scaled proportionally.
46
+ 4. If it is None, it means no scaling for the image.
47
+ qualityu: (int)
48
+ Quality of the image.
23
49
"""
24
50
25
51
def __init__ (
26
52
self ,
27
53
data_or_path : Union [str , np .ndarray , PILImage .Image , List ["Image" ]],
28
54
mode : str = "RGB" ,
29
55
caption : str = None ,
56
+ file_type : str = None ,
57
+ size : Union [int , list , tuple ] = None ,
30
58
# boxes: dict = None,
31
59
# masks: dict = None,
32
60
):
33
61
super ().__init__ (data_or_path )
34
62
self .image_data = None
35
63
self .mode = mode
36
64
self .caption = self .__convert_caption (caption )
65
+ self .format = self .__convert_file_type (file_type )
66
+ self .size = self .__convert_size (size )
67
+
68
+ # TODO: 等前端支持Boxes和Masks后再开启
37
69
38
70
# self.boxes = None
39
71
# self.boxes_total_classes = None
40
72
# self.masks = None
41
73
# self.masks_total_classes = None
42
-
43
- # TODO: 等前端支持Boxes和Masks后再开启
44
74
# if boxes:
45
75
# self.boxes, self.boxes_total_classes = self.__convert_boxes(boxes)
46
-
47
76
# if masks:
48
77
# self.masks, self.masks_total_classes = self.__convert_masks(masks)
49
78
@@ -58,9 +87,9 @@ def get_data(self):
58
87
# 设置保存路径, 保存文件名
59
88
save_dir = os .path .join (self .settings .static_dir , self .tag )
60
89
save_name = (
61
- f"{ self .caption } -step{ self .step } -{ hash_name } .png "
90
+ f"{ self .caption } -step{ self .step } -{ hash_name } .{ self . format } "
62
91
if self .caption is not None
63
- else f"image-step{ self .step } -{ hash_name } .png "
92
+ else f"image-step{ self .step } -{ hash_name } .{ self . format } "
64
93
)
65
94
# 如果不存在目录则创建
66
95
if os .path .exists (save_dir ) is False :
@@ -128,6 +157,42 @@ def __convert_masks(self, masks):
128
157
129
158
return masks_final , total_classes
130
159
160
+ def __convert_file_type (self , file_type ):
161
+ """转换file_type,并检测file_type是否正确"""
162
+ accepted_formats = ["png" , "jpg" , "jpeg" , "bmp" ]
163
+ if file_type is None :
164
+ format = "png"
165
+ else :
166
+ format = file_type
167
+
168
+ if format not in accepted_formats :
169
+ raise ValueError (f"file_type must be one of { accepted_formats } " )
170
+
171
+ return format
172
+
173
+ def __convert_size (self , size ):
174
+ """将size转换为PIL图像的size"""
175
+ if size is None :
176
+ return None
177
+ if isinstance (size , int ):
178
+ return size
179
+ if isinstance (size , (list , tuple )):
180
+ if len (size ) == 2 :
181
+ if size [0 ] is None and size [1 ] is None :
182
+ return None
183
+ elif size [0 ] is None :
184
+ return (None , int (size [1 ]))
185
+ elif size [1 ] is None :
186
+ return (int (size [0 ]), None )
187
+ else :
188
+ return (int (size [0 ]), int (size [1 ]))
189
+ if len (size ) == 1 :
190
+ if size [0 ] is None :
191
+ return None
192
+ else :
193
+ return int (size [0 ])
194
+ raise ValueError ("size must be an int, list or tuple with 2 or 1 elements" )
195
+
131
196
def __preprocess (self , data ):
132
197
"""将不同类型的输入转换为PIL图像"""
133
198
if isinstance (data , str ):
@@ -142,13 +207,21 @@ def __preprocess(self, data):
142
207
elif hasattr (self .value , "savefig" ):
143
208
# 如果输入为matplotlib图像
144
209
image = self .__convert_plt_to_image (data )
210
+ elif is_pytorch_tensor_typename (get_full_typename (data )):
211
+ # 如果输入为pytorch tensor
212
+ import torchvision
213
+
214
+ if hasattr (data , "requires_grad" ) and data .requires_grad :
215
+ data = data .detach ()
216
+ if hasattr (data , "detype" ) and str (data .type ) == "torch.uint8" :
217
+ data = data .to (float )
218
+ data = torchvision .utils .make_grid (data , normalize = True )
219
+ image = PILImage .fromarray (data .mul (255 ).clamp (0 , 255 ).byte ().permute (1 , 2 , 0 ).cpu ().numpy ())
145
220
else :
146
221
# 以上都不是,则报错
147
222
raise TypeError ("Unsupported image type. Please provide a valid path, numpy array, or PIL.Image." )
148
- # 缩放大小
149
- image = self .__resize (image )
150
223
151
- self .image_data = image
224
+ self .image_data = self . __resize ( image , self . size )
152
225
153
226
def __load_image_from_path (self , path ):
154
227
"""判断字符串是否为正确的图像路径,如果是则返回np.ndarray类型对象,如果不是则报错"""
@@ -172,18 +245,40 @@ def __convert_plt_to_image(self, plt_obj):
172
245
""" """
173
246
try :
174
247
buf = BytesIO ()
175
- plt_obj .savefig (buf , format = "png" ) # 将图像保存到BytesIO对象
248
+ plt_obj .savefig (buf , format = self . format ) # 将图像保存到BytesIO对象
176
249
buf .seek (0 ) # 移动到缓冲区的开始位置
177
250
image = PILImage .open (buf ).convert (self .mode ) # 使用PIL打开图像
178
251
buf .close () # 关闭缓冲区
179
252
return image
180
253
except Exception as e :
181
254
raise TypeError ("Invalid matplotlib figure for the image" ) from e
182
255
183
- def __resize (self , image , MAX_DIMENSION = 1280 ):
184
- """将图像调整大小, 保证最大边长不超过MAX_DIMENSION"""
185
- if max (image .size ) > MAX_DIMENSION :
186
- image .thumbnail ((MAX_DIMENSION , MAX_DIMENSION ))
256
+ def __resize (self , image , size = None ):
257
+ """将图像调整大小"""
258
+ # 如果size是None, 则返回原图
259
+ if size is None :
260
+ return image
261
+ # 如果size是int类型,且图像的最大边长超过了size,则进行缩放
262
+ if isinstance (size , int ):
263
+ MAX_DIMENSION = size
264
+ if max (image .size ) > MAX_DIMENSION :
265
+ image .thumbnail ((MAX_DIMENSION , MAX_DIMENSION ))
266
+ # 如果size是list或tuple类型
267
+ elif isinstance (size , (list , tuple )):
268
+ # 如果size是两个值的list或tuple,如(500, 500),则进行缩放
269
+ if None not in size :
270
+ image = image .resize (size )
271
+ else :
272
+ # 如果size中有一个值为None,且图像对应的边长超过了size中的另一个值,则进行缩放
273
+ if size [0 ] is not None :
274
+ wpercent = size [0 ] / float (image .size [0 ])
275
+ hsize = int (float (image .size [1 ]) * float (wpercent ))
276
+ image = image .resize ((size [0 ], hsize ), PILImage .ANTIALIAS )
277
+ elif size [1 ] is not None :
278
+ hpercent = size [1 ] / float (image .size [1 ])
279
+ wsize = int (float (image .size [0 ]) * float (hpercent ))
280
+ image = image .resize ((wsize , size [1 ]), PILImage .ANTIALIAS )
281
+
187
282
return image
188
283
189
284
def __save (self , save_path ):
@@ -192,7 +287,11 @@ def __save(self, save_path):
192
287
if not isinstance (pil_image , PILImage .Image ):
193
288
raise TypeError ("Invalid image data for the image" )
194
289
try :
195
- pil_image .save (save_path , format = "png" )
290
+ if self .format == "jpg" :
291
+ pil_image .save (save_path , format = "JPEG" )
292
+ else :
293
+ pil_image .save (save_path , format = self .format )
294
+
196
295
except Exception as e :
197
296
raise TypeError (f"Could not save the image to the path: { save_path } " ) from e
198
297
0 commit comments