Skip to content

Commit 6cae3a5

Browse files
committed
Update constrained to inherit from deformable
- This reduces code redunancy - Ensure that updates to deformable_registration Class are automatically included in constrained_deformable_registration Class.
1 parent 6bbac07 commit 6cae3a5

File tree

1 file changed

+4
-76
lines changed

1 file changed

+4
-76
lines changed
+4-76
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from builtins import super
22
import numpy as np
33
import numbers
4-
from .emregistration import EMRegistration
5-
from .utility import gaussian_kernel, low_rank_eigen
4+
from .deformable_registration import DeformableRegistration
65

76

8-
class ConstrainedDeformableRegistration(EMRegistration):
7+
class ConstrainedDeformableRegistration(DeformableRegistration):
98
"""
109
Constrained deformable registration.
1110
@@ -28,16 +27,8 @@ class ConstrainedDeformableRegistration(EMRegistration):
2827
2928
"""
3029

31-
def __init__(self, alpha=None, beta=None, e_alpha = None, source_id = None, target_id= None, low_rank=False, num_eig=100, *args, **kwargs):
30+
def __init__(self, e_alpha = None, source_id = None, target_id= None, *args, **kwargs):
3231
super().__init__(*args, **kwargs)
33-
if alpha is not None and (not isinstance(alpha, numbers.Number) or alpha <= 0):
34-
raise ValueError(
35-
"Expected a positive value for regularization parameter alpha. Instead got: {}".format(alpha))
36-
37-
if beta is not None and (not isinstance(beta, numbers.Number) or beta <= 0):
38-
raise ValueError(
39-
"Expected a positive value for the width of the coherent Gaussian kerenl. Instead got: {}".format(beta))
40-
4132
if e_alpha is not None and (not isinstance(e_alpha, numbers.Number) or e_alpha <= 0):
4233
raise ValueError(
4334
"Expected a positive value for regularization parameter e_alpha. Instead got: {}".format(e_alpha))
@@ -50,24 +41,13 @@ def __init__(self, alpha=None, beta=None, e_alpha = None, source_id = None, targ
5041
raise ValueError(
5142
"The target ids (target_id) must be a 1D numpy array of ints.")
5243

53-
self.alpha = 2 if alpha is None else alpha
54-
self.beta = 2 if beta is None else beta
5544
self.e_alpha = 1e-8 if e_alpha is None else e_alpha
5645
self.source_id = source_id
5746
self.target_id = target_id
5847
self.P_tilde = np.zeros((self.M, self.N))
5948
self.P_tilde[self.source_id, self.target_id] = 1
6049
self.P1_tilde = np.sum(self.P_tilde, axis=1)
6150
self.PX_tilde = np.dot(self.P_tilde, self.X)
62-
self.W = np.zeros((self.M, self.D))
63-
self.G = gaussian_kernel(self.Y, self.beta)
64-
self.low_rank = low_rank
65-
self.num_eig = num_eig
66-
if self.low_rank is True:
67-
self.Q, self.S = low_rank_eigen(self.G, self.num_eig)
68-
self.inv_S = np.diag(1./self.S)
69-
self.S = np.diag(self.S)
70-
self.E = 0.
7151

7252
def update_transform(self):
7353
"""
@@ -93,56 +73,4 @@ def update_transform(self):
9373
np.linalg.solve((self.alpha * self.sigma2 * self.inv_S + np.matmul(self.Q.T, dPQ)),
9474
(np.matmul(self.Q.T, F))))))
9575
QtW = np.matmul(self.Q.T, self.W)
96-
self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW)))
97-
98-
def transform_point_cloud(self, Y=None):
99-
"""
100-
Update a point cloud using the new estimate of the deformable transformation.
101-
102-
"""
103-
if Y is not None:
104-
G = gaussian_kernel(X=Y, beta=self.beta, Y=self.Y)
105-
return Y + np.dot(G, self.W)
106-
else:
107-
if self.low_rank is False:
108-
self.TY = self.Y + np.dot(self.G, self.W)
109-
110-
elif self.low_rank is True:
111-
self.TY = self.Y + np.matmul(self.Q, np.matmul(self.S, np.matmul(self.Q.T, self.W)))
112-
return
113-
114-
115-
def update_variance(self):
116-
"""
117-
Update the variance of the mixture model using the new estimate of the deformable transformation.
118-
See the update rule for sigma2 in Eq. 23 of of https://arxiv.org/pdf/0905.2635.pdf.
119-
120-
"""
121-
qprev = self.sigma2
122-
123-
# The original CPD paper does not explicitly calculate the objective functional.
124-
# This functional will include terms from both the negative log-likelihood and
125-
# the Gaussian kernel used for regularization.
126-
self.q = np.inf
127-
128-
xPx = np.dot(np.transpose(self.Pt1), np.sum(
129-
np.multiply(self.X, self.X), axis=1))
130-
yPy = np.dot(np.transpose(self.P1), np.sum(
131-
np.multiply(self.TY, self.TY), axis=1))
132-
trPXY = np.sum(np.multiply(self.TY, self.PX))
133-
134-
self.sigma2 = (xPx - 2 * trPXY + yPy) / (self.Np * self.D)
135-
136-
if self.sigma2 <= 0:
137-
self.sigma2 = self.tolerance / 10
138-
139-
# Here we use the difference between the current and previous
140-
# estimate of the variance as a proxy to test for convergence.
141-
self.diff = np.abs(self.sigma2 - qprev)
142-
143-
def get_registration_parameters(self):
144-
"""
145-
Return the current estimate of the deformable transformation parameters.
146-
147-
"""
148-
return self.G, self.W
76+
self.E = self.E + self.alpha / 2 * np.trace(np.matmul(QtW.T, np.matmul(self.S, QtW)))

0 commit comments

Comments
 (0)