Skip to content

Commit a4fa128

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0a3fd79 commit a4fa128

File tree

8 files changed

+572
-565
lines changed

8 files changed

+572
-565
lines changed

monailabel/datastore/utils/convert.py

Lines changed: 223 additions & 198 deletions
Large diffs are not rendered by default.

monailabel/transform/reader.py

Lines changed: 47 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
import warnings
1818
from collections.abc import Sequence
1919
from typing import TYPE_CHECKING, Any
20-
from packaging import version
20+
2121
import numpy as np
2222
from monai.config import PathLike
2323
from monai.data import ImageReader
2424
from monai.data.image_reader import _copy_compatible_dict, _stack_images
2525
from monai.data.utils import orientation_ras_lps
2626
from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg
27+
from packaging import version
2728
from torch.utils.data._utils.collate import np_str_obj_array_pattern
2829

2930
logger = logging.getLogger(__name__)
@@ -56,11 +57,11 @@ def _get_nvimgcodec_decoder():
5657
"""Get or create a thread-local nvimgcodec decoder singleton."""
5758
if not has_nvimgcodec:
5859
raise RuntimeError("nvimgcodec is not available. Cannot create decoder.")
59-
60-
if not hasattr(_thread_local, 'decoder') or _thread_local.decoder is None:
60+
61+
if not hasattr(_thread_local, "decoder") or _thread_local.decoder is None:
6162
_thread_local.decoder = nvimgcodec.Decoder()
6263
logger.debug(f"Initialized thread-local nvimgcodec.Decoder for thread {threading.current_thread().name}")
63-
64+
6465
return _thread_local.decoder
6566

6667

@@ -215,28 +216,28 @@ def _dir_contains_dcm(path):
215216
def _apply_rescale_and_dtype(self, pixel_data, ds, original_dtype):
216217
"""
217218
Apply DICOM rescale slope/intercept and handle dtype preservation.
218-
219+
219220
Args:
220221
pixel_data: numpy or cupy array of pixel data
221222
ds: pydicom dataset containing RescaleSlope/RescaleIntercept tags
222223
original_dtype: original dtype before any processing
223-
224+
224225
Returns:
225226
Processed pixel data array (potentially rescaled and dtype converted)
226227
"""
227228
# Detect array library (numpy or cupy)
228229
xp = cp if hasattr(pixel_data, "__cuda_array_interface__") else np
229-
230+
230231
# Check if rescaling is needed
231232
has_rescale = hasattr(ds, "RescaleSlope") and hasattr(ds, "RescaleIntercept")
232-
233+
233234
if has_rescale:
234235
slope = float(ds.RescaleSlope)
235236
intercept = float(ds.RescaleIntercept)
236237
slope = xp.asarray(slope, dtype=xp.float32)
237238
intercept = xp.asarray(intercept, dtype=xp.float32)
238239
pixel_data = pixel_data.astype(xp.float32) * slope + intercept
239-
240+
240241
# Convert back to original dtype if requested (matching ITK behavior)
241242
if self.preserve_dtype:
242243
# Determine target dtype based on original and rescale
@@ -254,7 +255,7 @@ def _apply_rescale_and_dtype(self, pixel_data, ds, original_dtype):
254255
# Preserve original dtype for other types
255256
target_dtype = original_dtype
256257
pixel_data = pixel_data.astype(target_dtype)
257-
258+
258259
return pixel_data
259260

260261
def _is_nvimgcodec_supported_syntax(self, img):
@@ -298,8 +299,8 @@ def _is_nvimgcodec_supported_syntax(self, img):
298299
]
299300

300301
jpeg_lossless_syntaxes = [
301-
'1.2.840.10008.1.2.4.57', # JPEG Lossless, Non-Hierarchical (Process 14)
302-
'1.2.840.10008.1.2.4.70', # JPEG Lossless, Non-Hierarchical, First-Order Prediction
302+
"1.2.840.10008.1.2.4.57", # JPEG Lossless, Non-Hierarchical (Process 14)
303+
"1.2.840.10008.1.2.4.70", # JPEG Lossless, Non-Hierarchical, First-Order Prediction
303304
]
304305

305306
return str(transfer_syntax) in jpeg2000_syntaxes + htj2k_syntaxes + jpeg_lossy_syntaxes + jpeg_lossless_syntaxes
@@ -526,15 +527,15 @@ def series_sort_key(series_uid):
526527
slices_no_pos.append((inst_num, fp, ds))
527528
slices_no_pos.sort(key=lambda s: s[0])
528529
sorted_filepaths = [fp for _, fp, _ in slices_no_pos]
529-
530+
530531
# Read all DICOM files for the series and store as a list of Datasets
531532
# This allows _process_dicom_series() to handle the series as a whole
532533
logger.info(f"NvDicomReader: Series contains {len(sorted_filepaths)} slices")
533534
series_datasets = []
534535
for fpath in sorted_filepaths:
535536
ds = pydicom.dcmread(fpath, **kwargs_)
536537
series_datasets.append(ds)
537-
538+
538539
# Append the list of datasets as a single series
539540
img_.append(series_datasets)
540541
self.filenames.extend(sorted_filepaths)
@@ -601,7 +602,7 @@ def get_data(self, img) -> tuple[np.ndarray, dict]:
601602
data_array = self._get_array_data(ds_or_list)
602603
metadata = self._get_meta_dict(ds_or_list)
603604
metadata[MetaKeys.SPATIAL_SHAPE] = np.asarray(data_array.shape)
604-
605+
605606
# Calculate spacing for single-frame images
606607
pixel_spacing = ds_or_list.PixelSpacing if hasattr(ds_or_list, "PixelSpacing") else [1.0, 1.0]
607608
slice_spacing = float(ds_or_list.SliceThickness) if hasattr(ds_or_list, "SliceThickness") else 1.0
@@ -645,7 +646,7 @@ def _process_dicom_series(self, datasets: list) -> tuple[np.ndarray, dict]:
645646
needs_rescale = hasattr(first_ds, "RescaleSlope") and hasattr(first_ds, "RescaleIntercept")
646647
rows = first_ds.Rows
647648
cols = first_ds.Columns
648-
649+
649650
# For multi-frame DICOMs, depth is the total number of frames, not the number of files
650651
# For single-frame DICOMs, depth is the number of files
651652
depth = 0
@@ -786,46 +787,48 @@ def _process_dicom_series(self, datasets: list) -> tuple[np.ndarray, dict]:
786787
if depth > 1:
787788
# For multi-frame DICOM, calculate spacing from per-frame positions
788789
is_multiframe = len(datasets) == 1 and hasattr(first_ds, "NumberOfFrames") and first_ds.NumberOfFrames > 1
789-
790+
790791
if is_multiframe and hasattr(first_ds, "PerFrameFunctionalGroupsSequence"):
791792
# Multi-frame DICOM: extract positions from PerFrameFunctionalGroupsSequence
792793
average_distance = 0.0
793794
positions = []
794-
795+
795796
try:
796797
# Extract all frame positions
797798
for frame_idx, frame in enumerate(first_ds.PerFrameFunctionalGroupsSequence):
798799
# Try to get PlanePositionSequence
799800
plane_pos_seq = None
800801
if hasattr(frame, "PlanePositionSequence"):
801802
plane_pos_seq = frame.PlanePositionSequence
802-
elif hasattr(frame, 'get'):
803+
elif hasattr(frame, "get"):
803804
plane_pos_seq = frame.get("PlanePositionSequence")
804-
805+
805806
if plane_pos_seq and len(plane_pos_seq) > 0:
806807
plane_pos_item = plane_pos_seq[0]
807808
if hasattr(plane_pos_item, "ImagePositionPatient"):
808809
ipp = plane_pos_item.ImagePositionPatient
809810
z_pos = float(ipp[2])
810811
positions.append(z_pos)
811-
812+
812813
# Calculate average distance between consecutive positions
813814
if len(positions) > 1:
814815
for i in range(1, len(positions)):
815-
average_distance += abs(positions[i] - positions[i-1])
816+
average_distance += abs(positions[i] - positions[i - 1])
816817
slice_spacing = average_distance / (len(positions) - 1)
817818
else:
818-
logger.warning(f"NvDicomReader: Only found {len(positions)} positions, cannot calculate spacing")
819+
logger.warning(
820+
f"NvDicomReader: Only found {len(positions)} positions, cannot calculate spacing"
821+
)
819822
slice_spacing = 1.0
820-
823+
821824
except Exception as e:
822825
logger.warning(f"NvDicomReader: Failed to calculate spacing from per-frame positions: {e}")
823826
# Fallback to SliceThickness or default
824827
if hasattr(first_ds, "SliceThickness"):
825828
slice_spacing = float(first_ds.SliceThickness)
826829
else:
827830
slice_spacing = 1.0
828-
831+
829832
elif len(datasets) > 1 and hasattr(first_ds, "ImagePositionPatient"):
830833
# Multiple single-frame DICOMs: calculate from dataset positions
831834
average_distance = 0.0
@@ -836,8 +839,10 @@ def _process_dicom_series(self, datasets: list) -> tuple[np.ndarray, dict]:
836839
average_distance += abs(curr_pos - prev_pos)
837840
prev_pos = curr_pos
838841
slice_spacing = average_distance / (len(datasets) - 1)
839-
logger.info(f"NvDicomReader: Calculated slice spacing from {len(datasets)} datasets: {slice_spacing:.4f}")
840-
842+
logger.info(
843+
f"NvDicomReader: Calculated slice spacing from {len(datasets)} datasets: {slice_spacing:.4f}"
844+
)
845+
841846
elif hasattr(first_ds, "SliceThickness"):
842847
# Fallback to SliceThickness tag if positions unavailable
843848
slice_spacing = float(first_ds.SliceThickness)
@@ -850,14 +855,14 @@ def _process_dicom_series(self, datasets: list) -> tuple[np.ndarray, dict]:
850855

851856
# Build metadata
852857
metadata = self._get_meta_dict(first_ds)
853-
858+
854859
metadata["spacing"] = np.array([float(pixel_spacing[1]), float(pixel_spacing[0]), slice_spacing])
855860
# Metadata should always use numpy arrays, even if data is on GPU
856861
metadata[MetaKeys.SPATIAL_SHAPE] = np.asarray(volume.shape)
857862

858863
# Store last position for affine calculation
859864
last_ds = datasets[-1]
860-
865+
861866
# For multi-frame DICOM, try to get the last frame's position from PerFrameFunctionalGroupsSequence
862867
is_multiframe = hasattr(last_ds, "NumberOfFrames") and last_ds.NumberOfFrames > 1
863868
if is_multiframe and hasattr(last_ds, "PerFrameFunctionalGroupsSequence"):
@@ -901,9 +906,7 @@ def _get_array_data(self, ds):
901906
original_dtype = pixel_array.dtype
902907
logger.info(f"NvDicomReader: Successfully decoded with nvImageCodec")
903908
except Exception as e:
904-
logger.warning(
905-
f"NvDicomReader: nvImageCodec decoding failed: {e}, falling back to pydicom"
906-
)
909+
logger.warning(f"NvDicomReader: nvImageCodec decoding failed: {e}, falling back to pydicom")
907910
pixel_array = ds.pixel_array
908911
original_dtype = pixel_array.dtype
909912
else:
@@ -965,13 +968,13 @@ def _get_meta_dict(self, ds) -> dict:
965968

966969
# Also store essential spatial tags with readable names
967970
# (for convenience and backward compatibility)
968-
971+
969972
# For multi-frame (Enhanced) DICOM, extract per-frame metadata from the first frame
970973
is_multiframe = hasattr(ds, "NumberOfFrames") and ds.NumberOfFrames > 1
971974
if is_multiframe and hasattr(ds, "PerFrameFunctionalGroupsSequence"):
972975
try:
973976
first_frame = ds.PerFrameFunctionalGroupsSequence[0]
974-
977+
975978
# Helper function to safely access sequence items (handles both attribute and dict access)
976979
def get_sequence_item(obj, seq_name, item_idx=0):
977980
"""Get item from a sequence, handling both attribute and dict access."""
@@ -980,48 +983,49 @@ def get_sequence_item(obj, seq_name, item_idx=0):
980983
if hasattr(obj, seq_name):
981984
seq = getattr(obj, seq_name, None)
982985
# Try dict-style access
983-
elif hasattr(obj, 'get'):
986+
elif hasattr(obj, "get"):
984987
seq = obj.get(seq_name)
985-
elif hasattr(obj, '__getitem__'):
988+
elif hasattr(obj, "__getitem__"):
986989
try:
987990
seq = obj[seq_name]
988991
except (KeyError, TypeError):
989992
pass
990-
993+
991994
if seq and len(seq) > item_idx:
992995
return seq[item_idx]
993996
return None
994-
997+
995998
# Extract ImageOrientationPatient from per-frame sequence
996999
plane_orient_item = get_sequence_item(first_frame, "PlaneOrientationSequence")
9971000
if plane_orient_item and hasattr(plane_orient_item, "ImageOrientationPatient"):
9981001
iop = plane_orient_item.ImageOrientationPatient
9991002
metadata["ImageOrientationPatient"] = list(iop)
1000-
1003+
10011004
# Extract ImagePositionPatient from per-frame sequence
10021005
plane_pos_item = get_sequence_item(first_frame, "PlanePositionSequence")
10031006
if plane_pos_item and hasattr(plane_pos_item, "ImagePositionPatient"):
10041007
ipp = plane_pos_item.ImagePositionPatient
10051008
metadata["ImagePositionPatient"] = list(ipp)
10061009
else:
10071010
logger.warning(f"NvDicomReader: PlanePositionSequence not found or empty")
1008-
1011+
10091012
# Extract PixelSpacing from per-frame sequence
10101013
pixel_measures_item = get_sequence_item(first_frame, "PixelMeasuresSequence")
10111014
if pixel_measures_item and hasattr(pixel_measures_item, "PixelSpacing"):
10121015
ps = pixel_measures_item.PixelSpacing
10131016
metadata["PixelSpacing"] = list(ps)
1014-
1017+
10151018
# Also check SliceThickness from PixelMeasuresSequence
10161019
if pixel_measures_item and hasattr(pixel_measures_item, "SliceThickness"):
10171020
st = pixel_measures_item.SliceThickness
10181021
metadata["SliceThickness"] = float(st)
1019-
1022+
10201023
except Exception as e:
10211024
logger.warning(f"NvDicomReader: Failed to extract per-frame metadata: {e}, falling back to top-level")
10221025
import traceback
1026+
10231027
logger.warning(f"NvDicomReader: Traceback: {traceback.format_exc()}")
1024-
1028+
10251029
# Fall back to top-level attributes if not extracted from per-frame sequence
10261030
if hasattr(ds, "ImageOrientationPatient") and "ImageOrientationPatient" not in metadata:
10271031
metadata["ImageOrientationPatient"] = list(ds.ImageOrientationPatient)

0 commit comments

Comments
 (0)