Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions swanlab/data/modules/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,21 @@ class Image(BaseType):

Parameters
----------
data_or_path: (str or numpy.array or PIL.Image)
Path to the image file, numpy array of image data or PIL.Image data.
data_or_path: (str or numpy.array or PIL.Image or torch.Tensor or matplotlib figure or List["Image"])
Path to the image file, numpy array of image data, PIL.Image data, torch.Tensor data or matplotlib figure data.
mode: (str)
The PIL Mode for a image. Most commom is 'L', 'RGB', 'RGBA'.
More information about the mode can be found at https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
caption: (str)
Caption for the image.
file_type: (str)
File type for the image. It is used to save the image in the specified format. The default is 'png'.
File type for the image. It is used to save the image in the specified format. The default is 'png'. The supported file types are ['png', 'jpg', 'jpeg', 'bmp'].
size: (int or list or tuple)
The size of the image can be controlled in four ways:
1. If int type, it represents the maximum side length of the image, that is, the width and height cannot exceed this maximum side length. The image will be scaled proportionally to ensure that the maximum side length does not exceed MAX_DIMENSION.
2. If list or tuple type with both specified values, e.g. (500, 500), then the image will be scaled to the specified width and height.
3. If list or tuple type with only one specified value and another value as None, e.g. (500, None), it means resize the image to the specified width, and the height is scaled proportionally.
4. If it is None, it means no scaling for the image.
qualityu: (int)
Quality of the image.
"""

def __init__(
Expand Down Expand Up @@ -100,7 +98,7 @@ def get_data(self):
return save_name

def expect_types(self, *args, **kwargs) -> list:
return ["str", "numpy.array", "PIL.Image.Image"]
return ["str", "numpy.array", "PIL.Image.Image", "torch.Tensor"]

def __convert_caption(self, caption):
"""将caption转换为字符串"""
Expand Down Expand Up @@ -209,17 +207,26 @@ def __preprocess(self, data):
image = self.__convert_plt_to_image(data)
elif is_pytorch_tensor_typename(get_full_typename(data)):
# 如果输入为pytorch tensor
import torchvision
try:
import torchvision
except ImportError as e:
raise TypeError(
"swanlab.Image requires `torchvision` when process torch.tensor data. Install with 'pip install torchvision'."
)

if hasattr(data, "requires_grad") and data.requires_grad:
data = data.detach()
if hasattr(data, "detype") and str(data.type) == "torch.uint8":
data = data.to(float)
data = torchvision.utils.make_grid(data, normalize=True)
image = PILImage.fromarray(data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy())
image = PILImage.fromarray(
data.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy(), mode=self.mode
)
else:
# 以上都不是,则报错
raise TypeError("Unsupported image type. Please provide a valid path, numpy array, or PIL.Image.")
raise TypeError(
"Unsupported image type. Please provide a valid path, numpy array, PIL.Image, torch.Tensor or matplotlib figure."
)

self.image_data = self.__resize(image, self.size)

Expand All @@ -239,7 +246,9 @@ def __convert_numpy_array_to_image(self, array):
else:
raise TypeError("Invalid numpy array: the numpy array must be 2D or 3D with 3 or 4 channels.")
except Exception as e:
raise TypeError("Invalid numpy array for the image") from e
raise TypeError(
"Invalid numpy array for the image: the numpy array must be 2D or 3D with 3 or 4 channels."
) from e

def __convert_plt_to_image(self, plt_obj):
""" """
Expand Down