Skip to content

Commit 81ccf1f

Browse files
authored
ENH: Add polar decomposition and convergence monitoring. (#819)
* ENH: Add polar decomposition and convergence monitoring. * BUG: typo.
1 parent 2ac0e45 commit 81ccf1f

File tree

7 files changed

+96
-60
lines changed

7 files changed

+96
-60
lines changed

ants/deeplearn/randomly_transform_image_data.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -101,17 +101,6 @@ def randomly_transform_image_data(reference_image,
101101
>>> transform_type = "affineAndDeformation" )
102102
"""
103103

104-
def polar_decomposition(X):
105-
U, d, V = np.linalg.svd(X, full_matrices=False)
106-
P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
107-
Z = np.matmul(U, V)
108-
if np.linalg.det(Z) < 0:
109-
n = X.shape[0]
110-
reflection_matrix = np.identity(n)
111-
reflection_matrix[0,0] = -1.0
112-
Z = np.matmul(Z, reflection_matrix)
113-
return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})
114-
115104
def create_random_linear_transform(image,
116105
fixed_parameters,
117106
transform_type='affine',
@@ -131,7 +120,7 @@ def create_random_linear_transform(image,
131120
random_matrix = np.reshape(
132121
random_parameters[:(len(identity_parameters) - image.dimension)],
133122
newshape=(image.dimension, image.dimension))
134-
decomposition = polar_decomposition(random_matrix)
123+
decomposition = ants.polar_decomposition(random_matrix)
135124

136125
if transform_type == "rotation" or transform_type == "rigid":
137126
random_matrix = decomposition['Z']

ants/registration/create_jacobian_determinant_image.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from tempfile import mktemp
88

99
import ants
10+
import numpy as np
1011
from ants.internal import get_lib_fn, process_arguments
1112

1213

@@ -44,17 +45,6 @@ def deformation_gradient( warp_image, to_rotation=False, py_based=False ):
4445
>>> mytx = ants.registration(fixed=fi , moving=mi, type_of_transform = ('SyN') )
4546
>>> dg = ants.deformation_gradient( ants.image_read( mytx['fwdtransforms'][0] ) )
4647
"""
47-
import numpy as np
48-
def polar_decomposition(X):
49-
U, d, V = np.linalg.svd(X, full_matrices=False)
50-
P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
51-
Z = np.matmul(U, V)
52-
if np.linalg.det(Z) < 0:
53-
n = X.shape[0]
54-
reflection_matrix = np.identity(n)
55-
reflection_matrix[0,0] = -1.0
56-
Z = np.matmul(Z, reflection_matrix)
57-
return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})
5848
if not py_based:
5949
if ants.is_image(warp_image):
6050
txuse = mktemp(suffix='.nii.gz')
@@ -78,7 +68,7 @@ def polar_decomposition(X):
7868
dg = np.reshape( dg.numpy(), newshape )
7969
it=np.ndindex(tshp)
8070
for i in it:
81-
dg[i]=polar_decomposition( dg[i] )['Z']
71+
dg[i]=ants.polar_decomposition( dg[i] )['Z']
8272
newshape = tshp + (dim*dim,)
8373
dg = np.reshape( dg, newshape )
8474
dg = ants.from_numpy( dg, has_components=True )
@@ -114,7 +104,7 @@ def polar_decomposition(X):
114104
if to_rotation:
115105
it=np.ndindex(tshp)
116106
for i in it:
117-
dg[i]=polar_decomposition( dg[i] )['Z']
107+
dg[i]=ants.polar_decomposition( dg[i] )['Z']
118108
newshape = tshp + (dim*dim,)
119109
dg = np.reshape( dg, newshape )
120110
dg = ants.from_numpy( dg, has_components=True )

ants/registration/landmark_transforms.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,6 @@
77

88
import ants
99

10-
def convergence_monitoring(values, window_size=10):
11-
if len(values) >= window_size:
12-
u = np.linspace(0.0, 1.0, num=window_size)
13-
scattered_data = np.expand_dims(values[-window_size:], axis=-1)
14-
parametric_data = np.expand_dims(u, axis=-1)
15-
spacing = 1 / (window_size-1)
16-
bspline_line = ants.fit_bspline_object_to_scattered_data(scattered_data, parametric_data,
17-
parametric_domain_origin=[0.0], parametric_domain_spacing=[spacing],
18-
parametric_domain_size=[window_size], number_of_fitting_levels=1, mesh_size=1,
19-
spline_order=1)
20-
bspline_slope = -(bspline_line[1][0] - bspline_line[0][0]) / spacing
21-
return(bspline_slope)
22-
else:
23-
return None
24-
25-
2610
def fit_transform_to_paired_points(moving_points,
2711
fixed_points,
2812
transform_type="affine",
@@ -130,17 +114,6 @@ def fit_transform_to_paired_points(moving_points,
130114
>>> xfrm = ants.fit_transform_to_paired_points(moving, fixed, transform_type="diffeo", domain_image=domain_image, number_of_fitting_levels=6)
131115
"""
132116

133-
def polar_decomposition(X):
134-
U, d, V = np.linalg.svd(X, full_matrices=False)
135-
P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
136-
Z = np.matmul(U, V)
137-
if np.linalg.det(Z) < 0:
138-
n = X.shape[0]
139-
reflection_matrix = np.identity(n)
140-
reflection_matrix[0,0] = -1.0
141-
Z = np.matmul(Z, reflection_matrix)
142-
return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})
143-
144117
def create_zero_displacement_field(domain_image):
145118
field_array = np.zeros((*domain_image.shape, domain_image.dimension))
146119
field = ants.from_numpy(field_array, origin=domain_image.origin,
@@ -191,7 +164,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
191164
M = x11 * (1.0 - regularization) + regularization * y_prior
192165
Minv = np.linalg.lstsq(M, y, rcond=None)[0]
193166

194-
p = polar_decomposition(Minv[0:dimensionality, 0:dimensionality].T)
167+
p = ants.polar_decomposition(Minv[0:dimensionality, 0:dimensionality].T)
195168
A = p['Xtilde']
196169
translation = Minv[dimensionality,:] + center_moving - center_fixed
197170

@@ -296,7 +269,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
296269
updated_fixed_points[j,:] = total_field_xfrm.apply_to_point(tuple(fixed_points[j,:]))
297270

298271
error_values.append(np.mean(np.sqrt(np.sum(np.square(updated_fixed_points - moving_points), axis=1, keepdims=True))))
299-
convergence_value = convergence_monitoring(error_values)
272+
convergence_value = ants.convergence_monitoring(error_values)
300273
if verbose:
301274
end_time = time.time()
302275
diff_time = end_time - start_time
@@ -395,7 +368,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
395368
updated_moving_points[j,:] = total_field_moving_to_middle_xfrm.apply_to_point(tuple(moving_points[j,:]))
396369

397370
error_values.append(np.mean(np.sqrt(np.sum(np.square(updated_fixed_points - updated_moving_points), axis=1, keepdims=True))))
398-
convergence_value = convergence_monitoring(error_values)
371+
convergence_value = ants.convergence_monitoring(error_values)
399372
if verbose:
400373
end_time = time.time()
401374
diff_time = end_time - start_time
@@ -512,7 +485,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
512485
has_components=True)
513486

514487
error_values.append(average_error)
515-
convergence_value = convergence_monitoring(error_values)
488+
convergence_value = ants.convergence_monitoring(error_values)
516489
if verbose:
517490
end_time = time.time()
518491
diff_time = end_time - start_time
@@ -816,7 +789,7 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
816789
has_components=True)
817790

818791
error_values.append(average_error)
819-
convergence_value = convergence_monitoring(error_values)
792+
convergence_value = ants.convergence_monitoring(error_values)
820793
if verbose:
821794
end_time = time.time()
822795
diff_time = end_time - start_time

ants/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
from .nifti_to_ants import nifti_to_ants
1414
from .scalar_rgb_vector import rgb_to_vector, vector_to_rgb, scalar_to_rgb
1515
from .sitk_to_ants import from_sitk, to_sitk
16+
from .convergence_monitoring import convergence_monitoring
17+
from .polar_decomposition import polar_decomposition

ants/utils/convergence_monitoring.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
import ants
3+
4+
__all__ = ['convergence_monitoring']
5+
6+
def convergence_monitoring(values, window_size=10):
7+
8+
if len(values) >= window_size:
9+
10+
u = np.linspace(0.0, 1.0, num=window_size)
11+
scattered_data = np.expand_dims(values[-window_size:], axis=-1)
12+
parametric_data = np.expand_dims(u, axis=-1)
13+
spacing = 1 / (window_size-1)
14+
bspline_line = ants.fit_bspline_object_to_scattered_data(scattered_data, parametric_data,
15+
parametric_domain_origin=[0.0], parametric_domain_spacing=[spacing],
16+
parametric_domain_size=[window_size], number_of_fitting_levels=1, mesh_size=1,
17+
spline_order=1)
18+
bspline_slope = -(bspline_line[1][0] - bspline_line[0][0]) / spacing
19+
20+
return(bspline_slope)
21+
22+
else:
23+
24+
return None
25+

ants/utils/polar_decomposition.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
import ants
3+
4+
__all__ = ['polar_decomposition']
5+
6+
def polar_decomposition(X):
7+
U, d, Vh = np.linalg.svd(X, full_matrices=False) # Vh is V transpose
8+
# This is the formula for P in a LEFT decomposition (X = PZ)
9+
P = np.matmul(U, np.matmul(np.diag(d), np.transpose(U)))
10+
# Z is the orthogonal part, U @ Vh
11+
Z = np.matmul(U, Vh)
12+
13+
# Correction to ensure Z is a proper rotation (det(Z) = +1)
14+
if np.linalg.det(Z) < 0:
15+
n = X.shape[0]
16+
reflection_matrix = np.identity(n)
17+
reflection_matrix[-1, -1] = -1.0 # More robust to change last element
18+
U_prime = U.copy()
19+
U_prime[:, -1] *= -1
20+
Z = U_prime @ Vh
21+
22+
# The returned Xtilde is P @ Z, consistent with a LEFT decomposition
23+
return({"P" : P, "Z" : Z, "Xtilde" : np.matmul(P, Z)})

tests/test_utils.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -990,15 +990,49 @@ def test_hausdorff_distance(self):
990990
stats = ants.hausdorff_distance(s16, s64)
991991

992992
def test_channels_first(self):
993-
import ants
994993
image = ants.image_read(ants.get_ants_data('r16'))
995994
image2 = ants.image_read(ants.get_ants_data('r16'))
996995
img3 = ants.merge_channels([image,image2])
997996
img4 = ants.merge_channels([image,image2], channels_first=True)
998-
999997
self.assertTrue(np.allclose(img3.numpy()[:,:,0], img4.numpy()[0,:,:]))
1000998
self.assertTrue(np.allclose(img3.numpy()[:,:,1], img4.numpy()[1,:,:]))
1001-
999+
1000+
def test_polar_decomposition(self):
1001+
# Helper functions for creating known matrices
1002+
def make_known_rotation(theta_deg):
1003+
theta = np.deg2rad(theta_deg)
1004+
R = np.eye(3)
1005+
R[:2, :2] = [[np.cos(theta), -np.sin(theta)],
1006+
[np.sin(theta), np.cos(theta)]]
1007+
return R
1008+
def make_known_scaling_matrix(sx, sy, sz):
1009+
return np.diag([sx, sy, sz])
1010+
1011+
# 1. Setup: Create a matrix from a known P and Z.
1012+
# The key is to multiply them in the order P @ Z.
1013+
P_known = make_known_scaling_matrix(2.5, 1.0, 1.5)
1014+
Z_known = make_known_rotation(45) # Use Z for "orthogonal" part
1015+
# Construct X using the left decomposition structure: X = P @ Z
1016+
X = P_known @ Z_known
1017+
1018+
# 2. Decompose the matrix using the function
1019+
result = ants.polar_decomposition(X)
1020+
P_est = result["P"]
1021+
Z_est = result["Z"]
1022+
X_reconstructed = result["Xtilde"]
1023+
1024+
# 3. Assertions
1025+
# a. Check if the reconstruction is close to the original matrix.
1026+
nptest.assert_almost_equal(X, X_reconstructed)
1027+
# b. Check if the estimated symmetric part is close to the known one.
1028+
nptest.assert_almost_equal(P_known, P_est)
1029+
# c. Check if the estimated orthogonal part is close to the known one.
1030+
nptest.assert_almost_equal(Z_known, Z_est)
1031+
1032+
def test_convergence_monitoring(self):
1033+
f = [1 / i for i in range(1, 100)]
1034+
convergence = ants.convergence_monitoring(f, window_size=10)
1035+
nptest.assert_almost_equal(convergence, 0.0, decimal=3)
10021036

10031037
@unittest.skipIf(sitk is None, "SimpleITK is not installed")
10041038
class TestModule_sitk_to_ants(unittest.TestCase):

0 commit comments

Comments
 (0)