-
Notifications
You must be signed in to change notification settings - Fork 2
Added dictionary learning (DL) functionality to LSQML, cleaned up tensor operations in RPIE + DL #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added dictionary learning (DL) functionality to LSQML, cleaned up tensor operations in RPIE + DL #42
Changes from 3 commits
3c1aeb7
cb2f641
ec900f7
f0ed9dd
519b961
91773fe
5979d66
be8ea9a
42b5187
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -225,13 +225,109 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions) | |||||||||
| ) | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| # Calculate probe update direction. | ||||||||||
| delta_p_i_unshifted = self._calculate_probe_update_direction( | ||||||||||
| chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None | ||||||||||
| ) # Eq. 24a | ||||||||||
| delta_p_i = self.adjoint_shift_probe_update_direction( | ||||||||||
| indices, delta_p_i_unshifted, first_mode_only=True | ||||||||||
| ) | ||||||||||
| if (self.parameter_group.probe.representation == "sparse_code"): | ||||||||||
|
|
||||||||||
| rc = chi.shape[-1] * chi.shape[-2] | ||||||||||
|
||||||||||
| n_scpm = chi.shape[-3] | ||||||||||
| n_spos = chi.shape[-4] | ||||||||||
|
|
||||||||||
| #====================================================================== | ||||||||||
| # sparse code update directions vs scan position and shared probe modes | ||||||||||
|
|
||||||||||
| obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) | ||||||||||
|
||||||||||
|
|
||||||||||
| delta_sparse_code = chi * obj_patches_slice_i_conj[:, None, ... ] | ||||||||||
| delta_sparse_code = self.adjoint_shift_probe_update_direction(indices, delta_sparse_code, first_mode_only=True) | ||||||||||
|
|
||||||||||
| delta_sparse_code = torch.reshape( delta_sparse_code, | ||||||||||
| ( n_spos, n_scpm, rc )) | ||||||||||
|
|
||||||||||
| delta_sparse_code = torch.einsum('ijk,kl->lij', | ||||||||||
| delta_sparse_code, | ||||||||||
| self.parameter_group.probe.dictionary_matrix_H.T) | ||||||||||
|
|
||||||||||
| #=================================================== | ||||||||||
| # compute optimal step length for sparse code update | ||||||||||
|
|
||||||||||
| dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', | ||||||||||
| self.parameter_group.probe.dictionary_matrix, | ||||||||||
| delta_sparse_code) | ||||||||||
|
|
||||||||||
| obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc )) | ||||||||||
|
|
||||||||||
| denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] | ||||||||||
|
|
||||||||||
| denom = torch.einsum('ij,jik->ik', | ||||||||||
| torch.conj( obj_patches_vec ), | ||||||||||
| denom) | ||||||||||
|
|
||||||||||
| #===== | ||||||||||
|
|
||||||||||
| chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction(indices, chi, first_mode_only=True) | ||||||||||
|
|
||||||||||
| numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, | ||||||||||
| ( n_spos, n_scpm, rc )).permute(2,0,1) | ||||||||||
| numer = torch.einsum('ij,jik->ik', | ||||||||||
| torch.conj( obj_patches_vec ), | ||||||||||
| numer) | ||||||||||
|
|
||||||||||
| # real is used to throw away small imag part due to numerical precision errors | ||||||||||
|
||||||||||
| # real is used to throw away small imag part due to numerical precision errors | |
| # In theory, numer/denom should be real, but small imaginary parts can arise due to floating-point | |
| # precision errors in complex arithmetic. We use .real to discard these. Alternatively, one could use | |
| # torch.real_if_close or check that the imaginary part is negligible before discarding it. |
Outdated
Copilot
AI
Aug 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This commented-out code should either be implemented or removed. Leaving commented code in production reduces maintainability and clarity.
| # if self.parameter_group.probe.use_avg_spos_sparse_code: | |
| # delta_p_i = torch.tile( delta_p_i, ( n_spos, 1, 1 ) ) |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -135,50 +135,89 @@ def compute_updates( | |||||||||
|
|
||||||||||
| delta_p_i = None | ||||||||||
| if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)): | ||||||||||
|
|
||||||||||
| if (self.parameter_group.probe.representation == "sparse_code"): | ||||||||||
| # TODO: move this into SynthesisDictLearnProbe class | ||||||||||
|
|
||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Separate method for all everything inside the |
||||||||||
| # TODO: move these into SynthesisDictLearnProbe class | ||||||||||
|
||||||||||
| # TODO: move these into SynthesisDictLearnProbe class | |
| # TODO: Move these calculations into SynthesisDictLearnProbe class for better modularity. | |
| # See issue tracker: https://github.com/AdvancedPhotonSource/pty-chi/issues/XXX | |
| # Target: Refactor by Q3 2025. Requirements: Move rc, n_scpm, n_spos, obj_patches_conj, conjT_i_delta_exwv_i, and related logic. |
Copilot
AI
Aug 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO comments should include more specific information about implementation timeline, requirements, or be converted to proper issue tracking.
| #(TODO: soft thresholding option as default?) | |
| # TODO: Consider implementing soft thresholding as the default option for enforcing sparsity. | |
| # See issue #123 on GitHub (https://github.com/AdvancedPhotonSource/pty-chi/issues/123) for requirements and discussion. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's just a boolean you don't have to register it as buffer. Also, instead of setting it as an attribute in
__init__, let's just reference it fromself.optionson the fly in the method where it is used. This way if the user changes the value in the options object in the middle of a reconstruction, the new value can take effect dynamically.