Skip to content

Commit

Permalink
refactor: consistent name in functions
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Feb 10, 2025
1 parent 56601e6 commit c56f1ba
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 30 deletions.
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,9 @@ live_mode = false

[tool.mypy]
ignore_missing_imports = true

[tool.pyright]
reportPossiblyUnboundVariable = false
typeCheckingMode = "basic"
reportOptionalSubscript = false
reportOptionalMemberAccess = false
42 changes: 22 additions & 20 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,16 @@

from abc import ABC, abstractmethod
from functools import partial

from typing import ClassVar, Callable
import numpy as np
from numpy.typing import NDArray

from mrinufft._array_compat import with_numpy, with_numpy_cupy, AUTOGRAD_AVAILABLE
from mrinufft._utils import auto_cast, power_method
from mrinufft.density import get_density
from mrinufft.extras import get_smaps
from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array

if AUTOGRAD_AVAILABLE:
from mrinufft.operators.autodiff import MRINufftAutoGrad


# Mapping between numpy float and complex types.
DTYPE_R2C = {"float32": "complex64", "float64": "complex128"}

Expand Down Expand Up @@ -122,6 +119,9 @@ class FourierOperatorBase(ABC):
_grad_wrt_data = False
_grad_wrt_traj = False

backend: ClassVar[str]
available: ClassVar[bool]

def __init__(self):
if not self.available:
raise RuntimeError(f"'{self.backend}' backend is not available.")
Expand Down Expand Up @@ -207,21 +207,21 @@ def adj_op(self, coeffs):
"""
pass

def data_consistency(self, image, obs_data):
def data_consistency(self, image_data, obs_data):
"""Compute the gradient data consistency.
This is the naive implementation using adj_op(op(x)-y).
Specific backend can (and should!) implement a more efficient version.
"""
return self.adj_op(self.op(image) - obs_data)
return self.adj_op(self.op(image_data) - obs_data)

def with_off_resonance_correction(self, B, C, indices):
"""Return a new operator with Off Resonnance Correction."""
from ..off_resonance import MRIFourierCorrected
from .off_resonance import MRIFourierCorrected

return MRIFourierCorrected(self, B, C, indices)

def compute_smaps(self, method=None):
def compute_smaps(self, method: NDArray | Callable | str | dict | None = None):
"""Compute the sensitivity maps and set it.
Parameters
Expand Down Expand Up @@ -286,6 +286,8 @@ def make_autograd(self, wrt_data=True, wrt_traj=False):
if not self.autograd_available:
raise ValueError("Backend does not support auto-differentiation.")

from mrinufft.operators.autodiff import MRINufftAutoGrad

return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj)

def compute_density(self, method=None):
Expand Down Expand Up @@ -401,9 +403,9 @@ def smaps(self):
return self._smaps

@smaps.setter
def smaps(self, smaps):
self._check_smaps_shape(smaps)
self._smaps = smaps
def smaps(self, new_smaps):
self._check_smaps_shape(new_smaps)
self._smaps = new_smaps

def _check_smaps_shape(self, smaps):
"""Check the shape of the sensitivity maps."""
Expand All @@ -421,22 +423,22 @@ def density(self):
return self._density

@density.setter
def density(self, density):
if density is None:
def density(self, new_density):
if new_density is None:
self._density = None
elif len(density) != self.n_samples:
elif len(new_density) != self.n_samples:
raise ValueError("Density and samples should have the same length")
else:
self._density = density
self._density = new_density

@property
def dtype(self):
"""Return floating precision of the operator."""
return self._dtype

@dtype.setter
def dtype(self, dtype):
self._dtype = np.dtype(dtype)
def dtype(self, new_dtype):
self._dtype = np.dtype(new_dtype)

@property
def cpx_dtype(self):
Expand All @@ -449,8 +451,8 @@ def samples(self):
return self._samples

@samples.setter
def samples(self, samples):
self._samples = samples
def samples(self, new_samples):
self._samples = new_samples

@property
def n_samples(self):
Expand Down
7 changes: 4 additions & 3 deletions src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
except ImportError:
CUFINUFFT_AVAILABLE = False


OPTS_FIELD_DECODE = {
"gpu_method": {1: "nonuniform pts driven", 2: "shared memory"},
"gpu_sort": {0: "no sort (GM)", 1: "sort (GM-sort)"},
Expand Down Expand Up @@ -269,10 +268,12 @@ def smaps(self, new_smaps):
self._smaps = new_smaps

@FourierOperatorBase.samples.setter
def samples(self, samples):
def samples(self, new_samples):
"""Update the plans when changing the samples."""
self._samples = np.asfortranarray(
proper_trajectory(samples, normalize="pi").astype(np.float32, copy=False)
proper_trajectory(new_samples, normalize="pi").astype(
np.float32, copy=False
)
)
for typ in [1, 2, "grad"]:
if typ == "grad" and not self._grad_wrt_traj:
Expand Down
10 changes: 5 additions & 5 deletions src/mrinufft/operators/interfaces/gpunufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def smaps(self, new_smaps):
self.raw_op.set_smaps(smaps=new_smaps)

@FourierOperatorBase.samples.setter
def samples(self, samples):
def samples(self, new_samples):
"""Set the samples for the Fourier Operator.
Parameters
Expand All @@ -541,7 +541,7 @@ def samples(self, samples):
The samples for the Fourier Operator.
"""
self._samples = proper_trajectory(
samples.astype(np.float32, copy=False), normalize="unit"
new_samples.astype(np.float32, copy=False), normalize="unit"
)
# TODO: gpuNUFFT needs to sort the points twice in this case.
# It could help to have access to directly dorted arrays from gpuNUFFT.
Expand All @@ -552,19 +552,19 @@ def samples(self, samples):
)

@FourierOperatorBase.density.setter
def density(self, density):
def density(self, new_density):
"""Set the density for the Fourier Operator.
Parameters
----------
density: np.ndarray
The density for the Fourier Operator.
"""
self._density = density
self._density = new_density
if hasattr(self, "raw_op"): # edge case for init
self.raw_op.set_pts(
self._samples,
density=density,
density=new_density,
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions src/mrinufft/operators/interfaces/tfnufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def norm_factor(self):
return np.sqrt(np.prod(self.shape) * 2 ** len(self.shape))

@with_tensorflow
def data_consistency(self, data, obs_data):
def data_consistency(self, image_data, obs_data):
"""Compute the data consistency.
Parameters
Expand All @@ -149,7 +149,7 @@ def data_consistency(self, data, obs_data):
Tensor
The data consistency error in image space.
"""
return self.adj_op(self.op(data) - obs_data)
return self.adj_op(self.op(image_data) - obs_data)

@classmethod
def pipe(
Expand Down

1 comment on commit c56f1ba

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Artifacts

Please sign in to comment.