Skip to content
42 changes: 30 additions & 12 deletions src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,9 @@ class ProbeOrthogonalizeIncoherentModesOptions(FeatureOptions):

method: enums.OrthogonalizationMethods = enums.OrthogonalizationMethods.SVD
"""The method to use for incoherent_mode orthogonalization."""


sort_by_occupancy: bool = False
"""Keep the probes sorted so that mode with highest occupancy is the 0th shared mode"""

@dataclasses.dataclass
class ProbeOrthogonalizeOPRModesOptions(FeatureOptions):
Expand Down Expand Up @@ -648,25 +650,35 @@ def get_non_data_fields(self) -> dict:

@dataclasses.dataclass
class SynthesisDictLearnProbeOptions(Options):

d_mat: Union[ndarray, Tensor] = None

enabled: bool = False
enabled_shared: bool = False
enabled_opr: bool = False

thresholding_type_shared: str = 'hard'
thresholding_type_opr: str = 'hard'
"""Choose between 'hard' or 'soft' thresholding."""

dictionary_matrix: Union[ndarray, Tensor] = None
"""The synthesis sparse dictionary matrix; contains the basis functions
that will be used to represent the probe via the sparse code weights."""

d_mat_conj_transpose: Union[ndarray, Tensor] = None
"""Conjugate transpose of the synthesis sparse dictionary matrix."""

d_mat_pinv: Union[ndarray, Tensor] = None
dictionary_matrix_pinv: Union[ndarray, Tensor] = None
"""Moore-Penrose pseudoinverse of the synthesis sparse dictionary matrix."""

probe_sparse_code: Union[ndarray, Tensor] = None
"""Sparse code weights vector."""
sparse_code_probe_shared: Union[ndarray, Tensor] = None
"""Sparse code weights vector for the shared modes."""

probe_sparse_code_nnz: float = None
sparse_code_probe_shared_nnz: float = None
"""Number of non-zeros we will keep when enforcing sparsity constraint on
the sparse code weights vector probe_sparse_code."""
the SHARED sparse code weights vector sparse_code_probe_shared."""

sparse_code_probe_opr: Union[ndarray, Tensor] = None
"""Sparse code weights vector for the OPRs."""

enabled: bool = False
sparse_code_probe_opr_nnz: float = None
"""Number of non-zeros we will keep when enforcing sparsity constraint on
the OPR sparse code weights vector sparse_code_probe_opr."""

@dataclasses.dataclass
class PositionCorrectionOptions(Options):
Expand Down Expand Up @@ -869,6 +881,12 @@ class OPRModeWeightsOptions(ParameterOptions):
A separate step size for eigenmode weight update.
"""

use_optimal_update: bool = False
"""
We do not compute an optimal update step for OPR weights and eigenmodes
using the default method; this add the option to do this.
"""

def check(self, options: "task_options.PtychographyTaskOptions"):
super().check(options)
if self.optimizable:
Expand Down
9 changes: 7 additions & 2 deletions src/ptychi/api/options/lsqml.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,20 @@ class LSQMLObjectOptions(base.ObjectOptions):
propagation always uses all probe modes regardless of this option.
"""

@dataclasses.dataclass
class LSQMLProbeExperimentalOptions(base.Options):
sdl_probe_options: base.SynthesisDictLearnProbeOptions = dataclasses.field(default_factory=base.SynthesisDictLearnProbeOptions)


@dataclasses.dataclass
class LSQMLProbeOptions(base.ProbeOptions):
optimal_step_size_scaler: float = 0.9
"""
A scaler for the solved optimal step size (beta_LSQ in PtychoShelves).
"""


experimental: LSQMLProbeExperimentalOptions = dataclasses.field(default_factory=LSQMLProbeExperimentalOptions)


@dataclasses.dataclass
class LSQMLProbePositionOptions(base.ProbePositionOptions):
pass
Expand Down
4 changes: 3 additions & 1 deletion src/ptychi/api/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def build_probe(self):
):
self.probe = probe.DIPProbe(**kwargs)
elif (
isinstance(self.probe_options, api.options.PIEProbeOptions)
isinstance(self.probe_options, api.options.PIEProbeOptions)
or
isinstance(self.probe_options, api.options.LSQMLProbeOptions)
) and (
self.probe_options.experimental.sdl_probe_options.enabled
):
Expand Down
209 changes: 169 additions & 40 deletions src/ptychi/data_structures/opr_mode_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def intensity_variation_optimization_enabled(self, epoch: int):
def update_variable_probe(
self,
probe: "Probe",
adjoint_shift_probe_update_direction, # what do I do for type hint here?
indices: Tensor,
chi: Tensor,
delta_p_i: Tensor,
Expand All @@ -117,8 +118,14 @@ def update_variable_probe(
probe.optimization_enabled(current_epoch)
or (self.eigenmode_weight_optimization_enabled(current_epoch))
):
self.update_opr_probe_modes_and_weights(
probe, indices, chi, delta_p_i, delta_p_hat, obj_patches, current_epoch
self.update_opr_probe_modes_and_weights(probe,
adjoint_shift_probe_update_direction,
indices,
chi,
delta_p_i,
delta_p_hat,
obj_patches,
current_epoch
)

if self.intensity_variation_optimization_enabled(current_epoch):
Expand All @@ -134,6 +141,7 @@ def update_variable_probe(
def update_opr_probe_modes_and_weights(
self,
probe: "Probe",
adjoint_shift_probe_update_direction, # what do I do for type hint here?
indices: Tensor,
chi: Tensor,
delta_p_i: Tensor,
Expand All @@ -144,12 +152,12 @@ def update_opr_probe_modes_and_weights(
"""
Update the eigenmodes of the first incoherent mode of the probe, and update the OPR mode weights.

This implementation is adapted from PtychoShelves code (update_variable_probe.m) and has some
differences from Eq. 31 of Odstrcil (2018).
The default (for self.options.use_optimal_update = False) implementation below is adapted from
PtychoShelves code (update_variable_probe.m) and has some differences from Eq. 31 of Odstrcil (2018).
"""
probe_data = probe.data
weights_data = self.data

batch_size = len(delta_p_i)
n_points_total = self.n_scan_points

Expand All @@ -158,44 +166,165 @@ def update_opr_probe_modes_and_weights(
if batch_size == 1:
return

# FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this.
relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation
relax_v = self.options.update_relaxation
# Shape of delta_p_i: (batch_size, n_probe_modes, h, w)
# Use only the first incoherent mode
delta_p_i = delta_p_i[:, 0, :, :]
delta_p_hat = delta_p_hat[0, :, :]
residue_update = delta_p_i - delta_p_hat

# Start from the second OPR mode which is the first after the main mode - i.e., the first eigenmode.
for i_opr_mode in range(1, probe.n_opr_modes):
# Just take the first incoherent mode.
eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode)
weights_i = self.get_weights(indices)[:, i_opr_mode]
eigenmode_i, weights_i = self._update_first_eigenmode_and_weight(
residue_update,
eigenmode_i,
weights_i,
relax_u,
relax_v,
obj_patches,
chi,
update_eigenmode=probe.optimization_enabled(current_epoch),
update_weights=self.eigenmode_weight_optimization_enabled(current_epoch),
)

# Project residue on this eigenmode, then subtract it.
if i_opr_mode < probe.n_opr_modes - 1:
residue_update = residue_update - pmath.project(
residue_update, eigenmode_i, dim=(-2, -1)
update_eigenmode = probe.optimization_enabled(current_epoch) # why is this needed again? To even get into this function, we need this to already be true?
update_eigenmode_weights = self.eigenmode_weight_optimization_enabled(current_epoch)

if self.options.use_optimal_update:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Split SDL and optimal OPR weight and eigenmode updates into 2 pull requests


rc = obj_patches.shape[-2] * obj_patches.shape[-1]
n_spos = obj_patches.shape[0]

U = probe_data[1:, 0, ...]

Ws = (weights_data[ indices, 1:]).to(torch.complex64)

Tsconj_chi = (obj_patches[:,0,...].conj() * chi[:,0,...])
Tsconj_chi = adjoint_shift_probe_update_direction( indices, Tsconj_chi[:,None,...], first_mode_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one should be replaceable by delta_p_i


chi = adjoint_shift_probe_update_direction( indices, chi, first_mode_only=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably ignorable because we have been mixing shifted and unshifted variables before and it works. In that case the adjoint_shift_probe_update argument can be removed


U = torch.reshape(U, (U.shape[0], rc))
chi_vec = torch.reshape(chi[:,0,...], (n_spos, rc))
Ts = torch.reshape(obj_patches[:,0,...], (n_spos, rc))
Tsconj_chi = torch.reshape(Tsconj_chi[:,0,...], (n_spos, rc)).T

# Optimal OPR weight updates

if update_eigenmode_weights:

delta_Ws = -2 * torch.real(U.conj() @ Tsconj_chi).to(torch.complex64)

Ts_U_deltaWs = Ts.T * (U.T @ delta_Ws)
numer = torch.sum(torch.real(chi_vec * Ts_U_deltaWs.H))
denom = torch.sum(torch.real( Ts_U_deltaWs.conj() * Ts_U_deltaWs ))
optimal_step_deltaWs = self.options.update_relaxation * (numer / denom)

Ws = (Ws + optimal_step_deltaWs * delta_Ws.T)

if (probe.representation == "sparse_code"
and probe.options.experimental.sdl_probe_options.enabled_opr):

# Optimal sparse code OPR mode updates

delta_U = -1 * Tsconj_chi @ Ws

delta_sparse_code_probe_opr = probe.dictionary_matrix.H @ delta_U

Gs = probe.dictionary_matrix @ delta_sparse_code_probe_opr @ Ws.T
TsHGsH = Ts.H * Gs.conj()
numer = torch.sum( torch.real(TsHGsH * chi_vec.T))
denom = torch.sum( torch.real(TsHGsH * TsHGsH.conj()))
optimal_step_sparse_code_probe_opr = probe.options.eigenmode_update_relaxation * (numer / denom)

sparse_code_probe_opr = probe.get_sparse_code_probe_opr_weights()

optimal_sparse_code_probe_opr = (sparse_code_probe_opr
+ optimal_step_sparse_code_probe_opr * delta_sparse_code_probe_opr.T)

# Enforce sparsity constraint on sparse code
abs_sparse_code = torch.abs(optimal_sparse_code_probe_opr)
abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=-1, descending=True)
sel = abs_sparse_code_sorted[0][:, probe.sparse_code_probe_nnz]
sparse_code_mask = (abs_sparse_code >= sel[:,None])

# Hard or Soft thresholding
if probe.options.experimental.sdl_probe_options.thresholding_type_opr == 'hard':
optimal_sparse_code_probe_opr = optimal_sparse_code_probe_opr * sparse_code_mask
elif probe.options.experimental.sdl_probe_options.thresholding_type_opr == 'soft':
optimal_sparse_code_probe_opr = ( abs_sparse_code - sel[:,None] ) * sparse_code_mask * torch.exp(1j * torch.angle(optimal_sparse_code_probe_opr))

probe.set_sparse_code_probe_opr(optimal_sparse_code_probe_opr)

# Back to dense OPR representation
U = (probe.dictionary_matrix @ optimal_sparse_code_probe_opr.T).T

# the OPR modes must have L2 norm = torch.sqrt(torch.tensor(rc))
U = U * torch.sqrt(torch.tensor(rc)) / torch.sqrt(torch.sum(torch.abs(U)**2, -1))[:,None]

U = torch.reshape(U, (U.shape[0], obj_patches.shape[-2], obj_patches.shape[-1]))

probe_data[1:, 0, :, :] = U
weights_data[indices, 1:] = Ws.real

# DELETE THIS FOR FINAL MERGING
# DELETE THIS FOR FINAL MERGING

# Test the rank of the new scan position dependent probe:

# probe_data_TEST = torch.reshape(probe_data[:,0,...], (probe_data.shape[0], probe_data.shape[-1] * probe_data.shape[-2]))
# Z1 = torch.sum(probe_data[:, 0, :, :][None,...] * weights_data[indices][...,None,None], 1)
# Z1 = torch.reshape(Z1, (Z1.shape[0], Z1.shape[1] * Z1.shape[2]))
# Z2 = probe_data_TEST.T @ weights_data[indices, :].T.to(torch.complex64)
# print( torch.linalg.matrix_rank(Z1) )
# print( torch.linalg.matrix_rank(Z2) )

# DELETE THIS FOR FINAL MERGING
# DELETE THIS FOR FINAL MERGING

else:

# Optimal dense OPR mode updates:

delta_U = -1 * Tsconj_chi @ Ws

Ts_deltaU_Ws = Ts.T * (delta_U @ Ws.T)
numer = torch.sum(torch.real(chi_vec * Ts_deltaU_Ws.H))
denom = torch.sum(torch.real( Ts_deltaU_Ws.conj() * Ts_deltaU_Ws ))
optimal_step_deltaU = probe.options.eigenmode_update_relaxation * (numer / denom)

U = U + optimal_step_deltaU * delta_U.T

# the OPR modes must have L2 norm = torch.sqrt(torch.tensor(rc))
U = U * torch.sqrt(torch.tensor(rc)) / torch.sqrt(torch.sum(torch.abs(U)**2, -1))[:,None]

U = torch.reshape(U, (U.shape[0], obj_patches.shape[-2], obj_patches.shape[-1]))

probe_data[1:, 0, :, :] = U
weights_data[indices, 1:] = Ws.real

else:

# Ptychoshelves method for OPR updates

# FIXME: reduced relax_u/v by a factor of 10 for stability, but PtychoShelves works without this.
relax_u = min(0.1, batch_size / n_points_total) * probe.options.eigenmode_update_relaxation
relax_v = self.options.update_relaxation
# Shape of delta_p_i: (batch_size, n_probe_modes, h, w)
# Use only the first incoherent mode
delta_p_i = delta_p_i[:, 0, :, :]
delta_p_hat = delta_p_hat[0, :, :]
residue_update = delta_p_i - delta_p_hat

# Start from the second OPR mode which is the first after the main mode - i.e., the first eigenmode.
for i_opr_mode in range(1, probe.n_opr_modes):
# Just take the first incoherent mode.
eigenmode_i = probe.get_mode_and_opr_mode(mode=0, opr_mode=i_opr_mode)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you can regenerate the eigenmodes using update SDL coefficients (probe.generate() to update the probe.data attribute) so that you can reuse the existing code

weights_i = self.get_weights(indices)[:, i_opr_mode]
eigenmode_i, weights_i = self._update_first_eigenmode_and_weight(
residue_update,
eigenmode_i,
weights_i,
relax_u,
relax_v,
obj_patches,
chi,
update_eigenmode=update_eigenmode,
update_weights=self.eigenmode_weight_optimization_enabled(current_epoch),
)

probe_data[i_opr_mode, 0, :, :] = eigenmode_i
weights_data[indices, i_opr_mode] = weights_i
# Project residue on this eigenmode, then subtract it.
if i_opr_mode < probe.n_opr_modes - 1:
residue_update = residue_update - pmath.project(
residue_update, eigenmode_i, dim=(-2, -1)
)

if probe.optimization_enabled(current_epoch):
probe.set_data(probe_data)
if self.eigenmode_weight_optimization_enabled(current_epoch):
probe_data[i_opr_mode, 0, :, :] = eigenmode_i
weights_data[indices, i_opr_mode] = weights_i

if update_eigenmode:
probe.set_data(probe_data)

if update_eigenmode_weights:
self.set_data(weights_data)

@timer()
Expand Down
Loading
Loading