Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions ants/deeplearn/randomly_transform_image_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,6 @@ def randomly_transform_image_data(reference_image,
>>> transform_type = "affineAndDeformation" )
"""

def polar_decomposition(X):
U, d, V = np.linalg.svd(X, full_matrices=False)
P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
Z = np.matmul(U, V)
if np.linalg.det(Z) < 0:
n = X.shape[0]
reflection_matrix = np.identity(n)
reflection_matrix[0,0] = -1.0
Z = np.matmul(Z, reflection_matrix)
return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})

def create_random_linear_transform(image,
fixed_parameters,
transform_type='affine',
Expand All @@ -131,7 +120,7 @@ def create_random_linear_transform(image,
random_matrix = np.reshape(
random_parameters[:(len(identity_parameters) - image.dimension)],
newshape=(image.dimension, image.dimension))
decomposition = polar_decomposition(random_matrix)
decomposition = ants.polar_decomposition(random_matrix)

if transform_type == "rotation" or transform_type == "rigid":
random_matrix = decomposition['Z']
Expand Down
16 changes: 3 additions & 13 deletions ants/registration/create_jacobian_determinant_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tempfile import mktemp

import ants
import numpy as np
from ants.internal import get_lib_fn, process_arguments


Expand Down Expand Up @@ -44,17 +45,6 @@ def deformation_gradient( warp_image, to_rotation=False, py_based=False ):
>>> mytx = ants.registration(fixed=fi , moving=mi, type_of_transform = ('SyN') )
>>> dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ) )
"""
import numpy as np
def polar_decomposition(X):
U, d, V = np.linalg.svd(X, full_matrices=False)
P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
Z = np.matmul(U, V)
if np.linalg.det(Z) < 0:
n = X.shape[0]
reflection_matrix = np.identity(n)
reflection_matrix[0,0] = -1.0
Z = np.matmul(Z, reflection_matrix)
return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})
if not py_based:
if ants.is_image(warp_image):
txuse = mktemp(suffix='.nii.gz')
Expand All @@ -78,7 +68,7 @@ def polar_decomposition(X):
dg = np.reshape( dg.numpy(), newshape )
it=np.ndindex(tshp)
for i in it:
dg[i]=polar_decomposition( dg[i] )['Z']
dg[i]=ants.polar_decomposition( dg[i] )['Z']
newshape = tshp + (dim*dim,)
dg = np.reshape( dg, newshape )
dg = ants.from_numpy( dg, has_components=True )
Expand Down Expand Up @@ -114,7 +104,7 @@ def polar_decomposition(X):
if to_rotation:
it=np.ndindex(tshp)
for i in it:
dg[i]=polar_decomposition( dg[i] )['Z']
dg[i]=ants.polar_decomposition( dg[i] )['Z']
newshape = tshp + (dim*dim,)
dg = np.reshape( dg, newshape )
dg = ants.from_numpy( dg, has_components=True )
Expand Down
37 changes: 5 additions & 32 deletions ants/registration/landmark_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,6 @@

import ants

def convergence_monitoring(values, window_size=10):
if len(values) >= window_size:
u = np.linspace(0.0, 1.0, num=window_size)
scattered_data = np.expand_dims(values[-window_size:], axis=-1)
parametric_data = np.expand_dims(u, axis=-1)
spacing = 1 / (window_size-1)
bspline_line = ants.fit_bspline_object_to_scattered_data(scattered_data, parametric_data,
parametric_domain_origin=[0.0], parametric_domain_spacing=[spacing],
parametric_domain_size=[window_size], number_of_fitting_levels=1, mesh_size=1,
spline_order=1)
bspline_slope = -(bspline_line[1][0] - bspline_line[0][0]) / spacing
return(bspline_slope)
else:
return None


def fit_transform_to_paired_points(moving_points,
fixed_points,
transform_type="affine",
Expand Down Expand Up @@ -130,17 +114,6 @@ def fit_transform_to_paired_points(moving_points,
>>> xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="diffeo", domain_image=domain_image, number_of_fitting_levels=6)
"""

def polar_decomposition(X):
U, d, V = np.linalg.svd(X, full_matrices=False)
P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
Z = np.matmul(U, V)
if np.linalg.det(Z) < 0:
n = X.shape[0]
reflection_matrix = np.identity(n)
reflection_matrix[0,0] = -1.0
Z = np.matmul(Z, reflection_matrix)
return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})

def create_zero_displacement_field(domain_image):
field_array = np.zeros((*domain_image.shape, domain_image.dimension))
field = ants.from_numpy(field_array, origin=domain_image.origin,
Expand Down Expand Up @@ -191,7 +164,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
M = x11 * (1.0 - regularization) + regularization * y_prior
Minv = np.linalg.lstsq(M, y, rcond=None)[0]

p = polar_decomposition(Minv[0:dimensionality, 0:dimensionality].T)
p = ants.polar_decomposition(Minv[0:dimensionality, 0:dimensionality].T)
A = p['Xtilde']
translation = Minv[dimensionality,:] + center_moving - center_fixed

Expand Down Expand Up @@ -296,7 +269,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
updated_fixed_points[j,:] = total_field_xfrm.apply_to_point(tuple(fixed_points[j,:]))

error_values.append(np.mean(np.sqrt(np.sum(np.square(updated_fixed_points - moving_points), axis=1, keepdims=True))))
convergence_value = convergence_monitoring(error_values)
convergence_value = ants.convergence_monitoring(error_values)
if verbose:
end_time = time.time()
diff_time = end_time - start_time
Expand Down Expand Up @@ -395,7 +368,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
updated_moving_points[j,:] = total_field_moving_to_middle_xfrm.apply_to_point(tuple(moving_points[j,:]))

error_values.append(np.mean(np.sqrt(np.sum(np.square(updated_fixed_points - updated_moving_points), axis=1, keepdims=True))))
convergence_value = convergence_monitoring(error_values)
convergence_value = ants.convergence_monitoring(error_values)
if verbose:
end_time = time.time()
diff_time = end_time - start_time
Expand Down Expand Up @@ -512,7 +485,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
has_components=True)

error_values.append(average_error)
convergence_value = convergence_monitoring(error_values)
convergence_value = ants.convergence_monitoring(error_values)
if verbose:
end_time = time.time()
diff_time = end_time - start_time
Expand Down Expand Up @@ -816,7 +789,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
has_components=True)

error_values.append(average_error)
convergence_value = convergence_monitoring(error_values)
convergence_value = ants.convergence_monitoring(error_values)
if verbose:
end_time = time.time()
diff_time = end_time - start_time
Expand Down
2 changes: 2 additions & 0 deletions ants/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .nifti_to_ants import nifti_to_ants
from .scalar_rgb_vector import rgb_to_vector, vector_to_rgb, scalar_to_rgb
from .sitk_to_ants import from_sitk, to_sitk
from .convergence_monitoring import convergence_monitoring
from .polar_decomposition import polar_decomposition
25 changes: 25 additions & 0 deletions ants/utils/convergence_monitoring.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np
import ants

__all__ = ['convergence_monitoring']

def convergence_monitoring(values, window_size=10):

if len(values) >= window_size:

u = np.linspace(0.0, 1.0, num=window_size)
scattered_data = np.expand_dims(values[-window_size:], axis=-1)
parametric_data = np.expand_dims(u, axis=-1)
spacing = 1 / (window_size-1)
bspline_line = ants.fit_bspline_object_to_scattered_data(scattered_data, parametric_data,
parametric_domain_origin=[0.0], parametric_domain_spacing=[spacing],
parametric_domain_size=[window_size], number_of_fitting_levels=1, mesh_size=1,
spline_order=1)
bspline_slope = -(bspline_line[1][0] - bspline_line[0][0]) / spacing

return(bspline_slope)

else:

return None

23 changes: 23 additions & 0 deletions ants/utils/polar_decomposition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
import ants

__all__ = ['polar_decomposition']

def polar_decomposition(X):
U, d, Vh = np.linalg.svd(X, full_matrices=False) # Vh is V transpose
# This is the formula for P in a LEFT decomposition (X = PZ)
P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
# Z is the orthogonal part, U @ Vh
Z = np.matmul(U, Vh)

# Correction to ensure Z is a proper rotation (det(Z) = +1)
if np.linalg.det(Z) < 0:
n = X.shape[0]
reflection_matrix = np.identity(n)
reflection_matrix[-1, -1] = -1.0 # More robust to change last element
U_prime = U.copy()
U_prime[:, -1] *= -1
Z = U_prime @ Vh

# The returned Xtilde is P @ Z, consistent with a LEFT decomposition
return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})
40 changes: 37 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,15 +990,49 @@ def test_hausdorff_distance(self):
stats = ants.hausdorff_distance(s16, s64)

def test_channels_first(self):
import ants
image = ants.image_read(ants.get_ants_data('r16'))
image2 = ants.image_read(ants.get_ants_data('r16'))
img3 = ants.merge_channels([image,image2])
img4 = ants.merge_channels([image,image2], channels_first=True)

self.assertTrue(np.allclose(img3.numpy()[:,:,0], img4.numpy()[0,:,:]))
self.assertTrue(np.allclose(img3.numpy()[:,:,1], img4.numpy()[1,:,:]))


def test_polar_decomposition(self):
# Helper functions for creating known matrices
def make_known_rotation(theta_deg):
theta = np.deg2rad(theta_deg)
R = np.eye(3)
R[:2, :2] = [[np.cos(theta), -np.sin(theta)],
[np.sin(theta), np.cos(theta)]]
return R
def make_known_scaling_matrix(sx, sy, sz):
return np.diag([sx, sy, sz])

# 1. Setup: Create a matrix from a known P and Z.
# The key is to multiply them in the order P @ Z.
P_known = make_known_scaling_matrix(2.5, 1.0, 1.5)
Z_known = make_known_rotation(45) # Use Z for "orthogonal" part
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would recommend making the rotation not 45 degrees so that the ground truth is not symmetric

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can add this in a subsequent pull request.

# Construct X using the left decomposition structure: X = P @ Z
X = P_known @ Z_known

# 2. Decompose the matrix using the function
result = ants.polar_decomposition(X)
P_est = result["P"]
Z_est = result["Z"]
X_reconstructed = result["Xtilde"]

# 3. Assertions
# a. Check if the reconstruction is close to the original matrix.
nptest.assert_almost_equal(X, X_reconstructed)
# b. Check if the estimated symmetric part is close to the known one.
nptest.assert_almost_equal(P_known, P_est)
# c. Check if the estimated orthogonal part is close to the known one.
nptest.assert_almost_equal(Z_known, Z_est)

def test_convergence_monitoring(self):
f = [1 / i for i in range(1, 100)]
convergence = ants.convergence_monitoring(f, window_size=10)
nptest.assert_almost_equal(convergence, 0.0, decimal=3)

@unittest.skipIf(sitk is None, "SimpleITK is not installed")
class TestModule_sitk_to_ants(unittest.TestCase):
Expand Down
Loading