Skip to content

Commit

Permalink
Added siemens data reader and phase shifter (#89)
Browse files Browse the repository at this point in the history
* Added siemens and add_raw shifts

* Update src/mrinufft/io/siemens.py

Co-authored-by: Pierre-Antoine Comby <[email protected]>

* Update

* Fixed some more

* Moved codes around

* Added np.ndarray

* Fix movement

* Fix movement

* Fix flake

* ruff fix

* Fix

* Remove bymistake add

* style: ruff

* feat(io.utils): add basic test.

---------

Co-authored-by: Pierre-Antoine Comby <[email protected]>
  • Loading branch information
chaithyagr and paquiteau authored Apr 29, 2024
1 parent 37c6921 commit 8e15f05
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 1 deletion.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cufinufft = ["cufinufft", "cupy-cuda11x"]
finufft = ["finufft"]
pynfft = ["pynfft2", "cython<3.0.0"]
pynufft = ["pynufft"]
io = ["pymapvbvd"]

test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"]
dev = ["black", "isort", "ruff"]
Expand Down
88 changes: 87 additions & 1 deletion src/mrinufft/io/nsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def read_trajectory(
grad_filename: str,
dwell_time: float = DEFAULT_RASTER_TIME,
num_adc_samples: int = None,
gamma: float = Gammas.HYDROGEN,
gamma: Gammas | float = Gammas.HYDROGEN,
raster_time: float = DEFAULT_RASTER_TIME,
read_shots: bool = False,
normalize_factor: float = KMAX,
Expand Down Expand Up @@ -390,3 +390,89 @@ def read_trajectory(
Kmax = img_size / 2 / fov
kspace_loc = kspace_loc / Kmax * normalize_factor
return kspace_loc, params


def read_siemens_rawdat(
filename: str,
removeOS: bool = False,
squeeze: bool = True,
data_type: str = "ARBGRAD_VE11C",
): # pragma: no cover
"""Read raw data from a Siemens MRI file.
Parameters
----------
filename : str
The path to the Siemens MRI file.
removeOS : bool, optional
Whether to remove the oversampling, by default False.
squeeze : bool, optional
Whether to squeeze the dimensions of the data, by default True.
data_type : str, optional
The type of data to read, by default 'ARBGRAD_VE11C'.
Returns
-------
data: ndarray
Imported data formatted as n_coils X n_samples X n_slices X n_contrasts
hdr: dict
Extra information about the data parsed from the twix file
Raises
------
ImportError
If the mapVBVD module is not available.
Notes
-----
This function requires the mapVBVD module to be installed.
You can install it using the following command:
`pip install pymapVBVD`
"""
try:
from mapvbvd import mapVBVD
except ImportError as err:
raise ImportError(
"The mapVBVD module is not available. Please install it using "
"the following command: pip install pymapVBVD"
) from err
twixObj = mapVBVD(filename)
if isinstance(twixObj, list):
twixObj = twixObj[-1]
twixObj.image.flagRemoveOS = removeOS
twixObj.image.squeeze = squeeze
raw_kspace = twixObj.image[""]
data = np.moveaxis(raw_kspace, 0, 2)
hdr = {
"n_coils": int(twixObj.image.NCha),
"n_shots": int(twixObj.image.NLin),
"n_contrasts": int(twixObj.image.NSet),
"n_adc_samples": int(twixObj.image.NCol),
"n_slices": int(twixObj.image.NSli),
}
data = data.reshape(
hdr["n_coils"],
hdr["n_shots"] * hdr["n_adc_samples"],
hdr["n_slices"],
hdr["n_contrasts"],
)
if "ARBGRAD_VE11C" in data_type:
hdr["type"] = "ARBGRAD_GRE"
hdr["shifts"] = ()
for s in [7, 6, 8]:
shift = twixObj.search_header_for_val(
"Phoenix", ("sWiPMemBlock", "adFree", str(s))
)
hdr["shifts"] += (0,) if shift == [] else (shift[0],)
hdr["oversampling_factor"] = twixObj.search_header_for_val(
"Phoenix", ("sWiPMemBlock", "alFree", "4")
)[0]
hdr["trajectory_name"] = twixObj.search_header_for_val(
"Phoenix", ("sWipMemBlock", "tFree")
)[0][1:-1]
if hdr["n_contrasts"] > 1:
hdr["turboFactor"] = twixObj.search_header_for_val(
"Phoenix", ("sFastImaging", "lTurboFactor")
)[0]
hdr["type"] = "ARBGRAD_MP2RAGE"
return data, hdr
37 changes: 37 additions & 0 deletions src/mrinufft/io/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Module containing utility functions for IO in MRI NUFFT."""

import numpy as np


def add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, normalized_shifts):
"""
Add phase shifts to k-space data.
Parameters
----------
kspace_data : np.ndarray
The k-space data.
kspace_loc : np.ndarray
The k-space locations.
normalized_shifts : tuple
The normalized shifts to apply to each dimension of k-space.
Returns
-------
ndarray
The k-space data with phase shifts applied.
Raises
------
ValueError
If the dimension of normalized_shifts does not match the number of
dimensions in kspace_loc.
"""
if len(normalized_shifts) != kspace_loc.shape[1]:
raise ValueError(
"Dimension mismatch between shift and kspace locations! "
"Ensure that shifts are right"
)
phi = np.sum(kspace_loc * normalized_shifts, axis=-1)
phase = np.exp(-2 * np.pi * 1j * phi)
return kspace_data * phase
20 changes: 20 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import numpy as np
from mrinufft.io import read_trajectory, write_trajectory
from mrinufft.io.utils import add_phase_to_kspace_with_shifts
from mrinufft.trajectories.trajectory2D import initialize_2D_radial
from mrinufft.trajectories.trajectory3D import initialize_3D_cones
from pytest_cases import parametrize_with_cases
from case_trajectories import CasesTrajectories


class CasesIO:
Expand Down Expand Up @@ -67,3 +69,21 @@ def test_write_n_read(
np.testing.assert_almost_equal(params["FOV"], FOV, decimal=6)
np.testing.assert_equal(params["img_size"], img_size)
np.testing.assert_almost_equal(read_traj, trajectory, decimal=5)


@parametrize_with_cases(
"kspace_loc, shape",
cases=[CasesTrajectories.case_random2D, CasesTrajectories.case_random3D],
)
def test_add_shift(kspace_loc, shape):
"""Test the add_phase_to_kspace_with_shifts function."""
n_samples = np.prod(kspace_loc.shape[:-1])
kspace_data = np.random.randn(n_samples) + 1j * np.random.randn(n_samples)
shifts = np.random.rand(kspace_loc.shape[-1])

shifted_data = add_phase_to_kspace_with_shifts(kspace_data, kspace_loc, shifts)

assert np.allclose(np.abs(shifted_data), np.abs(kspace_data))

phase = np.exp(-2 * np.pi * 1j * np.sum(kspace_loc * shifts, axis=-1))
np.testing.assert_almost_equal(shifted_data / phase, kspace_data, decimal=5)

0 comments on commit 8e15f05

Please sign in to comment.