diff --git a/model_zoo/TCIA_PROSTATEx_Prostate_MRI_Anatomy_Model.ipynb b/model_zoo/TCIA_PROSTATEx_Prostate_MRI_Anatomy_Model.ipynb
index a9fc2da09..351ef703c 100644
--- a/model_zoo/TCIA_PROSTATEx_Prostate_MRI_Anatomy_Model.ipynb
+++ b/model_zoo/TCIA_PROSTATEx_Prostate_MRI_Anatomy_Model.ipynb
@@ -115,18 +115,20 @@
"!python -c \"import monai\" || pip install -q \"monai[nibabel,itk,tqdm,pandas,skimage]\"\n",
"!python -c \"import cv2\" || pip install -q opencv-python-headless\n",
"\n",
- "# These are the libraries used to read DICOM Seg objects.\n",
- "!python -m pip install -q pydicom==2.4.4 pydicom-seg\n",
+ "# install pydicom to read dicom data\n",
+ "!python -m pip install pydicom\n",
"\n",
"# Install tcia_utils to download the datasets.\n",
- "!python -m pip install --upgrade -q --no-deps tcia_utils\n",
+ "!python -m pip install --upgrade -q tcia_utils\n",
"\n",
- "# Install the dependency manually to avoid installing opencv-python.\n",
- "!python -m pip install -q plotly bs4 ipywidgets unidecode jsonschema\n",
- "!python -m pip install -q --no-deps rt-utils\n",
"\n",
"# This is the installation required for itkWidgets.\n",
- "!python -m pip install --upgrade --pre -q \"itkwidgets[all]==1.0a23\" imjoy_elfinder"
+ "# Setup for the imjoy-jupyter-extension for itkWidgets varies in Google Colab,\n",
+ "# Jupyter Notebook and Jupyter Lab. See https://github.com/InsightSoftwareConsortium/itkwidgets/tree/main\n",
+ "# for more details, but the following should install what you need for all 3 environments.\n",
+ "# Afterwards, the ImJoy icon should appear in the top icon bar of Jupyter\n",
+ "# Notebook and Jupyter Lab. No icon will appear in Google Colab.\n",
+ "!python -m pip install --upgrade -q \"itkwidgets[all]>=1.0a55\" \"zarr<3\""
]
},
{
@@ -159,12 +161,9 @@
"# Numpy for numpy.arrays\n",
"import numpy as np\n",
"\n",
- "# Include ITK for DICOM reading.\n",
+ "# Include ITK and pydicom for DICOM reading.\n",
"import itk\n",
- "\n",
- "# Include pydicom_seg for DICOM SEG objects\n",
"import pydicom\n",
- "import pydicom_seg\n",
"\n",
"# for downloading data from TCIA\n",
"from tcia_utils import nbia\n",
@@ -246,7 +245,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"metadata": {
"id": "7wR4HJu-emJP"
},
@@ -255,7 +254,7 @@
"dicom_data_dir = \"tciaDownload\"\n",
"\n",
"# The series_uid defines their directory where the MR data was stored on disk.\n",
- "mr_series_uid = df.at[df.Modality.eq(\"MR\").idxmax(), \"Series UID\"]\n",
+ "mr_series_uid = df.at[df.Modality.eq(\"MR\").idxmax(), \"Series ID\"]\n",
"mr_dir = os.path.join(dicom_data_dir, mr_series_uid)\n",
"\n",
"# Read the DICOM MR series' objects and reconstruct them into a 3D ITK image.\n",
@@ -272,19 +271,66 @@
},
"outputs": [],
"source": [
- "# The series_uid defines where the RTSTRUCT was stored on disk. It is stored in a single file.\n",
- "seg_series_uid = df.at[df.Modality.eq(\"SEG\").idxmax(), \"Series UID\"]\n",
+ "# The series_uid defines where the SEG was stored on disk. It is stored in a single file.\n",
+ "seg_series_uid = df.at[df.Modality.eq(\"SEG\").idxmax(), \"Series ID\"]\n",
"seg_dir = os.path.join(dicom_data_dir, seg_series_uid)\n",
"seg_file = glob.glob(os.path.join(seg_dir, \"*.dcm\"))[0]\n",
"\n",
- "# Read the DICOM SEG object using pydicom and pydicom_seg.\n",
- "seg_dicom = pydicom.dcmread(seg_file)\n",
- "seg_reader = pydicom_seg.MultiClassReader()\n",
- "seg_obj = seg_reader.read(seg_dicom)\n",
- "\n",
- "# Convert the DICOM SEG object into an itk image, with correct voxel origin, spacing, and directions in physical space.\n",
- "seg_image = itk.GetImageFromArray(seg_obj.data.astype(np.float32))\n",
- "seg_image.CopyInformation(mr_image)"
+ "# Read the DICOM SEG object using pydicom.\n",
+ "seg_dcm = pydicom.dcmread(seg_file, force=True)\n",
+ "\n",
+ "# Define the DICOM SEG SOP Class UID\n",
+ "SEG_UID = \"1.2.840.10008.5.1.4.1.1.66.4\"\n",
+ "\n",
+ "# Extract segment metadata and create frame mapping\n",
+ "segment_metadata = {}\n",
+ "frame_mapping = {}\n",
+ "if seg_dcm.SOPClassUID == SEG_UID:\n",
+ " print(\"Detected DICOM SEG file. Loading...\")\n",
+ " stacked_mask_np = seg_dcm.pixel_array\n",
+ " for seg_item in seg_dcm.SegmentSequence:\n",
+ " seg_num, seg_label = seg_item.SegmentNumber, seg_item.SegmentLabel\n",
+ " segment_metadata[seg_num] = seg_label\n",
+ " for i, frame_item in enumerate(seg_dcm.PerFrameFunctionalGroupsSequence):\n",
+ " seg_id_item = frame_item.SegmentIdentificationSequence[0]\n",
+ " segment_number = seg_id_item.ReferencedSegmentNumber\n",
+ " if segment_number not in frame_mapping:\n",
+ " frame_mapping[segment_number] = []\n",
+ " frame_mapping[segment_number].append(i)\n",
+ "\n",
+ " # Create a zero-filled numpy array with the same shape as the MR image and integer type\n",
+ " mr_image_shape = itk.GetArrayViewFromImage(mr_image).shape\n",
+ " label_map_np = np.zeros(mr_image_shape, dtype=np.uint8) # Use uint8 for labels\n",
+ "\n",
+ " # Reconstruct the label map from stacked masks\n",
+ " for segment_number, frame_indices in frame_mapping.items():\n",
+ " # Ensure frame_indices are within the bounds of stacked_mask_np\n",
+ " valid_frame_indices = [idx for idx in frame_indices if idx < stacked_mask_np.shape[0]]\n",
+ " if valid_frame_indices:\n",
+ " segment_mask = stacked_mask_np[valid_frame_indices] > 0\n",
+ " # Assuming the mask shape is compatible with assigning to label_map_np slice\n",
+ " # Need to handle potential shape mismatches or different image orientations/spacings\n",
+ " # For simplicity here, assuming direct slice assignment is possible.\n",
+ " # A more robust approach would involve resampling the segmentation to the MR image grid.\n",
+ " label_map_np[segment_mask] = segment_number\n",
+ " print(\"Successfully created label map from DICOM SEG.\")\n",
+ "\n",
+ " # Convert the numpy array into an itk image, with correct voxel origin, spacing, and directions in physical space.\n",
+ " # Use an appropriate integer data type for the ITK image.\n",
+ " seg_image = itk.GetImageFromArray(label_map_np, is_vector=False)\n",
+ " seg_image.CopyInformation(mr_image)\n",
+ "\n",
+ "else:\n",
+ " print(f\"File is not a DICOM SEG file (SOP Class UID: {seg_dcm.SOPClassUID}). Skipping segmentation loading.\")\n",
+ " # Create an empty ITK image if it's not a SEG file\n",
+ " seg_image = itk.Image[itk.UC, mr_image.GetImageDimension()].New()\n",
+ " seg_image.CopyInformation(mr_image)\n",
+ "\n",
+ "# Display segment metadata if available\n",
+ "if segment_metadata:\n",
+ " print(\"\\nSegment Metadata:\")\n",
+ " for seg_num, seg_label in segment_metadata.items():\n",
+ " print(f\" Segment Number {seg_num}: {seg_label}\")"
]
},
{
@@ -314,7 +360,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"metadata": {
"id": "xcj8cA_ZemJQ"
},
@@ -392,7 +438,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {
"id": "CqoHulOhemJR"
},
@@ -417,7 +463,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"metadata": {
"id": "RF4LqDs-emJR",
"tags": []
@@ -427,11 +473,11 @@
"data": {
"text/html": [
"\n",
- "
\n",
+ "
\n",
" \n",
@@ -446,7 +492,9 @@
},
{
"data": {
- "application/javascript": "window.connectPlugin && window.connectPlugin(\"f2c5a353-92e5-49ea-8f0a-4a5232883799\")",
+ "application/javascript": [
+ "window.connectPlugin && window.connectPlugin(\"7ba53bd6-ca11-4e56-a0dd-f80199194b50\")"
+ ],
"text/plain": [
""
]
@@ -457,7 +505,7 @@
{
"data": {
"text/html": [
- ""
+ ""
],
"text/plain": [
""
@@ -469,10 +517,10 @@
{
"data": {
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 11,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
@@ -521,7 +569,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"metadata": {
"id": "cjLs9PjDemJS"
},
@@ -545,7 +593,7 @@
"# for inference. As a result, the result image may not be in the same\n",
"# spacing, orientation, etc as the original input data. So, we resample the results\n",
"# image to match the physical properties of the original input data.\n",
- "interpolator = itk.NearestNeighborInterpolateImageFunction.New(seg_image)\n",
+ "interpolator = itk.NearestNeighborInterpolateImageFunction.New(seg_image.astype(result_image.dtype))\n",
"result_image_resampled = itk.resample_image_filter(\n",
" Input=result_image, Interpolator=interpolator, reference_image=seg_image_prep, use_reference_image=True\n",
")"
@@ -553,7 +601,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 28,
"metadata": {
"id": "9bmgCBL9emJS",
"tags": []
@@ -563,11 +611,11 @@
"data": {
"text/html": [
"\n",
- "
\n",
+ "
\n",
" \n",
@@ -582,7 +630,9 @@
},
{
"data": {
- "application/javascript": "window.connectPlugin && window.connectPlugin(\"f2c5a353-92e5-49ea-8f0a-4a5232883799\")",
+ "application/javascript": [
+ "window.connectPlugin && window.connectPlugin(\"7ba53bd6-ca11-4e56-a0dd-f80199194b50\")"
+ ],
"text/plain": [
""
]
@@ -593,7 +643,7 @@
{
"data": {
"text/html": [
- ""
+ ""
],
"text/plain": [
""
@@ -605,12 +655,37 @@
],
"source": [
"# View the image with results overlaid in an interactive 2D slice viewer.\n",
- "viewer_b = view(image=mr_image_prep, label_image=result_image_resampled)"
+ "\n",
+ "# Get the unique label values in the model result image (excluding background 0)\n",
+ "unique_result_labels = sorted(\n",
+ " [label for label in np.unique(itk.GetArrayViewFromImage(result_image_resampled)) if label != 0]\n",
+ ")\n",
+ "\n",
+ "# Create lists for itkWidgets for the model results\n",
+ "# The model output labels are likely 1, 2, 3 for Central Gland, Peripheral Zone, Background\n",
+ "# Based on the background info, the model labels are: 0: background, 1: central gland, 2: peripheral zone.\n",
+ "# We need to map these model labels to the original segment names.\n",
+ "# Assuming model output labels 1->Central Gland, 2->Peripheral Zone based on background info.\n",
+ "# The expert labels are 1: Peripheral zone, 2: Transition zone, 3: Prostatic Urethra, 4: Anterior fibromuscular stroma.\n",
+ "# The comparison image labels are 1, 2, 3.\n",
+ "\n",
+ "# Let's create names for the model output labels based on the model's documentation/description.\n",
+ "# Assuming model labels 1, 2 correspond to Central Gland and Peripheral Zone respectively.\n",
+ "model_label_names = [\"Background\", \"Central Gland\", \"Peripheral Zone\"]\n",
+ "model_label_weights = [0.0, 1.0, 1.0] # Weight 0 for background, 1 for others\n",
+ "\n",
+ "\n",
+ "viewer_b = view(\n",
+ " image=mr_image_prep,\n",
+ " label_image=result_image_resampled,\n",
+ " label_image_names=model_label_names,\n",
+ " label_image_weights=model_label_weights,\n",
+ ")"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 29,
"metadata": {
"id": "mYRlAbKusf-L",
"tags": []
@@ -626,7 +701,7 @@
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": 30,
"metadata": {
"id": "2wHOPVkWemJS",
"tags": []
@@ -636,11 +711,11 @@
"data": {
"text/html": [
"\n",
- "
\n",
+ "
\n",
" \n",
@@ -655,7 +730,9 @@
},
{
"data": {
- "application/javascript": "window.connectPlugin && window.connectPlugin(\"f2c5a353-92e5-49ea-8f0a-4a5232883799\")",
+ "application/javascript": [
+ "window.connectPlugin && window.connectPlugin(\"7ba53bd6-ca11-4e56-a0dd-f80199194b50\")"
+ ],
"text/plain": [
""
]
@@ -666,7 +743,7 @@
{
"data": {
"text/html": [
- ""
+ ""
],
"text/plain": [
""
@@ -696,7 +773,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": 31,
"metadata": {
"id": "diqDQW4PemJS"
},
@@ -763,9 +840,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.12"
+ "version": "3.12.11"
}
},
"nbformat": 4,
- "nbformat_minor": 1
+ "nbformat_minor": 4
}