|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
16 | | -from typing import TYPE_CHECKING, List, Optional, Tuple, Union |
| 16 | +import warnings |
| 17 | +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union |
17 | 18 |
|
18 | 19 | import numpy as np |
19 | 20 |
|
|
25 | 26 |
|
26 | 27 | from .image_utils import ( |
27 | 28 | ChannelDimension, |
| 29 | + get_channel_dimension_axis, |
28 | 30 | get_image_size, |
29 | 31 | infer_channel_dimension_format, |
30 | 32 | is_jax_tensor, |
31 | 33 | is_tf_tensor, |
32 | 34 | is_torch_tensor, |
| 35 | + to_numpy_array, |
33 | 36 | ) |
34 | 37 |
|
35 | 38 |
|
@@ -257,3 +260,59 @@ def resize( |
257 | 260 | resized_image = np.array(resized_image) |
258 | 261 | resized_image = to_channel_dimension_format(resized_image, data_format) |
259 | 262 | return resized_image |
| 263 | + |
| 264 | + |
| 265 | +def normalize( |
| 266 | + image: np.ndarray, |
| 267 | + mean: Union[float, Iterable[float]], |
| 268 | + std: Union[float, Iterable[float]], |
| 269 | + data_format: Optional[ChannelDimension] = None, |
| 270 | +) -> np.ndarray: |
| 271 | + """ |
| 272 | + Normalizes `image` using the mean and standard deviation specified by `mean` and `std`. |
| 273 | +
|
| 274 | + image = (image - mean) / std |
| 275 | +
|
| 276 | + Args: |
| 277 | + image (`np.ndarray`): |
| 278 | + The image to normalize. |
| 279 | + mean (`float` or `Iterable[float]`): |
| 280 | + The mean to use for normalization. |
| 281 | + std (`float` or `Iterable[float]`): |
| 282 | + The standard deviation to use for normalization. |
| 283 | + data_format (`ChannelDimension`, *optional*): |
| 284 | + The channel dimension format of the output image. If `None`, will use the inferred format from the input. |
| 285 | + """ |
| 286 | + if isinstance(image, PIL.Image.Image): |
| 287 | + warnings.warn( |
| 288 | + "PIL.Image.Image inputs are deprecated and will be removed in v4.26.0. Please use numpy arrays instead.", |
| 289 | + FutureWarning, |
| 290 | + ) |
| 291 | + # Convert PIL image to numpy array with the same logic as in the previous feature extractor normalize - |
| 292 | + # casting to numpy array and dividing by 255. |
| 293 | + image = to_numpy_array(image) |
| 294 | + image = rescale(image, scale=1 / 255) |
| 295 | + |
| 296 | + input_data_format = infer_channel_dimension_format(image) |
| 297 | + channel_axis = get_channel_dimension_axis(image) |
| 298 | + num_channels = image.shape[channel_axis] |
| 299 | + |
| 300 | + if isinstance(mean, Iterable): |
| 301 | + if len(mean) != num_channels: |
| 302 | + raise ValueError(f"mean must have {num_channels} elements if it is an iterable, got {len(mean)}") |
| 303 | + else: |
| 304 | + mean = [mean] * num_channels |
| 305 | + |
| 306 | + if isinstance(std, Iterable): |
| 307 | + if len(std) != num_channels: |
| 308 | + raise ValueError(f"std must have {num_channels} elements if it is an iterable, got {len(std)}") |
| 309 | + else: |
| 310 | + std = [std] * num_channels |
| 311 | + |
| 312 | + if input_data_format == ChannelDimension.LAST: |
| 313 | + image = (image - mean) / std |
| 314 | + else: |
| 315 | + image = ((image.T - mean) / std).T |
| 316 | + |
| 317 | + image = to_channel_dimension_format(image, data_format) if data_format is not None else image |
| 318 | + return image |
0 commit comments