Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions keras/src/backend/jax/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@
"bicubic",
)

def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane."""
if array.ndim < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)
return jnp.rot90(array, k=k, axes=axes)

def rgb_to_grayscale(images, data_format=None):
images = convert_to_tensor(images)
Expand Down
11 changes: 11 additions & 0 deletions keras/src/backend/numpy/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@
"bicubic",
)

def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane."""
if array.ndim < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)
return np.rot90(array, k=k, axes=axes)

def rgb_to_grayscale(images, data_format=None):
images = convert_to_tensor(images)
Expand Down
13 changes: 13 additions & 0 deletions keras/src/backend/tensorflow/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from keras.src import backend
from keras.src.backend.tensorflow.core import convert_to_tensor

import numpy as np

RESIZE_INTERPOLATIONS = (
"bilinear",
"nearest",
Expand All @@ -16,6 +18,17 @@
"area",
)

def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane."""
if array.ndim < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={array.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)
return np.rot90(array, k=k, axes=axes)

def rgb_to_grayscale(images, data_format=None):
images = convert_to_tensor(images)
Expand Down
48 changes: 48 additions & 0 deletions keras/src/backend/torch/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,54 @@
"lanczos5",
)

def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the specified plane using PyTorch.

Args:
array: Input tensor
k: Number of 90-degree rotations (default=1)
axes: Tuple of two axes that define the plane of rotation (default=(0,1))

Returns:
Rotated tensor
"""
x = convert_to_tensor(array)

if x.ndim < 2:
raise ValueError(
f"Input array must have at least 2 dimensions. Received: array.ndim={x.ndim}"
)
if len(axes) != 2 or axes[0] == axes[1]:
raise ValueError(
f"Invalid axes: {axes}. Axes must be a tuple of two different dimensions."
)

k = k % 4
if k == 0:
return x

axes = tuple(axis if axis >= 0 else x.ndim + axis for axis in axes)

if not all(0 <= axis < x.ndim for axis in axes):
raise ValueError(f"Invalid axes {axes} for tensor with {x.ndim} dimensions")

for _ in range(k):
perm = list(range(x.ndim))
for i, axis in enumerate(axes):
perm.remove(axis)
perm.append(axis)
x = x.permute(perm)

x = torch.flip(x, dims=[-1])
x = x.transpose(-2, -1)

perm = list(range(x.ndim))
for i, axis in enumerate(axes):
perm.remove(x.ndim - 2 + i)
perm.insert(axis, x.ndim - 2 + i)
x = x.permute(perm)

return x

def rgb_to_grayscale(images, data_format=None):
images = convert_to_tensor(images)
Expand Down
62 changes: 62 additions & 0 deletions keras/src/ops/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,68 @@
from keras.src.ops.operation_utils import compute_conv_output_shape


class Rot90(Operation):
def __init__(self, k=1, axes=(0, 1)):
super().__init__()
self.k = k
self.axes = axes

def call(self, array):
return backend.image.rot90(array, k=self.k, axes=self.axes)

def compute_output_spec(self, array):
array_shape = list(array.shape)
if len(array_shape) < 2:
raise ValueError(
"Input array must have at least 2 dimensions. "
f"Received: array.shape={array_shape}"
)
if len(self.axes) != 2 or self.axes[0] == self.axes[1]:
raise ValueError(
f"Invalid axes: {self.axes}. Axes must be a tuple of two different dimensions."
)
axis1, axis2 = self.axes
array_shape[axis1], array_shape[axis2] = array_shape[axis2], array_shape[axis1]
return KerasTensor(shape=array_shape, dtype=array.dtype)


@keras_export("keras.ops.image.rot90")
def rot90(array, k=1, axes=(0, 1)):
"""Rotate an array by 90 degrees in the plane specified by axes.

This function rotates an array counterclockwise by 90 degrees `k` times
in the plane specified by `axes`. Supports arrays of two or more dimensions.

Args:
array: Input array to rotate.
k: Number of times the array is rotated by 90 degrees.
axes: A tuple of two integers specifying the plane for rotation.

Returns:
Rotated array.

Examples:

>>> import numpy as np
>>> from keras import ops
>>> m = np.array([[1, 2], [3, 4]])
>>> rotated = ops.image.rot90(m)
>>> rotated
array([[2, 4],
[1, 3]])

>>> m = np.arange(8).reshape((2, 2, 2))
>>> rotated = ops.image.rot90(m, k=1, axes=(1, 2))
>>> rotated
array([[[1, 3],
[0, 2]],
[[5, 7],
[4, 6]]])
"""
if any_symbolic_tensors((array,)):
return Rot90(k=k, axes=axes).symbolic_call(array)
return backend.image.rot90(array, k=k, axes=axes)

class RGBToGrayscale(Operation):
def __init__(self, data_format=None):
super().__init__()
Expand Down
66 changes: 66 additions & 0 deletions keras/src/ops/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,72 @@
from keras.src.ops import image as kimage
from keras.src.testing.test_utils import named_product

class TestRot90:
def test_basic(self):
array = np.array([[1, 2], [3, 4]])
rotated = kimage.rot90(array)
expected = np.array([[2, 4], [1, 3]])
assert np.array_equal(rotated, expected), f"Failed basic 2D test: {rotated}"

def test_multiple_k(self):
array = np.array([[1, 2], [3, 4]])

# k=2 (180 degrees rotation)
rotated = kimage.rot90(array, k=2)
expected = np.array([[4, 3], [2, 1]])
assert np.array_equal(rotated, expected), f"Failed k=2 test: {rotated}"

# k=3 (270 degrees rotation)
rotated = kimage.rot90(array, k=3)
expected = np.array([[3, 1], [4, 2]])
assert np.array_equal(rotated, expected), f"Failed k=3 test: {rotated}"

# k=4 (full rotation)
rotated = kimage.rot90(array, k=4)
expected = array
assert np.array_equal(rotated, expected), f"Failed k=4 test: {rotated}"

def test_axes(self):
array = np.arange(8).reshape((2, 2, 2))
rotated = kimage.rot90(array, k=1, axes=(1, 2))
expected = np.array([[[1, 3], [0, 2]], [[5, 7], [4, 6]]])
assert np.array_equal(rotated, expected), f"Failed custom axes test: {rotated}"

def test_single_image(self):
array = np.random.random((4, 4, 3))
rotated = kimage.rot90(array, k=1, axes=(0, 1))
expected = np.rot90(array, k=1, axes=(0, 1))
assert np.allclose(rotated, expected), "Failed single image test"

def test_batch_images(self):
array = np.random.random((2, 4, 4, 3))
rotated = kimage.rot90(array, k=1, axes=(1, 2))
expected = np.rot90(array, k=1, axes=(1, 2))
assert np.allclose(rotated, expected), "Failed batch images test"

def test_invalid_axes(self):
array = np.array([[1, 2], [3, 4]])
try:
kimage.rot90(array, axes=(0, 0))
except ValueError as e:
assert (
"Invalid axes: (0, 0). Axes must be a tuple of two different dimensions."
in str(e)
), f"Failed invalid axes test: {e}"
else:
raise AssertionError("Failed to raise error for invalid axes")

def test_invalid_rank(self):
array = np.array([1, 2, 3]) # 1D array
try:
kimage.rot90(array)
except ValueError as e:
assert (
"Input array must have at least 2 dimensions." in str(e)
), f"Failed invalid rank test: {e}"
else:
raise AssertionError("Failed to raise error for invalid input rank")


class ImageOpsDynamicShapeTest(testing.TestCase):
def setUp(self):
Expand Down
Loading