Skip to content

Commit

Permalink
Improvements to GMM
Browse files Browse the repository at this point in the history
* Sub-sampling of data for faster fitting (outputs original resolution data), disabled by default.
* If input data type is integer, cast to float and add a little bit of noise.
  • Loading branch information
brudfors committed Aug 16, 2023
1 parent 915438c commit 34924de
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions nitorch/vb/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, num_class=2):

# Functions
def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None,
show_fit=False):
show_fit=False, samp=1):
""" Fit mixture model.
Args:
X (torch.tensor): Observed data (N, C).
Expand All @@ -50,17 +50,34 @@ def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None,
fig_num (int, optional): Defaults to 1.
W (torch.tensor, optional): Observation weights (N, 1). Defaults to no weights.
show_fit (bool, optional): Plot mixture fit, defaults to False.
samp (int, optional): Sub-sampling, defaults to 1.
Returns:
Z (torch.tensor): Responsibilities (N, K).
"""
if not isinstance(samp, int) or samp < 1:
raise ValueError(f"samp parameter needs to be an int >= 1, got {samp}")

if verbose:
t0 = timer() # Start timer

# Set random seed
torch.manual_seed(1)

if torch.is_floating_point(X) == False:
# Integer data type -> convert to float and add some noise
X = X.type(torch.float)
X += (0.001*X.max())*torch.randn_like(X)

self.dev = X.device
self.dt = X.dtype

if samp > 1:
# Sub-sample
X0 = X[::1, :] # Ensures copy
X = X[::samp, :]
if W is not None:
W0 = W[::1, :] # Ensures copy
W = W[::samp, :]

if len(X.shape) == 1:
X = X[:, None]
Expand All @@ -86,6 +103,16 @@ def fit(self, X, verbose=1, max_iter=10000, tol=1e-8, fig_num=1, W=None,
# EM loop
Z, lb = self._em(X, max_iter=max_iter, tol=tol, verbose=verbose, W=W)

if samp > 1:
# Create original resolution responsibilites
X = X0
N = X.shape[0]
Z = torch.zeros((N, K), dtype=self.dt, device=self.dev)
for k in range(K):
Z[:, k] = torch.log(self.mp[k]) + self._log_likelihood(X, k)
if W is not None: W = W0
Z, _ = softmax_lse(Z, lse=True, weights=W)

# Print algorithm info
if verbose >= 1:
print('Algorithm finished in {} iterations, '
Expand Down Expand Up @@ -277,7 +304,7 @@ def full_resp(Z, msk, dm=[]):
""" Converts masked responsibilities to full.
Args:
Z (torch.tensor): Masked responsibilities (N, K).
msk (torch.tensor): Mask of original data (N0, 1).
msk (torch.tensor): Mask of original data (N0, ).
dm (torch.Size, optional): Reshapes Z_full using dm. Defaults to [].
Returns:
Z_full (torch.tensor): Full responsibilities (N0, K).
Expand Down

0 comments on commit 34924de

Please sign in to comment.