From c3f7d20aed6f2fc66f271e367f8f5b4c8a257589 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Mon, 8 Sep 2025 14:38:26 -0500 Subject: [PATCH 01/13] attempting to merge dictionary learning into multiGPU LSQML --- src/ptychi/api/options/base.py | 36 +++-- src/ptychi/api/options/lsqml.py | 9 +- src/ptychi/api/task.py | 8 +- src/ptychi/data_structures/probe.py | 197 +++++++++++++++++++++++----- src/ptychi/maths.py | 4 + src/ptychi/reconstructors/lsqml.py | 127 +++++++++++++++++- 6 files changed, 326 insertions(+), 55 deletions(-) diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 710f93e..c08356d 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -545,7 +545,9 @@ class ProbeOrthogonalizeIncoherentModesOptions(FeatureOptions): method: enums.OrthogonalizationMethods = enums.OrthogonalizationMethods.SVD """The method to use for incoherent_mode orthogonalization.""" - + + sort_by_occupancy: bool = False + """Keep the probes sorted so that mode with highest occupancy is the 0th shared mode""" @dataclasses.dataclass class ProbeOrthogonalizeOPRModesOptions(FeatureOptions): @@ -651,25 +653,35 @@ def get_non_data_fields(self) -> dict: @dataclasses.dataclass class SynthesisDictLearnProbeOptions(Options): - - d_mat: Union[ndarray, Tensor] = None + + enabled: bool = False + enabled_shared: bool = False + enabled_opr: bool = False + + thresholding_type_shared: str = 'hard' + thresholding_type_opr: str = 'hard' + """Choose between 'hard' or 'soft' thresholding.""" + + dictionary_matrix: Union[ndarray, Tensor] = None """The synthesis sparse dictionary matrix; contains the basis functions that will be used to represent the probe via the sparse code weights.""" - d_mat_conj_transpose: Union[ndarray, Tensor] = None - """Conjugate transpose of the synthesis sparse dictionary matrix.""" - - d_mat_pinv: Union[ndarray, Tensor] = None + dictionary_matrix_pinv: Union[ndarray, Tensor] = None """Moore-Penrose pseudoinverse of the synthesis sparse dictionary matrix.""" - probe_sparse_code: Union[ndarray, Tensor] = None - """Sparse code weights vector.""" + sparse_code_probe_shared: Union[ndarray, Tensor] = None + """Sparse code weights vector for the shared modes.""" - probe_sparse_code_nnz: float = None + sparse_code_probe_shared_nnz: float = None """Number of non-zeros we will keep when enforcing sparsity constraint on - the sparse code weights vector probe_sparse_code.""" + the SHARED sparse code weights vector sparse_code_probe_shared.""" + + sparse_code_probe_opr: Union[ndarray, Tensor] = None + """Sparse code weights vector for the OPRs.""" - enabled: bool = False + sparse_code_probe_opr_nnz: float = None + """Number of non-zeros we will keep when enforcing sparsity constraint on + the OPR sparse code weights vector sparse_code_probe_opr.""" @dataclasses.dataclass class PositionCorrectionOptions(Options): diff --git a/src/ptychi/api/options/lsqml.py b/src/ptychi/api/options/lsqml.py index 2a2f063..31c6364 100644 --- a/src/ptychi/api/options/lsqml.py +++ b/src/ptychi/api/options/lsqml.py @@ -100,6 +100,10 @@ class LSQMLObjectOptions(base.ObjectOptions): propagation always uses all probe modes regardless of this option. """ +@dataclasses.dataclass +class LSQMLProbeExperimentalOptions(base.Options): + sdl_probe_options: base.SynthesisDictLearnProbeOptions = dataclasses.field(default_factory=base.SynthesisDictLearnProbeOptions) + @dataclasses.dataclass class LSQMLProbeOptions(base.ProbeOptions): @@ -107,8 +111,9 @@ class LSQMLProbeOptions(base.ProbeOptions): """ A scaler for the solved optimal step size (beta_LSQ in PtychoShelves). """ - - + experimental: LSQMLProbeExperimentalOptions = dataclasses.field(default_factory=LSQMLProbeExperimentalOptions) + + @dataclasses.dataclass class LSQMLProbePositionOptions(base.ProbePositionOptions): pass diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index b60a2b9..f04e225 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -192,9 +192,11 @@ def build_probe(self): self.probe_options.experimental.deep_image_prior_options.enabled ): self.probe = probe.DIPProbe(**kwargs) - elif ( - isinstance(self.probe_options, api.options.PIEProbeOptions) - ) and ( + elif ( + isinstance(self.probe_options, api.options.PIEProbeOptions) + or + isinstance(self.probe_options, api.options.LSQMLProbeOptions) + ) and ( self.probe_options.experimental.sdl_probe_options.enabled ): self.probe = probe.SynthesisDictLearnProbe(**kwargs) diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index b730b64..e60bacf 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -221,6 +221,11 @@ def constrain_incoherent_modes_orthogonality(self): return probe = self.data + + if self.options.orthogonalize_incoherent_modes.sort_by_occupancy: + shared_occupancy = torch.sum(torch.abs(probe[0,...])**2,(-2,-1)) / torch.sum(torch.abs(probe[0,...])**2) + shared_occupancy = torch.sort(shared_occupancy, dim=0, descending=True) + probe[0,...] = probe[ 0, shared_occupancy[1],...] norm_first_mode_orig = pmath.norm(probe[0, 0], dim=(-2, -1)) @@ -470,31 +475,52 @@ def __init__(self, name = "probe", options = None, *args, **kwargs): super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs) - dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H = self.get_dictionary() + dictionary_matrix, dictionary_matrix_pinv = self.get_dictionary() self.register_buffer("dictionary_matrix", dictionary_matrix) self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv) - self.register_buffer("dictionary_matrix_H", dictionary_matrix_H) + + sparse_code_probe_shared_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, dtype=torch.uint32 ) + sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() + self.register_buffer("sparse_code_probe_nnz", sparse_code_probe_shared_nnz ) + self.register_parameter("sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared)) - probe_sparse_code_nnz = torch.tensor( self.options.experimental.sdl_probe_options.probe_sparse_code_nnz, dtype=torch.uint32 ) - self.register_buffer("probe_sparse_code_nnz", probe_sparse_code_nnz ) + sparse_code_probe_opr_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_nnz, dtype=torch.uint32 ) + sparse_code_probe_opr = self.get_sparse_code_probe_opr_weights() + self.register_buffer("sparse_code_opr_nnz", sparse_code_probe_opr_nnz ) + self.register_parameter("sparse_code_probe_opr", torch.nn.Parameter(sparse_code_probe_opr)) - sparse_code_probe = self.get_sparse_code_weights() - self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe)) - self.build_optimizer() def get_dictionary(self): - dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.d_mat, dtype=torch.complex64 ) - dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_pinv, dtype=torch.complex64 ) - dictionary_matrix_H = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_conj_transpose, dtype=torch.complex64 ) - return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H - - def get_sparse_code_weights(self): - sz = self.data.shape - probe_vec = torch.reshape( self.data[0,...], (sz[1], sz[2] * sz[3])) - probe_vec = torch.swapaxes( probe_vec, 0, -1) - sparse_code_probe = self.dictionary_matrix_pinv @ probe_vec - return sparse_code_probe + dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix, dtype=torch.complex64 ) + dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix_pinv, dtype=torch.complex64 ) + return dictionary_matrix, dictionary_matrix_pinv + + def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions ): + + sz = probe_vs_scanpositions.shape + probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2]*sz[3])) + sparse_code_vs_scanpositions = torch.einsum('ij,klj->ikl', self.dictionary_matrix_pinv, probe_vec) + + return sparse_code_vs_scanpositions + + def get_sparse_code_probe_shared_weights(self): + + probe_shared = self.data[0,...] + sz = probe_shared.shape + probe_vec = torch.reshape(probe_shared, (sz[0], sz[1]*sz[2])) + sparse_code_probe_shared = self.dictionary_matrix_pinv @ probe_vec.T + + return sparse_code_probe_shared.T + + def get_sparse_code_probe_opr_weights(self): + + probe_opr = self.data[1:,0,...] + sz = probe_opr.shape + probe_vec = torch.reshape(probe_opr, (sz[0], sz[1]*sz[2])) + sparse_code_probe_opr = self.dictionary_matrix_pinv @ probe_vec.T + + return sparse_code_probe_opr.T def generate(self): """Generate the probe using the sparse code, and set the @@ -505,16 +531,49 @@ def generate(self): Tensor A (n_opr_modes, n_modes, h, w) tensor giving the generated probe. """ - probe_vec = self.dictionary_matrix @ self.sparse_code_probe - probe_vec = torch.swapaxes( probe_vec, 0, -1) - probe = torch.reshape(probe_vec, *[self.data[0,...].shape]) - probe = probe[None,...] - - # we only use sparse codes for the shared modes, not the OPRs - probe = torch.cat((probe, self.data[1:,...]), 0) - - self.set_data(probe) - return probe + if (self.options.experimental.sdl_probe_options.enabled_shared + and self.options.experimental.sdl_probe_options.enabled_opr): + + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + + probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T + probe_opr = self.dictionary_matrix @ self.sparse_code_probe_opr.T + + probe[0,...] = torch.reshape( probe_shared.T, *[sz[1:]] ) + probe[1:,0,...] = torch.reshape( probe_opr.T, [sz[0] - 1, sz[-2], sz[-1]] ) + + self.set_data(probe) + + elif (self.options.experimental.sdl_probe_options.enabled_shared + and not self.options.experimental.sdl_probe_options.enabled_opr): + + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + + probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T + + probe[0,...] = torch.reshape( probe_shared.T, *[sz[1:]] ) + probe[1:,0,...] = self.data[1:,0,...] + + self.set_data(probe) + + elif (self.options.experimental.sdl_probe_options.enabled_opr + and not self.options.experimental.sdl_probe_options.enabled_shared): + + sz = self.data.shape + probe = torch.zeros( *[sz], dtype = torch.complex64 ) + + probe_opr = self.dictionary_matrix @ self.sparse_code_probe_opr.T + + probe[0,...] = self.data[0,...] + probe[1:,0,...] = torch.reshape( probe_opr.T, [sz[0] - 1, sz[-2], sz[-1]] ) + + self.set_data(probe) + + else: + + probe = self.data def build_optimizer(self): if self.optimizable and self.optimizer_class is None: @@ -522,11 +581,87 @@ def build_optimizer(self): "Parameter {} is optimizable but no optimizer is specified.".format(self.name) ) if self.optimizable: - self.optimizer = self.optimizer_class([self.sparse_code_probe], **self.optimizer_params) + self.optimizer = self.optimizer_class([self.sparse_code_probe_shared], **self.optimizer_params) + + def set_sparse_code_probe_shared(self, data): + self.sparse_code_probe_shared.data = data + + def set_sparse_code_probe_opr(self, data): + self.sparse_code_probe_opr.data = data + + def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, obj_patches): + + nr = chi.shape[-2] + nc = chi.shape[-1] + nrnc = nr*nc + n_scpm = chi.shape[-3] + n_spos = chi.shape[-4] + + chi_rm_subpx_shft = torch.reshape(chi, (n_spos, n_scpm, nrnc)).permute(2,0,1) + obj_patches_vec = torch.reshape(obj_patches, (n_spos, nrnc)) + + # get sparse code update direction + delta_sparse_code = torch.einsum('ijk,kl->lij', + torch.reshape(delta_p_i, (n_spos, n_scpm, nrnc)), + self.dictionary_matrix.conj() + ) + + # compute optimal step length for sparse code update + dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', + self.dictionary_matrix, + delta_sparse_code + ) + + denom = (torch.abs(dict_delta_sparse_code)**2)*obj_patches_vec.swapaxes(0,-1)[...,None] + denom = torch.einsum('ij,jik->ik', + torch.conj(obj_patches_vec), + denom + ) - def set_sparse_code(self, data): - self.sparse_code_probe.data = data + numer = torch.conj(dict_delta_sparse_code)*torch.reshape( + chi_rm_subpx_shft, (n_spos, n_scpm, nrnc) + ).permute(2,0,1) + numer = torch.einsum('ij,jik->ik', + torch.conj(obj_patches_vec), + numer) + + # real is used to throw away small imag part due to numerical precision errors + optimal_step_sparse_code = (numer/denom).real + + optimal_delta_sparse_code = optimal_step_sparse_code[None,...]*delta_sparse_code + + # enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(optimal_delta_sparse_code) + abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + + sel = abs_sparse_code_sorted[0][self.sparse_code_probe_nnz, ...] + sparse_code_mask = (abs_sparse_code >= sel[None,...]) + + # hard or soft thresholding + if self.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': + + optimal_delta_sparse_code=optimal_delta_sparse_code*sparse_code_mask + + elif self.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': + + optimal_delta_sparse_code=((abs_sparse_code - sel[None,...])*sparse_code_mask + *torch.exp(1j*torch.angle(optimal_delta_sparse_code)) + ) + + # update the shared probe sparse codes using the average over scan positions + sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() + + sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T + + self.set_sparse_code_probe_shared(sparse_code_probe_shared) + + delta_p_i = torch.einsum('ij,jlk->ilk', self.dictionary_matrix, + optimal_delta_sparse_code + ).permute(1,2,0) + + delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, nr, nc)) + return delta_p_i class DIPProbe(Probe): diff --git a/src/ptychi/maths.py b/src/ptychi/maths.py index d19efc0..292aa86 100644 --- a/src/ptychi/maths.py +++ b/src/ptychi/maths.py @@ -213,6 +213,10 @@ def orthogonalize_svd( def project(a, b, dim=None): """Return complex vector projection of a onto b for along given axis.""" projected_length = inner(a, b, dim=dim, keepdims=True) / inner(b, b, dim=dim, keepdims=True) + + # if the inner product of b with itself has any zeros: + projected_length = torch.nan_to_num(projected_length, nan=0.0) + return projected_length * b def inner(x, y, dim=None, keepdims=False): diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 0260295..e2d11d7 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -255,16 +255,129 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): else: self._record_object_slice_gradient(i_slice, delta_o_comb, add_to_existing=False) - # Calculate probe update direction. - delta_p_i_unshifted = self._calculate_probe_update_direction( - chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None - ) # Eq. 24a - delta_p_i = self.adjoint_shift_probe_update_direction( - indices, delta_p_i_unshifted, first_mode_only=True - ) + # TODO: move this to SynthesisDictLearnProbe class methods, so it can be used in rPIE as well + if (self.parameter_group.probe.representation == "sparse_code" + and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared): + + # Calculate probe update direction using the sparse code representation + + # delta_p_i_unshifted = self._calculate_probe_update_direction( + # chi, obj_patches = obj_patches, slice_index=i_slice, probe_mode_index=None + # ) + + # delta_p_i = self.adjoint_shift_probe_update_direction( + # indices, delta_p_i_unshifted, first_mode_only=True + # ) + + chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction( + indices, chi, first_mode_only=True + ) + + # delta_p_i = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( + # delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] + # ) + + # delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + + + + + nr = chi.shape[-2] + nc = chi.shape[-1] + nrnc = nr*nc + n_scpm = chi.shape[-3] + n_spos = chi.shape[-4] + + + + + chi_rm_subpx_shft = torch.reshape( chi_rm_subpx_shft, (n_spos, n_scpm, nrnc)).permute(2,0,1) + obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, nrnc )) + + + + # sparse code update directions vs scan position and shared probe modes + obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) + + delta_p_i = self.adjoint_shift_probe_update_direction(indices, chi * obj_patches_slice_i_conj[:, None, ... ], first_mode_only=True) + + + + + + delta_sparse_code = torch.reshape( delta_p_i, + ( n_spos, n_scpm, nrnc )) + + delta_sparse_code = torch.einsum('ijk,kl->lij', + delta_sparse_code, + self.parameter_group.probe.dictionary_matrix.conj()) + + # compute optimal step length for sparse code update + dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', + self.parameter_group.probe.dictionary_matrix, + delta_sparse_code) + + + denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] + denom = torch.einsum('ij,jik->ik', + torch.conj( obj_patches_vec ), + denom) + + numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, + ( n_spos, n_scpm, nrnc )).permute(2,0,1) + numer = torch.einsum('ij,jik->ik', + torch.conj( obj_patches_vec ), + numer) + + # real is used to throw away small imag part due to numerical precision errors + optimal_step_sparse_code = ( numer / denom ).real + + optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code + + # Enforce sparsity constraint on sparse code + abs_sparse_code = torch.abs(optimal_delta_sparse_code) + abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + + sel = abs_sparse_code_sorted[0][self.parameter_group.probe.sparse_code_probe_nnz, ...] + sparse_code_mask = (abs_sparse_code >= sel[None,...]) + + # Hard or Soft thresholding + if self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': + optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask + elif self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': + optimal_delta_sparse_code = ( abs_sparse_code - sel[None,...] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) + + # update the shared probe sparse codes using the average over scan positions + sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() + sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T + self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) + + delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, + optimal_delta_sparse_code).permute(1, 2, 0) + delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) + + delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + + else: + + # Calculate probe update direction (dense representation) + delta_p_i_unshifted = self._calculate_probe_update_direction( + chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None + ) # Eq. 24a + delta_p_i = self.adjoint_shift_probe_update_direction( + indices, delta_p_i_unshifted, first_mode_only=True + ) + delta_p_hat = self._precondition_probe_update_direction(delta_p_i) # Eq. 25a self._record_probe_gradient(delta_p_hat) + + + + + + + # Calculate update vectors for OPR modes and weights. if i_slice == 0: if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch): From 7eaa075b1c69b98a114045739dc06746210f75b3 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Mon, 8 Sep 2025 17:16:25 -0500 Subject: [PATCH 02/13] redo of shared probe dictionary learning, removed OPR dictionary learning, will do different pull request for OPR --- src/ptychi/api/options/base.py | 7 -- src/ptychi/api/task.py | 6 +- src/ptychi/data_structures/probe.py | 76 +++----------- src/ptychi/reconstructors/lsqml.py | 154 +++++++++++++++------------- src/ptychi/reconstructors/pie.py | 101 ++++++++++++------ 5 files changed, 169 insertions(+), 175 deletions(-) diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index c08356d..66f3d37 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -676,13 +676,6 @@ class SynthesisDictLearnProbeOptions(Options): """Number of non-zeros we will keep when enforcing sparsity constraint on the SHARED sparse code weights vector sparse_code_probe_shared.""" - sparse_code_probe_opr: Union[ndarray, Tensor] = None - """Sparse code weights vector for the OPRs.""" - - sparse_code_probe_opr_nnz: float = None - """Number of non-zeros we will keep when enforcing sparsity constraint on - the OPR sparse code weights vector sparse_code_probe_opr.""" - @dataclasses.dataclass class PositionCorrectionOptions(Options): """Options used for specifying the position correction function.""" diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index f04e225..16925a5 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -193,9 +193,9 @@ def build_probe(self): ): self.probe = probe.DIPProbe(**kwargs) elif ( - isinstance(self.probe_options, api.options.PIEProbeOptions) - or - isinstance(self.probe_options, api.options.LSQMLProbeOptions) + isinstance(self.probe_options, api.options.PIEProbeOptions) + or + isinstance(self.probe_options, api.options.LSQMLProbeOptions) ) and ( self.probe_options.experimental.sdl_probe_options.enabled ): diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index e60bacf..49f80f1 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -481,14 +481,9 @@ def __init__(self, name = "probe", options = None, *args, **kwargs): sparse_code_probe_shared_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, dtype=torch.uint32 ) sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() - self.register_buffer("sparse_code_probe_nnz", sparse_code_probe_shared_nnz ) + self.register_buffer("sparse_code_probe_shared_nnz", sparse_code_probe_shared_nnz ) self.register_parameter("sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared)) - sparse_code_probe_opr_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_opr_nnz, dtype=torch.uint32 ) - sparse_code_probe_opr = self.get_sparse_code_probe_opr_weights() - self.register_buffer("sparse_code_opr_nnz", sparse_code_probe_opr_nnz ) - self.register_parameter("sparse_code_probe_opr", torch.nn.Parameter(sparse_code_probe_opr)) - self.build_optimizer() def get_dictionary(self): @@ -496,7 +491,7 @@ def get_dictionary(self): dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix_pinv, dtype=torch.complex64 ) return dictionary_matrix, dictionary_matrix_pinv - def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions ): + def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions): sz = probe_vs_scanpositions.shape probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2]*sz[3])) @@ -513,15 +508,6 @@ def get_sparse_code_probe_shared_weights(self): return sparse_code_probe_shared.T - def get_sparse_code_probe_opr_weights(self): - - probe_opr = self.data[1:,0,...] - sz = probe_opr.shape - probe_vec = torch.reshape(probe_opr, (sz[0], sz[1]*sz[2])) - sparse_code_probe_opr = self.dictionary_matrix_pinv @ probe_vec.T - - return sparse_code_probe_opr.T - def generate(self): """Generate the probe using the sparse code, and set the generated probe to self.data. @@ -531,22 +517,8 @@ def generate(self): Tensor A (n_opr_modes, n_modes, h, w) tensor giving the generated probe. """ - if (self.options.experimental.sdl_probe_options.enabled_shared - and self.options.experimental.sdl_probe_options.enabled_opr): - - sz = self.data.shape - probe = torch.zeros( *[sz], dtype = torch.complex64 ) - - probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T - probe_opr = self.dictionary_matrix @ self.sparse_code_probe_opr.T - - probe[0,...] = torch.reshape( probe_shared.T, *[sz[1:]] ) - probe[1:,0,...] = torch.reshape( probe_opr.T, [sz[0] - 1, sz[-2], sz[-1]] ) - self.set_data(probe) - - elif (self.options.experimental.sdl_probe_options.enabled_shared - and not self.options.experimental.sdl_probe_options.enabled_opr): + if (self.options.experimental.sdl_probe_options.enabled_shared): sz = self.data.shape probe = torch.zeros( *[sz], dtype = torch.complex64 ) @@ -558,19 +530,6 @@ def generate(self): self.set_data(probe) - elif (self.options.experimental.sdl_probe_options.enabled_opr - and not self.options.experimental.sdl_probe_options.enabled_shared): - - sz = self.data.shape - probe = torch.zeros( *[sz], dtype = torch.complex64 ) - - probe_opr = self.dictionary_matrix @ self.sparse_code_probe_opr.T - - probe[0,...] = self.data[0,...] - probe[1:,0,...] = torch.reshape( probe_opr.T, [sz[0] - 1, sz[-2], sz[-1]] ) - - self.set_data(probe) - else: probe = self.data @@ -586,19 +545,16 @@ def build_optimizer(self): def set_sparse_code_probe_shared(self, data): self.sparse_code_probe_shared.data = data - def set_sparse_code_probe_opr(self, data): - self.sparse_code_probe_opr.data = data - def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, obj_patches): nr = chi.shape[-2] nc = chi.shape[-1] - nrnc = nr*nc + nrnc = nr * nc n_scpm = chi.shape[-3] n_spos = chi.shape[-4] - chi_rm_subpx_shft = torch.reshape(chi, (n_spos, n_scpm, nrnc)).permute(2,0,1) - obj_patches_vec = torch.reshape(obj_patches, (n_spos, nrnc)) + obj_patches = torch.reshape(obj_patches, (n_spos, nrnc)) + chi = torch.reshape(chi, (n_spos, n_scpm, nrnc)).permute(2,0,1) # get sparse code update direction delta_sparse_code = torch.einsum('ijk,kl->lij', @@ -612,40 +568,38 @@ def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, ob delta_sparse_code ) - denom = (torch.abs(dict_delta_sparse_code)**2)*obj_patches_vec.swapaxes(0,-1)[...,None] + denom = (torch.abs(dict_delta_sparse_code)**2) * obj_patches.swapaxes(0,-1)[...,None] denom = torch.einsum('ij,jik->ik', - torch.conj(obj_patches_vec), + torch.conj(obj_patches), denom ) - numer = torch.conj(dict_delta_sparse_code)*torch.reshape( - chi_rm_subpx_shft, (n_spos, n_scpm, nrnc) - ).permute(2,0,1) + numer = torch.conj(dict_delta_sparse_code) * chi numer = torch.einsum('ij,jik->ik', - torch.conj(obj_patches_vec), + torch.conj(obj_patches), numer) # real is used to throw away small imag part due to numerical precision errors optimal_step_sparse_code = (numer/denom).real - optimal_delta_sparse_code = optimal_step_sparse_code[None,...]*delta_sparse_code + optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code # enforce sparsity constraint on sparse code abs_sparse_code = torch.abs(optimal_delta_sparse_code) abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) - sel = abs_sparse_code_sorted[0][self.sparse_code_probe_nnz, ...] + sel = abs_sparse_code_sorted[0][self.sparse_code_probe_shared_nnz, ...] sparse_code_mask = (abs_sparse_code >= sel[None,...]) # hard or soft thresholding if self.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': - optimal_delta_sparse_code=optimal_delta_sparse_code*sparse_code_mask + optimal_delta_sparse_code=optimal_delta_sparse_code * sparse_code_mask elif self.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': - optimal_delta_sparse_code=((abs_sparse_code - sel[None,...])*sparse_code_mask - *torch.exp(1j*torch.angle(optimal_delta_sparse_code)) + optimal_delta_sparse_code=((abs_sparse_code - sel[None,...]) * sparse_code_mask + * torch.exp(1j*torch.angle(optimal_delta_sparse_code)) ) # update the shared probe sparse codes using the average over scan positions diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index e2d11d7..21353a4 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -255,108 +255,112 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): else: self._record_object_slice_gradient(i_slice, delta_o_comb, add_to_existing=False) - # TODO: move this to SynthesisDictLearnProbe class methods, so it can be used in rPIE as well - if (self.parameter_group.probe.representation == "sparse_code" - and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared): + use_sparse_probe_shared_update = (self.parameter_group.probe.representation == "sparse_code" + and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared) + + if use_sparse_probe_shared_update: # Calculate probe update direction using the sparse code representation - # delta_p_i_unshifted = self._calculate_probe_update_direction( - # chi, obj_patches = obj_patches, slice_index=i_slice, probe_mode_index=None - # ) + delta_p_i_unshifted = self._calculate_probe_update_direction( + chi, obj_patches = obj_patches, slice_index=i_slice, probe_mode_index=None + ) - # delta_p_i = self.adjoint_shift_probe_update_direction( - # indices, delta_p_i_unshifted, first_mode_only=True - # ) + delta_p_i = self.adjoint_shift_probe_update_direction( + indices, delta_p_i_unshifted, first_mode_only=True + ) chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction( indices, chi, first_mode_only=True ) - # delta_p_i = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( - # delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] - # ) + delta_p_i = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( + delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] + ) - # delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) - nr = chi.shape[-2] - nc = chi.shape[-1] - nrnc = nr*nc - n_scpm = chi.shape[-3] - n_spos = chi.shape[-4] + # nr = chi.shape[-2] + # nc = chi.shape[-1] + # nrnc = nr*nc + # n_scpm = chi.shape[-3] + # n_spos = chi.shape[-4] - chi_rm_subpx_shft = torch.reshape( chi_rm_subpx_shft, (n_spos, n_scpm, nrnc)).permute(2,0,1) - obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, nrnc )) + # #chi_rm_subpx_shft = torch.reshape( chi_rm_subpx_shft, (n_spos, n_scpm, nrnc)).permute(2,0,1) + # obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, nrnc )) - # sparse code update directions vs scan position and shared probe modes - obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] ) - - delta_p_i = self.adjoint_shift_probe_update_direction(indices, chi * obj_patches_slice_i_conj[:, None, ... ], first_mode_only=True) + # # delta_p_i_unshifted = self._calculate_probe_update_direction( + # # chi, obj_patches = obj_patches, slice_index=i_slice, probe_mode_index=None + # # ) + # delta_p_i = self.adjoint_shift_probe_update_direction(indices, + # chi * torch.conj( obj_patches[:, i_slice, ...] )[:, None, ... ], + # first_mode_only=True) + - delta_sparse_code = torch.reshape( delta_p_i, - ( n_spos, n_scpm, nrnc )) + # delta_sparse_code = torch.reshape( delta_p_i, + # ( n_spos, n_scpm, nrnc )) - delta_sparse_code = torch.einsum('ijk,kl->lij', - delta_sparse_code, - self.parameter_group.probe.dictionary_matrix.conj()) + # delta_sparse_code = torch.einsum('ijk,kl->lij', + # delta_sparse_code, + # self.parameter_group.probe.dictionary_matrix.conj()) - # compute optimal step length for sparse code update - dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', - self.parameter_group.probe.dictionary_matrix, - delta_sparse_code) + # # compute optimal step length for sparse code update + # dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', + # self.parameter_group.probe.dictionary_matrix, + # delta_sparse_code) - denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] - denom = torch.einsum('ij,jik->ik', - torch.conj( obj_patches_vec ), - denom) - - numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, - ( n_spos, n_scpm, nrnc )).permute(2,0,1) - numer = torch.einsum('ij,jik->ik', - torch.conj( obj_patches_vec ), - numer) - - # real is used to throw away small imag part due to numerical precision errors - optimal_step_sparse_code = ( numer / denom ).real - - optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code - - # Enforce sparsity constraint on sparse code - abs_sparse_code = torch.abs(optimal_delta_sparse_code) - abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) - - sel = abs_sparse_code_sorted[0][self.parameter_group.probe.sparse_code_probe_nnz, ...] - sparse_code_mask = (abs_sparse_code >= sel[None,...]) - - # Hard or Soft thresholding - if self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': - optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask - elif self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': - optimal_delta_sparse_code = ( abs_sparse_code - sel[None,...] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) - - # update the shared probe sparse codes using the average over scan positions - sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() - sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T - self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) - - delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, - optimal_delta_sparse_code).permute(1, 2, 0) - delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) + # denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] + # denom = torch.einsum('ij,jik->ik', + # torch.conj( obj_patches_vec ), + # denom) + + # numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, + # ( n_spos, n_scpm, nrnc )).permute(2,0,1) + # numer = torch.einsum('ij,jik->ik', + # torch.conj( obj_patches_vec ), + # numer) + + # # real is used to throw away small imag part due to numerical precision errors + # optimal_step_sparse_code = ( numer / denom ).real + + # optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code + + # # Enforce sparsity constraint on sparse code + # abs_sparse_code = torch.abs(optimal_delta_sparse_code) + # abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + + # sel = abs_sparse_code_sorted[0][self.parameter_group.probe.sparse_code_probe_nnz, ...] + # sparse_code_mask = (abs_sparse_code >= sel[None,...]) + + # # Hard or Soft thresholding + # if self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': + # optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask + # elif self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': + # optimal_delta_sparse_code = ( abs_sparse_code - sel[None,...] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) + + # # update the shared probe sparse codes using the average over scan positions + # sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() + # sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T + # self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) + + # delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, + # optimal_delta_sparse_code).permute(1, 2, 0) + # delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) - delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + # delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) else: @@ -381,6 +385,12 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): # Calculate update vectors for OPR modes and weights. if i_slice == 0: if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch): + + if use_sparse_probe_shared_update: + apply_updates = True + else: + apply_updates = False + self.parameter_group.opr_mode_weights.update_variable_probe( self.parameter_group.probe, indices, @@ -390,7 +400,7 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): obj_patches, self.current_epoch, probe_mode_index=0, - apply_updates=False, + apply_updates=apply_updates, ) # Update buffered data for momentum acceleration. diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index ac032ac..fe69511 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -135,50 +135,87 @@ def compute_updates( delta_p_i = None if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)): - if (self.parameter_group.probe.representation == "sparse_code"): - # TODO: move this into SynthesisDictLearnProbe class - rc = delta_exwv_i.shape[-1] * delta_exwv_i.shape[-2] - n_scpm = delta_exwv_i.shape[-3] - n_spos = delta_exwv_i.shape[-4] + if (self.parameter_group.probe.representation == "sparse_code" + and + self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared + ): + + # Calculate probe update direction using the sparse code representation + + step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...])) + delta_p_i = step_weight * delta_exwv_i # get delta p at each position + + # Undo subpixel shift in probe update directions. + delta_p_i = self.adjoint_shift_probe_update_direction(indices, delta_p_i, first_mode_only=True) + + chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction( + indices, delta_exwv_i, first_mode_only=True + ) + + delta_p_i = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( + delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] + ) + + + + + + + + + + + + # # TODO: move this into SynthesisDictLearnProbe class + # rc = delta_exwv_i.shape[-1] * delta_exwv_i.shape[-2] + # n_scpm = delta_exwv_i.shape[-3] + # n_spos = delta_exwv_i.shape[-4] - obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc )) - abs2_obj_patches = torch.abs(obj_patches_vec) ** 2 + # obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc )) + # abs2_obj_patches = torch.abs(obj_patches_vec) ** 2 - z = torch.sum(abs2_obj_patches, dim = 0) - z_max = torch.max(z) - w = self.parameter_group.probe.options.alpha * (z_max - z) - z_plus_w = torch.swapaxes(z + w, 0, 1) + # z = torch.sum(abs2_obj_patches, dim = 0) + # z_max = torch.max(z) + # w = self.parameter_group.probe.options.alpha * (z_max - z) + # z_plus_w = torch.swapaxes(z + w, 0, 1) - delta_exwv = self.adjoint_shift_probe_update_direction(indices, delta_exwv_i, first_mode_only=True) - delta_exwv = torch.sum(delta_exwv, 0) - delta_exwv = torch.reshape( delta_exwv, (n_scpm, rc)).T + # delta_exwv = self.adjoint_shift_probe_update_direction(indices, delta_exwv_i, first_mode_only=True) + # delta_exwv = torch.sum(delta_exwv, 0) + # delta_exwv = torch.reshape( delta_exwv, (n_scpm, rc)).T - denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix)) - numer = self.parameter_group.probe.dictionary_matrix_H @ delta_exwv + # denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix)) + # numer = self.parameter_group.probe.dictionary_matrix_H @ delta_exwv - delta_sparse_code = torch.linalg.solve(denom, numer) + # delta_sparse_code = torch.linalg.solve(denom, numer) - delta_p = self.parameter_group.probe.dictionary_matrix @ delta_sparse_code - delta_p = torch.reshape( delta_p.T, ( n_scpm, delta_exwv_i.shape[-1] , delta_exwv_i.shape[-2])) - delta_p_i = torch.tile(delta_p, (n_spos,1,1,1)) + # delta_p = self.parameter_group.probe.dictionary_matrix @ delta_sparse_code + # delta_p = torch.reshape( delta_p.T, ( n_scpm, delta_exwv_i.shape[-1] , delta_exwv_i.shape[-2])) + # delta_p_i = torch.tile(delta_p, (n_spos,1,1,1)) - # sparse code update - sparse_code = self.parameter_group.probe.get_sparse_code_weights() - sparse_code = sparse_code + delta_sparse_code + # # sparse code update + # sparse_code = self.parameter_group.probe.get_sparse_code_weights() + # sparse_code = sparse_code + delta_sparse_code - # Enforce sparsity constraint on sparse code - abs_sparse_code = torch.abs(sparse_code) - sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + # # Enforce sparsity constraint on sparse code + # abs_sparse_code = torch.abs(sparse_code) + # sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) + + # sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :] + + # # hard thresholding: + # sparse_code = sparse_code * (abs_sparse_code >= sel) + + # #(TODO: soft thresholding option) + + # # Update the new sparse code in the probe class + # self.parameter_group.probe.set_sparse_code(sparse_code) + + + - sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :] - # hard thresholding: - sparse_code = sparse_code * (abs_sparse_code >= sel) - #(TODO: soft thresholding option) - # Update the new sparse code in the probe class - self.parameter_group.probe.set_sparse_code(sparse_code) else: step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...])) delta_p_i = step_weight * delta_exwv_i # get delta p at each position From f830aa4e08c6c7485bf49458d6c18699e6f62ca7 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Mon, 8 Sep 2025 17:22:03 -0500 Subject: [PATCH 03/13] get rid of debugging comments and white spaces --- src/ptychi/reconstructors/lsqml.py | 89 ------------------------------ 1 file changed, 89 deletions(-) diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 21353a4..097b83c 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -279,89 +279,7 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): ) delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) - - - - - # nr = chi.shape[-2] - # nc = chi.shape[-1] - # nrnc = nr*nc - # n_scpm = chi.shape[-3] - # n_spos = chi.shape[-4] - - - - - # #chi_rm_subpx_shft = torch.reshape( chi_rm_subpx_shft, (n_spos, n_scpm, nrnc)).permute(2,0,1) - # obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, nrnc )) - - - - # # delta_p_i_unshifted = self._calculate_probe_update_direction( - # # chi, obj_patches = obj_patches, slice_index=i_slice, probe_mode_index=None - # # ) - # delta_p_i = self.adjoint_shift_probe_update_direction(indices, - # chi * torch.conj( obj_patches[:, i_slice, ...] )[:, None, ... ], - # first_mode_only=True) - - - - - - # delta_sparse_code = torch.reshape( delta_p_i, - # ( n_spos, n_scpm, nrnc )) - - # delta_sparse_code = torch.einsum('ijk,kl->lij', - # delta_sparse_code, - # self.parameter_group.probe.dictionary_matrix.conj()) - - # # compute optimal step length for sparse code update - # dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', - # self.parameter_group.probe.dictionary_matrix, - # delta_sparse_code) - - - # denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None] - # denom = torch.einsum('ij,jik->ik', - # torch.conj( obj_patches_vec ), - # denom) - - # numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft, - # ( n_spos, n_scpm, nrnc )).permute(2,0,1) - # numer = torch.einsum('ij,jik->ik', - # torch.conj( obj_patches_vec ), - # numer) - - # # real is used to throw away small imag part due to numerical precision errors - # optimal_step_sparse_code = ( numer / denom ).real - - # optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code - - # # Enforce sparsity constraint on sparse code - # abs_sparse_code = torch.abs(optimal_delta_sparse_code) - # abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) - - # sel = abs_sparse_code_sorted[0][self.parameter_group.probe.sparse_code_probe_nnz, ...] - # sparse_code_mask = (abs_sparse_code >= sel[None,...]) - - # # Hard or Soft thresholding - # if self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': - # optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask - # elif self.parameter_group.probe.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': - # optimal_delta_sparse_code = ( abs_sparse_code - sel[None,...] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) - - # # update the shared probe sparse codes using the average over scan positions - # sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() - # sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T - # self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) - - # delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix, - # optimal_delta_sparse_code).permute(1, 2, 0) - # delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, chi.shape[-1], chi.shape[-2])) - - # delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) - else: # Calculate probe update direction (dense representation) @@ -375,13 +293,6 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): delta_p_hat = self._precondition_probe_update_direction(delta_p_i) # Eq. 25a self._record_probe_gradient(delta_p_hat) - - - - - - - # Calculate update vectors for OPR modes and weights. if i_slice == 0: if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch): From 8e0f4177629264709599cbd1e97e19f8769db57a Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Mon, 8 Sep 2025 17:24:21 -0500 Subject: [PATCH 04/13] remove OPR related bool switches --- src/ptychi/api/options/base.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ptychi/api/options/base.py b/src/ptychi/api/options/base.py index 66f3d37..06ab65a 100644 --- a/src/ptychi/api/options/base.py +++ b/src/ptychi/api/options/base.py @@ -656,10 +656,8 @@ class SynthesisDictLearnProbeOptions(Options): enabled: bool = False enabled_shared: bool = False - enabled_opr: bool = False thresholding_type_shared: str = 'hard' - thresholding_type_opr: str = 'hard' """Choose between 'hard' or 'soft' thresholding.""" dictionary_matrix: Union[ndarray, Tensor] = None From 80872238ae13f3a1513e13b9977fb8050fb53446 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Mon, 8 Sep 2025 17:28:50 -0500 Subject: [PATCH 05/13] get rid of old comments and white spaces in pir reconstructor --- src/ptychi/reconstructors/pie.py | 60 -------------------------------- 1 file changed, 60 deletions(-) diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index fe69511..dd9c4ad 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -156,66 +156,6 @@ def compute_updates( delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] ) - - - - - - - - - - - # # TODO: move this into SynthesisDictLearnProbe class - # rc = delta_exwv_i.shape[-1] * delta_exwv_i.shape[-2] - # n_scpm = delta_exwv_i.shape[-3] - # n_spos = delta_exwv_i.shape[-4] - - # obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc )) - # abs2_obj_patches = torch.abs(obj_patches_vec) ** 2 - - # z = torch.sum(abs2_obj_patches, dim = 0) - # z_max = torch.max(z) - # w = self.parameter_group.probe.options.alpha * (z_max - z) - # z_plus_w = torch.swapaxes(z + w, 0, 1) - - # delta_exwv = self.adjoint_shift_probe_update_direction(indices, delta_exwv_i, first_mode_only=True) - # delta_exwv = torch.sum(delta_exwv, 0) - # delta_exwv = torch.reshape( delta_exwv, (n_scpm, rc)).T - - # denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix)) - # numer = self.parameter_group.probe.dictionary_matrix_H @ delta_exwv - - # delta_sparse_code = torch.linalg.solve(denom, numer) - - # delta_p = self.parameter_group.probe.dictionary_matrix @ delta_sparse_code - # delta_p = torch.reshape( delta_p.T, ( n_scpm, delta_exwv_i.shape[-1] , delta_exwv_i.shape[-2])) - # delta_p_i = torch.tile(delta_p, (n_spos,1,1,1)) - - # # sparse code update - # sparse_code = self.parameter_group.probe.get_sparse_code_weights() - # sparse_code = sparse_code + delta_sparse_code - - # # Enforce sparsity constraint on sparse code - # abs_sparse_code = torch.abs(sparse_code) - # sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) - - # sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :] - - # # hard thresholding: - # sparse_code = sparse_code * (abs_sparse_code >= sel) - - # #(TODO: soft thresholding option) - - # # Update the new sparse code in the probe class - # self.parameter_group.probe.set_sparse_code(sparse_code) - - - - - - - else: step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...])) delta_p_i = step_weight * delta_exwv_i # get delta p at each position From 5d8c60100bf196fe3b90498e57ee084364394a3d Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 12 Sep 2025 10:57:41 -0500 Subject: [PATCH 06/13] STYLE: fix style --- src/ptychi/data_structures/probe.py | 53 +++++++++++++++++++++---- src/ptychi/reconstructors/lsqml.py | 61 ++++++++++++++++------------- 2 files changed, 79 insertions(+), 35 deletions(-) diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 49f80f1..c38b479 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -492,7 +492,18 @@ def get_dictionary(self): return dictionary_matrix, dictionary_matrix_pinv def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions): + """Get the sparse code weights for a given probe vs scan positions. + Parameters + ---------- + probe_vs_scanpositions : Tensor + A (n_pos, 1, h, w) tensor giving the probe vs scan positions. + + Returns + ------- + Tensor + A tensor giving the sparse code weights for the given probe vs scan positions. + """ sz = probe_vs_scanpositions.shape probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2]*sz[3])) sparse_code_vs_scanpositions = torch.einsum('ij,klj->ikl', self.dictionary_matrix_pinv, probe_vec) @@ -501,7 +512,7 @@ def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions): def get_sparse_code_probe_shared_weights(self): - probe_shared = self.data[0,...] + probe_shared = self.data[0, ...] sz = probe_shared.shape probe_vec = torch.reshape(probe_shared, (sz[0], sz[1]*sz[2])) sparse_code_probe_shared = self.dictionary_matrix_pinv @ probe_vec.T @@ -543,7 +554,36 @@ def build_optimizer(self): self.optimizer = self.optimizer_class([self.sparse_code_probe_shared], **self.optimizer_params) def set_sparse_code_probe_shared(self, data): + """ + Set the sparse code weights for the shared probe. + + Parameters + ---------- + data : Tensor + A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe. + """ self.sparse_code_probe_shared.data = data + + def set_sparse_code_weights_vs_scanpositions( + self, + sparse_code_vs_scanpositions: Tensor, + indices: tuple | Tensor = None + ): + """ + Set the sparse code weights for a given probe vs scan positions. + + Parameters + ---------- + sparse_code_vs_scanpositions : Tensor + A (n_pos, n_opr_modes, n_scpm) tensor giving the sparse code weights for the + given probe vs scan positions. + indices : tuple | Tensor + The indices to apply to the sparse code weights. + """ + raise NotImplementedError("This method is not implemented yet.") + if indices is None: + indices = slice(None) + self.sparse_code_weights_vs_scanpositions[indices] = sparse_code_vs_scanpositions def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, obj_patches): @@ -593,24 +633,21 @@ def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, ob # hard or soft thresholding if self.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': - optimal_delta_sparse_code=optimal_delta_sparse_code * sparse_code_mask - elif self.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': - optimal_delta_sparse_code=((abs_sparse_code - sel[None,...]) * sparse_code_mask * torch.exp(1j*torch.angle(optimal_delta_sparse_code)) ) # update the shared probe sparse codes using the average over scan positions sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() - sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T - self.set_sparse_code_probe_shared(sparse_code_probe_shared) - delta_p_i = torch.einsum('ij,jlk->ilk', self.dictionary_matrix, - optimal_delta_sparse_code + delta_p_i = torch.einsum( + "ij,jlk->ilk", + self.dictionary_matrix, + optimal_delta_sparse_code ).permute(1,2,0) delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, nr, nc)) diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 097b83c..2553de4 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -199,6 +199,13 @@ def run_real_space_step(self, psi_opt, indices): obj_patches = self.forward_model.intermediate_variables["obj_patches"] self.calculate_update_vectors(indices, chi, obj_patches, positions) + + @property + def use_sparse_probe_shared_update(self): + return ( + self.parameter_group.probe.representation == "sparse_code" + and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared + ) @timer() def calculate_update_vectors(self, indices, chi, obj_patches, positions): @@ -254,34 +261,14 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): self._record_object_slice_gradient(i_slice, delta_o_precond, add_to_existing=False) else: self._record_object_slice_gradient(i_slice, delta_o_comb, add_to_existing=False) - - use_sparse_probe_shared_update = (self.parameter_group.probe.representation == "sparse_code" - and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared) - if use_sparse_probe_shared_update: - - # Calculate probe update direction using the sparse code representation - - delta_p_i_unshifted = self._calculate_probe_update_direction( - chi, obj_patches = obj_patches, slice_index=i_slice, probe_mode_index=None - ) - - delta_p_i = self.adjoint_shift_probe_update_direction( - indices, delta_p_i_unshifted, first_mode_only=True - ) - - chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction( - indices, chi, first_mode_only=True - ) - - delta_p_i = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( - delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] - ) - - delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) - + if self.use_sparse_probe_shared_update: + ( + delta_p_i_unshifted, delta_p_i + ) = self.calculate_probe_update_direction_sparse_code_probe_shared( + indices, chi, obj_patches, i_slice + ) else: - # Calculate probe update direction (dense representation) delta_p_i_unshifted = self._calculate_probe_update_direction( chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None @@ -297,7 +284,7 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): if i_slice == 0: if self.parameter_group.opr_mode_weights.optimization_enabled(self.current_epoch): - if use_sparse_probe_shared_update: + if self.use_sparse_probe_shared_update: apply_updates = True else: apply_updates = False @@ -349,6 +336,26 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): # Set chi to conjugate-modulated wavefield. chi = delta_p_i_unshifted + + def calculate_probe_update_direction_sparse_code_probe_shared( + self, indices, chi, obj_patches, i_slice=None + ): + """Calculate probe update direction using the sparse code representation. + """ + delta_p_i_unshifted = self._calculate_probe_update_direction( + chi, obj_patches = obj_patches, slice_index=i_slice, probe_mode_index=None + ) + delta_p_i = self.adjoint_shift_probe_update_direction( + indices, delta_p_i_unshifted, first_mode_only=True + ) + chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction( + indices, chi, first_mode_only=True + ) + delta_p_i = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( + delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] + ) + delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) + return delta_p_i_unshifted, delta_p_i @timer() def apply_reconstruction_parameter_updates(self, indices: torch.Tensor): From 10e7faef2db81e19f569e0c4786be420c8e2a7e9 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 12 Sep 2025 10:59:36 -0500 Subject: [PATCH 07/13] REFACTOR: move use_sparse_probe_shared_update to parent class --- src/ptychi/reconstructors/base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/ptychi/reconstructors/base.py b/src/ptychi/reconstructors/base.py index 4a4fc31..27b22de 100644 --- a/src/ptychi/reconstructors/base.py +++ b/src/ptychi/reconstructors/base.py @@ -642,6 +642,13 @@ def __init__( ) self.forward_model = None self.build_forward_model() + + @property + def use_sparse_probe_shared_update(self): + return ( + self.parameter_group.probe.representation == "sparse_code" + and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared + ) def build_forward_model(self): self.forward_model = fm.PlanarPtychographyForwardModel( From a613664d39596ef1480b2bf5c7d9b9751dcab5c0 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 12 Sep 2025 11:00:19 -0500 Subject: [PATCH 08/13] REFACTOR: move use_sparse_probe_shared_update to parent class --- src/ptychi/reconstructors/lsqml.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 2553de4..5fd7dbd 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -199,13 +199,6 @@ def run_real_space_step(self, psi_opt, indices): obj_patches = self.forward_model.intermediate_variables["obj_patches"] self.calculate_update_vectors(indices, chi, obj_patches, positions) - - @property - def use_sparse_probe_shared_update(self): - return ( - self.parameter_group.probe.representation == "sparse_code" - and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared - ) @timer() def calculate_update_vectors(self, indices, chi, obj_patches, positions): From a380779676560dfe7bb204437fab38f44ba52ecf Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 12 Sep 2025 11:18:35 -0500 Subject: [PATCH 09/13] STYLE: format `probe.py` with RUFF --- src/ptychi/data_structures/probe.py | 225 +++++++++++++++------------- 1 file changed, 121 insertions(+), 104 deletions(-) diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index c38b479..442b74c 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -37,7 +37,7 @@ class Probe(dsbase.ReconstructParameter): # to contain additional options for ReconstructParameter classes, and subclass them for specific # reconstruction algorithms - for example, ProbeOptions -> LSQMLProbeOptions. options: "api.options.base.ProbeOptions" - + representation: ProbeRepresentation = ProbeRepresentation.NORMAL def __init__( @@ -104,7 +104,7 @@ def n_opr_modes(self): @property def has_multiple_opr_modes(self): return self.n_opr_modes > 1 - + @property def lateral_shape(self): return self.shape[-2:] @@ -162,7 +162,9 @@ def get_all_mode_intensity( return torch.sum((p.abs()) ** 2, dim=0) def get_unique_probes( - self, weights: Union[Tensor, "dsbase.ReconstructParameter"], mode_to_apply: Optional[int] = None + self, + weights: Union[Tensor, "dsbase.ReconstructParameter"], + mode_to_apply: Optional[int] = None, ) -> Tensor: """ Parameters @@ -221,11 +223,13 @@ def constrain_incoherent_modes_orthogonality(self): return probe = self.data - + if self.options.orthogonalize_incoherent_modes.sort_by_occupancy: - shared_occupancy = torch.sum(torch.abs(probe[0,...])**2,(-2,-1)) / torch.sum(torch.abs(probe[0,...])**2) + shared_occupancy = torch.sum(torch.abs(probe[0, ...]) ** 2, (-2, -1)) / torch.sum( + torch.abs(probe[0, ...]) ** 2 + ) shared_occupancy = torch.sort(shared_occupancy, dim=0, descending=True) - probe[0,...] = probe[ 0, shared_occupancy[1],...] + probe[0, ...] = probe[0, shared_occupancy[1], ...] norm_first_mode_orig = pmath.norm(probe[0, 0], dim=(-2, -1)) @@ -373,9 +377,13 @@ def constrain_probe_power( if isinstance(propagator, FourierPropagator): # Cancel the normalization factor so that the power is conserved. if propagator.norm == "backward" or propagator.norm is None: - propagated_probe_power = torch.sum(propagated_probe.abs() ** 2) / self.data.size().numel() + propagated_probe_power = ( + torch.sum(propagated_probe.abs() ** 2) / self.data.size().numel() + ) elif propagator.norm == "forward": - propagated_probe_power = torch.sum(propagated_probe.abs() ** 2) * self.data.size().numel() + propagated_probe_power = ( + torch.sum(propagated_probe.abs() ** 2) * self.data.size().numel() + ) else: propagated_probe_power = torch.sum(propagated_probe.abs() ** 2) else: @@ -407,12 +415,12 @@ def center_probe(self): """ Move the probe's center of mass to the center of the probe array. """ - + if self.options.center_constraint.use_intensity_for_com: probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0) else: probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0) - + com = ip.find_center_of_mass(probe_to_be_shifted) shift = utils.to_tensor(self.shape[-2:]) // 2 - com shifted_probe = self.shift(shift) @@ -467,28 +475,39 @@ def save_tiff(self, path: str): tifffile.imwrite(fname + "_mag.tif", mag_img) tifffile.imwrite(fname + "_phase.tif", phase_img) -class SynthesisDictLearnProbe( Probe ): - + +class SynthesisDictLearnProbe(Probe): representation: ProbeRepresentation = ProbeRepresentation.SPARSE_CODE - - def __init__(self, name = "probe", options = None, *args, **kwargs): - - super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs) + + def __init__(self, name="probe", options=None, *args, **kwargs): + super().__init__( + name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs + ) dictionary_matrix, dictionary_matrix_pinv = self.get_dictionary() self.register_buffer("dictionary_matrix", dictionary_matrix) self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv) - sparse_code_probe_shared_nnz = torch.tensor( self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, dtype=torch.uint32 ) + sparse_code_probe_shared_nnz = torch.tensor( + self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz, + dtype=torch.uint32, + ) sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() - self.register_buffer("sparse_code_probe_shared_nnz", sparse_code_probe_shared_nnz ) - self.register_parameter("sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared)) - + self.register_buffer("sparse_code_probe_shared_nnz", sparse_code_probe_shared_nnz) + self.register_parameter( + "sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared) + ) + self.build_optimizer() def get_dictionary(self): - dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix, dtype=torch.complex64 ) - dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.dictionary_matrix_pinv, dtype=torch.complex64 ) + dictionary_matrix = torch.tensor( + self.options.experimental.sdl_probe_options.dictionary_matrix, dtype=torch.complex64 + ) + dictionary_matrix_pinv = torch.tensor( + self.options.experimental.sdl_probe_options.dictionary_matrix_pinv, + dtype=torch.complex64, + ) return dictionary_matrix, dictionary_matrix_pinv def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions): @@ -505,16 +524,17 @@ def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions): A tensor giving the sparse code weights for the given probe vs scan positions. """ sz = probe_vs_scanpositions.shape - probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2]*sz[3])) - sparse_code_vs_scanpositions = torch.einsum('ij,klj->ikl', self.dictionary_matrix_pinv, probe_vec) + probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2] * sz[3])) + sparse_code_vs_scanpositions = torch.einsum( + "ij,klj->ikl", self.dictionary_matrix_pinv, probe_vec + ) return sparse_code_vs_scanpositions def get_sparse_code_probe_shared_weights(self): - probe_shared = self.data[0, ...] sz = probe_shared.shape - probe_vec = torch.reshape(probe_shared, (sz[0], sz[1]*sz[2])) + probe_vec = torch.reshape(probe_shared, (sz[0], sz[1] * sz[2])) sparse_code_probe_shared = self.dictionary_matrix_pinv @ probe_vec.T return sparse_code_probe_shared.T @@ -522,60 +542,58 @@ def get_sparse_code_probe_shared_weights(self): def generate(self): """Generate the probe using the sparse code, and set the generated probe to self.data. - + Returns ------- Tensor A (n_opr_modes, n_modes, h, w) tensor giving the generated probe. """ - if (self.options.experimental.sdl_probe_options.enabled_shared): - + if self.options.experimental.sdl_probe_options.enabled_shared: sz = self.data.shape - probe = torch.zeros( *[sz], dtype = torch.complex64 ) + probe = torch.zeros(*[sz], dtype=torch.complex64) probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T - probe[0,...] = torch.reshape( probe_shared.T, *[sz[1:]] ) - probe[1:,0,...] = self.data[1:,0,...] + probe[0, ...] = torch.reshape(probe_shared.T, *[sz[1:]]) + probe[1:, 0, ...] = self.data[1:, 0, ...] self.set_data(probe) else: + probe = self.data - probe = self.data - def build_optimizer(self): if self.optimizable and self.optimizer_class is None: raise ValueError( "Parameter {} is optimizable but no optimizer is specified.".format(self.name) ) if self.optimizable: - self.optimizer = self.optimizer_class([self.sparse_code_probe_shared], **self.optimizer_params) + self.optimizer = self.optimizer_class( + [self.sparse_code_probe_shared], **self.optimizer_params + ) def set_sparse_code_probe_shared(self, data): """ Set the sparse code weights for the shared probe. - + Parameters ---------- data : Tensor A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe. """ self.sparse_code_probe_shared.data = data - + def set_sparse_code_weights_vs_scanpositions( - self, - sparse_code_vs_scanpositions: Tensor, - indices: tuple | Tensor = None + self, sparse_code_vs_scanpositions: Tensor, indices: tuple | Tensor = None ): """ Set the sparse code weights for a given probe vs scan positions. - + Parameters ---------- sparse_code_vs_scanpositions : Tensor - A (n_pos, n_opr_modes, n_scpm) tensor giving the sparse code weights for the + A (n_pos, n_opr_modes, n_scpm) tensor giving the sparse code weights for the given probe vs scan positions. indices : tuple | Tensor The indices to apply to the sparse code weights. @@ -586,85 +604,79 @@ def set_sparse_code_weights_vs_scanpositions( self.sparse_code_weights_vs_scanpositions[indices] = sparse_code_vs_scanpositions def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, obj_patches): - nr = chi.shape[-2] nc = chi.shape[-1] nrnc = nr * nc n_scpm = chi.shape[-3] n_spos = chi.shape[-4] - - obj_patches = torch.reshape(obj_patches, (n_spos, nrnc)) - chi = torch.reshape(chi, (n_spos, n_scpm, nrnc)).permute(2,0,1) + + obj_patches = torch.reshape(obj_patches, (n_spos, nrnc)) + chi = torch.reshape(chi, (n_spos, n_scpm, nrnc)).permute(2, 0, 1) # get sparse code update direction - delta_sparse_code = torch.einsum('ijk,kl->lij', - torch.reshape(delta_p_i, (n_spos, n_scpm, nrnc)), - self.dictionary_matrix.conj() + delta_sparse_code = torch.einsum( + "ijk,kl->lij", + torch.reshape(delta_p_i, (n_spos, n_scpm, nrnc)), + self.dictionary_matrix.conj(), ) - # compute optimal step length for sparse code update - dict_delta_sparse_code = torch.einsum('ij,jkl->ikl', - self.dictionary_matrix, - delta_sparse_code + # compute optimal step length for sparse code update + dict_delta_sparse_code = torch.einsum( + "ij,jkl->ikl", self.dictionary_matrix, delta_sparse_code ) - denom = (torch.abs(dict_delta_sparse_code)**2) * obj_patches.swapaxes(0,-1)[...,None] - denom = torch.einsum('ij,jik->ik', - torch.conj(obj_patches), - denom - ) + denom = (torch.abs(dict_delta_sparse_code) ** 2) * obj_patches.swapaxes(0, -1)[..., None] + denom = torch.einsum("ij,jik->ik", torch.conj(obj_patches), denom) - numer = torch.conj(dict_delta_sparse_code) * chi - numer = torch.einsum('ij,jik->ik', - torch.conj(obj_patches), - numer) + numer = torch.conj(dict_delta_sparse_code) * chi + numer = torch.einsum("ij,jik->ik", torch.conj(obj_patches), numer) - # real is used to throw away small imag part due to numerical precision errors - optimal_step_sparse_code = (numer/denom).real + # real is used to throw away small imag part due to numerical precision errors + optimal_step_sparse_code = (numer / denom).real - optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code + optimal_delta_sparse_code = optimal_step_sparse_code[None, ...] * delta_sparse_code # enforce sparsity constraint on sparse code abs_sparse_code = torch.abs(optimal_delta_sparse_code) abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True) sel = abs_sparse_code_sorted[0][self.sparse_code_probe_shared_nnz, ...] - sparse_code_mask = (abs_sparse_code >= sel[None,...]) + sparse_code_mask = abs_sparse_code >= sel[None, ...] # hard or soft thresholding - if self.options.experimental.sdl_probe_options.thresholding_type_shared == 'hard': - optimal_delta_sparse_code=optimal_delta_sparse_code * sparse_code_mask - elif self.options.experimental.sdl_probe_options.thresholding_type_shared == 'soft': - optimal_delta_sparse_code=((abs_sparse_code - sel[None,...]) * sparse_code_mask - * torch.exp(1j*torch.angle(optimal_delta_sparse_code)) + if self.options.experimental.sdl_probe_options.thresholding_type_shared == "hard": + optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask + elif self.options.experimental.sdl_probe_options.thresholding_type_shared == "soft": + optimal_delta_sparse_code = ( + (abs_sparse_code - sel[None, ...]) + * sparse_code_mask + * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) ) # update the shared probe sparse codes using the average over scan positions sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T self.set_sparse_code_probe_shared(sparse_code_probe_shared) - + delta_p_i = torch.einsum( - "ij,jlk->ilk", - self.dictionary_matrix, - optimal_delta_sparse_code - ).permute(1,2,0) - + "ij,jlk->ilk", self.dictionary_matrix, optimal_delta_sparse_code + ).permute(1, 2, 0) + delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, nr, nc)) return delta_p_i + class DIPProbe(Probe): - options: "api.options.ad_ptychography.AutodiffPtychographyProbeOptions" representation: ProbeRepresentation = ProbeRepresentation.DIP - + def __init__( self, name: str = "probe", options: "api.options.ad_ptychography.AutodiffPtychographyProbeOptions" = None, *args, - **kwargs + **kwargs, ) -> None: """Deep image prior object. @@ -679,29 +691,29 @@ def __init__( self.model = None self.dip_output_magnitude = None self.dip_output_phase = None - + self.build_model() self.build_dip_optimizer() - + # `self.tensor` is used to hold the object generated by the DIP model and # is not trainable. self.tensor.requires_grad_(False) - + nn_input = self.get_nn_input() self.register_buffer("nn_input", nn_input) - + self.initial_data = None if self.options.experimental.deep_image_prior_options.residual_generation: - self.initial_data = self.data.clone() - + self.initial_data = self.data.clone() + def build_model(self): if not self.options.experimental.deep_image_prior_options.enabled: return - model_class = maps.get_nn_model_by_enum(self.options.experimental.deep_image_prior_options.model) - self.model = model_class( - **self.options.experimental.deep_image_prior_options.model_params + model_class = maps.get_nn_model_by_enum( + self.options.experimental.deep_image_prior_options.model ) - + self.model = model_class(**self.options.experimental.deep_image_prior_options.model_params) + def build_dip_optimizer(self): if self.optimizable and self.optimizer_class is None: raise ValueError( @@ -711,17 +723,22 @@ def build_dip_optimizer(self): self.optimizer = self.optimizer_class(self.model.parameters(), **self.optimizer_params) def get_nn_input(self): - z = torch.rand( - [self.n_opr_modes * self.n_modes, - self.options.experimental.deep_image_prior_options.net_input_channels, - *self.lateral_shape], - ) * 0.1 + z = ( + torch.rand( + [ + self.n_opr_modes * self.n_modes, + self.options.experimental.deep_image_prior_options.net_input_channels, + *self.lateral_shape, + ], + ) + * 0.1 + ) return z def generate(self) -> Tensor: """Generate the probe using the deep image prior model, and set the generated probe to self.data. - + Returns ------- Tensor @@ -730,13 +747,13 @@ def generate(self) -> Tensor: if self.model is None: raise ValueError("Model is not built.") p = self.model(self.nn_input) - + p, mag, phase = self.process_net_output(p) - + with torch.no_grad(): self.dip_output_magnitude = mag.clone() self.dip_output_phase = phase.clone() - + if self.options.experimental.deep_image_prior_options.residual_generation: init_data = torch.stack([self.initial_data.real, self.initial_data.imag], dim=-1) p = p + init_data @@ -749,15 +766,15 @@ def process_net_output(self, p): Parameters ---------- o : Tensor | tuple[Tensor, Tensor] - The output of the DIP network. It should either be a [n_modes * n_opr_modes, 2, h, w] + The output of the DIP network. It should either be a [n_modes * n_opr_modes, 2, h, w] tensor with the channels giving the magnitude and phase of the probe, or a tuple of two [n_modes * n_opr_modes, h, w] tensors giving the magnitude and phase of the probe. - + Returns ------- Tensor - A [n_opr_modes, n_modes, h, w, 2] tensor representing the real and imaginary parts + A [n_opr_modes, n_modes, h, w, 2] tensor representing the real and imaginary parts of the probe. Tensor The magnitude of the probe. @@ -771,7 +788,7 @@ def process_net_output(self, p): else: mag = p[:, 0] phase = p[:, 1] - + expected_phase_shape = (self.n_opr_modes * self.n_modes, *self.lateral_shape) if tuple(phase.shape) != expected_phase_shape: logger.warning( @@ -785,7 +802,7 @@ def process_net_output(self, p): phase_resized.append(ip.central_crop_or_pad(phase[i_img], expected_phase_shape[1:])) mag = torch.stack(mag_resized) phase = torch.stack(phase_resized) - + p_complex = mag * torch.exp(1j * phase) p = torch.stack([p_complex.real, p_complex.imag], dim=-1) p = p.reshape([*self.shape, 2]) From 424f899c1917322ff31d66eaeefcc2d84354a8de Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Wed, 17 Sep 2025 19:51:11 -0500 Subject: [PATCH 10/13] split up sparse code update in LSQML: 1) compute/set updates, 2) apply --- src/ptychi/data_structures/probe.py | 36 +++++++++++++++++++++++++---- src/ptychi/reconstructors/lsqml.py | 20 ++++++++++------ src/ptychi/reconstructors/pie.py | 5 +--- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index 442b74c..d2ad7ca 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -584,6 +584,28 @@ def set_sparse_code_probe_shared(self, data): """ self.sparse_code_probe_shared.data = data + def initialize_grad_sparse_code_probe_shared(self): + """ + Initialize the gradient of the sparse code weights update for the shared probe. + + Parameters + ---------- + data : Tensor + A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe. + """ + self.sparse_code_probe_shared.grad = torch.zeros_like(self.sparse_code_probe_shared.data) + + def set_gradient_sparse_code_probe_shared(self, grad): + """ + Set the gradient of the sparse code weights update for the shared probe. + + Parameters + ---------- + data : Tensor + A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe. + """ + self.sparse_code_probe_shared.grad = grad + def set_sparse_code_weights_vs_scanpositions( self, sparse_code_vs_scanpositions: Tensor, indices: tuple | Tensor = None ): @@ -653,18 +675,22 @@ def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, ob * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) ) - # update the shared probe sparse codes using the average over scan positions - sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() - sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T - self.set_sparse_code_probe_shared(sparse_code_probe_shared) + # OLD WAY: + # # update the shared probe sparse codes using the average over scan positions + # sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() + # sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T + # self.set_sparse_code_probe_shared(sparse_code_probe_shared) + # DON'T KNOW HOW TO USE set_grad() METHOD FROM PROBE CLASS + self.set_gradient_sparse_code_probe_shared(optimal_delta_sparse_code.mean(1).T) + delta_p_i = torch.einsum( "ij,jlk->ilk", self.dictionary_matrix, optimal_delta_sparse_code ).permute(1, 2, 0) delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, nr, nc)) - return delta_p_i + return delta_p_i, optimal_delta_sparse_code class DIPProbe(Probe): diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index 5fd7dbd..b5de37a 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -255,12 +255,12 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): else: self._record_object_slice_gradient(i_slice, delta_o_comb, add_to_existing=False) - if self.use_sparse_probe_shared_update: + if self.use_sparse_probe_shared_update and self.parameter_group.probe.optimization_enabled(self.current_epoch): ( - delta_p_i_unshifted, delta_p_i + delta_p_i_unshifted, delta_p_i, _ ) = self.calculate_probe_update_direction_sparse_code_probe_shared( indices, chi, obj_patches, i_slice - ) + ) else: # Calculate probe update direction (dense representation) delta_p_i_unshifted = self._calculate_probe_update_direction( @@ -344,11 +344,11 @@ def calculate_probe_update_direction_sparse_code_probe_shared( chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction( indices, chi, first_mode_only=True ) - delta_p_i = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( + delta_p_i, optimal_delta_sparse_code_vs_spos = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] ) delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) - return delta_p_i_unshifted, delta_p_i + return delta_p_i_unshifted, delta_p_i, optimal_delta_sparse_code_vs_spos @timer() def apply_reconstruction_parameter_updates(self, indices: torch.Tensor): @@ -377,7 +377,12 @@ def apply_reconstruction_parameter_updates(self, indices: torch.Tensor): alpha_p_i = self.reconstructor_buffers.alpha_probe_all_pos[indices] if self.parameter_group.probe.optimization_enabled(self.current_epoch): self._apply_probe_update(alpha_p_i, -self.parameter_group.probe.get_grad()[0]) - + # update the shared probe sparse code if enabled + if self.use_sparse_probe_shared_update: + sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() + sparse_code_probe_shared = sparse_code_probe_shared + self.parameter_group.probe.sparse_code_probe_shared.grad + self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) + # Update probe positions. if self.parameter_group.probe_positions.optimization_enabled(self.current_epoch): self.parameter_group.probe_positions.step_optimizer() @@ -1014,7 +1019,8 @@ def _initialize_object_gradient(self): @timer() def _initialize_probe_gradient(self): self.parameter_group.probe.initialize_grad() - + if self.use_sparse_probe_shared_update: + self.parameter_group.probe.initialize_grad_sparse_code_probe_shared() @timer() def _initialize_probe_position_gradient(self): self.parameter_group.probe_positions.initialize_grad() diff --git a/src/ptychi/reconstructors/pie.py b/src/ptychi/reconstructors/pie.py index dd9c4ad..048fcd7 100644 --- a/src/ptychi/reconstructors/pie.py +++ b/src/ptychi/reconstructors/pie.py @@ -135,10 +135,7 @@ def compute_updates( delta_p_i = None if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)): - if (self.parameter_group.probe.representation == "sparse_code" - and - self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared - ): + if self.use_sparse_probe_shared_update: # Calculate probe update direction using the sparse code representation From 99ff79e795a33dd39f5ecdfd4b654b862b0d01ad Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 19 Sep 2025 16:07:19 -0500 Subject: [PATCH 11/13] Override get_grad and set_grad in SyntheticDictProbe class so that `parameter_group.synchronize_optimizable_parameter_gradients` works with sparse code --- src/ptychi/data_structures/probe.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index d2ad7ca..1cfc035 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -674,15 +674,6 @@ def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, ob * sparse_code_mask * torch.exp(1j * torch.angle(optimal_delta_sparse_code)) ) - - # OLD WAY: - # # update the shared probe sparse codes using the average over scan positions - # sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights() - # sparse_code_probe_shared = sparse_code_probe_shared + optimal_delta_sparse_code.mean(1).T - # self.set_sparse_code_probe_shared(sparse_code_probe_shared) - - # DON'T KNOW HOW TO USE set_grad() METHOD FROM PROBE CLASS - self.set_gradient_sparse_code_probe_shared(optimal_delta_sparse_code.mean(1).T) delta_p_i = torch.einsum( "ij,jlk->ilk", self.dictionary_matrix, optimal_delta_sparse_code @@ -691,6 +682,25 @@ def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, ob delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, nr, nc)) return delta_p_i, optimal_delta_sparse_code + + def get_grad(self) -> torch.Tensor: + """Get the gradient of the sparse code weights for the shared probe. + This method overrides the method in the base class, which returns + the `.grad` attribute of the tensor. + + Returns + ------- + Tensor + The gradient of the sparse code weights for the shared probe. + """ + return self.sparse_code_probe_shared.grad + + def set_grad(self, grad: torch.Tensor): + """Set the gradient of the sparse code weights for the shared probe. + This method overrides the method in the base class, which sets the `.grad` + attribute of the tensor. + """ + self.set_gradient_sparse_code_probe_shared(grad) class DIPProbe(Probe): From daf61f88ffdbacf56d5b5ec14ffe0f812ac856b5 Mon Sep 17 00:00:00 2001 From: Ming Du Date: Fri, 19 Sep 2025 16:08:18 -0500 Subject: [PATCH 12/13] Encapsulate sparse code update application into `_apply_probe_sparse_code_shared_updates`; split set sparse code update out of get sparse code method --- src/ptychi/reconstructors/lsqml.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index b5de37a..379ee13 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -347,6 +347,7 @@ def calculate_probe_update_direction_sparse_code_probe_shared( delta_p_i, optimal_delta_sparse_code_vs_spos = self.parameter_group.probe.get_probe_update_direction_sparse_code_probe_shared( delta_p_i, chi_rm_subpx_shft, obj_patches[:, i_slice, ...] ) + self.parameter_group.probe.set_grad(optimal_delta_sparse_code_vs_spos.mean(1).T) delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True) return delta_p_i_unshifted, delta_p_i, optimal_delta_sparse_code_vs_spos @@ -379,9 +380,7 @@ def apply_reconstruction_parameter_updates(self, indices: torch.Tensor): self._apply_probe_update(alpha_p_i, -self.parameter_group.probe.get_grad()[0]) # update the shared probe sparse code if enabled if self.use_sparse_probe_shared_update: - sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() - sparse_code_probe_shared = sparse_code_probe_shared + self.parameter_group.probe.sparse_code_probe_shared.grad - self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) + self._apply_probe_sparse_code_shared_updates() # Update probe positions. if self.parameter_group.probe_positions.optimization_enabled(self.current_epoch): @@ -788,6 +787,11 @@ def _apply_probe_update(self, alpha_p_i, delta_p_hat, probe_mode_index=None): alpha_p_mean = torch.mean(alpha_p_i) self.parameter_group.probe.set_grad(-delta_p_hat * alpha_p_mean, slicer=(0, mode_slicer)) self.parameter_group.probe.optimizer.step() + + def _apply_probe_sparse_code_shared_updates(self): + sparse_code_probe_shared = self.parameter_group.probe.get_sparse_code_probe_shared_weights() + sparse_code_probe_shared = sparse_code_probe_shared + self.parameter_group.probe.sparse_code_probe_shared.grad + self.parameter_group.probe.set_sparse_code_probe_shared(sparse_code_probe_shared) @timer() def _apply_probe_momentum(self, alpha_p_mean, delta_p_hat): @@ -1021,6 +1025,7 @@ def _initialize_probe_gradient(self): self.parameter_group.probe.initialize_grad() if self.use_sparse_probe_shared_update: self.parameter_group.probe.initialize_grad_sparse_code_probe_shared() + @timer() def _initialize_probe_position_gradient(self): self.parameter_group.probe_positions.initialize_grad() From e9cecc425054c166f47715bcb8c412cfc1721065 Mon Sep 17 00:00:00 2001 From: Ashish Tripathi Date: Fri, 10 Oct 2025 16:54:06 -0500 Subject: [PATCH 13/13] merge main into probe_sparse_dict_learning --- src/ptychi/api/task.py | 14 ++-- src/ptychi/data_structures/probe.py | 93 ++++++++++------------ src/ptychi/image_proc.py | 2 +- src/ptychi/maps.py | 1 + src/ptychi/parallel.py | 2 + src/ptychi/reconstructors/ad_general.py | 22 ++++- src/ptychi/reconstructors/lsqml.py | 17 ++-- tests/test_2d_ptycho_lsqml_multiprocess.py | 41 +++++++++- tests/test_2d_ptycho_lsqml_multiscan.py | 4 +- tests/test_remove_grid_artifacts.py | 30 +++++++ 10 files changed, 153 insertions(+), 73 deletions(-) diff --git a/src/ptychi/api/task.py b/src/ptychi/api/task.py index 16925a5..ffdab78 100644 --- a/src/ptychi/api/task.py +++ b/src/ptychi/api/task.py @@ -192,11 +192,9 @@ def build_probe(self): self.probe_options.experimental.deep_image_prior_options.enabled ): self.probe = probe.DIPProbe(**kwargs) - elif ( - isinstance(self.probe_options, api.options.PIEProbeOptions) - or - isinstance(self.probe_options, api.options.LSQMLProbeOptions) - ) and ( + elif ( + isinstance(self.probe_options, api.options.PIEProbeOptions) + ) and ( self.probe_options.experimental.sdl_probe_options.enabled ): self.probe = probe.SynthesisDictLearnProbe(**kwargs) @@ -281,6 +279,12 @@ def get_data( Tensor The data of the given name. """ + # Deep image prior objects and probes need to be generated + # before fetching to avoid issues with multi-GPU. + if name == "object" and isinstance(self.object, object.DIPPlanarObject): + self.object.generate() + elif name == "probe" and isinstance(self.probe, probe.DIPProbe): + self.probe.generate() return getattr(self, name).data.detach() def get_data_to_cpu( diff --git a/src/ptychi/data_structures/probe.py b/src/ptychi/data_structures/probe.py index d2ad7ca..6ee5a45 100644 --- a/src/ptychi/data_structures/probe.py +++ b/src/ptychi/data_structures/probe.py @@ -37,7 +37,7 @@ class Probe(dsbase.ReconstructParameter): # to contain additional options for ReconstructParameter classes, and subclass them for specific # reconstruction algorithms - for example, ProbeOptions -> LSQMLProbeOptions. options: "api.options.base.ProbeOptions" - + representation: ProbeRepresentation = ProbeRepresentation.NORMAL def __init__( @@ -104,7 +104,7 @@ def n_opr_modes(self): @property def has_multiple_opr_modes(self): return self.n_opr_modes > 1 - + @property def lateral_shape(self): return self.shape[-2:] @@ -162,9 +162,7 @@ def get_all_mode_intensity( return torch.sum((p.abs()) ** 2, dim=0) def get_unique_probes( - self, - weights: Union[Tensor, "dsbase.ReconstructParameter"], - mode_to_apply: Optional[int] = None, + self, weights: Union[Tensor, "dsbase.ReconstructParameter"], mode_to_apply: Optional[int] = None ) -> Tensor: """ Parameters @@ -377,13 +375,9 @@ def constrain_probe_power( if isinstance(propagator, FourierPropagator): # Cancel the normalization factor so that the power is conserved. if propagator.norm == "backward" or propagator.norm is None: - propagated_probe_power = ( - torch.sum(propagated_probe.abs() ** 2) / self.data.size().numel() - ) + propagated_probe_power = torch.sum(propagated_probe.abs() ** 2) / self.data.size().numel() elif propagator.norm == "forward": - propagated_probe_power = ( - torch.sum(propagated_probe.abs() ** 2) * self.data.size().numel() - ) + propagated_probe_power = torch.sum(propagated_probe.abs() ** 2) * self.data.size().numel() else: propagated_probe_power = torch.sum(propagated_probe.abs() ** 2) else: @@ -415,12 +409,12 @@ def center_probe(self): """ Move the probe's center of mass to the center of the probe array. """ - + if self.options.center_constraint.use_intensity_for_com: probe_to_be_shifted = torch.sum(torch.abs(self.data[0, ...]) ** 2, dim=0) else: probe_to_be_shifted = self.get_mode_and_opr_mode(0, 0) - + com = ip.find_center_of_mass(probe_to_be_shifted) shift = utils.to_tensor(self.shape[-2:]) // 2 - com shifted_probe = self.shift(shift) @@ -475,14 +469,13 @@ def save_tiff(self, path: str): tifffile.imwrite(fname + "_mag.tif", mag_img) tifffile.imwrite(fname + "_phase.tif", phase_img) - -class SynthesisDictLearnProbe(Probe): +class SynthesisDictLearnProbe( Probe ): + representation: ProbeRepresentation = ProbeRepresentation.SPARSE_CODE - - def __init__(self, name="probe", options=None, *args, **kwargs): - super().__init__( - name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs - ) + + def __init__(self, name = "probe", options = None, *args, **kwargs): + + super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs) dictionary_matrix, dictionary_matrix_pinv = self.get_dictionary() self.register_buffer("dictionary_matrix", dictionary_matrix) @@ -542,7 +535,7 @@ def get_sparse_code_probe_shared_weights(self): def generate(self): """Generate the probe using the sparse code, and set the generated probe to self.data. - + Returns ------- Tensor @@ -694,15 +687,16 @@ def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, ob class DIPProbe(Probe): + options: "api.options.ad_ptychography.AutodiffPtychographyProbeOptions" representation: ProbeRepresentation = ProbeRepresentation.DIP - + def __init__( self, name: str = "probe", options: "api.options.ad_ptychography.AutodiffPtychographyProbeOptions" = None, *args, - **kwargs, + **kwargs ) -> None: """Deep image prior object. @@ -717,29 +711,29 @@ def __init__( self.model = None self.dip_output_magnitude = None self.dip_output_phase = None - + self.build_model() self.build_dip_optimizer() - + # `self.tensor` is used to hold the object generated by the DIP model and # is not trainable. self.tensor.requires_grad_(False) - + nn_input = self.get_nn_input() self.register_buffer("nn_input", nn_input) - + self.initial_data = None if self.options.experimental.deep_image_prior_options.residual_generation: - self.initial_data = self.data.clone() - + self.initial_data = self.data.clone() + def build_model(self): if not self.options.experimental.deep_image_prior_options.enabled: return - model_class = maps.get_nn_model_by_enum( - self.options.experimental.deep_image_prior_options.model + model_class = maps.get_nn_model_by_enum(self.options.experimental.deep_image_prior_options.model) + self.model = model_class( + **self.options.experimental.deep_image_prior_options.model_params ) - self.model = model_class(**self.options.experimental.deep_image_prior_options.model_params) - + def build_dip_optimizer(self): if self.optimizable and self.optimizer_class is None: raise ValueError( @@ -749,22 +743,17 @@ def build_dip_optimizer(self): self.optimizer = self.optimizer_class(self.model.parameters(), **self.optimizer_params) def get_nn_input(self): - z = ( - torch.rand( - [ - self.n_opr_modes * self.n_modes, - self.options.experimental.deep_image_prior_options.net_input_channels, - *self.lateral_shape, - ], - ) - * 0.1 - ) + z = torch.rand( + [self.n_opr_modes * self.n_modes, + self.options.experimental.deep_image_prior_options.net_input_channels, + *self.lateral_shape], + ) * 0.1 return z def generate(self) -> Tensor: """Generate the probe using the deep image prior model, and set the generated probe to self.data. - + Returns ------- Tensor @@ -773,13 +762,13 @@ def generate(self) -> Tensor: if self.model is None: raise ValueError("Model is not built.") p = self.model(self.nn_input) - + p, mag, phase = self.process_net_output(p) - + with torch.no_grad(): self.dip_output_magnitude = mag.clone() self.dip_output_phase = phase.clone() - + if self.options.experimental.deep_image_prior_options.residual_generation: init_data = torch.stack([self.initial_data.real, self.initial_data.imag], dim=-1) p = p + init_data @@ -792,15 +781,15 @@ def process_net_output(self, p): Parameters ---------- o : Tensor | tuple[Tensor, Tensor] - The output of the DIP network. It should either be a [n_modes * n_opr_modes, 2, h, w] + The output of the DIP network. It should either be a [n_modes * n_opr_modes, 2, h, w] tensor with the channels giving the magnitude and phase of the probe, or a tuple of two [n_modes * n_opr_modes, h, w] tensors giving the magnitude and phase of the probe. - + Returns ------- Tensor - A [n_opr_modes, n_modes, h, w, 2] tensor representing the real and imaginary parts + A [n_opr_modes, n_modes, h, w, 2] tensor representing the real and imaginary parts of the probe. Tensor The magnitude of the probe. @@ -814,7 +803,7 @@ def process_net_output(self, p): else: mag = p[:, 0] phase = p[:, 1] - + expected_phase_shape = (self.n_opr_modes * self.n_modes, *self.lateral_shape) if tuple(phase.shape) != expected_phase_shape: logger.warning( @@ -828,7 +817,7 @@ def process_net_output(self, p): phase_resized.append(ip.central_crop_or_pad(phase[i_img], expected_phase_shape[1:])) mag = torch.stack(mag_resized) phase = torch.stack(phase_resized) - + p_complex = mag * torch.exp(1j * phase) p = torch.stack([p_complex.real, p_complex.imag], dim=-1) p = p.reshape([*self.shape, 2]) diff --git a/src/ptychi/image_proc.py b/src/ptychi/image_proc.py index 76f4d4f..f41062a 100644 --- a/src/ptychi/image_proc.py +++ b/src/ptychi/image_proc.py @@ -1351,7 +1351,7 @@ def remove_grid_artifacts( # Frequencies of the artifacts. dk_s_y, dk_s_x = 1 / period_y_m, 1 / period_x_m - x_range, y_range = 0, 0 + x_range, y_range = torch.zeros([1], dtype=torch.int32), torch.zeros([1], dtype=torch.int32) # Get the frequencies of all harmonic peaks. if "x" in direction: x_range = torch.arange(math.ceil(-k_max / dk_s_x), math.floor(k_max / dk_s_x)) diff --git a/src/ptychi/maps.py b/src/ptychi/maps.py index 47136d6..8ec2c25 100644 --- a/src/ptychi/maps.py +++ b/src/ptychi/maps.py @@ -65,6 +65,7 @@ def get_reconstructor_by_enum(key: enums.Reconstructors) -> Type["Reconstructor" def get_multiprocess_reconstructor_by_enum(key: enums.Reconstructors) -> Type["Reconstructor"]: d = { enums.Reconstructors.LSQML: reconstructors.MultiprocessLSQMLReconstructor, + enums.Reconstructors.AD_PTYCHO: reconstructors.AutodiffPtychographyReconstructor, } reconstructor_class = d.get(key, None) if reconstructor_class is None: diff --git a/src/ptychi/parallel.py b/src/ptychi/parallel.py index 8c1fb8b..53689f8 100644 --- a/src/ptychi/parallel.py +++ b/src/ptychi/parallel.py @@ -126,6 +126,8 @@ def sync_buffer( return buffer def init_process_group(self, backend: str = "nccl") -> None: + if dist.is_initialized(): + return dist.init_process_group(backend=backend, init_method="env://") def detect_launcher(self) -> str | None: diff --git a/src/ptychi/reconstructors/ad_general.py b/src/ptychi/reconstructors/ad_general.py index 8cfbecf..80d82da 100644 --- a/src/ptychi/reconstructors/ad_general.py +++ b/src/ptychi/reconstructors/ad_general.py @@ -2,11 +2,12 @@ # Full license accessible at https://github.com//AdvancedPhotonSource/pty-chi/blob/main/LICENSE from typing import Optional, TYPE_CHECKING +import logging import torch +import torch.distributed as dist from torch.utils.data import Dataset - import ptychi.forward_models as fm from ptychi.reconstructors.base import IterativeReconstructor, LossTracker import ptychi.maps as maps @@ -14,6 +15,8 @@ import ptychi.data_structures.parameter_group as pg import ptychi.api as api +logger = logging.getLogger(__name__) + class AutodiffReconstructor(IterativeReconstructor): def __init__( @@ -55,9 +58,20 @@ def build_forward_model(self): self.forward_model = self.forward_model_class( self.parameter_group, **self.forward_model_params ) - if not torch.get_default_device().type == "cpu": - self.forward_model = torch.nn.DataParallel(self.forward_model) + if dist.is_initialized(): + self.forward_model = torch.nn.parallel.DistributedDataParallel(self.forward_model) self.forward_model.to(torch.get_default_device()) + else: + logger.warning( + "The default parallelization behavior of AutodiffReconstructor has been changed. " + "Now, multi-GPU support must be enabled by launching the script with `torchrun`. " + "If you are not the PtychographyTask wrapper, you also have to initialize " + "the process group with `torch.distributed.init_process_group`. Directly " + "launching the script that uses AutodiffReconstructor with `python` will " + "only use a single GPU. For more information, see " + "https://pty-chi.readthedocs.io/en/latest/using_pty_chi/devices.html#multi-gpu-and-multi-processing" + ) + def run_post_differentiation_hooks(self, input_data, y_true): self.get_forward_model().post_differentiation_hook(*input_data, y_true) @@ -92,7 +106,7 @@ def step_all_optimizers(self): sub_module.step_optimizer() def get_forward_model(self) -> "fm.ForwardModel": - if isinstance(self.forward_model, torch.nn.DataParallel): + if isinstance(self.forward_model, torch.nn.parallel.DistributedDataParallel): return self.forward_model.module else: return self.forward_model diff --git a/src/ptychi/reconstructors/lsqml.py b/src/ptychi/reconstructors/lsqml.py index b5de37a..8611897 100644 --- a/src/ptychi/reconstructors/lsqml.py +++ b/src/ptychi/reconstructors/lsqml.py @@ -224,9 +224,9 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): """ object_ = self.parameter_group.object self._initialize_object_gradient() - self._initialize_probe_gradient() - self._initialize_probe_position_gradient() - self._initialize_opr_mode_weights_gradient() + self.parameter_group.probe.initialize_grad() + self.parameter_group.probe_positions.initialize_grad() + self.parameter_group.opr_mode_weights.initialize_grad() self._initialize_object_step_size_buffer() self._initialize_probe_step_size_buffer() self._initialize_momentum_buffers() @@ -257,17 +257,17 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): if self.use_sparse_probe_shared_update and self.parameter_group.probe.optimization_enabled(self.current_epoch): ( - delta_p_i_unshifted, delta_p_i, _ + delta_p_i_before_adj_shift, delta_p_i, _ ) = self.calculate_probe_update_direction_sparse_code_probe_shared( indices, chi, obj_patches, i_slice ) else: # Calculate probe update direction (dense representation) - delta_p_i_unshifted = self._calculate_probe_update_direction( + delta_p_i_before_adj_shift = self._calculate_probe_update_direction( chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None ) # Eq. 24a delta_p_i = self.adjoint_shift_probe_update_direction( - indices, delta_p_i_unshifted, first_mode_only=True + indices, delta_p_i_before_adj_shift, first_mode_only=True ) delta_p_hat = self._precondition_probe_update_direction(delta_p_i) # Eq. 25a @@ -328,7 +328,7 @@ def calculate_update_vectors(self, indices, chi, obj_patches, positions): ) # Set chi to conjugate-modulated wavefield. - chi = delta_p_i_unshifted + chi = delta_p_i_before_adj_shift def calculate_probe_update_direction_sparse_code_probe_shared( self, indices, chi, obj_patches, i_slice=None @@ -1360,7 +1360,8 @@ def run_minibatch(self, input_data, y_true, *args, **kwargs) -> None: if self.current_epoch == 0 and self.options.rescale_probe_intensity_in_first_epoch: self.update_accumulated_intensities(y_true, y_pred) - self.reconstructor_buffers.synchronize(["accumulated_true_intensity", "accumulated_pred_intensity"]) + if self.current_minibatch == len(self.dataloader) - 1: + self.reconstructor_buffers.synchronize(["accumulated_true_intensity", "accumulated_pred_intensity"]) else: self.compute_reconstruction_parameter_updates(y_pred, y_true, indices) self.reconstructor_buffers.synchronize( diff --git a/tests/test_2d_ptycho_lsqml_multiprocess.py b/tests/test_2d_ptycho_lsqml_multiprocess.py index 29fe752..bb875aa 100644 --- a/tests/test_2d_ptycho_lsqml_multiprocess.py +++ b/tests/test_2d_ptycho_lsqml_multiprocess.py @@ -55,6 +55,45 @@ def test_2d_ptycho_lsqml_multiprocess(self): recon = task.get_data_to_cpu('object', as_numpy=True)[0] return recon + @pytest.mark.local + @tutils.TungstenDataTester.wrap_recon_tester(name='test_2d_ptycho_lsqml_compact_multiprocess') + def test_2d_ptycho_lsqml_compact_multiprocess(self): + self.setup_ptychi(cpu_only=False, gpu_indices=(0, 1)) + + data, probe, pixel_size_m, positions_px = self.load_tungsten_data(pos_type='true') + + options = api.LSQMLOptions() + options.data_options.data = data + + options.object_options.initial_guess = torch.ones([1, *get_suggested_object_size(positions_px, probe.shape[-2:], extra=100)], dtype=get_default_complex_dtype()) + options.object_options.pixel_size_m = pixel_size_m + options.object_options.optimizable = True + options.object_options.optimizer = api.Optimizers.SGD + options.object_options.step_size = 1 + options.object_options.build_preconditioner_with_all_modes = True + + options.probe_options.initial_guess = probe + options.probe_options.optimizable = True + options.probe_options.optimizer = api.Optimizers.SGD + options.probe_options.step_size = 1 + + options.probe_position_options.position_x_px = positions_px[:, 1] + options.probe_position_options.position_y_px = positions_px[:, 0] + options.probe_position_options.optimizable = False + + options.reconstructor_options.batch_size = 100 + options.reconstructor_options.noise_model = api.NoiseModels.GAUSSIAN + options.reconstructor_options.num_epochs = 8 + options.reconstructor_options.allow_nondeterministic_algorithms = False + options.reconstructor_options.batching_mode = api.BatchingModes.COMPACT + options.reconstructor_options.momentum_acceleration_gain = 0.5 + + task = PtychographyTask(options) + task.run() + + recon = task.get_data_to_cpu('object', as_numpy=True)[0] + return recon + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -64,4 +103,4 @@ def test_2d_ptycho_lsqml_multiprocess(self): tester = Test2dPtychoLsqmlMultiprocess() tester.setup_method(name="", generate_data=False, generate_gold=args.generate_gold, debug=True) tester.test_2d_ptycho_lsqml_multiprocess() - + tester.test_2d_ptycho_lsqml_compact_multiprocess() diff --git a/tests/test_2d_ptycho_lsqml_multiscan.py b/tests/test_2d_ptycho_lsqml_multiscan.py index e457bf1..21e69f0 100644 --- a/tests/test_2d_ptycho_lsqml_multiscan.py +++ b/tests/test_2d_ptycho_lsqml_multiscan.py @@ -87,8 +87,8 @@ def test_2d_ptycho_lsqml_multiscan(self): for i_task, task in enumerate(all_tasks): task.run(1) # Copy object to next task - if i_task < len(all_tasks) - 1: - all_tasks[i_task + 1].copy_data_from_task(task, params_to_copy=("object",)) + i_next_task = (i_task + 1) % len(all_tasks) + all_tasks[i_next_task].copy_data_from_task(task, params_to_copy=("object",)) recon = all_tasks[-1].get_data_to_cpu('object', as_numpy=True)[0] return recon diff --git a/tests/test_remove_grid_artifacts.py b/tests/test_remove_grid_artifacts.py index 99642e0..e6c7394 100644 --- a/tests/test_remove_grid_artifacts.py +++ b/tests/test_remove_grid_artifacts.py @@ -43,7 +43,36 @@ def test_remove_grid_artifacts_phase(self): plt.show() assert np.max(np.abs((np.angle(data)))) < 0.1 + + def test_remove_grid_artifacts_phase_y(self): + phase = torch.zeros([64, 64]) + phase[::10, ::10] = 1 + data = torch.ones([1, 64, 64]) * torch.exp(1j * phase) + object = PlanarObject( + data=data, + options=ObjectOptions( + pixel_size_m=1, + remove_grid_artifacts=RemoveGridArtifactsOptions( + enabled=True, + component=api.MagPhaseComponents.PHASE, + period_y_m=10, + window_size=3, + direction=api.Directions.Y, + ), + ), + ) + with torch.no_grad(): + object.remove_grid_artifacts() + + data = object.data.detach().cpu().numpy()[0] + if self.debug: + import matplotlib.pyplot as plt + + _, ax = plt.subplots() + ax.imshow(np.angle(data)) + plt.show() + def test_remove_grid_artifacts_both(self): phase = torch.zeros([64, 64]) phase[::10, ::10] = 1 @@ -84,4 +113,5 @@ def test_remove_grid_artifacts_both(self): tester = TestRemoveGridArtifacts() tester.setup_method(name="", generate_data=False, generate_gold=args.generate_gold, debug=True) tester.test_remove_grid_artifacts_phase() + tester.test_remove_grid_artifacts_phase_y() tester.test_remove_grid_artifacts_both()