Skip to content
5 changes: 5 additions & 0 deletions src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,11 @@ def get_non_data_fields(self) -> dict:
@dataclasses.dataclass
class SynthesisDictLearnProbeOptions(Options):

use_avg_spos_sparse_code: bool = True
"""When computing the sparse code updates, we can either solve for
sparse codes that are scan position dependent or we can use the average
over scan positions before solving for the average sparse code."""

d_mat: Union[ndarray, Tensor] = None
"""The synthesis sparse dictionary matrix; contains the basis functions
that will be used to represent the probe via the sparse code weights."""
Expand Down
9 changes: 7 additions & 2 deletions src/ptychi/api/options/lsqml.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,20 @@ class LSQMLObjectOptions(base.ObjectOptions):
propagation always uses all probe modes regardless of this option.
"""

@dataclasses.dataclass
class LSQMLProbeExperimentalOptions(base.Options):
sdl_probe_options: base.SynthesisDictLearnProbeOptions = dataclasses.field(default_factory=base.SynthesisDictLearnProbeOptions)


@dataclasses.dataclass
class LSQMLProbeOptions(base.ProbeOptions):
optimal_step_size_scaler: float = 0.9
"""
A scaler for the solved optimal step size (beta_LSQ in PtychoShelves).
"""


experimental: LSQMLProbeExperimentalOptions = dataclasses.field(default_factory=LSQMLProbeExperimentalOptions)


@dataclasses.dataclass
class LSQMLProbePositionOptions(base.ProbePositionOptions):
pass
Expand Down
4 changes: 3 additions & 1 deletion src/ptychi/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def build_probe(self):
):
self.probe = probe.DIPProbe(**kwargs)
elif (
isinstance(self.probe_options, api.options.PIEProbeOptions)
isinstance(self.probe_options, api.options.PIEProbeOptions)
or
isinstance(self.probe_options, api.options.LSQMLProbeOptions)
) and (
self.probe_options.experimental.sdl_probe_options.enabled
):
Expand Down
18 changes: 9 additions & 9 deletions src/ptychi/data_structures/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,9 @@ def __init__(self, name = "probe", options = None, *args, **kwargs):
sparse_code_probe = self.get_sparse_code_weights()
self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe))

use_avg_spos_sparse_code = self.options.experimental.sdl_probe_options.use_avg_spos_sparse_code
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it's just a boolean you don't have to register it as buffer. Also, instead of setting it as an attribute in __init__, let's just reference it from self.options on the fly in the method where it is used. This way if the user changes the value in the options object in the middle of a reconstruction, the new value can take effect dynamically.

self.register_buffer("use_avg_spos_sparse_code", torch.tensor(use_avg_spos_sparse_code, dtype=torch.bool))

self.build_optimizer()

def get_dictionary(self):
Expand All @@ -482,10 +485,11 @@ def get_dictionary(self):
return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H

def get_sparse_code_weights(self):

sz = self.data.shape
probe_vec = torch.reshape( self.data[0,...], (sz[1], sz[2] * sz[3]))
probe_vec = torch.swapaxes( probe_vec, 0, -1)
sparse_code_probe = self.dictionary_matrix_pinv @ probe_vec
probe_vec = torch.reshape(self.data, (sz[0], sz[1], sz[2] * sz[3]))
sparse_code_probe = torch.einsum('ij,klj->kli', self.dictionary_matrix_pinv, probe_vec)

return sparse_code_probe

def generate(self):
Expand All @@ -497,13 +501,9 @@ def generate(self):
Tensor
A (n_opr_modes, n_modes, h, w) tensor giving the generated probe.
"""
probe_vec = self.dictionary_matrix @ self.sparse_code_probe
probe_vec = torch.swapaxes( probe_vec, 0, -1)
probe = torch.reshape(probe_vec, *[self.data[0,...].shape])
probe = probe[None,...]

# we only use sparse codes for the shared modes, not the OPRs
probe = torch.cat((probe, self.data[1:,...]), 0)
probe = torch.einsum('ij,klj->kli', self.dictionary_matrix, self.sparse_code_probe)
probe = torch.reshape( probe, *[self.data.shape] )

self.set_data(probe)
return probe
Expand Down
110 changes: 103 additions & 7 deletions src/ptychi/reconstructors/lsqml.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,109 @@ def update_reconstruction_parameters(self, indices, chi, obj_patches, positions)
)
)

# Calculate probe update direction.
delta_p_i_unshifted = self._calculate_probe_update_direction(
chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None
) # Eq. 24a
delta_p_i = self.adjoint_shift_probe_update_direction(
indices, delta_p_i_unshifted, first_mode_only=True
)
if (self.parameter_group.probe.representation == "sparse_code"):

rc = chi.shape[-1] * chi.shape[-2]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also make everything side the if a separate method and call it here. Also avoid using ASCII dividers.
If the routines for sparse code update is the same across all reconstructors or if there is at least something common between them, please put the common parts in SynthesisDictLearnProbe.

n_scpm = chi.shape[-3]
n_spos = chi.shape[-4]

#======================================================================
# sparse code update directions vs scan position and shared probe modes

obj_patches_slice_i_conj = torch.conj( obj_patches[:, i_slice, ...] )
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove extra space in the brackets. Make code more compact by removing unnecessary blank spaces.


delta_sparse_code = chi * obj_patches_slice_i_conj[:, None, ... ]
delta_sparse_code = self.adjoint_shift_probe_update_direction(indices, delta_sparse_code, first_mode_only=True)

delta_sparse_code = torch.reshape( delta_sparse_code,
( n_spos, n_scpm, rc ))

delta_sparse_code = torch.einsum('ijk,kl->lij',
delta_sparse_code,
self.parameter_group.probe.dictionary_matrix_H.T)

#===================================================
# compute optimal step length for sparse code update

dict_delta_sparse_code = torch.einsum('ij,jkl->ikl',
self.parameter_group.probe.dictionary_matrix,
delta_sparse_code)

obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc ))

denom = torch.abs( dict_delta_sparse_code )**2 * obj_patches_vec.swapaxes(0,-1)[...,None]

denom = torch.einsum('ij,jik->ik',
torch.conj( obj_patches_vec ),
denom)

#=====

chi_rm_subpx_shft = self.adjoint_shift_probe_update_direction(indices, chi, first_mode_only=True)

numer = torch.conj( dict_delta_sparse_code ) * torch.reshape( chi_rm_subpx_shft,
( n_spos, n_scpm, rc )).permute(2,0,1)
numer = torch.einsum('ij,jik->ik',
torch.conj( obj_patches_vec ),
numer)

# real is used to throw away small imag part due to numerical precision errors
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment about using .real to handle numerical precision errors should be more specific about when this occurs and potential alternatives.

Suggested change
# real is used to throw away small imag part due to numerical precision errors
# In theory, numer/denom should be real, but small imaginary parts can arise due to floating-point
# precision errors in complex arithmetic. We use .real to discard these. Alternatively, one could use
# torch.real_if_close or check that the imaginary part is negligible before discarding it.

Copilot uses AI. Check for mistakes.
optimal_step_sparse_code = ( numer / denom ).real

#=====

optimal_delta_sparse_code = optimal_step_sparse_code[None,...] * delta_sparse_code

optimal_delta_sparse_code_mean_spos = ( optimal_delta_sparse_code.mean(1).T )[None, ...]

# sparse code update
sparse_code = self.parameter_group.probe.get_sparse_code_weights()
sparse_code = sparse_code + optimal_delta_sparse_code_mean_spos

#===========================================
# Enforce sparsity constraint on sparse code

abs_sparse_code = torch.abs(sparse_code)
sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True)

sel = sparse_code_sorted[0][...,self.parameter_group.probe.probe_sparse_code_nnz]

# hard thresholding:
sparse_code = sparse_code * (abs_sparse_code >= sel[...,None])

#(TODO: soft thresholding option)

#==============================================
# Update the new sparse code in the probe class

self.parameter_group.probe.set_sparse_code(sparse_code)

#===============================================================
# Create the probe update vs scan position using the sparse code

delta_p_i = torch.einsum('ij,jlk->ilk', self.parameter_group.probe.dictionary_matrix,
optimal_delta_sparse_code)
delta_p_i = delta_p_i.permute(1,2,0)

# if self.parameter_group.probe.use_avg_spos_sparse_code:
# delta_p_i = torch.tile( delta_p_i, ( n_spos, 1, 1 ) )
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commented-out code should either be implemented or removed. Leaving commented code in production reduces maintainability and clarity.

Suggested change
# if self.parameter_group.probe.use_avg_spos_sparse_code:
# delta_p_i = torch.tile( delta_p_i, ( n_spos, 1, 1 ) )

Copilot uses AI. Check for mistakes.

delta_p_i = torch.reshape(delta_p_i, ( n_spos, n_scpm,
chi.shape[-1],
chi.shape[-2] ))

delta_p_i_unshifted = self.forward_model.shift_unique_probes(indices, delta_p_i, first_mode_only=True)

else:
# Calculate probe update direction.
delta_p_i_unshifted = self._calculate_probe_update_direction(
chi, obj_patches=obj_patches, slice_index=i_slice, probe_mode_index=None
) # Eq. 24a

delta_p_i = self.adjoint_shift_probe_update_direction(
indices, delta_p_i_unshifted, first_mode_only=True
)

delta_p_hat = self._precondition_probe_update_direction(delta_p_i) # Eq. 25a

# Update OPR modes and weights.
Expand Down
89 changes: 64 additions & 25 deletions src/ptychi/reconstructors/pie.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,50 +135,89 @@ def compute_updates(

delta_p_i = None
if (i_slice == 0) and (probe.optimization_enabled(self.current_epoch)):

if (self.parameter_group.probe.representation == "sparse_code"):
# TODO: move this into SynthesisDictLearnProbe class

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separate method for all everything inside the if. If the routines for sparse code update is the same across all reconstructors or if there is at least something common between them, please put the common parts in SynthesisDictLearnProbe.

# TODO: move these into SynthesisDictLearnProbe class
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO comments should include more specific information about implementation timeline, requirements, or be converted to proper issue tracking.

Suggested change
# TODO: move these into SynthesisDictLearnProbe class
# TODO: Move these calculations into SynthesisDictLearnProbe class for better modularity.
# See issue tracker: https://github.com/AdvancedPhotonSource/pty-chi/issues/XXX
# Target: Refactor by Q3 2025. Requirements: Move rc, n_scpm, n_spos, obj_patches_conj, conjT_i_delta_exwv_i, and related logic.

Copilot uses AI. Check for mistakes.
rc = delta_exwv_i.shape[-1] * delta_exwv_i.shape[-2]
n_scpm = delta_exwv_i.shape[-3]
n_spos = delta_exwv_i.shape[-4]

obj_patches_vec = torch.reshape(obj_patches[:, i_slice, ...], (n_spos, 1, rc ))
abs2_obj_patches = torch.abs(obj_patches_vec) ** 2
obj_patches_conj = torch.conj( obj_patches[:, i_slice, ...])
conjT_i_delta_exwv_i = obj_patches_conj[:, None,...] * delta_exwv_i

# undo subpixel shifts and reshape
conjT_delta_exwv = self.adjoint_shift_probe_update_direction(indices, conjT_i_delta_exwv_i, first_mode_only=True)
conjT_delta_exwv_vec = torch.reshape( conjT_delta_exwv.permute( 2, 3, 0, 1 ), (rc, n_spos, n_scpm ) )

z = torch.sum(abs2_obj_patches, dim = 0)
z_max = torch.max(z)
w = self.parameter_group.probe.options.alpha * (z_max - z)
z_plus_w = torch.swapaxes(z + w, 0, 1)
obj_patches_vec = torch.reshape( obj_patches[:, i_slice, ...], ( n_spos, rc ))
abs2_obj_patches = torch.abs( obj_patches_vec )**2

z_plus_w = torch.max(abs2_obj_patches, dim=0, keepdim=True)[0]
z_plus_w = self.parameter_group.probe.options.alpha * (z_plus_w - abs2_obj_patches)
z_plus_w = abs2_obj_patches + z_plus_w

#================================
# use average over scan positions

delta_exwv = self.adjoint_shift_probe_update_direction(indices, delta_exwv_i, first_mode_only=True)
delta_exwv = torch.sum(delta_exwv, 0)
delta_exwv = torch.reshape( delta_exwv, (n_scpm, rc)).T
if self.parameter_group.probe.use_avg_spos_sparse_code:

z_plus_w = torch.sum( z_plus_w, 0 )[None,:]
conjT_delta_exwv_vec = torch.sum( conjT_delta_exwv_vec, 1 )[:,None,:]

denom = (self.parameter_group.probe.dictionary_matrix_H @ (z_plus_w * self.parameter_group.probe.dictionary_matrix))
numer = self.parameter_group.probe.dictionary_matrix_H @ delta_exwv
#=====

denom = torch.einsum('ij,ik,lki->jkl',
self.parameter_group.probe.dictionary_matrix.conj(),
self.parameter_group.probe.dictionary_matrix,
z_plus_w[:,None,...].to(torch.complex64))

numer = torch.einsum('ij,jlk->ilk',
self.parameter_group.probe.dictionary_matrix_H,
conjT_delta_exwv_vec)

delta_sparse_code = torch.linalg.solve(denom.permute(2, 0, 1), numer.permute(1, 0, 2))

delta_sparse_code = torch.linalg.solve(denom, numer)
# # If dictionary has bad condition number, use Tikhonov regularization?
# delta_sparse_code, _, _, _ = torch.linalg.lstsq(denom.permute(2, 0, 1), numer.permute(1, 0, 2), rcond=1e-6)
# delta_sparse_code = delta_sparse_code.permute(1, 0, 2)

delta_p = self.parameter_group.probe.dictionary_matrix @ delta_sparse_code
delta_p = torch.reshape( delta_p.T, ( n_scpm, delta_exwv_i.shape[-1] , delta_exwv_i.shape[-2]))
delta_p_i = torch.tile(delta_p, (n_spos,1,1,1))

# sparse code update
delta_sparse_code_mean_spos = ( delta_sparse_code.mean(0).T )[None, ...]

sparse_code = self.parameter_group.probe.get_sparse_code_weights()
sparse_code = sparse_code + delta_sparse_code
sparse_code = sparse_code + delta_sparse_code_mean_spos

#===========================================
# Enforce sparsity constraint on sparse code

abs_sparse_code = torch.abs(sparse_code)
sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True)
sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True)

sel = sparse_code_sorted[0][self.parameter_group.probe.probe_sparse_code_nnz, :]
sel = sparse_code_sorted[0][..., self.parameter_group.probe.probe_sparse_code_nnz]

#(TODO: soft thresholding option as default?)
Copy link

Copilot AI Aug 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO comments should include more specific information about implementation timeline, requirements, or be converted to proper issue tracking.

Suggested change
#(TODO: soft thresholding option as default?)
# TODO: Consider implementing soft thresholding as the default option for enforcing sparsity.
# See issue #123 on GitHub (https://github.com/AdvancedPhotonSource/pty-chi/issues/123) for requirements and discussion.

Copilot uses AI. Check for mistakes.
# hard thresholding:
sparse_code = sparse_code * (abs_sparse_code >= sel)

#(TODO: soft thresholding option)

sparse_code = sparse_code * (abs_sparse_code >= sel[...,None])

#==============================================
# Update the new sparse code in the probe class

self.parameter_group.probe.set_sparse_code(sparse_code)

#===============================================================
# Create the probe update vs scan position using the sparse code

delta_p_i = torch.einsum('ij,ljk->ilk', self.parameter_group.probe.dictionary_matrix,
delta_sparse_code)
delta_p_i = delta_p_i.permute(1,2,0)

if self.parameter_group.probe.use_avg_spos_sparse_code:
delta_p_i = torch.tile( delta_p_i, ( n_spos, 1, 1 ) )

delta_p_i = torch.reshape(delta_p_i, ( n_spos, n_scpm,
delta_exwv_i.shape[-1],
delta_exwv_i.shape[-2] ))

else:
step_weight = self.calculate_probe_step_weight((obj_patches[:, [i_slice], ...]))
delta_p_i = step_weight * delta_exwv_i # get delta p at each position
Expand Down
Loading