1
1
from builtins import super
2
2
import numpy as np
3
3
import numbers
4
- from .emregistration import EMRegistration
5
- from .utility import gaussian_kernel , low_rank_eigen
4
+ from .deformable_registration import DeformableRegistration
6
5
7
6
8
- class ConstrainedDeformableRegistration (EMRegistration ):
7
+ class ConstrainedDeformableRegistration (DeformableRegistration ):
9
8
"""
10
9
Constrained deformable registration.
11
10
@@ -28,16 +27,8 @@ class ConstrainedDeformableRegistration(EMRegistration):
28
27
29
28
"""
30
29
31
- def __init__ (self , alpha = None , beta = None , e_alpha = None , source_id = None , target_id = None , low_rank = False , num_eig = 100 , * args , ** kwargs ):
30
+ def __init__ (self , e_alpha = None , source_id = None , target_id = None , * args , ** kwargs ):
32
31
super ().__init__ (* args , ** kwargs )
33
- if alpha is not None and (not isinstance (alpha , numbers .Number ) or alpha <= 0 ):
34
- raise ValueError (
35
- "Expected a positive value for regularization parameter alpha. Instead got: {}" .format (alpha ))
36
-
37
- if beta is not None and (not isinstance (beta , numbers .Number ) or beta <= 0 ):
38
- raise ValueError (
39
- "Expected a positive value for the width of the coherent Gaussian kerenl. Instead got: {}" .format (beta ))
40
-
41
32
if e_alpha is not None and (not isinstance (e_alpha , numbers .Number ) or e_alpha <= 0 ):
42
33
raise ValueError (
43
34
"Expected a positive value for regularization parameter e_alpha. Instead got: {}" .format (e_alpha ))
@@ -50,24 +41,13 @@ def __init__(self, alpha=None, beta=None, e_alpha = None, source_id = None, targ
50
41
raise ValueError (
51
42
"The target ids (target_id) must be a 1D numpy array of ints." )
52
43
53
- self .alpha = 2 if alpha is None else alpha
54
- self .beta = 2 if beta is None else beta
55
44
self .e_alpha = 1e-8 if e_alpha is None else e_alpha
56
45
self .source_id = source_id
57
46
self .target_id = target_id
58
47
self .P_tilde = np .zeros ((self .M , self .N ))
59
48
self .P_tilde [self .source_id , self .target_id ] = 1
60
49
self .P1_tilde = np .sum (self .P_tilde , axis = 1 )
61
50
self .PX_tilde = np .dot (self .P_tilde , self .X )
62
- self .W = np .zeros ((self .M , self .D ))
63
- self .G = gaussian_kernel (self .Y , self .beta )
64
- self .low_rank = low_rank
65
- self .num_eig = num_eig
66
- if self .low_rank is True :
67
- self .Q , self .S = low_rank_eigen (self .G , self .num_eig )
68
- self .inv_S = np .diag (1. / self .S )
69
- self .S = np .diag (self .S )
70
- self .E = 0.
71
51
72
52
def update_transform (self ):
73
53
"""
@@ -93,56 +73,4 @@ def update_transform(self):
93
73
np .linalg .solve ((self .alpha * self .sigma2 * self .inv_S + np .matmul (self .Q .T , dPQ )),
94
74
(np .matmul (self .Q .T , F ))))))
95
75
QtW = np .matmul (self .Q .T , self .W )
96
- self .E = self .E + self .alpha / 2 * np .trace (np .matmul (QtW .T , np .matmul (self .S , QtW )))
97
-
98
- def transform_point_cloud (self , Y = None ):
99
- """
100
- Update a point cloud using the new estimate of the deformable transformation.
101
-
102
- """
103
- if Y is not None :
104
- G = gaussian_kernel (X = Y , beta = self .beta , Y = self .Y )
105
- return Y + np .dot (G , self .W )
106
- else :
107
- if self .low_rank is False :
108
- self .TY = self .Y + np .dot (self .G , self .W )
109
-
110
- elif self .low_rank is True :
111
- self .TY = self .Y + np .matmul (self .Q , np .matmul (self .S , np .matmul (self .Q .T , self .W )))
112
- return
113
-
114
-
115
- def update_variance (self ):
116
- """
117
- Update the variance of the mixture model using the new estimate of the deformable transformation.
118
- See the update rule for sigma2 in Eq. 23 of of https://arxiv.org/pdf/0905.2635.pdf.
119
-
120
- """
121
- qprev = self .sigma2
122
-
123
- # The original CPD paper does not explicitly calculate the objective functional.
124
- # This functional will include terms from both the negative log-likelihood and
125
- # the Gaussian kernel used for regularization.
126
- self .q = np .inf
127
-
128
- xPx = np .dot (np .transpose (self .Pt1 ), np .sum (
129
- np .multiply (self .X , self .X ), axis = 1 ))
130
- yPy = np .dot (np .transpose (self .P1 ), np .sum (
131
- np .multiply (self .TY , self .TY ), axis = 1 ))
132
- trPXY = np .sum (np .multiply (self .TY , self .PX ))
133
-
134
- self .sigma2 = (xPx - 2 * trPXY + yPy ) / (self .Np * self .D )
135
-
136
- if self .sigma2 <= 0 :
137
- self .sigma2 = self .tolerance / 10
138
-
139
- # Here we use the difference between the current and previous
140
- # estimate of the variance as a proxy to test for convergence.
141
- self .diff = np .abs (self .sigma2 - qprev )
142
-
143
- def get_registration_parameters (self ):
144
- """
145
- Return the current estimate of the deformable transformation parameters.
146
-
147
- """
148
- return self .G , self .W
76
+ self .E = self .E + self .alpha / 2 * np .trace (np .matmul (QtW .T , np .matmul (self .S , QtW )))
0 commit comments