Skip to content

Commit

Permalink
Merge branch 'master' of github.com:balbasty/nitorch
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Sep 5, 2023
2 parents aa7562f + bd6d180 commit d67915e
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions nitorch/vb/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, num_class=2):

# Functions
def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None,
show_fit=False):
show_fit=False, samp=1):
""" Fit mixture model.
Args:
X (torch.tensor): Observed data (N, C).
Expand All @@ -50,17 +50,34 @@ def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None,
fig_num (int, optional): Defaults to 1.
W (torch.tensor, optional): Observation weights (N, 1). Defaults to no weights.
show_fit (bool, optional): Plot mixture fit, defaults to False.
samp (int, optional): Sub-sampling, defaults to 1.
Returns:
Z (torch.tensor): Responsibilities (N, K).
"""
if not isinstance(samp, int) or samp < 1:
raise ValueError(f"samp parameter needs to be an int >= 1, got {samp}")

if verbose:
t0 = timer() # Start timer

# Set random seed
torch.manual_seed(1)

if torch.is_floating_point(X) == False:
# Integer data type -> convert to float and add some noise
X = X.type(torch.float)
X += (0.001*X.max())*torch.randn_like(X)

self.dev = X.device
self.dt = X.dtype

if samp > 1:
# Sub-sample
X0 = X[::1, :] # Ensures copy
X = X[::samp, :]
if W is not None:
W0 = W[::1, :] # Ensures copy
W = W[::samp, :]

if len(X.shape) == 1:
X = X[:, None]
Expand All @@ -86,6 +103,16 @@ def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None,
# EM loop
Z, lb = self._em(X, max_iter=max_iter, tol=tol, verbose=verbose, W=W)

if samp > 1:
# Create original resolution responsibilites
X = X0
N = X.shape[0]
Z = torch.zeros((N, K), dtype=self.dt, device=self.dev)
for k in range(K):
Z[:, k] = torch.log(self.mp[k]) + self._log_likelihood(X, k)
if W is not None: W = W0
Z, _ = softmax_lse(Z, lse=True, weights=W)

# Print algorithm info
if verbose >= 1:
print('Algorithm finished in {} iterations, '
Expand Down Expand Up @@ -125,6 +152,7 @@ def _em(self, X, max_iter, tol, verbose, W):
# Start EM algorithm
Z = torch.zeros((N, K), dtype=dtype, device=device) # responsibility
lb = torch.zeros(max_iter, dtype=torch.float64, device=device)
gain_count = 0
for n_iter in range(max_iter): # EM loop
# ==========
# E-step
Expand All @@ -143,7 +171,11 @@ def _em(self, X, max_iter, tol, verbose, W):
print('n_iter: {}, lb: {}, gain: {}'
.format(n_iter + 1, lb[n_iter], gain))
if gain < tol:
break # Finished
gain_count += 1
if gain_count >= 6:
break # Finished
else:
gain_count = 0

if W is not None: # Weight responsibilities
Z = Z * W
Expand Down Expand Up @@ -273,12 +305,12 @@ def reshape_input(img):
return X, N0, C

@staticmethod
def full_resp(Z, msk, dm=[]):
def full_resp(Z, msk, dm=None):
""" Converts masked responsibilities to full.
Args:
Z (torch.tensor): Masked responsibilities (N, K).
msk (torch.tensor): Mask of original data (N0, 1).
dm (torch.Size, optional): Reshapes Z_full using dm. Defaults to [].
msk (torch.tensor): Mask of original data (N0, ).
dm (tuple/list, optional): Reshapes Z_full using dm.
Returns:
Z_full (torch.tensor): Full responsibilities (N0, K).
"""
Expand All @@ -287,20 +319,20 @@ def full_resp(Z, msk, dm=[]):
Z_full = torch.zeros((N0, K), dtype=Z.dtype, device=Z.device)
for k in range(K):
Z_full[msk, k] = Z[:, k]
if len(dm) >= 3:
Z_full = torch.reshape(Z_full, (dm[0], dm[1], dm[2], K))
if isinstance(dm, (list, tuple)):
Z_full = torch.reshape(Z_full, tuple(dm) + (K,))

return Z_full

@staticmethod
def maximum_likelihood(Z):
""" Return maximum likelihood map.
Args:
Z (torch.tensor): Responsibilities (N, K).
Z (torch.tensor): Responsibilities (N/*spatial, K).
Returns:
(torch.tensor): Maximum likelihood map (N, 1).
(torch.tensor): Maximum likelihood map (N/*spatial, 1).
"""
return torch.argmax(Z, dim=3)
return torch.argmax(Z, dim=-1)


class GMM(Mixture):
Expand Down

0 comments on commit d67915e

Please sign in to comment.