-
Notifications
You must be signed in to change notification settings - Fork 2
Added dictionary learning (DL) functionality to LSQML, cleaned up tensor operations in RPIE + DL #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added dictionary learning (DL) functionality to LSQML, cleaned up tensor operations in RPIE + DL #42
Changes from all commits
3c1aeb7
cb2f641
ec900f7
f0ed9dd
519b961
91773fe
5979d66
be8ea9a
42b5187
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -97,6 +97,7 @@ def intensity_variation_optimization_enabled(self, epoch: int): | |
| def update_variable_probe( | ||
| self, | ||
| probe: "Probe", | ||
| adjoint_shift_probe_update_direction, # what do I do for type hint here? | ||
| indices: Tensor, | ||
| chi: Tensor, | ||
| delta_p_i: Tensor, | ||
|
|
@@ -117,8 +118,14 @@ def update_variable_probe( | |
| probe.optimization_enabled(current_epoch) | ||
| or (self.eigenmode_weight_optimization_enabled(current_epoch)) | ||
| ): | ||
| self.update_opr_probe_modes_and_weights( | ||
| probe, indices, chi, delta_p_i, delta_p_hat, obj_patches, current_epoch | ||
| self.update_opr_probe_modes_and_weights(probe, | ||
| adjoint_shift_probe_update_direction, | ||
| indices, | ||
| chi, | ||
| delta_p_i, | ||
| delta_p_hat, | ||
| obj_patches, | ||
| current_epoch | ||
| ) | ||
|
|
||
| if self.intensity_variation_optimization_enabled(current_epoch): | ||
|
|
@@ -134,6 +141,7 @@ def update_variable_probe( | |
| def update_opr_probe_modes_and_weights( | ||
| self, | ||
| probe: "Probe", | ||
| adjoint_shift_probe_update_direction, # what do I do for type hint here? | ||
| indices: Tensor, | ||
| chi: Tensor, | ||
| delta_p_i: Tensor, | ||
|
|
@@ -144,12 +152,12 @@ def update_opr_probe_modes_and_weights( | |
| """ | ||
| Update the eigenmodes of the first incoherent mode of the probe, and update the OPR mode weights. | ||
|
|
||
| This implementation is adapted from PtychoShelves code (update_variable_probe.m) and has some | ||
| differences from Eq. 31 of Odstrcil (2018). | ||
| The default (for self.options.use_optimal_update = False) implementation below is adapted from | ||
| PtychoShelves code (update_variable_probe.m) and has some differences from Eq. 31 of Odstrcil (2018). | ||
| """ | ||
| probe_data = probe.data | ||
| weights_data = self.data | ||
|
|
||
| batch_size = len(delta_p_i) | ||
| n_points_total = self.n_scan_points | ||
|
|
||
|
|
@@ -158,44 +166,165 @@ def update_opr_probe_modes_and_weights( | |
| if batch_size == 1: | ||
| return | ||
|
|
||
| # FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this. | ||
| relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation | ||
| relax_v = self.options.update_relaxation | ||
| # Shape of delta_p_i: (batch_size, n_probe_modes, h, w) | ||
| # Use only the first incoherent mode | ||
| delta_p_i = delta_p_i[:, 0, :, :] | ||
| delta_p_hat = delta_p_hat[0, :, :] | ||
| residue_update = delta_p_i - delta_p_hat | ||
|
|
||
| # Start from the second OPR mode which is the first after the main mode - i.e., the first eigenmode. | ||
| for i_opr_mode in range(1, probe.n_opr_modes): | ||
| # Just take the first incoherent mode. | ||
| eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode) | ||
| weights_i = self.get_weights(indices)[:, i_opr_mode] | ||
| eigenmode_i, weights_i = self._update_first_eigenmode_and_weight( | ||
| residue_update, | ||
| eigenmode_i, | ||
| weights_i, | ||
| relax_u, | ||
| relax_v, | ||
| obj_patches, | ||
| chi, | ||
| update_eigenmode=probe.optimization_enabled(current_epoch), | ||
| update_weights=self.eigenmode_weight_optimization_enabled(current_epoch), | ||
| ) | ||
|
|
||
| # Project residue on this eigenmode, then subtract it. | ||
| if i_opr_mode < probe.n_opr_modes - 1: | ||
| residue_update = residue_update - pmath.project( | ||
| residue_update, eigenmode_i, dim=(-2, -1) | ||
| update_eigenmode = probe.optimization_enabled(current_epoch) # why is this needed again? To even get into this function, we need this to already be true? | ||
| update_eigenmode_weights = self.eigenmode_weight_optimization_enabled(current_epoch) | ||
|
|
||
| if self.options.use_optimal_update: | ||
|
|
||
| rc = obj_patches.shape[-2] * obj_patches.shape[-1] | ||
| n_spos = obj_patches.shape[0] | ||
|
|
||
| U = probe_data[1:, 0, ...] | ||
|
|
||
| Ws = (weights_data[ indices, 1:]).to(torch.complex64) | ||
|
|
||
| Tsconj_chi = (obj_patches[:,0,...].conj() * chi[:,0,...]) | ||
| Tsconj_chi = adjoint_shift_probe_update_direction( indices, Tsconj_chi[:,None,...], first_mode_only=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one should be replaceable by delta_p_i |
||
|
|
||
| chi = adjoint_shift_probe_update_direction( indices, chi, first_mode_only=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably ignorable because we have been mixing shifted and unshifted variables before and it works. In that case the |
||
|
|
||
| U = torch.reshape(U, (U.shape[0], rc)) | ||
| chi_vec = torch.reshape(chi[:,0,...], (n_spos, rc)) | ||
| Ts = torch.reshape(obj_patches[:,0,...], (n_spos, rc)) | ||
| Tsconj_chi = torch.reshape(Tsconj_chi[:,0,...], (n_spos, rc)).T | ||
|
|
||
| # Optimal OPR weight updates | ||
|
|
||
| if update_eigenmode_weights: | ||
|
|
||
| delta_Ws = -2 * torch.real(U.conj() @ Tsconj_chi).to(torch.complex64) | ||
|
|
||
| Ts_U_deltaWs = Ts.T * (U.T @ delta_Ws) | ||
| numer = torch.sum(torch.real(chi_vec * Ts_U_deltaWs.H)) | ||
| denom = torch.sum(torch.real( Ts_U_deltaWs.conj() * Ts_U_deltaWs )) | ||
| optimal_step_deltaWs = self.options.update_relaxation * (numer / denom) | ||
|
|
||
| Ws = (Ws + optimal_step_deltaWs * delta_Ws.T) | ||
|
|
||
| if (probe.representation == "sparse_code" | ||
| and probe.options.experimental.sdl_probe_options.enabled_opr): | ||
|
|
||
| # Optimal sparse code OPR mode updates | ||
|
|
||
| delta_U = -1 * Tsconj_chi @ Ws | ||
|
|
||
| delta_sparse_code_probe_opr = probe.dictionary_matrix.H @ delta_U | ||
|
|
||
| Gs = probe.dictionary_matrix @ delta_sparse_code_probe_opr @ Ws.T | ||
| TsHGsH = Ts.H * Gs.conj() | ||
| numer = torch.sum( torch.real(TsHGsH * chi_vec.T)) | ||
| denom = torch.sum( torch.real(TsHGsH * TsHGsH.conj())) | ||
| optimal_step_sparse_code_probe_opr = probe.options.eigenmode_update_relaxation * (numer / denom) | ||
|
|
||
| sparse_code_probe_opr = probe.get_sparse_code_probe_opr_weights() | ||
|
|
||
| optimal_sparse_code_probe_opr = (sparse_code_probe_opr | ||
| + optimal_step_sparse_code_probe_opr * delta_sparse_code_probe_opr.T) | ||
|
|
||
| # Enforce sparsity constraint on sparse code | ||
| abs_sparse_code = torch.abs(optimal_sparse_code_probe_opr) | ||
| abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True) | ||
| sel = abs_sparse_code_sorted[0][:, probe.sparse_code_probe_nnz] | ||
| sparse_code_mask = (abs_sparse_code >= sel[:,None]) | ||
|
|
||
| # Hard or Soft thresholding | ||
| if probe.options.experimental.sdl_probe_options.thresholding_type_opr == 'hard': | ||
| optimal_sparse_code_probe_opr = optimal_sparse_code_probe_opr * sparse_code_mask | ||
| elif probe.options.experimental.sdl_probe_options.thresholding_type_opr == 'soft': | ||
| optimal_sparse_code_probe_opr = ( abs_sparse_code - sel[:,None] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_sparse_code_probe_opr)) | ||
|
|
||
| probe.set_sparse_code_probe_opr(optimal_sparse_code_probe_opr) | ||
|
|
||
| # Back to dense OPR representation | ||
| U = (probe.dictionary_matrix @ optimal_sparse_code_probe_opr.T).T | ||
|
|
||
| # the OPR modes must have L2 norm = torch.sqrt(torch.tensor(rc)) | ||
| U = U * torch.sqrt(torch.tensor(rc)) / torch.sqrt(torch.sum(torch.abs(U)**2, -1))[:,None] | ||
|
|
||
| U = torch.reshape(U, (U.shape[0], obj_patches.shape[-2], obj_patches.shape[-1])) | ||
|
|
||
| probe_data[1:, 0, :, :] = U | ||
| weights_data[indices, 1:] = Ws.real | ||
|
|
||
| # DELETE THIS FOR FINAL MERGING | ||
| # DELETE THIS FOR FINAL MERGING | ||
|
|
||
| # Test the rank of the new scan position dependent probe: | ||
|
|
||
| # probe_data_TEST = torch.reshape(probe_data[:,0,...], (probe_data.shape[0], probe_data.shape[-1] * probe_data.shape[-2])) | ||
| # Z1 = torch.sum(probe_data[:, 0, :, :][None,...] * weights_data[indices][...,None,None], 1) | ||
| # Z1 = torch.reshape(Z1, (Z1.shape[0], Z1.shape[1] * Z1.shape[2])) | ||
| # Z2 = probe_data_TEST.T @ weights_data[indices, :].T.to(torch.complex64) | ||
| # print( torch.linalg.matrix_rank(Z1) ) | ||
| # print( torch.linalg.matrix_rank(Z2) ) | ||
|
|
||
| # DELETE THIS FOR FINAL MERGING | ||
| # DELETE THIS FOR FINAL MERGING | ||
|
|
||
| else: | ||
|
|
||
| # Optimal dense OPR mode updates: | ||
|
|
||
| delta_U = -1 * Tsconj_chi @ Ws | ||
|
|
||
| Ts_deltaU_Ws = Ts.T * (delta_U @ Ws.T) | ||
| numer = torch.sum(torch.real(chi_vec * Ts_deltaU_Ws.H)) | ||
| denom = torch.sum(torch.real( Ts_deltaU_Ws.conj() * Ts_deltaU_Ws )) | ||
| optimal_step_deltaU = probe.options.eigenmode_update_relaxation * (numer / denom) | ||
|
|
||
| U = U + optimal_step_deltaU * delta_U.T | ||
|
|
||
| # the OPR modes must have L2 norm = torch.sqrt(torch.tensor(rc)) | ||
| U = U * torch.sqrt(torch.tensor(rc)) / torch.sqrt(torch.sum(torch.abs(U)**2, -1))[:,None] | ||
|
|
||
| U = torch.reshape(U, (U.shape[0], obj_patches.shape[-2], obj_patches.shape[-1])) | ||
|
|
||
| probe_data[1:, 0, :, :] = U | ||
| weights_data[indices, 1:] = Ws.real | ||
|
|
||
| else: | ||
|
|
||
| # Ptychoshelves method for OPR updates | ||
|
|
||
| # FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this. | ||
| relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation | ||
| relax_v = self.options.update_relaxation | ||
| # Shape of delta_p_i: (batch_size, n_probe_modes, h, w) | ||
| # Use only the first incoherent mode | ||
| delta_p_i = delta_p_i[:, 0, :, :] | ||
| delta_p_hat = delta_p_hat[0, :, :] | ||
| residue_update = delta_p_i - delta_p_hat | ||
|
|
||
| # Start from the second OPR mode which is the first after the main mode - i.e., the first eigenmode. | ||
| for i_opr_mode in range(1, probe.n_opr_modes): | ||
| # Just take the first incoherent mode. | ||
| eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here you can regenerate the eigenmodes using update SDL coefficients ( |
||
| weights_i = self.get_weights(indices)[:, i_opr_mode] | ||
| eigenmode_i, weights_i = self._update_first_eigenmode_and_weight( | ||
| residue_update, | ||
| eigenmode_i, | ||
| weights_i, | ||
| relax_u, | ||
| relax_v, | ||
| obj_patches, | ||
| chi, | ||
| update_eigenmode=update_eigenmode, | ||
| update_weights=self.eigenmode_weight_optimization_enabled(current_epoch), | ||
| ) | ||
|
|
||
| probe_data[i_opr_mode, 0, :, :] = eigenmode_i | ||
| weights_data[indices, i_opr_mode] = weights_i | ||
| # Project residue on this eigenmode, then subtract it. | ||
| if i_opr_mode < probe.n_opr_modes - 1: | ||
| residue_update = residue_update - pmath.project( | ||
| residue_update, eigenmode_i, dim=(-2, -1) | ||
| ) | ||
|
|
||
| if probe.optimization_enabled(current_epoch): | ||
| probe.set_data(probe_data) | ||
| if self.eigenmode_weight_optimization_enabled(current_epoch): | ||
| probe_data[i_opr_mode, 0, :, :] = eigenmode_i | ||
| weights_data[indices, i_opr_mode] = weights_i | ||
|
|
||
| if update_eigenmode: | ||
| probe.set_data(probe_data) | ||
|
|
||
| if update_eigenmode_weights: | ||
| self.set_data(weights_data) | ||
|
|
||
| @timer() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Split SDL and optimal OPR weight and eigenmode updates into 2 pull requests