Skip to content

Commit

Permalink
Update mixtures.py
Browse files Browse the repository at this point in the history
Added gain counter to smooth convergence threshold
  • Loading branch information
brudfors authored Aug 23, 2023
1 parent b7f73ef commit bd6d180
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion nitorch/vb/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ def _em(self, X, max_iter, tol, verbose, W):
# Start EM algorithm
Z = torch.zeros((N, K), dtype=dtype, device=device) # responsibility
lb = torch.zeros(max_iter, dtype=torch.float64, device=device)
gain_count = 0
for n_iter in range(max_iter): # EM loop
# ==========
# E-step
Expand All @@ -170,7 +171,11 @@ def _em(self, X, max_iter, tol, verbose, W):
print('n_iter: {}, lb: {}, gain: {}'
.format(n_iter + 1, lb[n_iter], gain))
if gain < tol:
break # Finished
gain_count += 1
if gain_count >= 6:
break # Finished
else:
gain_count = 0

if W is not None: # Weight responsibilities
Z = Z * W
Expand Down

0 comments on commit bd6d180

Please sign in to comment.