diff --git a/monailabel/config.py b/monailabel/config.py index 4de6c896f..ea8d1c37e 100644 --- a/monailabel/config.py +++ b/monailabel/config.py @@ -18,7 +18,7 @@ def is_package_installed(name): - return name in (x.metadata.get("Name") for x in distributions()) + return name in (x.metadata.get("Name") for x in distributions() if x.metadata is not None) class Settings(BaseSettings): diff --git a/monailabel/datastore/utils/convert.py b/monailabel/datastore/utils/convert.py index f5429a1ef..0b11f7337 100644 --- a/monailabel/datastore/utils/convert.py +++ b/monailabel/datastore/utils/convert.py @@ -9,55 +9,220 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import json import logging import os import pathlib import tempfile import time +from random import randint import numpy as np import pydicom -import pydicom_seg import SimpleITK from monai.transforms import LoadImage from pydicom.filereader import dcmread +from pydicom.sr.codedict import codes +try: + import highdicom as hd + from pydicom.sr.coding import Code + + HIGHDICOM_AVAILABLE = True +except ImportError: + HIGHDICOM_AVAILABLE = False + hd = None + Code = None + +from monailabel import __version__ from monailabel.config import settings from monailabel.datastore.utils.colors import GENERIC_ANATOMY_COLORS from monailabel.transform.writer import write_itk -from monailabel.utils.others.generic import run_command logger = logging.getLogger(__name__) +class SegmentDescription: + """Wrapper class for segment description following MONAI Deploy pattern. + + This class encapsulates segment metadata and can convert to either: + - highdicom.seg.SegmentDescription for the primary highdicom-based conversion + - dcmqi JSON dict for ITK/dcmqi-based conversion (legacy fallback) + """ + + def __init__( + self, + segment_label, + segmented_property_category=None, + segmented_property_type=None, + algorithm_name="MONAILABEL", + algorithm_version="1.0", + segment_description=None, + recommended_display_rgb_value=None, + label_id=None, + ): + """Initialize segment description. + + Args: + segment_label: Label for the segment (e.g., "Spleen") + segmented_property_category: Code for category (e.g., codes.SCT.Organ) + segmented_property_type: Code for type (e.g., codes.SCT.Spleen) + algorithm_name: Name of the algorithm + algorithm_version: Version of the algorithm + segment_description: Optional description text + recommended_display_rgb_value: RGB color tuple [R, G, B] + label_id: Numeric label ID + """ + self.segment_label = segment_label + # Use default category if not provided (safe fallback) + if segmented_property_category is None: + try: + self.segmented_property_category = codes.SCT.Organ + except Exception: + self.segmented_property_category = None + else: + self.segmented_property_category = segmented_property_category + self.segmented_property_type = segmented_property_type + self.algorithm_name = algorithm_name + self.algorithm_version = algorithm_version + self.segment_description = segment_description or segment_label + self.recommended_display_rgb_value = recommended_display_rgb_value or [255, 0, 0] + self.label_id = label_id + + def to_highdicom_description(self, segment_number): + """Convert to highdicom SegmentDescription object. + + Args: + segment_number: Segment number (1-based) + + Returns: + hd.seg.SegmentDescription object + """ + if not HIGHDICOM_AVAILABLE: + raise ImportError("highdicom is not available") + + return hd.seg.SegmentDescription( + segment_number=segment_number, + segment_label=self.segment_label, + segmented_property_category=self.segmented_property_category, + segmented_property_type=self.segmented_property_type, + algorithm_identification=hd.AlgorithmIdentificationSequence( + name=self.algorithm_name, + family=codes.DCM.ArtificialIntelligence, + version=self.algorithm_version, + ), + algorithm_type="AUTOMATIC", + ) + + def to_dcmqi_dict(self): + """Convert to dcmqi JSON dict for ITK-based conversion. + + Returns: + Dictionary compatible with dcmqi itkimage2segimage + """ + # Extract code values from pydicom Code objects + if hasattr(self.segmented_property_type, "value"): + type_code_value = self.segmented_property_type.value + type_scheme = self.segmented_property_type.scheme_designator + type_meaning = self.segmented_property_type.meaning + else: + type_code_value = "78961009" + type_scheme = "SCT" + type_meaning = self.segment_label + + return { + "labelID": self.label_id if self.label_id is not None else 1, + "SegmentLabel": self.segment_label, + "SegmentDescription": self.segment_description, + "SegmentAlgorithmType": "AUTOMATIC", + "SegmentAlgorithmName": self.algorithm_name, + "SegmentedPropertyCategoryCodeSequence": { + "CodeValue": "123037004", + "CodingSchemeDesignator": "SCT", + "CodeMeaning": "Anatomical Structure", + }, + "SegmentedPropertyTypeCodeSequence": { + "CodeValue": type_code_value, + "CodingSchemeDesignator": type_scheme, + "CodeMeaning": type_meaning, + }, + "recommendedDisplayRGBValue": self.recommended_display_rgb_value, + } + + +def random_with_n_digits(n): + """Generate a random number with n digits.""" + n = n if n >= 1 else 1 + range_start = 10 ** (n - 1) + range_end = (10**n) - 1 + return randint(range_start, range_end) + + def dicom_to_nifti(series_dir, is_seg=False): start = time.time() + t_load = t_cpu = t_write = None if is_seg: output_file = dicom_seg_to_itk_image(series_dir) else: - # https://simpleitk.readthedocs.io/en/master/link_DicomConvert_docs.html - if os.path.isdir(series_dir) and len(os.listdir(series_dir)) > 1: - reader = SimpleITK.ImageSeriesReader() - dicom_names = reader.GetGDCMSeriesFileNames(series_dir) - reader.SetFileNames(dicom_names) - image = reader.Execute() - else: - filename = ( - series_dir if not os.path.isdir(series_dir) else os.path.join(series_dir, os.listdir(series_dir)[0]) - ) - - file_reader = SimpleITK.ImageFileReader() - file_reader.SetImageIO("GDCMImageIO") - file_reader.SetFileName(filename) - image = file_reader.Execute() - - logger.info(f"Image size: {image.GetSize()}") - output_file = tempfile.NamedTemporaryFile(suffix=".nii.gz").name - SimpleITK.WriteImage(image, output_file) - - logger.info(f"dicom_to_nifti latency : {time.time() - start} (sec)") + # Use NvDicomReader for better DICOM handling with GPU acceleration + logger.info(f"dicom_to_nifti: Converting DICOM from {series_dir} using NvDicomReader") + + try: + from monailabel.transform.reader import NvDicomReader + + # Use NvDicomReader with LoadImage + reader = NvDicomReader() + loader = LoadImage(reader=reader, image_only=False) + + # Load the DICOM (supports both directories and single files) + t0 = time.time() + image_data, metadata = loader(series_dir) + t_load = time.time() - t0 + logger.info(f"dicom_to_nifti: LoadImage time: {t_load:.3f} sec") + + t1 = time.time() + image_data = image_data.cpu().numpy() + t_cpu = time.time() - t1 + logger.info(f"dicom_to_nifti: to.cpu().numpy() time: {t_cpu:.3f} sec") + + # Save as NIfTI using MONAI's write_itk + output_file = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False).name + + # Get affine from metadata if available + affine = metadata.get("affine", metadata.get("original_affine", np.eye(4))) + + t2 = time.time() + # Use write_itk which handles the conversion properly + write_itk(image_data, output_file, affine, image_data.dtype, compress=True) + t_write = time.time() - t2 + logger.info(f"dicom_to_nifti: write_itk time: {t_write:.3f} sec") + + except Exception as e: + logger.warning(f"dicom_to_nifti: NvDicomReader failed: {e}, falling back to SimpleITK") + + # Fallback to SimpleITK + if os.path.isdir(series_dir) and len(os.listdir(series_dir)) > 1: + reader = SimpleITK.ImageSeriesReader() + dicom_names = reader.GetGDCMSeriesFileNames(series_dir) + reader.SetFileNames(dicom_names) + image = reader.Execute() + else: + filename = ( + series_dir if not os.path.isdir(series_dir) else os.path.join(series_dir, os.listdir(series_dir)[0]) + ) + file_reader = SimpleITK.ImageFileReader() + file_reader.SetImageIO("GDCMImageIO") + file_reader.SetFileName(filename) + image = file_reader.Execute() + + logger.info(f"dicom_to_nifti: Image size: {image.GetSize()}") + output_file = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False).name + SimpleITK.WriteImage(image, output_file) + + latency = time.time() - start + logger.info(f"dicom_to_nifti latency: {latency:.3f} sec") return output_file @@ -81,14 +246,38 @@ def binary_to_image(reference_image, label, dtype=np.uint8, file_ext=".nii.gz"): return output_file -def nifti_to_dicom_seg(series_dir, label, label_info, file_ext="*", use_itk=None) -> str: +def nifti_to_dicom_seg( + series_dir, label, label_info, file_ext="*", use_itk=None, omit_empty_frames=False, custom_tags=None +) -> str: + """Convert NIfTI segmentation to DICOM SEG format using highdicom or ITK (fallback). + This function uses highdicom by default for creating DICOM SEG objects. + The ITK/dcmqi method is available as a fallback option (use_itk=True). + + Args: + series_dir: Directory containing source DICOM images + label: Path to NIfTI label file + label_info: List of dictionaries containing segment information + file_ext: File extension pattern for DICOM files (default: "*") + use_itk: If True, use ITK/dcmqi-based conversion (fallback). If False or None, use highdicom (default). + omit_empty_frames: If True, omit frames with no segmented pixels (default: False to match legacy behavior) + custom_tags: Optional dictionary of custom DICOM tags to add (keyword: value) + + Returns: + Path to output DICOM SEG file + """ # Only use config if no explicit override if use_itk is None: use_itk = settings.MONAI_LABEL_USE_ITK_FOR_DICOM_SEG start = time.time() + # Check if highdicom is available (unless using ITK fallback) + if not use_itk and not HIGHDICOM_AVAILABLE: + logger.warning("highdicom not available, falling back to ITK method") + use_itk = True + + # Load label and get unique segments label_np, meta_dict = LoadImage(image_only=False)(label) unique_labels = np.unique(label_np.flatten()).astype(np.int_) unique_labels = unique_labels[unique_labels != 0] @@ -96,93 +285,259 @@ def nifti_to_dicom_seg(series_dir, label, label_info, file_ext="*", use_itk=None info = label_info[0] if label_info and 0 < len(label_info) else {} model_name = info.get("model_name", "AIName") - segment_attributes = [] + if not unique_labels.size: + logger.error("No non-zero labels found in segmentation") + return "" + + # Build segment descriptions + segment_descriptions = [] for i, idx in enumerate(unique_labels): info = label_info[i] if label_info and i < len(label_info) else {} - name = info.get("name", "unknown") - description = info.get("description", "Unknown") - rgb = list(info.get("color", GENERIC_ANATOMY_COLORS.get(name, (255, 0, 0))))[0:3] - rgb = [int(x) for x in rgb] - - logger.info(f"{i} => {idx} => {name}") - - segment_attribute = info.get( - "segmentAttribute", - { - "labelID": int(idx), - "SegmentLabel": name, - "SegmentDescription": description, - "SegmentAlgorithmType": "AUTOMATIC", - "SegmentAlgorithmName": "MONAILABEL", - "SegmentedPropertyCategoryCodeSequence": { - "CodeValue": "123037004", - "CodingSchemeDesignator": "SCT", - "CodeMeaning": "Anatomical Structure", + name = info.get("name", f"Segment_{idx}") + description = info.get("description", name) + + logger.info(f"Segment {i}: idx={idx}, name={name}") + + if use_itk: + # Build template for ITK method + rgb = list(info.get("color", GENERIC_ANATOMY_COLORS.get(name, (255, 0, 0))))[0:3] + rgb = [int(x) for x in rgb] + + segment_attr = info.get( + "segmentAttribute", + { + "labelID": int(idx), + "SegmentLabel": name, + "SegmentDescription": description, + "SegmentAlgorithmType": "AUTOMATIC", + "SegmentAlgorithmName": "MONAILABEL", + "SegmentedPropertyCategoryCodeSequence": { + "CodeValue": "123037004", + "CodingSchemeDesignator": "SCT", + "CodeMeaning": "Anatomical Structure", + }, + "SegmentedPropertyTypeCodeSequence": { + "CodeValue": "78961009", + "CodingSchemeDesignator": "SCT", + "CodeMeaning": name, + }, + "recommendedDisplayRGBValue": rgb, }, - "SegmentedPropertyTypeCodeSequence": { - "CodeValue": "78961009", - "CodingSchemeDesignator": "SCT", - "CodeMeaning": name, - }, - "recommendedDisplayRGBValue": rgb, - }, - ) - segment_attributes.append(segment_attribute) - - template = { - "ContentCreatorName": "Reader1", - "ClinicalTrialSeriesID": "Session1", - "ClinicalTrialTimePointID": "1", - "SeriesDescription": model_name, - "SeriesNumber": "300", - "InstanceNumber": "1", - "segmentAttributes": [segment_attributes], - "ContentLabel": "SEGMENTATION", - "ContentDescription": "MONAI Label - Image segmentation", - "ClinicalTrialCoordinatingCenterName": "MONAI", - "BodyPartExamined": "", - } - - logger.info(json.dumps(template, indent=2)) - if not segment_attributes: - logger.error("Missing Attributes/Empty Label provided") + ) + segment_descriptions.append(segment_attr) + else: + # Build highdicom SegmentDescription + # Get codes from label_info or use defaults + category_code = codes.SCT.Organ # Default: Organ + type_code_dict = info.get("SegmentedPropertyTypeCodeSequence", {}) + + if type_code_dict and isinstance(type_code_dict, dict): + type_code = Code( + value=type_code_dict.get("CodeValue", "78961009"), + scheme_designator=type_code_dict.get("CodingSchemeDesignator", "SCT"), + meaning=type_code_dict.get("CodeMeaning", name), + ) + else: + # Default type code + type_code = Code("78961009", "SCT", name) + + # Create highdicom segment description + seg_desc = hd.seg.SegmentDescription( + segment_number=int(idx), + segment_label=name, + segmented_property_category=category_code, + segmented_property_type=type_code, + algorithm_identification=hd.AlgorithmIdentificationSequence( + name="MONAILABEL", family=codes.DCM.ArtificialIntelligence, version=model_name + ), + algorithm_type="AUTOMATIC", + ) + segment_descriptions.append(seg_desc) + + if not segment_descriptions: + logger.error("Missing segment descriptions") return "" if use_itk: + # Use ITK method + template = { + "ContentCreatorName": "Reader1", + "ClinicalTrialSeriesID": "Session1", + "ClinicalTrialTimePointID": "1", + "SeriesDescription": model_name, + "SeriesNumber": "300", + "InstanceNumber": "1", + "segmentAttributes": [segment_descriptions], + "ContentLabel": "SEGMENTATION", + "ContentDescription": "MONAI Label - Image segmentation", + "ClinicalTrialCoordinatingCenterName": "MONAI", + "BodyPartExamined": "", + } + logger.info(json.dumps(template, indent=2)) output_file = itk_image_to_dicom_seg(label, series_dir, template) else: - template = pydicom_seg.template.from_dcmqi_metainfo(template) - writer = pydicom_seg.MultiClassWriter( - template=template, - inplane_cropping=False, - skip_empty_slices=False, - skip_missing_segment=False, - ) - - # Read source Images + # Use highdicom method + # Read source DICOM images (headers only for memory efficiency) series_dir = pathlib.Path(series_dir) - image_files = series_dir.glob(file_ext) - image_datasets = [dcmread(str(f), stop_before_pixels=True) for f in image_files] + image_files = list(series_dir.glob(file_ext)) + image_datasets = [dcmread(str(f), stop_before_pixels=True) for f in sorted(image_files)] logger.info(f"Total Source Images: {len(image_datasets)}") + if not image_datasets: + logger.error(f"No DICOM images found in {series_dir} with pattern {file_ext}") + return "" + + # Load label using SimpleITK and convert to numpy array + # Use uint16 to support up to 65,535 segments mask = SimpleITK.ReadImage(label) mask = SimpleITK.Cast(mask, SimpleITK.sitkUInt16) - output_file = tempfile.NamedTemporaryFile(suffix=".dcm").name - dcm = writer.write(mask, image_datasets) - dcm.save_as(output_file) + # Convert to numpy array for highdicom + seg_array = SimpleITK.GetArrayFromImage(mask) + + # Remap label values to sequential 1, 2, 3... as required by highdicom + # (highdicom requires explicit sequential remapping) + remapped_array = np.zeros_like(seg_array, dtype=np.uint16) + for new_idx, orig_idx in enumerate(unique_labels, start=1): + remapped_array[seg_array == orig_idx] = new_idx + seg_array = remapped_array + + # Generate SOP instance UID + seg_sop_instance_uid = hd.UID() + + # Create DICOM SEG using highdicom + try: + # Get software version + try: + software_version = f"MONAI Label {__version__}" + except Exception: + software_version = "MONAI Label" + + seg = hd.seg.Segmentation( + source_images=image_datasets, + pixel_array=seg_array, + segmentation_type=hd.seg.SegmentationTypeValues.BINARY, + segment_descriptions=segment_descriptions, + series_instance_uid=hd.UID(), + series_number=random_with_n_digits(4), + sop_instance_uid=seg_sop_instance_uid, + instance_number=1, + manufacturer="MONAI Consortium", + manufacturer_model_name="MONAI Label", + software_versions=software_version, + device_serial_number="0000", + omit_empty_frames=omit_empty_frames, + ) - logger.info(f"nifti_to_dicom_seg latency : {time.time() - start} (sec)") + # Add timestamp and timezone + dt_now = datetime.datetime.now() + seg.SeriesDate = dt_now.strftime("%Y%m%d") + seg.SeriesTime = dt_now.strftime("%H%M%S") + seg.TimezoneOffsetFromUTC = dt_now.astimezone().isoformat()[-6:].replace(":", "") # Format: +0000 or -0700 + seg.SeriesDescription = model_name + + # Add Contributing Equipment Sequence (following MONAI Deploy pattern) + try: + from pydicom.dataset import Dataset + from pydicom.sequence import Sequence as PyDicomSequence + + # Create Purpose of Reference Code Sequence + seq_purpose_of_reference_code = PyDicomSequence() + seg_purpose_of_reference_code = Dataset() + seg_purpose_of_reference_code.CodeValue = "Newcode1" + seg_purpose_of_reference_code.CodingSchemeDesignator = "99IHE" + seg_purpose_of_reference_code.CodeMeaning = "Processing Algorithm" + seq_purpose_of_reference_code.append(seg_purpose_of_reference_code) + + # Create Contributing Equipment Sequence + seq_contributing_equipment = PyDicomSequence() + seg_contributing_equipment = Dataset() + seg_contributing_equipment.PurposeOfReferenceCodeSequence = seq_purpose_of_reference_code + seg_contributing_equipment.Manufacturer = "MONAI Consortium" + seg_contributing_equipment.ManufacturerModelName = model_name + seg_contributing_equipment.SoftwareVersions = software_version + seg_contributing_equipment.DeviceUID = hd.UID() + seq_contributing_equipment.append(seg_contributing_equipment) + seg.ContributingEquipmentSequence = seq_contributing_equipment + except Exception as e: + logger.warning(f"Could not add ContributingEquipmentSequence: {e}") + + # Add custom tags if provided (following MONAI Deploy pattern) + if custom_tags: + for k, v in custom_tags.items(): + if isinstance(k, str) and isinstance(v, str): + try: + if k in seg: + data_element = seg.data_element(k) + if data_element: + data_element.value = v + else: + seg.update({k: v}) + except Exception as ex: + logger.warning(f"Custom tag {k} was not written, due to {ex}") + + # Save DICOM SEG + output_file = tempfile.NamedTemporaryFile(suffix=".dcm", delete=False).name + seg.save_as(output_file) + logger.info(f"DICOM SEG saved to: {output_file}") + + except Exception as e: + logger.error(f"Failed to create DICOM SEG with highdicom: {e}") + logger.info("Falling back to ITK method") + # Fallback to ITK method + template = { + "ContentCreatorName": "Reader1", + "SeriesDescription": model_name, + "SeriesNumber": "300", + "InstanceNumber": "1", + "segmentAttributes": [ + [ + { + "labelID": int(idx), + "SegmentLabel": info.get("name", f"Segment_{idx}"), + "SegmentDescription": info.get("description", ""), + "SegmentAlgorithmType": "AUTOMATIC", + "SegmentAlgorithmName": "MONAILABEL", + } + for idx, info in zip(unique_labels, label_info or []) + ] + ], + "ContentLabel": "SEGMENTATION", + "ContentDescription": "MONAI Label - Image segmentation", + } + output_file = itk_image_to_dicom_seg(label, str(series_dir), template) + + logger.info(f"nifti_to_dicom_seg latency: {time.time() - start:.3f} sec") return output_file def itk_image_to_dicom_seg(label, series_dir, template) -> str: + import shutil + + from monailabel.utils.others.generic import run_command + + command = "itkimage2segimage" + if not shutil.which(command): + error_msg = ( + f"\n{'='*80}\n" + f"ERROR: {command} command-line tool not found\n" + f"{'='*80}\n\n" + f"The ITK-based DICOM SEG conversion requires the dcmqi package.\n\n" + f"Install dcmqi:\n" + f" pip install dcmqi\n\n" + f"For more information:\n" + f" https://github.com/QIICR/dcmqi\n\n" + f"Note: Consider using the default highdicom-based conversion (use_itk=False)\n" + f"which doesn't require dcmqi.\n" + f"{'='*80}\n" + ) + raise RuntimeError(error_msg) + output_file = tempfile.NamedTemporaryFile(suffix=".dcm").name meta_data = tempfile.NamedTemporaryFile(suffix=".json").name with open(meta_data, "w") as fp: json.dump(template, fp) - command = "itkimage2segimage" args = [ "--inputImageList", label, @@ -199,15 +554,42 @@ def itk_image_to_dicom_seg(label, series_dir, template) -> str: def dicom_seg_to_itk_image(label, output_ext=".seg.nrrd"): + """Convert DICOM SEG to ITK image format using highdicom. + + Args: + label: Path to DICOM SEG file or directory containing it + output_ext: Output file extension (default: ".seg.nrrd") + + Returns: + Path to output file, or None if conversion fails + """ filename = label if not os.path.isdir(label) else os.path.join(label, os.listdir(label)[0]) - dcm = pydicom.dcmread(filename) - reader = pydicom_seg.MultiClassReader() - result = reader.read(dcm) - image = result.image + if not HIGHDICOM_AVAILABLE: + raise ImportError("highdicom is not available") - output_file = tempfile.NamedTemporaryFile(suffix=output_ext).name + # Use pydicom to read DICOM SEG + dcm = pydicom.dcmread(filename) + # Extract pixel array from DICOM SEG + seg_dataset = hd.seg.Segmentation.from_dataset(dcm) + pixel_array = seg_dataset.get_total_pixel_matrix() + + # Convert to SimpleITK image + image = SimpleITK.GetImageFromArray(pixel_array) + + # Try to get spacing and other metadata from original DICOM + if hasattr(dcm, "SharedFunctionalGroupsSequence") and len(dcm.SharedFunctionalGroupsSequence) > 0: + shared_func_groups = dcm.SharedFunctionalGroupsSequence[0] + if hasattr(shared_func_groups, "PixelMeasuresSequence"): + pixel_measures = shared_func_groups.PixelMeasuresSequence[0] + if hasattr(pixel_measures, "PixelSpacing"): + spacing = list(pixel_measures.PixelSpacing) + if hasattr(pixel_measures, "SliceThickness"): + spacing.append(float(pixel_measures.SliceThickness)) + image.SetSpacing(spacing) + + output_file = tempfile.NamedTemporaryFile(suffix=output_ext, delete=False).name SimpleITK.WriteImage(image, output_file, True) if not os.path.exists(output_file): diff --git a/monailabel/datastore/utils/convert_htj2k.py b/monailabel/datastore/utils/convert_htj2k.py new file mode 100644 index 000000000..91b5396b6 --- /dev/null +++ b/monailabel/datastore/utils/convert_htj2k.py @@ -0,0 +1,1345 @@ +import logging +import os +import tempfile +import time +from typing import Iterable + +import numpy as np +import pydicom + +logger = logging.getLogger(__name__) + +# Global singleton instances for nvimgcodec encoder/decoder +# These are initialized lazily on first use to avoid import errors +# when nvimgcodec is not available +_NVIMGCODEC_ENCODER = None +_NVIMGCODEC_DECODER = None + + +def _get_nvimgcodec_encoder(): + """Get or create the global nvimgcodec encoder instance.""" + global _NVIMGCODEC_ENCODER + if _NVIMGCODEC_ENCODER is None: + from nvidia import nvimgcodec + + _NVIMGCODEC_ENCODER = nvimgcodec.Encoder() + return _NVIMGCODEC_ENCODER + + +def _get_nvimgcodec_decoder(): + """Get or create the global nvimgcodec decoder instance.""" + global _NVIMGCODEC_DECODER + if _NVIMGCODEC_DECODER is None: + from nvidia import nvimgcodec + + _NVIMGCODEC_DECODER = nvimgcodec.Decoder(options=":fancy_upsampling=1") + return _NVIMGCODEC_DECODER + + +def _setup_htj2k_decode_params(color_spec=None): + """ + Create nvimgcodec decoding parameters for DICOM images. + + Args: + color_spec: Color specification to use. If None, defaults to UNCHANGED. + + Returns: + nvimgcodec.DecodeParams: Decode parameters configured for DICOM + """ + from nvidia import nvimgcodec + + if color_spec is None: + color_spec = nvimgcodec.ColorSpec.UNCHANGED + decode_params = nvimgcodec.DecodeParams( + allow_any_depth=True, + color_spec=color_spec, + ) + return decode_params + + +def _setup_htj2k_encode_params( + num_resolutions: int = 6, code_block_size: tuple = (64, 64), progression_order: str = "RPCL" +): + """ + Create nvimgcodec encoding parameters for HTJ2K lossless compression. + + Args: + num_resolutions: Number of wavelet decomposition levels + code_block_size: Code block size as (height, width) tuple + progression_order: Progression order for encoding. Must be one of: + - "LRCP": Layer-Resolution-Component-Position (quality scalability) + - "RLCP": Resolution-Layer-Component-Position (resolution scalability) + - "RPCL": Resolution-Position-Component-Layer (progressive by resolution) + - "PCRL": Position-Component-Resolution-Layer (progressive by spatial area) + - "CPRL": Component-Position-Resolution-Layer (component scalability) + + Returns: + tuple: (encode_params, target_transfer_syntax) + + Raises: + ValueError: If progression_order is not one of the valid values + """ + from nvidia import nvimgcodec + + # Valid progression orders and their mappings + VALID_PROG_ORDERS = { + "LRCP": (nvimgcodec.Jpeg2kProgOrder.LRCP, "1.2.840.10008.1.2.4.201"), # HTJ2K (Lossless Only) + "RLCP": (nvimgcodec.Jpeg2kProgOrder.RLCP, "1.2.840.10008.1.2.4.201"), # HTJ2K (Lossless Only) + "RPCL": (nvimgcodec.Jpeg2kProgOrder.RPCL, "1.2.840.10008.1.2.4.202"), # HTJ2K with RPCL Options + "PCRL": (nvimgcodec.Jpeg2kProgOrder.PCRL, "1.2.840.10008.1.2.4.201"), # HTJ2K (Lossless Only) + "CPRL": (nvimgcodec.Jpeg2kProgOrder.CPRL, "1.2.840.10008.1.2.4.201"), # HTJ2K (Lossless Only) + } + + # Validate progression order + if progression_order not in VALID_PROG_ORDERS: + valid_orders = ", ".join(f"'{o}'" for o in VALID_PROG_ORDERS.keys()) + raise ValueError(f"Invalid progression_order '{progression_order}'. " f"Must be one of: {valid_orders}") + + # Get progression order enum and transfer syntax + prog_order_enum, target_transfer_syntax = VALID_PROG_ORDERS[progression_order] + + quality_type = nvimgcodec.QualityType.LOSSLESS + + # Configure JPEG2K encoding parameters + jpeg2k_encode_params = nvimgcodec.Jpeg2kEncodeParams() + jpeg2k_encode_params.num_resolutions = num_resolutions + jpeg2k_encode_params.code_block_size = code_block_size + jpeg2k_encode_params.bitstream_type = nvimgcodec.Jpeg2kBitstreamType.J2K + jpeg2k_encode_params.prog_order = prog_order_enum + jpeg2k_encode_params.ht = True # Enable High Throughput mode + + encode_params = nvimgcodec.EncodeParams( + quality_type=quality_type, + jpeg2k_encode_params=jpeg2k_encode_params, + ) + + return encode_params, target_transfer_syntax + + +def _extract_frames_from_compressed(ds, number_of_frames=None): + """ + Extract frames from encapsulated (compressed) DICOM pixel data. + + Args: + ds: pydicom Dataset with encapsulated PixelData + number_of_frames: Expected number of frames (from NumberOfFrames tag) + + Returns: + list: List of compressed frame data (bytes) + """ + # Default to 1 frame if not specified (for single-frame images without NumberOfFrames tag) + if number_of_frames is None: + number_of_frames = 1 + + frames = list(pydicom.encaps.generate_frames(ds.PixelData, number_of_frames=number_of_frames)) + return frames + + +def _extract_frames_from_uncompressed(pixel_array, num_frames_tag): + """ + Extract individual frames from uncompressed pixel array. + + Handles different array shapes: + - 2D (H, W): single frame grayscale + - 3D (N, H, W): multi-frame grayscale OR (H, W, C): single frame color + - 4D (N, H, W, C): multi-frame color + + Args: + pixel_array: Numpy array of pixel data + num_frames_tag: NumberOfFrames value from DICOM tag + + Returns: + list: List of frame arrays + """ + if not isinstance(pixel_array, np.ndarray): + pixel_array = np.array(pixel_array) + + # 2D: single frame grayscale + if pixel_array.ndim == 2: + return [pixel_array] + + # 3D: multi-frame grayscale OR single-frame color + if pixel_array.ndim == 3: + if num_frames_tag > 1 or pixel_array.shape[0] == num_frames_tag: + # Multi-frame grayscale: (N, H, W) + return [pixel_array[i] for i in range(pixel_array.shape[0])] + # Single-frame color: (H, W, C) + return [pixel_array] + + # 4D: multi-frame color + if pixel_array.ndim == 4: + return [pixel_array[i] for i in range(pixel_array.shape[0])] + + raise ValueError(f"Unexpected pixel array dimensions: {pixel_array.ndim}") + + +def _validate_frames(frames, context_msg="Frame"): + """ + Check for None values in decoded/encoded frames. + + Args: + frames: List of frames to validate + context_msg: Context message for error reporting + + Raises: + ValueError: If any frame is None + """ + for idx, frame in enumerate(frames): + if frame is None: + raise ValueError(f"{context_msg} {idx} failed (returned None)") + + +def _find_dicom_files(input_dir): + """ + Recursively find all valid DICOM files in a directory. + + Args: + input_dir: Directory to search + + Returns: + list: Sorted list of DICOM file paths + """ + valid_dicom_files = [] + for root, dirs, files in os.walk(input_dir): + for f in files: + file_path = os.path.join(root, f) + if os.path.isfile(file_path): + try: + with open(file_path, "rb") as fp: + fp.seek(128) + if fp.read(4) == b"DICM": + valid_dicom_files.append(file_path) + except Exception: + continue + + valid_dicom_files.sort() # For reproducible processing order + return valid_dicom_files + + +def _get_transfer_syntax_constants(): + """ + Get transfer syntax UID constants for categorizing DICOM files. + + Returns: + dict: Dictionary with keys 'JPEG2000', 'HTJ2K', 'JPEG', 'NVIMGCODEC' (combined set) + """ + JPEG2000_SYNTAXES = frozenset( + [ + "1.2.840.10008.1.2.4.90", # JPEG 2000 Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.91", # JPEG 2000 Image Compression + ] + ) + + HTJ2K_SYNTAXES = frozenset( + [ + "1.2.840.10008.1.2.4.201", # High-Throughput JPEG 2000 Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.202", # High-Throughput JPEG 2000 with RPCL Options Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.203", # High-Throughput JPEG 2000 Image Compression + ] + ) + + JPEG_SYNTAXES = frozenset( + [ + "1.2.840.10008.1.2.4.50", # JPEG Baseline (Process 1) + "1.2.840.10008.1.2.4.51", # JPEG Extended (Process 2 & 4) + "1.2.840.10008.1.2.4.57", # JPEG Lossless, Non-Hierarchical (Process 14) + "1.2.840.10008.1.2.4.70", # JPEG Lossless, Non-Hierarchical, First-Order Prediction + ] + ) + + return { + "JPEG2000": JPEG2000_SYNTAXES, + "HTJ2K": HTJ2K_SYNTAXES, + "JPEG": JPEG_SYNTAXES, + "NVIMGCODEC": JPEG2000_SYNTAXES | HTJ2K_SYNTAXES | JPEG_SYNTAXES, + } + + +class DicomFileLoader: + """ + Simple iterable that auto-discovers DICOM files from a directory and yields batches. + + This class provides a simple interface for batch processing DICOM files without + requiring external dependencies like PyTorch. It can be used with any function + that accepts an iterable of (input_batch, output_batch) tuples. + + Args: + input_dir: Path to directory containing DICOM files to process + output_dir: Path to output directory. Output paths will preserve the directory + structure relative to input_dir. + batch_size: Number of files to include in each batch (default: 256) + + Yields: + tuple: (batch_input, batch_output) where both are lists of file paths + batch_input contains source file paths + batch_output contains corresponding output file paths with preserved directory structure + + Example: + >>> loader = DicomFileLoader("/path/to/dicoms", "/path/to/output", batch_size=50) + >>> for batch_in, batch_out in loader: + ... print(f"Processing {len(batch_in)} files") + ... print(f"Input: {batch_in[0]}") + ... print(f"Output: {batch_out[0]}") + """ + + def __init__(self, input_dir: str, output_dir: str, batch_size: int = 256): + self.input_dir = input_dir + self.output_dir = output_dir + self.batch_size = batch_size + self._files = None + + def _discover_files(self): + """Discover DICOM files in the input directory.""" + if self._files is None: + # Validate input + if not os.path.exists(self.input_dir): + raise ValueError(f"Input directory does not exist: {self.input_dir}") + + if not os.path.isdir(self.input_dir): + raise ValueError(f"Input path is not a directory: {self.input_dir}") + + # Find all valid DICOM files + self._files = _find_dicom_files(self.input_dir) + if not self._files: + raise ValueError(f"No valid DICOM files found in {self.input_dir}") + + logger.info(f"Found {len(self._files)} DICOM files to process") + + def __iter__(self): + """Iterate over batches of DICOM files.""" + self._discover_files() + + total_files = len(self._files) + for batch_start in range(0, total_files, self.batch_size): + batch_end = min(batch_start + self.batch_size, total_files) + batch_input = self._files[batch_start:batch_end] + + # Compute output paths preserving directory structure + batch_output = [] + for input_path in batch_input: + relative_path = os.path.relpath(input_path, self.input_dir) + output_path = os.path.join(self.output_dir, relative_path) + batch_output.append(output_path) + + yield batch_input, batch_output + + +def transcode_dicom_to_htj2k( + file_loader: Iterable[tuple[list[str], list[str]]], + num_resolutions: int = 6, + code_block_size: tuple = (64, 64), + progression_order: str = "RPCL", + max_batch_size: int = 256, + add_basic_offset_table: bool = True, + skip_transfer_syntaxes: list = ( + _get_transfer_syntax_constants()["HTJ2K"] + | frozenset( + [ + # Lossy JPEG 2000 + "1.2.840.10008.1.2.4.91", # JPEG 2000 Image Compression (lossy allowed) + # Lossy JPEG + "1.2.840.10008.1.2.4.50", # JPEG Baseline (Process 1) - always lossy + "1.2.840.10008.1.2.4.51", # JPEG Extended (Process 2 & 4, can be lossy) + ] + ) + ), +): + """ + Transcode DICOM files to HTJ2K (High Throughput JPEG 2000) lossless compression. + + HTJ2K is a faster variant of JPEG 2000 that provides better compression performance + for medical imaging applications. This function uses NVIDIA's nvimgcodec for hardware- + accelerated decoding and encoding with batch processing for optimal performance. + All transcoding is performed using lossless compression to preserve image quality. + + The function processes files with streaming decode-encode batches: + 1. Categorizes files by transfer syntax (HTJ2K/JPEG2000/JPEG/uncompressed) + 2. Extracts all frames from source files + 3. Processes frames in batches for efficient encoding: + - Decodes batch using nvimgcodec (compressed) or pydicom (uncompressed) + - Immediately encodes batch to HTJ2K + - Discards decoded frames to save memory (streaming) + 4. Saves transcoded files with updated transfer syntax and optional Basic Offset Table + + This streaming approach minimizes memory usage by never holding all decoded frames + in memory simultaneously. + + Supported source transfer syntaxes: + - HTJ2K (High-Throughput JPEG 2000) - decoded and re-encoded (add bot if needed) + - JPEG 2000 (lossless and lossy) + - JPEG (baseline, extended, lossless) + - Uncompressed (Explicit/Implicit VR Little/Big Endian) + + Typical compression ratios of 60-70% with lossless quality. + Processing speed depends on batch size and GPU capabilities. + + Args: + file_loader: + Iterable of (input_files, output_files) tuples, where: + - input_files: List[str] of input DICOM file paths to transcode as a batch. + - output_files: List[str] of output file paths to write the transcoded DICOMs. + The recommended usage is to provide a DicomFileLoader instance, which automatically yields + appropriately sized batches of file paths for efficient streaming. Custom iterables can also + be used to precisely control batching or file selection. + + Each yielded tuple should contain two lists of identical length, specifying the correspondence + between input and output files for each batch. The function will read each input file, + transcode to HTJ2K if necessary, and write the result to the corresponding output file. + + Example: + for batch_input, batch_output in file_loader: + # len(batch_input) == len(batch_output) + # batch_input: ['a.dcm', 'b.dcm'], batch_output: ['a_out.dcm', 'b_out.dcm'] + The loader should guarantee that input and output lists are aligned and consistent across batches. + num_resolutions: Number of wavelet decomposition levels (default: 6) + Higher values = better compression but slower encoding + code_block_size: Code block size as (height, width) tuple (default: (64, 64)) + Must be powers of 2. Common values: (32,32), (64,64), (128,128) + progression_order: Progression order for HTJ2K encoding (default: "RPCL") + Must be one of: "LRCP", "RLCP", "RPCL", "PCRL", "CPRL" + - "LRCP": Layer-Resolution-Component-Position (quality scalability) + - "RLCP": Resolution-Layer-Component-Position (resolution scalability) + - "RPCL": Resolution-Position-Component-Layer (progressive by resolution) + - "PCRL": Position-Component-Resolution-Layer (progressive by spatial area) + - "CPRL": Component-Position-Resolution-Layer (component scalability) + max_batch_size: Maximum number of frames to decode/encode in parallel (default: 256) + This controls internal frame-level batching for GPU operations, not file-level batching. + Lower values reduce memory usage, higher values may improve GPU utilization. + Note: File-level batching is controlled by the DicomFileLoader's batch_size parameter. + add_basic_offset_table: If True, creates Basic Offset Table for multi-frame DICOMs (default: True) + BOT enables O(1) frame access without parsing entire pixel data stream + Per DICOM Part 5 Section A.4. Only affects multi-frame files. + skip_transfer_syntaxes: Optional list of Transfer Syntax UIDs to skip transcoding (default: HTJ2K, lossy JPEG 2000, and lossy JPEG) + Files with these transfer syntaxes will be copied directly to output + without transcoding. Useful for preserving already-compressed formats. + Example: ["1.2.840.10008.1.2.4.201", "1.2.840.10008.1.2.4.202"] + + Raises: + ImportError: If nvidia-nvimgcodec is not available + ValueError: If DICOM files are missing required attributes (TransferSyntaxUID, PixelData) + ValueError: If progression_order is not one of: "LRCP", "RLCP", "RPCL", "PCRL", "CPRL" + + Example: + >>> # Basic usage with DicomFileLoader + >>> loader = DicomFileLoader("/path/to/input", "/path/to/output") + >>> transcode_dicom_to_htj2k(loader) + >>> print(f"Transcoded files saved to: {loader.output_dir}") + + >>> # Custom settings with DicomFileLoader + >>> loader = DicomFileLoader("/path/to/input", "/path/to/output", batch_size=50) + >>> transcode_dicom_to_htj2k( + ... file_loader=loader, + ... num_resolutions=5, + ... code_block_size=(32, 32) + ... ) + + >>> # Skip transcoding for files already in HTJ2K format + >>> loader = DicomFileLoader("/path/to/input", "/path/to/output") + >>> transcode_dicom_to_htj2k( + ... file_loader=loader, + ... skip_transfer_syntaxes=["1.2.840.10008.1.2.4.201", "1.2.840.10008.1.2.4.202"] + ... ) + + Note: + Requires nvidia-nvimgcodec to be installed: + pip install nvidia-nvimgcodec-cu{XX}[all] + Replace {XX} with your CUDA version (e.g., cu13 for CUDA 13.x) + + The function preserves all DICOM metadata including Patient, Study, and Series + information. Only the transfer syntax and pixel data encoding are modified. + """ + import shutil + + # Check for nvidia-nvimgcodec + try: + from nvidia import nvimgcodec + except ImportError: + raise ImportError( + "nvidia-nvimgcodec is required for HTJ2K transcoding. " + "Install it with: pip install nvidia-nvimgcodec-cu{XX}[all] " + "(replace {XX} with your CUDA version, e.g., cu13)" + ) + + # Create encoder and decoder instances (reused for all files) + encoder = _get_nvimgcodec_encoder() + decoder = _get_nvimgcodec_decoder() # Always needed for decoding input DICOM images + + # Setup HTJ2K encoding parameters + encode_params, target_transfer_syntax = _setup_htj2k_encode_params( + num_resolutions=num_resolutions, code_block_size=code_block_size, progression_order=progression_order + ) + # Note: decode_params is created per-PhotometricInterpretation group in the batch processing + logger.info("Using lossless HTJ2K compression") + + # Get transfer syntax constants + ts_constants = _get_transfer_syntax_constants() + NVIMGCODEC_SYNTAXES = ts_constants["NVIMGCODEC"] + + # Initialize skip list + if skip_transfer_syntaxes is None: + skip_transfer_syntaxes = [] + else: + # Convert to set of strings for faster lookup + skip_transfer_syntaxes = {str(ts) for ts in skip_transfer_syntaxes} + logger.info(f"Files with these transfer syntaxes will be copied without transcoding: {skip_transfer_syntaxes}") + + start_time = time.time() + transcoded_count = 0 + skipped_count = 0 + total_files = 0 + + # Iterate over batches from file_loader + for batch_in, batch_out in file_loader: + batch_datasets = [pydicom.dcmread(file) for file in batch_in] + total_files += len(batch_datasets) + nvimgcodec_batch = [] + pydicom_batch = [] + skip_batch = [] # Indices of files to skip (copy directly) + + for idx, ds in enumerate(batch_datasets): + current_ts = getattr(ds, "file_meta", {}).get("TransferSyntaxUID", None) + if current_ts is None: + raise ValueError(f"DICOM file {os.path.basename(batch_in[idx])} does not have a Transfer Syntax UID") + + ts_str = str(current_ts) + + # Check if this transfer syntax should be skipped + has_pixel_data = hasattr(ds, "PixelData") and ds.PixelData is not None + if ts_str in skip_transfer_syntaxes or not has_pixel_data: + skip_batch.append(idx) + logger.info( + f" Skipping {os.path.basename(batch_in[idx])} (Transfer Syntax: {ts_str}, has_pixel_data: {has_pixel_data})" + ) + continue + + assert has_pixel_data, f"DICOM file {os.path.basename(batch_in[idx])} does not have a PixelData member" + if ts_str in NVIMGCODEC_SYNTAXES: + nvimgcodec_batch.append(idx) + else: + pydicom_batch.append(idx) + + # Handle skip_batch: copy files directly to output + if skip_batch: + for idx in skip_batch: + source_file = batch_in[idx] + output_file = batch_out[idx] + + # Ensure output directory exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + shutil.copy2(source_file, output_file) + skipped_count += 1 + logger.info(f" Copied {os.path.basename(source_file)} to output (skipped transcoding)") + + num_frames = [] + encoded_data = [] + + # Process nvimgcodec_batch: extract frames, decode, encode in streaming batches + if nvimgcodec_batch: + from collections import defaultdict + + # First, extract all compressed frames and group by PhotometricInterpretation + grouped_frames = defaultdict(list) # Key: PhotometricInterpretation, Value: list of (file_idx, frame_data) + frame_counts = {} # Track number of frames per file + successful_nvimgcodec_batch = [] # Track successfully processed files + + logger.info(f" Extracting frames from {len(nvimgcodec_batch)} nvimgcodec files:") + for idx in nvimgcodec_batch: + try: + ds = batch_datasets[idx] + number_of_frames = int(ds.NumberOfFrames) if hasattr(ds, "NumberOfFrames") else None + + if "PixelData" not in ds: + logger.warning(f"Skipping file {batch_in[idx]} (index {idx}): no PixelData found") + skipped_count += 1 + continue + + frames = _extract_frames_from_compressed(ds, number_of_frames) + logger.info( + f" File idx={idx} ({os.path.basename(batch_in[idx])}): extracted {len(frames)} frames (expected: {number_of_frames})" + ) + + # Get PhotometricInterpretation for this file + photometric = getattr(ds, "PhotometricInterpretation", "UNKNOWN") + + # Store frames grouped by PhotometricInterpretation + for frame in frames: + grouped_frames[photometric].append((idx, frame)) + + frame_counts[idx] = len(frames) + num_frames.append(len(frames)) + successful_nvimgcodec_batch.append(idx) # Only add if successful + except Exception as e: + logger.warning(f"Skipping file {batch_in[idx]} (index {idx}): {e}") + skipped_count += 1 + + # Update nvimgcodec_batch to only include successfully processed files + nvimgcodec_batch = successful_nvimgcodec_batch + + # Process each PhotometricInterpretation group separately + logger.info(f" Found {len(grouped_frames)} unique PhotometricInterpretation groups") + + # Track encoded frames per file to maintain order + encoded_frames_by_file = {idx: [] for idx in nvimgcodec_batch} + + for photometric, frame_list in grouped_frames.items(): + # Determine color_spec based on PhotometricInterpretation + if photometric.startswith("YBR"): + color_spec = nvimgcodec.ColorSpec.RGB + logger.info( + f" Processing {len(frame_list)} frames with PhotometricInterpretation={photometric} using color_spec=RGB" + ) + else: + color_spec = nvimgcodec.ColorSpec.UNCHANGED + logger.info( + f" Processing {len(frame_list)} frames with PhotometricInterpretation={photometric} using color_spec=UNCHANGED" + ) + + # Create decode params for this group + group_decode_params = _setup_htj2k_decode_params(color_spec=color_spec) + + # Extract just the frame data (without file index) + compressed_frames = [frame_data for _, frame_data in frame_list] + + # Decode and encode in batches (streaming to reduce memory) + total_frames = len(compressed_frames) + + for frame_batch_start in range(0, total_frames, max_batch_size): + frame_batch_end = min(frame_batch_start + max_batch_size, total_frames) + compressed_batch = compressed_frames[frame_batch_start:frame_batch_end] + file_indices_batch = [file_idx for file_idx, _ in frame_list[frame_batch_start:frame_batch_end]] + + if total_frames > max_batch_size: + logger.info( + f" Processing frames [{frame_batch_start}..{frame_batch_end}) of {total_frames} for {photometric}" + ) + + # Decode batch with appropriate color_spec + decoded_batch = decoder.decode(compressed_batch, params=group_decode_params) + _validate_frames(decoded_batch, f"Decoded frame [{frame_batch_start}+") + + # Encode batch immediately (streaming - no need to keep decoded data) + encoded_batch = encoder.encode(decoded_batch, codec="jpeg2k", params=encode_params) + _validate_frames(encoded_batch, f"Encoded frame [{frame_batch_start}+") + + # Store encoded frames by file index to maintain order + for file_idx, encoded_frame in zip(file_indices_batch, encoded_batch): + encoded_frames_by_file[file_idx].append(encoded_frame) + + # decoded_batch is automatically freed here + + # Reconstruct encoded_data in original file order + for idx in nvimgcodec_batch: + encoded_data.extend(encoded_frames_by_file[idx]) + + # Process pydicom_batch: extract frames and encode in streaming batches + if pydicom_batch: + # Extract all frames from uncompressed files + all_decoded_frames = [] + successful_pydicom_batch = [] # Track successfully processed files + + for idx in pydicom_batch: + try: + ds = batch_datasets[idx] + num_frames_tag = int(ds.NumberOfFrames) if hasattr(ds, "NumberOfFrames") else 1 + if "PixelData" in ds: + frames = _extract_frames_from_uncompressed(ds.pixel_array, num_frames_tag) + all_decoded_frames.extend(frames) + num_frames.append(len(frames)) + successful_pydicom_batch.append(idx) # Only add if successful + else: + # No PixelData - log warning and skip file completely + logger.warning(f"Skipping file {batch_in[idx]} (index {idx}): no PixelData found") + skipped_count += 1 + except Exception as e: + logger.warning(f"Skipping file {batch_in[idx]} (index {idx}): {e}") + skipped_count += 1 + + # Encode in batches (streaming) + total_frames = len(all_decoded_frames) + if total_frames > 0: + logger.info(f" Encoding {total_frames} uncompressed frames in batches of {max_batch_size}") + + for frame_batch_start in range(0, total_frames, max_batch_size): + frame_batch_end = min(frame_batch_start + max_batch_size, total_frames) + decoded_batch = all_decoded_frames[frame_batch_start:frame_batch_end] + + if total_frames > max_batch_size: + logger.info(f" Encoding frames [{frame_batch_start}..{frame_batch_end}) of {total_frames}") + + # Encode batch + encoded_batch = encoder.encode(decoded_batch, codec="jpeg2k", params=encode_params) + _validate_frames(encoded_batch, f"Encoded frame [{frame_batch_start}+") + + # Store encoded frames + encoded_data.extend(encoded_batch) + + # Update pydicom_batch to only include successfully processed files + pydicom_batch = successful_pydicom_batch + + # Reassemble and save transcoded files + frame_offset = 0 + files_to_process = nvimgcodec_batch + pydicom_batch + + for list_idx, dataset_idx in enumerate(files_to_process): + nframes = num_frames[list_idx] + encoded_frames = [bytes(enc) for enc in encoded_data[frame_offset : frame_offset + nframes]] + frame_offset += nframes + + # Update dataset with HTJ2K encoded data + # Create Basic Offset Table for multi-frame files if requested + batch_datasets[dataset_idx].PixelData = pydicom.encaps.encapsulate( + encoded_frames, has_bot=add_basic_offset_table + ) + batch_datasets[dataset_idx].file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax) + + # Update PhotometricInterpretation to RGB for YBR images since we decoded with RGB color_spec + # The pixel data is now in RGB color space, so the metadata must reflect this + # to prevent double conversion by DICOM readers + if hasattr(batch_datasets[dataset_idx], "PhotometricInterpretation"): + original_pi = batch_datasets[dataset_idx].PhotometricInterpretation + if original_pi.startswith("YBR"): + batch_datasets[dataset_idx].PhotometricInterpretation = "RGB" + logger.info(f" Updated PhotometricInterpretation: {original_pi} -> RGB") + + try: + # Save transcoded file using output path from file_loader + input_file = batch_in[dataset_idx] + output_file = batch_out[dataset_idx] + + # Ensure output directory exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + + batch_datasets[dataset_idx].save_as(output_file) + transcoded_count += 1 + logger.info(f"#{transcoded_count}: Transcoded {input_file}, saving as: {output_file}") + except Exception as e: + logger.error(f"Error saving transcoded file {batch_in[dataset_idx]}: {output_file}") + logger.error(f"Error: {e}") + + elapsed_time = time.time() - start_time + + logger.info(f"Transcoding complete:") + logger.info(f" Total files: {total_files}") + logger.info(f" Successfully transcoded: {transcoded_count}") + logger.info(f" Skipped (copied without transcoding): {skipped_count}") + logger.info(f" Time elapsed: {elapsed_time:.2f} seconds") + + +def convert_single_frame_dicom_series_to_multiframe( + input_dir: str, + output_dir: str = None, + convert_to_htj2k: bool = False, + num_resolutions: int = 6, + code_block_size: tuple = (64, 64), + progression_order: str = "RPCL", + add_basic_offset_table: bool = True, +) -> str: + """ + Convert single-frame DICOM series to multi-frame DICOM files, optionally with HTJ2K compression. + + This function groups DICOM files by SeriesInstanceUID and combines all frames from each series + into a single multi-frame DICOM file. This is useful for: + - Reducing file count (one file per series instead of many) + - Improving storage efficiency + - Enabling more efficient frame-level access patterns + + The function: + 1. Scans input directory recursively for DICOM files + 2. Groups files by StudyInstanceUID and SeriesInstanceUID + 3. For each series, decodes all frames and combines them + 4. Optionally encodes combined frames to HTJ2K (if convert_to_htj2k=True) + 5. Creates a Basic Offset Table for efficient frame access (per DICOM Part 5 Section A.4) + 6. Saves as a single multi-frame DICOM file per series + + Args: + input_dir: Path to directory containing DICOM files (will scan recursively) + output_dir: Path to output directory for transcoded files. If None, creates temp directory + convert_to_htj2k: If True, convert frames to HTJ2K compression; if False, use uncompressed format (default: False) + num_resolutions: Number of wavelet decomposition levels (default: 6, only used if convert_to_htj2k=True) + code_block_size: Code block size as (height, width) tuple (default: (64, 64), only used if convert_to_htj2k=True) + progression_order: Progression order for HTJ2K encoding (default: "RPCL", only used if convert_to_htj2k=True) + Must be one of: "LRCP", "RLCP", "RPCL", "PCRL", "CPRL" + - "LRCP": Layer-Resolution-Component-Position (quality scalability) + - "RLCP": Resolution-Layer-Component-Position (resolution scalability) + - "RPCL": Resolution-Position-Component-Layer (progressive by resolution) + - "PCRL": Position-Component-Resolution-Layer (progressive by spatial area) + - "CPRL": Component-Position-Resolution-Layer (component scalability) + add_basic_offset_table: If True, creates Basic Offset Table for multi-frame DICOMs (default: True) + BOT enables O(1) frame access without parsing entire pixel data stream + Per DICOM Part 5 Section A.4. Only affects multi-frame files. + + Returns: + str: Path to output directory containing multi-frame DICOM files + + Raises: + ImportError: If nvidia-nvimgcodec is not available and convert_to_htj2k=True + ValueError: If input directory doesn't exist or contains no valid DICOM files + ValueError: If progression_order is not one of: "LRCP", "RLCP", "RPCL", "PCRL", "CPRL" + + Example: + >>> # Combine series without HTJ2K conversion (uncompressed) + >>> output_dir = convert_single_frame_dicom_series_to_multiframe("/path/to/dicoms") + >>> print(f"Multi-frame files saved to: {output_dir}") + + >>> # Combine series with HTJ2K conversion + >>> output_dir = convert_single_frame_dicom_series_to_multiframe( + ... "/path/to/dicoms", + ... convert_to_htj2k=True + ... ) + + Note: + Each output file is named using the SeriesInstanceUID: + /.dcm + + The NumberOfFrames tag is set to the total frame count. + All other DICOM metadata is preserved from the first instance in each series. + + Basic Offset Table: + A Basic Offset Table is automatically created containing byte offsets to each frame. + This allows DICOM readers to quickly locate and extract individual frames without + parsing the entire encapsulated pixel data stream. The offsets are 32-bit unsigned + integers measured from the first byte of the first Item Tag following the BOT. + """ + import glob + import shutil + import tempfile + from collections import defaultdict + from pathlib import Path + + # Check for nvidia-nvimgcodec only if HTJ2K conversion is requested + if convert_to_htj2k: + try: + from nvidia import nvimgcodec + except ImportError: + raise ImportError( + "nvidia-nvimgcodec is required for HTJ2K transcoding. " + "Install it with: pip install nvidia-nvimgcodec-cu{XX}[all] " + "(replace {XX} with your CUDA version, e.g., cu13)" + ) + + import time + + import numpy as np + import pydicom + + # Validate input + if not os.path.exists(input_dir): + raise ValueError(f"Input directory does not exist: {input_dir}") + + if not os.path.isdir(input_dir): + raise ValueError(f"Input path is not a directory: {input_dir}") + + # Get all DICOM files recursively + dicom_files = [] + for root, dirs, files in os.walk(input_dir): + for file in files: + if file.endswith(".dcm") or file.endswith(".DCM"): + dicom_files.append(os.path.join(root, file)) + + # Also check for files without extension + for pattern in ["*"]: + found_files = glob.glob(os.path.join(input_dir, "**", pattern), recursive=True) + for file_path in found_files: + if os.path.isfile(file_path) and file_path not in dicom_files: + try: + with open(file_path, "rb") as f: + f.seek(128) + magic = f.read(4) + if magic == b"DICM": + dicom_files.append(file_path) + except Exception: + continue + + if not dicom_files: + raise ValueError(f"No valid DICOM files found in {input_dir}") + + logger.info(f"Found {len(dicom_files)} DICOM files to process") + + # Group files by study and series + series_groups = defaultdict(list) # Key: (StudyUID, SeriesUID), Value: list of file paths + + logger.info("Grouping DICOM files by series...") + for file_path in dicom_files: + try: + ds = pydicom.dcmread(file_path, stop_before_pixels=True) + study_uid = str(ds.StudyInstanceUID) + series_uid = str(ds.SeriesInstanceUID) + instance_number = int(getattr(ds, "InstanceNumber", 0)) + series_groups[(study_uid, series_uid)].append((instance_number, file_path)) + except Exception as e: + logger.warning(f"Failed to read metadata from {file_path}: {e}") + continue + + # Sort files within each series by InstanceNumber + for key in series_groups: + series_groups[key].sort(key=lambda x: x[0]) # Sort by instance number + + logger.info(f"Found {len(series_groups)} unique series") + + # Create output directory + if output_dir is None: + prefix = "htj2k_multiframe_" if convert_to_htj2k else "multiframe_" + output_dir = tempfile.mkdtemp(prefix=prefix) + else: + os.makedirs(output_dir, exist_ok=True) + + # Setup encoder/decoder and parameters based on conversion mode + if convert_to_htj2k: + # Create encoder and decoder instances for HTJ2K + encoder = _get_nvimgcodec_encoder() + decoder = _get_nvimgcodec_decoder() + + # Setup HTJ2K encoding parameters + encode_params, target_transfer_syntax = _setup_htj2k_encode_params( + num_resolutions=num_resolutions, code_block_size=code_block_size, progression_order=progression_order + ) + # Note: decode_params is created per-series based on PhotometricInterpretation + logger.info("HTJ2K conversion enabled") + else: + # No conversion - preserve original transfer syntax + encoder = None + decoder = None + encode_params = None + target_transfer_syntax = None # Will be determined from first dataset + logger.info("Preserving original transfer syntax (no HTJ2K conversion)") + + # Get transfer syntax constants + ts_constants = _get_transfer_syntax_constants() + NVIMGCODEC_SYNTAXES = ts_constants["NVIMGCODEC"] + + start_time = time.time() + processed_series = 0 + total_frames = 0 + + # Process each series + for (study_uid, series_uid), file_list in series_groups.items(): + try: + logger.info(f"Processing series {series_uid} ({len(file_list)} instances)") + + # Load all datasets for this series + file_paths = [fp for _, fp in file_list] + datasets = [pydicom.dcmread(fp) for fp in file_paths] + + # Filter out datasets without PixelData (e.g., DICOM SR, Presentation States, corrupted files) + datasets_with_pixels = [] + for idx, ds in enumerate(datasets): + if hasattr(ds, "PixelData") and ds.PixelData is not None: + datasets_with_pixels.append(ds) + else: + logger.warning(f" Skipping file {file_paths[idx]} (no PixelData found)") + + if not datasets_with_pixels: + logger.error(f" Series {series_uid}: No valid datasets with PixelData found, skipping series") + continue + + # Replace datasets with filtered list + datasets = datasets_with_pixels + logger.info(f" Loaded {len(datasets)} valid datasets with PixelData") + + # CRITICAL: Sort datasets by ImagePositionPatient Z-coordinate + # This ensures Frame[0] is the first slice, Frame[N] is the last slice + if all(hasattr(ds, "ImagePositionPatient") for ds in datasets): + # Sort by Z coordinate (3rd element of ImagePositionPatient) + datasets.sort(key=lambda ds: float(ds.ImagePositionPatient[2])) + logger.info(f" ✓ Sorted {len(datasets)} frames by ImagePositionPatient Z-coordinate") + logger.info(f" First frame Z: {datasets[0].ImagePositionPatient[2]}") + logger.info(f" Last frame Z: {datasets[-1].ImagePositionPatient[2]}") + + # NOTE: We keep anatomically correct order (Z-ascending) + # Cornerstone3D should use per-frame ImagePositionPatient from PerFrameFunctionalGroupsSequence + # We provide complete per-frame metadata (PlanePositionSequence + PlaneOrientationSequence) + logger.info(f" ✓ Frames in anatomical order (lowest Z first)") + logger.info( + f" Cornerstone3D should use per-frame ImagePositionPatient for correct volume reconstruction" + ) + else: + logger.warning(f" ⚠️ Some frames missing ImagePositionPatient, using file order") + + # Use first dataset as template + template_ds = datasets[0] + + # Determine transfer syntax from first dataset + if target_transfer_syntax is None: + target_transfer_syntax = str(getattr(template_ds.file_meta, "TransferSyntaxUID", "1.2.840.10008.1.2.1")) + logger.info(f" Using original transfer syntax: {target_transfer_syntax}") + + # Check if we're dealing with encapsulated (compressed) data + has_pixel_data = hasattr(template_ds, "PixelData") and template_ds.PixelData is not None + # At this point we have filtered out datasets without PixelData, so this should never happen + assert has_pixel_data, f"Template dataset {file_paths[0]} does not have a PixelData member" + is_encapsulated = ( + has_pixel_data and template_ds.file_meta.TransferSyntaxUID != pydicom.uid.ExplicitVRLittleEndian + ) + + # Determine color_spec for this series based on PhotometricInterpretation + if convert_to_htj2k: + photometric = getattr(template_ds, "PhotometricInterpretation", "UNKNOWN") + if photometric.startswith("YBR"): + series_color_spec = nvimgcodec.ColorSpec.RGB + logger.info(f" Series PhotometricInterpretation={photometric}, using color_spec=RGB") + else: + series_color_spec = nvimgcodec.ColorSpec.UNCHANGED + logger.info(f" Series PhotometricInterpretation={photometric}, using color_spec=UNCHANGED") + series_decode_params = _setup_htj2k_decode_params(color_spec=series_color_spec) + + # Collect all frames from all instances + all_frames = [] # Will contain either numpy arrays (for HTJ2K) or bytes (for preserving) + + if convert_to_htj2k: + # HTJ2K mode: decode all frames + for ds in datasets: + current_ts = str(getattr(ds.file_meta, "TransferSyntaxUID", None)) + + if current_ts in NVIMGCODEC_SYNTAXES: + # Compressed format - use nvimgcodec decoder + frames = [fragment for fragment in pydicom.encaps.generate_frames(ds.PixelData)] + decoded = decoder.decode(frames, params=series_decode_params) + all_frames.extend(decoded) + else: + # Uncompressed format - use pydicom + pixel_array = ds.pixel_array + if not isinstance(pixel_array, np.ndarray): + pixel_array = np.array(pixel_array) + + # Handle single frame vs multi-frame + if pixel_array.ndim == 2: + all_frames.append(pixel_array) + elif pixel_array.ndim == 3: + for frame_idx in range(pixel_array.shape[0]): + all_frames.append(pixel_array[frame_idx, :, :]) + else: + # Preserve original encoding: extract frames without decoding + first_ts = str(getattr(datasets[0].file_meta, "TransferSyntaxUID", None)) + + if first_ts in NVIMGCODEC_SYNTAXES or pydicom.encaps.encapsulate_extended: + # Encapsulated data - extract compressed frames + for ds in datasets: + has_pixel_data = hasattr(ds, "PixelData") and ds.PixelData is not None + assert has_pixel_data, f"Dataset {file_paths[idx]} does not have a PixelData member" + try: + # Extract compressed frames + frames = [fragment for fragment in pydicom.encaps.generate_frames(ds.PixelData)] + all_frames.extend(frames) + except: + # Fall back to pixel_array for uncompressed + pixel_array = ds.pixel_array + if not isinstance(pixel_array, np.ndarray): + pixel_array = np.array(pixel_array) + if pixel_array.ndim == 2: + all_frames.append(pixel_array) + elif pixel_array.ndim == 3: + for frame_idx in range(pixel_array.shape[0]): + all_frames.append(pixel_array[frame_idx, :, :]) + else: + # Uncompressed data - use pixel arrays + for ds in datasets: + pixel_array = ds.pixel_array + if not isinstance(pixel_array, np.ndarray): + pixel_array = np.array(pixel_array) + if pixel_array.ndim == 2: + all_frames.append(pixel_array) + elif pixel_array.ndim == 3: + for frame_idx in range(pixel_array.shape[0]): + all_frames.append(pixel_array[frame_idx, :, :]) + + total_frame_count = len(all_frames) + logger.info(f" Total frames in series: {total_frame_count}") + + # Encode frames based on conversion mode + if convert_to_htj2k: + logger.info(f" Encoding {total_frame_count} frames to HTJ2K...") + # Ensure frames have channel dimension for encoder + frames_for_encoding = [] + for frame in all_frames: + if frame.ndim == 2: + frame = frame[:, :, np.newaxis] + frames_for_encoding.append(frame) + encoded_frames = encoder.encode(frames_for_encoding, codec="jpeg2k", params=encode_params) + # Convert to bytes + encoded_frames_bytes = [bytes(enc) for enc in encoded_frames] + else: + logger.info(f" Preserving original encoding for {total_frame_count} frames...") + # Check if frames are already bytes (encapsulated) or numpy arrays (uncompressed) + if len(all_frames) > 0 and isinstance(all_frames[0], bytes): + # Already encapsulated - use as-is + encoded_frames_bytes = all_frames + else: + # Uncompressed numpy arrays + encoded_frames_bytes = None + + # Save ImageOrientationPatient and ImagePositionPatient BEFORE creating output_ds + # The shallow copy + delattr will affect the original datasets objects + # Save these values now so we can use them in functional groups later + original_image_orientation = ( + datasets[0].ImageOrientationPatient if hasattr(datasets[0], "ImageOrientationPatient") else None + ) + original_image_positions = [ + ds.ImagePositionPatient if hasattr(ds, "ImagePositionPatient") else None for ds in datasets + ] + + # Create SIMPLE multi-frame DICOM file (like the user's example) + # Use first dataset as template, keeping its metadata + logger.info(f" Creating simple multi-frame DICOM from {total_frame_count} frames...") + output_ds = datasets[0].copy() # shallow copy + + # CRITICAL: Set SOP Instance UID to match the SeriesInstanceUID (which will be the filename) + # This ensures the file's internal SOP Instance UID matches its filename + output_ds.SOPInstanceUID = series_uid + + # Update pixel data based on conversion mode + if encoded_frames_bytes is not None: + # Encapsulated data (HTJ2K or preserved compressed format) + # Use Basic Offset Table for multi-frame efficiency + output_ds.PixelData = pydicom.encaps.encapsulate(encoded_frames_bytes, has_bot=add_basic_offset_table) + else: + # Uncompressed mode: combine all frames into a 3D array + # Stack frames: (frames, rows, cols) + combined_pixel_array = np.stack(all_frames, axis=0) + output_ds.PixelData = combined_pixel_array.tobytes() + + output_ds.file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax) + + # Update PhotometricInterpretation if we converted from YBR to RGB + if convert_to_htj2k and hasattr(output_ds, "PhotometricInterpretation"): + original_pi = output_ds.PhotometricInterpretation + if original_pi.startswith("YBR"): + output_ds.PhotometricInterpretation = "RGB" + logger.info(f" Updated PhotometricInterpretation: {original_pi} -> RGB") + + # Set NumberOfFrames (critical!) + output_ds.NumberOfFrames = total_frame_count + + # DICOM Multi-frame Module (C.7.6.6) - Mandatory attributes + + # FrameIncrementPointer - REQUIRED to tell viewers how frames are ordered + # Points to ImagePositionPatient (0020,0032) which varies per frame + output_ds.FrameIncrementPointer = 0x00200032 + logger.info(f" ✓ Set FrameIncrementPointer to ImagePositionPatient") + + # Ensure all Image Pixel Module attributes are present (C.7.6.3) + # These should be inherited from first frame, but verify: + required_pixel_attrs = [ + ("SamplesPerPixel", 1), + ("PhotometricInterpretation", "MONOCHROME2"), + ("Rows", 512), + ("Columns", 512), + ] + + for attr, default in required_pixel_attrs: + if not hasattr(output_ds, attr): + setattr(output_ds, attr, default) + logger.warning(f" ⚠️ Added missing {attr} = {default}") + + # CRITICAL: Remove top-level ImagePositionPatient and ImageOrientationPatient + # Working files (that display correctly in OHIF MPR) have NEITHER at top level + # These should ONLY exist in functional groups for Enhanced CT + + if hasattr(output_ds, "ImagePositionPatient"): + delattr(output_ds, "ImagePositionPatient") + logger.info(f" ✓ Removed top-level ImagePositionPatient (use per-frame only)") + + if hasattr(output_ds, "ImageOrientationPatient"): + delattr(output_ds, "ImageOrientationPatient") + logger.info(f" ✓ Removed top-level ImageOrientationPatient (use SharedFunctionalGroupsSequence only)") + # Set correct SOPClassUID for multi-frame (Enhanced/Multiframe) conversion + sopclass_map = { + "1.2.840.10008.5.1.4.1.1.2": ( + "1.2.840.10008.5.1.4.1.1.2.1", + "Enhanced CT Image Storage", + ), # CT -> Enhanced CT + "1.2.840.10008.5.1.4.1.1.4": ( + "1.2.840.10008.5.1.4.1.1.4.1", + "Enhanced MR Image Storage", + ), # MR -> Enhanced MR + "1.2.840.10008.5.1.4.1.1.6.1": ( + "1.2.840.10008.5.1.4.1.1.3.1", + "Ultrasound Multi-frame Image Storage", + ), # US -> Ultrasound Multi-frame + } + + original_sopclass = getattr(datasets[0], "SOPClassUID", None) + if original_sopclass and str(original_sopclass) in sopclass_map: + new_uid, desc = sopclass_map[str(original_sopclass)] + output_ds.SOPClassUID = new_uid + logger.info(f" ✓ Set SOPClassUID to {desc}") + else: + logger.info(f" Keeping original SOPClassUID: {original_sopclass}") + + # Keep pixel spacing and slice thickness + if hasattr(datasets[0], "PixelSpacing"): + output_ds.PixelSpacing = datasets[0].PixelSpacing + logger.info(f" ✓ PixelSpacing: {output_ds.PixelSpacing}") + + if hasattr(datasets[0], "SliceThickness"): + output_ds.SliceThickness = datasets[0].SliceThickness + logger.info(f" ✓ SliceThickness: {output_ds.SliceThickness}") + + # Fix InstanceNumber (should be >= 1) + output_ds.InstanceNumber = 1 + + # Ensure SeriesNumber is present + if not hasattr(output_ds, "SeriesNumber"): + output_ds.SeriesNumber = 1 + + # Remove per-frame tags that conflict with multi-frame + if hasattr(output_ds, "SliceLocation"): + delattr(output_ds, "SliceLocation") + logger.info(f" ✓ Removed SliceLocation (per-frame tag)") + + # Add SpacingBetweenSlices + if len(datasets) > 1: + pos0 = datasets[0].ImagePositionPatient if hasattr(datasets[0], "ImagePositionPatient") else None + pos1 = datasets[1].ImagePositionPatient if hasattr(datasets[1], "ImagePositionPatient") else None + + if pos0 and pos1: + # Calculate spacing as distance between consecutive slices + import math + + spacing = math.sqrt(sum((float(pos1[i]) - float(pos0[i])) ** 2 for i in range(3))) + output_ds.SpacingBetweenSlices = spacing + logger.info(f" ✓ Added SpacingBetweenSlices: {spacing:.6f} mm") + + # Add PerFrameFunctionalGroupsSequence for OHIF/Cornerstone3D compatibility + # CRITICAL: Structure must be exactly right to avoid "1/Infinity" MPR display bug + # - Per-frame: PlanePositionSequence ONLY (unique position per frame) + # - Shared: PlaneOrientationSequence (common orientation for all frames) + logger.info(f" Adding per-frame functional groups (OHIF-compatible structure)...") + from pydicom.dataset import Dataset as DicomDataset + from pydicom.sequence import Sequence + + per_frame_seq = [] + for frame_idx, ds_frame in enumerate(datasets): + frame_item = DicomDataset() + + # PlanePositionSequence - ImagePositionPatient for this frame + # This is MANDATORY for Enhanced CT multi-frame + plane_pos_item = DicomDataset() + # Use saved value (before it was deleted from datasets) + if original_image_positions[frame_idx] is not None: + plane_pos_item.ImagePositionPatient = original_image_positions[frame_idx] + else: + # If missing, use default (0,0,frame_idx * spacing) + # This shouldn't happen for valid CT series, but ensures MPR compatibility + default_spacing = ( + float(output_ds.SpacingBetweenSlices) if hasattr(output_ds, "SpacingBetweenSlices") else 1.0 + ) + plane_pos_item.ImagePositionPatient = [0.0, 0.0, frame_idx * default_spacing] + logger.warning(f" Frame {frame_idx} missing ImagePositionPatient, using default") + frame_item.PlanePositionSequence = Sequence([plane_pos_item]) + + # CRITICAL: Do NOT add per-frame PlaneOrientationSequence! + # PlaneOrientationSequence should ONLY be in SharedFunctionalGroupsSequence + # Having it per-frame triggers different parsing logic in OHIF/Cornerstone3D + # Result: metadata not read correctly, spacing[2] = 0 + # (The orientation is shared across all frames anyway) + + # FrameContentSequence - helps with frame identification + frame_content_item = DicomDataset() + frame_content_item.StackID = "1" + frame_content_item.InStackPositionNumber = frame_idx + 1 + frame_content_item.DimensionIndexValues = [1, frame_idx + 1] + frame_item.FrameContentSequence = Sequence([frame_content_item]) + + per_frame_seq.append(frame_item) + + output_ds.PerFrameFunctionalGroupsSequence = Sequence(per_frame_seq) + logger.info(f" ✓ Added PerFrameFunctionalGroupsSequence with {len(per_frame_seq)} frame items") + logger.info(f" Each frame includes: PlanePositionSequence only (orientation in shared)") + + # Add SharedFunctionalGroupsSequence for additional Cornerstone3D compatibility + # This defines attributes that are common to ALL frames + shared_item = DicomDataset() + + # PlaneOrientationSequence - MANDATORY for Enhanced CT multi-frame + shared_orient_item = DicomDataset() + # Use saved value (before it was deleted from datasets) + if original_image_orientation is not None: + shared_orient_item.ImageOrientationPatient = original_image_orientation + else: + # If missing, use standard axial orientation + # This ensures MPR button is enabled in OHIF + shared_orient_item.ImageOrientationPatient = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0] + logger.warning(f" Source files missing ImageOrientationPatient, using standard axial orientation") + shared_item.PlaneOrientationSequence = Sequence([shared_orient_item]) + + # PixelMeasuresSequence - pixel spacing and slice thickness + if hasattr(datasets[0], "PixelSpacing") or hasattr(datasets[0], "SliceThickness"): + pixel_measures_item = DicomDataset() + if hasattr(datasets[0], "PixelSpacing"): + pixel_measures_item.PixelSpacing = datasets[0].PixelSpacing + if hasattr(datasets[0], "SliceThickness"): + pixel_measures_item.SliceThickness = datasets[0].SliceThickness + if hasattr(output_ds, "SpacingBetweenSlices"): + pixel_measures_item.SpacingBetweenSlices = output_ds.SpacingBetweenSlices + shared_item.PixelMeasuresSequence = Sequence([pixel_measures_item]) + + output_ds.SharedFunctionalGroupsSequence = Sequence([shared_item]) + logger.info(f" ✓ Added SharedFunctionalGroupsSequence (common attributes for all frames)") + logger.info(f" Includes PlaneOrientationSequence (ONLY location for orientation!)") + + # Verify frame ordering + if len(per_frame_seq) > 0: + first_frame_pos = ( + per_frame_seq[0].PlanePositionSequence[0].ImagePositionPatient + if hasattr(per_frame_seq[0], "PlanePositionSequence") + else None + ) + last_frame_pos = ( + per_frame_seq[-1].PlanePositionSequence[0].ImagePositionPatient + if hasattr(per_frame_seq[-1], "PlanePositionSequence") + else None + ) + + if first_frame_pos and last_frame_pos: + logger.info(f" ✓ Frame ordering verification:") + logger.info(f" Frame[0] Z = {first_frame_pos[2]} (should match top-level)") + logger.info(f" Frame[{len(per_frame_seq)-1}] Z = {last_frame_pos[2]} (last slice)") + + # Verify top-level matches Frame[0] + if hasattr(output_ds, "ImagePositionPatient"): + if abs(float(output_ds.ImagePositionPatient[2]) - float(first_frame_pos[2])) < 0.001: + logger.info(f" ✅ Top-level ImagePositionPatient matches Frame[0]") + else: + logger.error( + f" ❌ MISMATCH: Top-level Z={output_ds.ImagePositionPatient[2]} != Frame[0] Z={first_frame_pos[2]}" + ) + + logger.info(f" ✓ Created multi-frame with {total_frame_count} frames (OHIF-compatible)") + if encoded_frames_bytes is not None: + logger.info(f" ✓ Basic Offset Table included for efficient frame access") + + # Create output directory structure + study_output_dir = os.path.join(output_dir, study_uid) + os.makedirs(study_output_dir, exist_ok=True) + + # Save as single multi-frame file + output_file = os.path.join(study_output_dir, f"{series_uid}.dcm") + output_ds.save_as(output_file, enforce_file_format=False) + + logger.info(f" ✓ Saved multi-frame file: {output_file}") + processed_series += 1 + total_frames += total_frame_count + + except Exception as e: + logger.error(f"Failed to process series {series_uid}: {e}") + import traceback + + traceback.print_exc() + continue + + elapsed_time = time.time() - start_time + + if convert_to_htj2k: + logger.info(f"\nMulti-frame HTJ2K conversion complete:") + else: + logger.info(f"\nMulti-frame DICOM conversion complete:") + logger.info(f" Total series processed: {processed_series}") + logger.info(f" Total frames combined: {total_frames}") + if convert_to_htj2k: + logger.info(f" Format: HTJ2K compressed") + else: + logger.info(f" Format: Original transfer syntax preserved") + logger.info(f" Time elapsed: {elapsed_time:.2f} seconds") + logger.info(f" Output directory: {output_dir}") + + return output_dir diff --git a/monailabel/datastore/utils/convert_multiframe.py b/monailabel/datastore/utils/convert_multiframe.py new file mode 100644 index 000000000..0cd270a6c --- /dev/null +++ b/monailabel/datastore/utils/convert_multiframe.py @@ -0,0 +1,1465 @@ +# Copyright 2025 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for converting legacy DICOM series to enhanced multi-frame format. + +This module provides tools to convert series of single-frame DICOM files into +single multi-frame enhanced DICOM files, with optional HTJ2K compression for +improved storage efficiency. + +Key Features: +- Convert legacy CT/MR/PT series to enhanced multi-frame format using highdicom +- Optional HTJ2K (High-Throughput JPEG 2000) lossless compression +- Batch processing of multiple series with automatic grouping by SeriesInstanceUID +- Preserve or generate new SeriesInstanceUID +- Handle unsupported modalities (MG, US, XA) by transcoding or copying +- Comprehensive statistics including frame counts and compression ratios + +Enhanced DICOM multi-frame format benefits: +- Single file instead of hundreds of individual files +- Better organization and metadata structure +- More efficient I/O operations +- Standards-compliant with DICOM Part 3 + +Supported modalities for enhanced conversion: +- CT (Computed Tomography) +- MR (Magnetic Resonance) +- PT (Positron Emission Tomography) + +Unsupported modalities (MG, US, XA, etc.) can be: +- Transcoded to HTJ2K (preserving original format) +- Copied without modification + +Example: + >>> from monailabel.datastore.utils.convert_multiframe import ( + ... convert_to_enhanced_dicom, + ... batch_convert_by_series, + ... convert_and_convert_to_htj2k, + ... ) + >>> + >>> # Single series conversion (preserves original SeriesInstanceUID by default) + >>> convert_to_enhanced_dicom( + ... input_source="/path/to/legacy/ct/series", + ... output_file="/path/to/output/enhanced.dcm" + ... ) + >>> + >>> # Convert with HTJ2K compression + >>> convert_and_convert_to_htj2k( + ... input_source="/path/to/legacy/ct/series", + ... output_file="/path/to/output/enhanced_htj2k.dcm", + ... num_resolutions=6 + ... ) + >>> + >>> # Batch convert multiple series with HTJ2K + >>> import pydicom + >>> from pathlib import Path + >>> + >>> # Collect DICOM files + >>> input_dir = Path("/path/to/mixed/dicoms") + >>> input_files = [str(f) for f in input_dir.rglob("*.dcm")] + >>> + >>> # Create file_loader + >>> file_loader = [(input_files, "/path/to/output")] + >>> + >>> # Batch convert + >>> stats = batch_convert_by_series( + ... file_loader=file_loader, + ... compress_htj2k=True, + ... num_resolutions=6 + ... ) + >>> print(f"Processed {stats['total_frames_input']} frames") + >>> print(f"Converted {stats['converted_to_multiframe']} series to multi-frame") +""" + +import logging +import os +import shutil +import tempfile +import warnings +from contextlib import contextmanager +from pathlib import Path +from typing import Iterable, List, Optional, Union + +import numpy as np +import pydicom +import pydicom.config +from pydicom.uid import generate_uid + +pydicom.config.assume_implicit_vr_switch = True + +logger = logging.getLogger(__name__) + +# Constants for DICOM modalities +SUPPORTED_MODALITIES = {"CT", "MR", "PT"} + +# Transfer syntax UIDs +EXPLICIT_VR_LITTLE_ENDIAN = "1.2.840.10008.1.2.1" +IMPLICIT_VR_LITTLE_ENDIAN = "1.2.840.10008.1.2" + + +def _check_highdicom_available(): + """Check if highdicom is installed.""" + try: + import highdicom + + return True + except ImportError: + return False + + +@contextmanager +def _suppress_highdicom_warnings(): + """ + Context manager to suppress common highdicom warnings. + + Suppresses warnings like: + - "unknown derived pixel contrast" + - Other non-critical highdicom warnings + + This suppresses both Python warnings and logging-based warnings from highdicom. + """ + # Suppress Python warnings + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=".*unknown derived pixel contrast.*") + warnings.filterwarnings("ignore", category=UserWarning, module="highdicom.*") + + # Suppress highdicom logging warnings + highdicom_logger = logging.getLogger("highdicom") + highdicom_legacy_logger = logging.getLogger("highdicom.legacy") + highdicom_sop_logger = logging.getLogger("highdicom.legacy.sop") + + # Save original log levels + original_level = highdicom_logger.level + original_legacy_level = highdicom_legacy_logger.level + original_sop_level = highdicom_sop_logger.level + + try: + # Temporarily set to ERROR to suppress WARNING messages + highdicom_logger.setLevel(logging.ERROR) + highdicom_legacy_logger.setLevel(logging.ERROR) + highdicom_sop_logger.setLevel(logging.ERROR) + yield + finally: + # Restore original log levels + highdicom_logger.setLevel(original_level) + highdicom_legacy_logger.setLevel(original_legacy_level) + highdicom_sop_logger.setLevel(original_sop_level) + + +def _load_dicom_series(input_source: Union[str, Path, List[Union[str, Path]]]) -> List[pydicom.Dataset]: + """ + Load DICOM files from a directory or list of file paths and sort them by spatial position. + + Args: + input_source: Either: + - Directory path containing DICOM files + - List of DICOM file paths + + Returns: + List of sorted pydicom.Dataset objects + + Raises: + ValueError: If no DICOM files found or files have inconsistent metadata + """ + # Handle different input types + if isinstance(input_source, (list, tuple)): + # List of file paths provided + file_paths = [Path(f) for f in input_source] + source_desc = f"{len(file_paths)} provided files" + else: + # Directory path provided + input_dir = Path(input_source) + if not input_dir.is_dir(): + raise ValueError(f"Input path is not a directory: {input_dir}") + + # Find all DICOM files in directory + file_paths = [f for f in input_dir.iterdir() if f.is_file() and not f.name.startswith(".")] + source_desc = f"directory {input_dir}" + + # Load DICOM files + dicom_files = [] + for filepath in file_paths: + try: + ds = pydicom.dcmread(filepath) + dicom_files.append(ds) + except Exception as e: + logger.debug(f"Skipping non-DICOM file {filepath.name}: {e}") + continue + + if not dicom_files: + raise ValueError(f"No DICOM files found in {source_desc}") + + logger.info(f"Loaded {len(dicom_files)} DICOM files from {source_desc}") + + # Sort by ImagePositionPatient if available + if all(hasattr(ds, "ImagePositionPatient") and hasattr(ds, "ImageOrientationPatient") for ds in dicom_files): + # Calculate distance along normal vector for each slice + first_ds = dicom_files[0] + orientation = np.array(first_ds.ImageOrientationPatient).reshape(2, 3) + normal = np.cross(orientation[0], orientation[1]) + + def get_position_along_normal(ds): + position = np.array(ds.ImagePositionPatient) + return np.dot(position, normal) + + dicom_files.sort(key=get_position_along_normal) + logger.info("Sorted files by spatial position") + elif all(hasattr(ds, "InstanceNumber") for ds in dicom_files): + # Fall back to InstanceNumber + dicom_files.sort(key=lambda ds: ds.InstanceNumber) + logger.info("Sorted files by InstanceNumber") + else: + logger.warning("Could not determine optimal sorting order, using file order") + + return dicom_files + + +def _validate_series_consistency(datasets: List[pydicom.Dataset]) -> dict: + """ + Validate that all datasets in a series are consistent. + + Args: + datasets: List of pydicom.Dataset objects + + Returns: + Dictionary with series metadata + + Raises: + ValueError: If datasets are inconsistent + """ + if not datasets: + raise ValueError("Empty dataset list") + + first_ds = datasets[0] + + # Check modality + modality = getattr(first_ds, "Modality", None) + if not modality: + raise ValueError("First dataset missing Modality tag") + + if modality not in SUPPORTED_MODALITIES: + raise ValueError( + f"Unsupported modality: {modality}. " f"Supported modalities are: {', '.join(SUPPORTED_MODALITIES)}" + ) + + # Required attributes that must be consistent + required_attrs = ["Rows", "Columns", "Modality"] + optional_consistent_attrs = [ + "SeriesInstanceUID", + "StudyInstanceUID", + "PatientID", + "PixelSpacing", + "ImageOrientationPatient", + ] + + # Collect metadata from first dataset + metadata = { + "modality": modality, + "rows": first_ds.Rows, + "columns": first_ds.Columns, + "num_frames": len(datasets), + } + + # Check consistency across all datasets + for attr in required_attrs: + if not all(hasattr(ds, attr) and getattr(ds, attr) == getattr(first_ds, attr) for ds in datasets): + raise ValueError(f"Inconsistent {attr} values across series") + + # Collect optional metadata + for attr in optional_consistent_attrs: + if hasattr(first_ds, attr): + metadata[attr.lower()] = getattr(first_ds, attr) + + logger.info( + f"Series validated: {modality} {metadata['rows']}x{metadata['columns']}, " f"{metadata['num_frames']} frames" + ) + + return metadata + + +def _fix_dicom_datetime_attributes(datasets: List[pydicom.Dataset]) -> None: + """ + Fix malformed date/time attributes in DICOM datasets. + + Some legacy DICOM files have date/time values stored as strings in non-standard + formats. This function converts valid date strings to proper Python date objects + and removes invalid ones. This is necessary because highdicom expects proper + date/time objects, not strings. + + Args: + datasets: List of pydicom.Dataset objects to modify in-place + """ + from datetime import date, datetime, time + + fixed_attrs = set() + + for ds in datasets: + # List of date/time attributes that might need fixing + date_attrs = ["StudyDate", "SeriesDate", "AcquisitionDate", "ContentDate"] + time_attrs = ["StudyTime", "SeriesTime", "AcquisitionTime", "ContentTime"] + + # Fix date attributes - convert strings to date objects + for attr in date_attrs: + if hasattr(ds, attr): + value = getattr(ds, attr) + # If it's already a proper date/datetime object, skip + if isinstance(value, (date, datetime)): + continue + # If it's a string, try to convert it to a date object + if isinstance(value, str) and value: + try: + # DICOM date format is YYYYMMDD + if len(value) >= 8 and value[:8].isdigit(): + year = int(value[0:4]) + month = int(value[4:6]) + day = int(value[6:8]) + date_obj = date(year, month, day) + setattr(ds, attr, date_obj) + fixed_attrs.add(f"{attr} (converted to date)") + else: + # Invalid format, remove it + delattr(ds, attr) + fixed_attrs.add(f"{attr} (removed)") + except (ValueError, IndexError) as e: + # Invalid date values, remove it + delattr(ds, attr) + fixed_attrs.add(f"{attr} (removed - invalid)") + elif not value: + # Empty string, remove it + delattr(ds, attr) + fixed_attrs.add(f"{attr} (removed - empty)") + + # Fix time attributes - convert strings to time objects + for attr in time_attrs: + if hasattr(ds, attr): + value = getattr(ds, attr) + # If it's already a proper time/datetime object, skip + if isinstance(value, (time, datetime)): + continue + # If it's a string, try to convert it to a time object + if isinstance(value, str) and value: + try: + # DICOM time format is HHMMSS.FFFFFF or HHMMSS + # Clean up the string + time_str = value.replace(":", "") + + if "." in time_str: + parts = time_str.split(".") + main_part = parts[0] + frac_part = parts[1] if len(parts) > 1 else "0" + else: + main_part = time_str + frac_part = "0" + + # Parse hours, minutes, seconds + if len(main_part) >= 2: + hour = int(main_part[0:2]) + minute = int(main_part[2:4]) if len(main_part) >= 4 else 0 + second = int(main_part[4:6]) if len(main_part) >= 6 else 0 + microsecond = int(frac_part[:6].ljust(6, "0")) if frac_part else 0 + + time_obj = time(hour, minute, second, microsecond) + setattr(ds, attr, time_obj) + fixed_attrs.add(f"{attr} (converted to time)") + else: + # Too short to be valid, remove it + delattr(ds, attr) + fixed_attrs.add(f"{attr} (removed)") + except (ValueError, IndexError) as e: + # Invalid time values, remove it + delattr(ds, attr) + fixed_attrs.add(f"{attr} (removed - invalid)") + elif not value: + # Empty string, remove it + delattr(ds, attr) + fixed_attrs.add(f"{attr} (removed - empty)") + + if fixed_attrs: + logger.info( + f"Converted/fixed date/time attributes: {len([a for a in fixed_attrs if 'converted' in a])} converted, " + f"{len([a for a in fixed_attrs if 'removed' in a])} removed" + ) + + +def _ensure_required_attributes(datasets: List[pydicom.Dataset]) -> None: + """ + Ensure that all datasets have the required attributes for enhanced multi-frame conversion. + + If required attributes are missing, they are added with sensible default values. + This is necessary because the DICOM enhanced multi-frame standard requires certain + attributes that may be missing from legacy DICOM files. + + Args: + datasets: List of pydicom.Dataset objects to modify in-place + """ + # Required attributes and their default values + required_attrs = { + "Manufacturer": "Unknown", + "ManufacturerModelName": "Unknown", + "DeviceSerialNumber": "Unknown", + "SoftwareVersions": "Unknown", + } + + # Check and add missing attributes to all datasets + added_attrs = set() + for ds in datasets: + for attr, default_value in required_attrs.items(): + if not hasattr(ds, attr): + setattr(ds, attr, default_value) + added_attrs.add(attr) + + if added_attrs: + logger.info(f"Added missing required attributes with default values: {', '.join(sorted(added_attrs))}") + + +def _transcode_files_to_htj2k( + file_paths: List[Path], + output_dir: Path, + compression_kwargs: dict, +) -> tuple[bool, float]: + """ + Transcode DICOM files to HTJ2K format (helper function). + + This function handles HTJ2K transcoding for files that cannot be converted to + enhanced multi-frame format (e.g., unsupported modalities like MG, US, XA). + The original file format is preserved, only the pixel data is compressed with HTJ2K. + + Args: + file_paths: List of input DICOM file paths to transcode + output_dir: Output directory for transcoded files + compression_kwargs: Dictionary with HTJ2K compression parameters: + - num_resolutions (int): Wavelet decomposition levels (default: 6) + - code_block_size (tuple): Code block size (default: (64, 64)) + - progression_order (str): JPEG2K progression order (default: 'RPCL') + + Returns: + Tuple of (success: bool, output_size_mb: float): + - success: True if transcoding succeeded, False otherwise + - output_size_mb: Total size of transcoded files in MB (0.0 if failed) + """ + try: + from monailabel.datastore.utils.convert_htj2k import transcode_dicom_to_htj2k + + # Prepare file pairs for transcoding (input -> output) + input_files = [] + output_files = [] + + for file_path in file_paths: + output_path = output_dir / file_path.name + output_path.parent.mkdir(parents=True, exist_ok=True) + + input_files.append(str(file_path)) + output_files.append(str(output_path)) + + # Transcode with HTJ2K + file_loader = [(input_files, output_files)] + num_resolutions = compression_kwargs.get("num_resolutions", 6) + code_block_size = compression_kwargs.get("code_block_size", (64, 64)) + progression_order = compression_kwargs.get("progression_order", "RPCL") + + transcode_dicom_to_htj2k( + file_loader=file_loader, + num_resolutions=num_resolutions, + code_block_size=code_block_size, + progression_order=progression_order, + ) + + # Calculate output size + transcoded_size = sum(Path(f).stat().st_size for f in output_files if Path(f).exists()) + transcoded_size_mb = transcoded_size / (1024 * 1024) + + return True, transcoded_size_mb + + except Exception as e: + logger.error(f"HTJ2K transcoding failed: {e}") + return False, 0.0 + + +def _copy_files( + file_paths: List[Path], + output_dir: Path, +) -> int: + """ + Copy DICOM files to output directory. + + Args: + file_paths: List of input file paths + output_dir: Output directory + + Returns: + Number of files successfully copied + """ + copied_count = 0 + for file_path in file_paths: + try: + output_path = output_dir / file_path.name + output_path.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(file_path, output_path) + copied_count += 1 + except Exception as e: + logger.error(f"Failed to copy {file_path.name}: {e}") + + return copied_count + + +def convert_to_enhanced_dicom( + input_source: Union[str, Path, List[Union[str, Path]]], + output_file: Union[str, Path], + transfer_syntax_uid: Optional[str] = None, + validate_only: bool = False, + preserve_series_uid: bool = True, + show_stats: bool = True, +) -> bool: + """ + Convert legacy DICOM series to enhanced multi-frame DICOM. + + Args: + input_source: Either: + - Directory path containing legacy DICOM files (single-frame per file) + - List of DICOM file paths to convert as a single series + output_file: Path for output enhanced DICOM file + transfer_syntax_uid: Transfer syntax for output. If None, uses Explicit VR Little Endian. + validate_only: If True, only validate series without creating output file. + preserve_series_uid: If True (default), preserve the original SeriesInstanceUID from + the legacy datasets. If False, generate a new SeriesInstanceUID for the enhanced series. + show_stats: If True (default), display conversion statistics. Set to False to suppress output. + + Returns: + True if successful, False otherwise + + Raises: + ImportError: If highdicom is not installed + ValueError: If series is invalid or inconsistent + FileNotFoundError: If input directory doesn't exist + + Example: + >>> # Convert CT series from directory + >>> convert_to_enhanced_dicom( + ... input_source="./ct_series/", + ... output_file="./enhanced_ct.dcm" + ... ) + + >>> # Convert from list of files (file_loader pattern) + >>> file_paths = ['/data/ct_001.dcm', '/data/ct_002.dcm', '/data/ct_003.dcm'] + >>> convert_to_enhanced_dicom( + ... input_source=file_paths, + ... output_file="./enhanced_ct.dcm" + ... ) + + >>> # Convert with specific transfer syntax + >>> convert_to_enhanced_dicom( + ... input_source="./mr_series/", + ... output_file="./enhanced_mr.dcm", + ... transfer_syntax_uid="1.2.840.10008.1.2.4.202" # HTJ2K + ... ) + """ + if not _check_highdicom_available(): + raise ImportError("highdicom is not installed. Install it with: pip install highdicom") + + import highdicom + from highdicom.legacy import ( + LegacyConvertedEnhancedCTImage, + LegacyConvertedEnhancedMRImage, + LegacyConvertedEnhancedPETImage, + ) + + output_file = Path(output_file) + + # Set default transfer syntax + if transfer_syntax_uid is None: + transfer_syntax_uid = EXPLICIT_VR_LITTLE_ENDIAN + + # Describe input source for logging + if isinstance(input_source, (list, tuple)): + input_desc = f"{len(input_source)} files" + else: + input_desc = str(Path(input_source)) + + logger.info(f"Converting legacy DICOM series to enhanced multi-frame format") + logger.info(f" Input: {input_desc}") + if not validate_only: + logger.info(f" Output: {output_file}") + logger.info(f" Transfer Syntax: {transfer_syntax_uid}") + + try: + # Load and sort DICOM files + datasets = _load_dicom_series(input_source) + + # Validate consistency + metadata = _validate_series_consistency(datasets) + detected_modality = metadata["modality"] + + if validate_only: + logger.info("Validation successful (validate_only=True, not creating output file)") + return True + + # Create output directory if needed + output_file.parent.mkdir(parents=True, exist_ok=True) + + # Extract SeriesInstanceUID from legacy datasets (preserve original if requested) + # This maintains traceability between legacy and enhanced series + original_series_uid = metadata.get("seriesinstanceuid") + if preserve_series_uid and original_series_uid: + series_uid = original_series_uid + logger.info(f"Preserving original SeriesInstanceUID: {series_uid}") + else: + series_uid = generate_uid() + if preserve_series_uid and not original_series_uid: + logger.warning("SeriesInstanceUID not found in legacy datasets, generating new UID") + logger.info(f"Generated new SeriesInstanceUID: {series_uid}") + + # Extract SeriesNumber and InstanceNumber from legacy datasets (use original if available) + # Convert to native Python int (highdicom requires Python int, not pydicom IS/DS types) + first_ds = datasets[0] + series_number = int(getattr(first_ds, "SeriesNumber", 1)) + if series_number < 1: + logger.warning(f"SeriesNumber was {series_number}, using default value: 1") + series_number = 1 + instance_number = int(getattr(first_ds, "InstanceNumber", 1)) + if instance_number < 1: + logger.warning(f"InstanceNumber was {instance_number}, using default value: 1") + instance_number = 1 + + # Note: highdicom's LegacyConverted* classes automatically preserve other important + # metadata from the legacy datasets including: + # - StudyInstanceUID + # - PatientID, PatientName, PatientBirthDate, PatientSex + # - StudyDate, StudyTime, StudyDescription + # - Pixel spacing, slice spacing, image orientation/position + # - And many other standard DICOM attributes + + # Fix any malformed date/time attributes that might cause issues + _fix_dicom_datetime_attributes(datasets) + + # Add missing required attributes with default values if needed + # The enhanced multi-frame DICOM standard requires these attributes + _ensure_required_attributes(datasets) + + # Convert based on modality + logger.info(f"Converting {detected_modality} series with {len(datasets)} frames...") + + # Generate a NEW SOP Instance UID for the enhanced multi-frame DICOM + # Note: We do NOT use an original SOP Instance UID because: + # 1. This is a new DICOM instance (different SOP Class) + # 2. We're combining multiple instances (each with their own SOP Instance UID) into one + # 3. DICOM standard requires each instance to have a unique identifier + new_sop_instance_uid = generate_uid() + + # Suppress common highdicom warnings during conversion + with _suppress_highdicom_warnings(): + if detected_modality == "CT": + enhanced = LegacyConvertedEnhancedCTImage( + legacy_datasets=datasets, + series_instance_uid=series_uid, + series_number=series_number, + sop_instance_uid=new_sop_instance_uid, + instance_number=instance_number, + ) + elif detected_modality == "MR": + enhanced = LegacyConvertedEnhancedMRImage( + legacy_datasets=datasets, + series_instance_uid=series_uid, + series_number=series_number, + sop_instance_uid=new_sop_instance_uid, + instance_number=instance_number, + ) + elif detected_modality == "PT": + enhanced = LegacyConvertedEnhancedPETImage( + legacy_datasets=datasets, + series_instance_uid=series_uid, + series_number=series_number, + sop_instance_uid=new_sop_instance_uid, + instance_number=instance_number, + ) + else: + raise ValueError(f"Unsupported modality: {detected_modality}") + + # Set transfer syntax + enhanced.file_meta.TransferSyntaxUID = transfer_syntax_uid + + # After highdicom creates the enhanced image, ALSO set top-level tags + # for DICOMweb compatibility + + # Add top-level Window/Level (from first legacy dataset) + if hasattr(datasets[0], "WindowCenter"): + enhanced.WindowCenter = datasets[0].WindowCenter + if hasattr(datasets[0], "WindowWidth"): + enhanced.WindowWidth = datasets[0].WindowWidth + + # Add top-level Rescale parameters + if hasattr(datasets[0], "RescaleSlope"): + enhanced.RescaleSlope = datasets[0].RescaleSlope + if hasattr(datasets[0], "RescaleIntercept"): + enhanced.RescaleIntercept = datasets[0].RescaleIntercept + + # Save the enhanced DICOM file + enhanced.save_as(str(output_file), enforce_file_format=False) + + # Calculate statistics + output_size_bytes = output_file.stat().st_size + output_size_mb = output_size_bytes / (1024 * 1024) + + # Calculate original combined size + original_size_bytes = 0 + for ds in datasets: + if hasattr(ds, "filename") and ds.filename: + try: + original_size_bytes += Path(ds.filename).stat().st_size + except Exception: + pass + + original_size_mb = original_size_bytes / (1024 * 1024) + + # Calculate compression statistics + if original_size_bytes > 0: + compression_ratio = original_size_bytes / output_size_bytes + size_reduction_pct = ((original_size_bytes - output_size_bytes) / original_size_bytes) * 100 + else: + compression_ratio = 0.0 + size_reduction_pct = 0.0 + + # Display results (only if show_stats is True) + if show_stats: + logger.info(f"✓ Successfully created enhanced DICOM file: {output_file}") + logger.info(f"") + logger.info(f" Statistics:") + logger.info(f" Original files: {len(datasets)} files, {original_size_mb:.2f} MB") + logger.info(f" Output file: 1 file, {output_size_mb:.2f} MB") + if original_size_bytes > 0: + if output_size_bytes < original_size_bytes: + logger.info(f" Size reduction: {size_reduction_pct:.1f}% smaller") + logger.info(f" Compression: {compression_ratio:.2f}x") + else: + size_increase_pct = ((output_size_bytes - original_size_bytes) / original_size_bytes) * 100 + logger.info(f" Size increase: {size_increase_pct:.1f}% larger") + logger.info(f" Ratio: {1/compression_ratio:.2f}x") + logger.info(f"") + logger.info(f" Image info:") + logger.info(f" Frames: {len(datasets)}") + logger.info(f" Dimensions: {metadata['rows']}x{metadata['columns']}") + + return True + + except AttributeError as e: + logger.error( + f"Failed to convert DICOM series - missing required DICOM attribute: {e}\n" + f"The legacy DICOM files may be missing required attributes such as:\n" + f" - Manufacturer\n" + f" - ManufacturerModelName\n" + f" - SoftwareVersions\n" + f"These attributes are required by the DICOM enhanced multi-frame standard.", + exc_info=True, + ) + return False + except Exception as e: + logger.error(f"Failed to convert DICOM series: {e}", exc_info=True) + return False + + +def validate_dicom_series(input_source: Union[str, Path, List[Union[str, Path]]]) -> bool: + """ + Validate that a DICOM series can be converted to enhanced format. + + Args: + input_source: Either directory path or list of DICOM file paths + + Returns: + True if series is valid, False otherwise + + Example: + >>> # Validate from directory + >>> if validate_dicom_series("./my_series/"): + ... print("Series is ready for conversion") + >>> + >>> # Validate from file list + >>> files = ['/data/ct_001.dcm', '/data/ct_002.dcm'] + >>> if validate_dicom_series(files): + ... print("Files are ready for conversion") + """ + try: + return convert_to_enhanced_dicom( + input_source=input_source, + output_file="dummy.dcm", # Not used in validate mode + validate_only=True, + ) + except Exception as e: + logger.error(f"Validation failed: {e}") + return False + + +def batch_convert_by_series( + file_loader: Iterable[tuple[list[str], str]], + preserve_series_uid: bool = True, + compress_htj2k: bool = False, + **compression_kwargs, +) -> dict: + """ + Group DICOM files by SeriesInstanceUID and convert each series to enhanced multi-frame. + + This function automatically detects all unique DICOM series from the provided files + and converts each series to a separate enhanced multi-frame file. Useful when you have + multiple series mixed together. + + Output filenames are automatically generated based on metadata: + - Format: {Modality}_{SeriesInstanceUID}.dcm + - Examples: + - CT_1.2.826.0.1.3680043.8.274.1.1.8323329.686521.1629744176.620266.dcm + - MR_1.2.840.113619.2.55.3.123456789.101.20231127.143052.0.dcm + - PT_1.3.12.2.1107.5.1.4.12345.30000023110710323456789.dcm + + Unsupported modalities (e.g., MG, US, XA) cannot be converted to enhanced multi-frame + format, but are still processed: + - If compress_htj2k=True: Transcoded to HTJ2K (compressed, original format preserved) + - If compress_htj2k=False: Copied without modification + Original subdirectory structure is preserved in both cases. + + Args: + file_loader: + Iterable of (input_files, output_dir) tuples, where: + - input_files: List[str] of input DICOM file paths to process + - output_dir: str output directory path for this batch + + Each yielded tuple specifies a batch of files to scan and the output directory. + Files from all batches will be grouped by SeriesInstanceUID before conversion. + + Example: + >>> # Simple usage with one batch + >>> file_loader = [ + ... (['/data/ct_001.dcm', '/data/ct_002.dcm', '/data/mr_001.dcm'], '/output') + ... ] + >>> stats = batch_convert_by_series(file_loader) + + >>> # Multiple batches from different sources + >>> file_loader = [ + ... (['/data1/file1.dcm', '/data1/file2.dcm'], '/output'), + ... (['/data2/file3.dcm', '/data2/file4.dcm'], '/output'), + ... ] + >>> stats = batch_convert_by_series(file_loader) + preserve_series_uid: If True, preserve original SeriesInstanceUID + compress_htj2k: If True, compress output with HTJ2K + **compression_kwargs: Additional HTJ2K compression arguments (if compress_htj2k=True) + + Returns: + Dictionary with conversion statistics: + - 'total_series_input': Total number of unique series found + - 'total_series_output': Total number of series in output + - 'total_frames_input': Total number of frames (files) processed + - 'total_frames_output': Total number of frames in output + - 'total_size_input_mb': Total size of all input files in MB + - 'total_size_output_mb': Total size of all output files in MB + - 'converted_to_multiframe': Number of series converted to enhanced multi-frame + - 'transcoded_htj2k': Number of series transcoded to HTJ2K (unsupported modalities) + - 'copied': Number of series copied without compression (unsupported modalities) + - 'failed': Number of failed conversions + - 'series_info': List of dicts with per-series information + + Example: + >>> # Collect DICOM files from directory + >>> from pathlib import Path + >>> import pydicom + >>> + >>> input_dir = Path('/data/mixed_dicoms') + >>> input_files = [] + >>> for filepath in input_dir.rglob('*.dcm'): + ... try: + ... pydicom.dcmread(filepath, stop_before_pixels=True) + ... input_files.append(str(filepath)) + ... except: + ... pass # Skip non-DICOM files + >>> + >>> # Convert all series (uncompressed) + >>> file_loader = [(input_files, '/output')] + >>> stats = batch_convert_by_series(file_loader) + >>> print(f"Processed {stats['total_frames_input']} frames from {stats['total_series_input']} series") + >>> print(f"Converted {stats['converted_to_multiframe']} series to multi-frame") + >>> print(f"Compression: {stats['total_size_input_mb'] / stats['total_size_output_mb']:.2f}x") + + >>> # Convert with HTJ2K compression + >>> file_loader = [(input_files, '/output')] + >>> stats = batch_convert_by_series( + ... file_loader=file_loader, + ... compress_htj2k=True, + ... num_resolutions=6, + ... code_block_size=(64, 64), + ... progression_order='RPCL' + ... ) + >>> print(f"HTJ2K compressed: {stats['converted_to_multiframe']} series") + >>> print(f"Transcoded only: {stats['transcoded_htj2k']} series (unsupported modalities)") + """ + logger.info(f"") + logger.info(f"{'='*80}") + logger.info(f"Batch Converting DICOM Series") + logger.info(f"{'='*80}") + logger.info(f" HTJ2K compression: {'Yes' if compress_htj2k else 'No'}") + logger.info(f"") + + # Step 1: Collect all files from file_loader and group by SeriesInstanceUID + logger.info("Step 1: Scanning files and grouping by SeriesInstanceUID...") + series_files = {} # Maps SeriesInstanceUID -> list of file paths + series_metadata = {} # Maps SeriesInstanceUID -> metadata dict + series_output_dirs = {} # Maps SeriesInstanceUID -> output directory + output_dirs_set = set() # Track all output directories + + total_files_scanned = 0 + for input_files, output_dir_str in file_loader: + output_dir = Path(output_dir_str) + output_dirs_set.add(output_dir) + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + + for filepath_str in input_files: + filepath = Path(filepath_str) + total_files_scanned += 1 + + try: + ds = pydicom.dcmread(filepath, stop_before_pixels=True) + series_uid = getattr(ds, "SeriesInstanceUID", None) + + if series_uid: + if series_uid not in series_files: + series_files[series_uid] = [] + series_output_dirs[series_uid] = output_dir + # Store metadata from first file + series_metadata[series_uid] = { + "modality": getattr(ds, "Modality", "Unknown"), + "series_number": getattr(ds, "SeriesNumber", "N/A"), + "series_description": getattr(ds, "SeriesDescription", "N/A"), + "patient_id": getattr(ds, "PatientID", "N/A"), + } + series_files[series_uid].append(filepath) + except Exception as e: + logger.debug(f"Skipping file {filepath.name}: {e}") + continue + + total_series = len(series_files) + logger.info(f" Scanned {total_files_scanned} files") + logger.info(f" Found {total_series} unique series") + logger.info(f" Output directories: {len(output_dirs_set)}") + logger.info(f"") + + if total_series == 0: + logger.warning("No DICOM series found in input files") + return { + "total_series_input": 0, + "total_series_output": 0, + "total_frames_input": 0, + "total_frames_output": 0, + "total_size_input_mb": 0.0, + "total_size_output_mb": 0.0, + "converted_to_multiframe": 0, + "transcoded_htj2k": 0, + "copied": 0, + "failed": 0, + "series_info": [], + } + + # Step 2: Convert each series + stats = { + "total_series_input": total_series, + "total_series_output": 0, + "total_frames_input": 0, + "total_frames_output": 0, + "total_size_input_mb": 0.0, + "total_size_output_mb": 0.0, + "converted_to_multiframe": 0, + "transcoded_htj2k": 0, + "copied": 0, + "failed": 0, + "series_info": [], + } + + for idx, (series_uid, file_paths) in enumerate(series_files.items(), 1): + metadata = series_metadata[series_uid] + output_dir = series_output_dirs[series_uid] + num_files = len(file_paths) + + # Calculate input size for this series + series_input_size = sum(fp.stat().st_size for fp in file_paths) + series_input_size_mb = series_input_size / (1024 * 1024) + + # Always count input + stats["total_size_input_mb"] += series_input_size_mb + stats["total_frames_input"] += num_files + + logger.info(f"") + logger.info(f"{'-'*80}") + logger.info(f"Series {idx}/{total_series}") + logger.info(f"{'-'*80}") + logger.info(f" SeriesInstanceUID: {series_uid}") + logger.info(f" Modality: {metadata['modality']}") + logger.info(f" SeriesNumber: {metadata['series_number']}") + logger.info(f" SeriesDescription: {metadata['series_description']}") + logger.info(f" PatientID: {metadata['patient_id']}") + logger.info(f" Number of files: {num_files}") + logger.info(f" Input size: {series_input_size_mb:.2f} MB") + logger.info(f" Output directory: {output_dir}") + logger.info(f"") + + # Check if modality is supported for enhanced multi-frame conversion + if metadata["modality"] not in SUPPORTED_MODALITIES: + logger.info(f" Unsupported modality for enhanced conversion: {metadata['modality']}") + logger.info(f" Supported for enhanced conversion: {', '.join(SUPPORTED_MODALITIES)}") + + if compress_htj2k: + # Transcode to HTJ2K even though we can't convert to enhanced format + logger.info(f" Transcoding to HTJ2K (preserving original format)...") + + success, transcoded_size_mb = _transcode_files_to_htj2k(file_paths, output_dir, compression_kwargs) + + if success: + logger.info(f" Transcoded {num_files} files with HTJ2K") + logger.info(f" Input size: {series_input_size_mb:.2f} MB") + logger.info(f" Output size: {transcoded_size_mb:.2f} MB") + if series_input_size_mb > 0: + compression = series_input_size_mb / transcoded_size_mb + logger.info(f" Compression: {compression:.2f}x") + + # Update transcoded HTJ2K statistics + stats["transcoded_htj2k"] += 1 + stats["total_series_output"] += 1 + stats["total_frames_output"] += num_files + stats["total_size_output_mb"] += transcoded_size_mb + + stats["series_info"].append( + { + "series_uid": series_uid, + "status": "transcoded_htj2k", + "reason": f"Unsupported modality: {metadata['modality']} (HTJ2K compressed)", + "num_files": num_files, + "input_size_mb": series_input_size_mb, + "output_size_mb": transcoded_size_mb, + } + ) + else: + # Fall back to copying + logger.info(f" Falling back to copying files without compression...") + copied_count = _copy_files(file_paths, output_dir) + + logger.info(f" Copied {copied_count}/{num_files} files") + stats["copied"] += 1 + stats["total_series_output"] += 1 + stats["total_frames_output"] += num_files + stats["total_size_output_mb"] += series_input_size_mb + + stats["series_info"].append( + { + "series_uid": series_uid, + "status": "copied", + "reason": f"Unsupported modality: {metadata['modality']} (copied, transcode failed)", + "num_files": num_files, + } + ) + else: + # No compression - just copy files + logger.info(f" Copying files to output directory...") + copied_count = _copy_files(file_paths, output_dir) + + logger.info(f" Copied {copied_count}/{num_files} files") + + # Update copied statistics + stats["copied"] += 1 + stats["total_series_output"] += 1 + stats["total_frames_output"] += num_files + stats["total_size_output_mb"] += series_input_size_mb + + stats["series_info"].append( + { + "series_uid": series_uid, + "status": "copied", + "reason": f"Unsupported modality: {metadata['modality']} (copied)", + "num_files": num_files, + } + ) + + continue + + # Create temporary directory for this series + with tempfile.TemporaryDirectory() as temp_series_dir: + temp_series_path = Path(temp_series_dir) + + # Copy files to temporary directory + for file_path in file_paths: + shutil.copy(file_path, temp_series_path / file_path.name) + + # Generate output filename: {Modality}_{SeriesInstanceUID}.dcm + output_filename = f"{metadata['modality']}_{series_uid}.dcm" + output_file = output_dir / output_filename + logger.info(f" Output filename: {output_filename}") + + # Convert the series + try: + if compress_htj2k: + success = convert_and_convert_to_htj2k( + input_source=temp_series_path, + output_file=output_file, + preserve_series_uid=preserve_series_uid, + **compression_kwargs, + ) + else: + success = convert_to_enhanced_dicom( + input_source=temp_series_path, + output_file=output_file, + preserve_series_uid=preserve_series_uid, + ) + + if success: + # Get output file size + output_size = output_file.stat().st_size + output_size_mb = output_size / (1024 * 1024) + + # Update conversion statistics + stats["converted_to_multiframe"] += 1 + stats["total_series_output"] += 1 + stats["total_frames_output"] += num_files # Frames are combined into 1 file + stats["total_size_output_mb"] += output_size_mb + + stats["series_info"].append( + { + "series_uid": series_uid, + "status": "success", + "output_file": str(output_file), + "num_frames": num_files, + "input_size_mb": series_input_size_mb, + "output_size_mb": output_size_mb, + } + ) + else: + stats["failed"] += 1 + stats["series_info"].append( + {"series_uid": series_uid, "status": "failed", "reason": "Conversion returned False"} + ) + + except Exception as e: + logger.error(f" Failed to convert series: {e}") + stats["failed"] += 1 + stats["series_info"].append({"series_uid": series_uid, "status": "failed", "reason": str(e)}) + + # Calculate overall compression statistics + if stats["total_size_input_mb"] > 0 and stats["total_size_output_mb"] > 0: + overall_compression = stats["total_size_input_mb"] / stats["total_size_output_mb"] + size_reduction_pct = ( + (stats["total_size_input_mb"] - stats["total_size_output_mb"]) / stats["total_size_input_mb"] + ) * 100 + else: + overall_compression = 0.0 + size_reduction_pct = 0.0 + + # Print comprehensive summary + logger.info(f"") + logger.info(f"{'='*80}") + logger.info(f"Batch Conversion Summary") + logger.info(f"{'='*80}") + logger.info(f"") + logger.info(f" Total series (input): {stats['total_series_input']}") + logger.info(f" Total series (output): {stats['total_series_output']}") + logger.info(f" Total frames (input): {stats['total_frames_input']}") + logger.info(f" Total frames (output): {stats['total_frames_output']}") + logger.info(f" Total size (input): {stats['total_size_input_mb']:.2f} MB") + logger.info(f" Total size (output): {stats['total_size_output_mb']:.2f} MB") + if overall_compression > 0: + logger.info(f" Compression ratio: {overall_compression:.2f}x") + logger.info(f" Size reduction: {size_reduction_pct:.1f}%") + logger.info(f"") + logger.info(f" Details:") + if compress_htj2k: + logger.info(f" Converted to multi-frame + HTJ2K: {stats['converted_to_multiframe']}") + logger.info(f" Transcoded to HTJ2K only: {stats['transcoded_htj2k']} (unsupported modalities)") + else: + logger.info(f" Converted to multi-frame: {stats['converted_to_multiframe']}") + logger.info(f" Transcoded to HTJ2K: {stats['transcoded_htj2k']}") + logger.info(f" Copied: {stats['copied']}") + logger.info(f" Failed: {stats['failed']}") + logger.info(f"") + if len(output_dirs_set) == 1: + logger.info(f" Output directory: {list(output_dirs_set)[0]}") + else: + logger.info(f" Output directories: {len(output_dirs_set)}") + for out_dir in sorted(output_dirs_set): + logger.info(f" - {out_dir}") + logger.info(f"{'='*80}") + logger.info(f"") + + return stats + + +def convert_and_convert_to_htj2k( + input_source: Union[str, Path, List[Union[str, Path]]], + output_file: Union[str, Path], + preserve_series_uid: bool = True, + **compression_kwargs, +) -> bool: + """ + Convert legacy DICOM series to enhanced multi-frame format and compress with HTJ2K. + + This is a convenience function that combines multi-frame conversion and HTJ2K compression + in one step. It creates an uncompressed enhanced DICOM first, then compresses it using + HTJ2K (High-Throughput JPEG2000). + + Args: + input_source: Either: + - Directory path containing legacy DICOM files + - List of DICOM file paths to convert as a single series + output_file: Path for output HTJ2K compressed enhanced DICOM file + preserve_series_uid: If True, preserve original SeriesInstanceUID + **compression_kwargs: Additional arguments for HTJ2K compression: + - num_resolutions (int): Number of wavelet decomposition levels (default: 6) + - code_block_size (tuple): Code block size (default: (64, 64)) + - progression_order (str): Progression order (default: "RPCL") + + Returns: + True if successful, False otherwise + + Example: + >>> # Convert to multi-frame and compress with HTJ2K from directory + >>> convert_and_convert_to_htj2k( + ... input_source="./legacy_ct/", + ... output_file="./enhanced_htj2k.dcm", + ... num_resolutions=6, + ... progression_order="RPCL" + ... ) + >>> + >>> # Convert from file list + >>> files = ['/data/ct_001.dcm', '/data/ct_002.dcm', '/data/ct_003.dcm'] + >>> convert_and_convert_to_htj2k( + ... input_source=files, + ... output_file="./enhanced_htj2k.dcm" + ... ) + """ + output_file = Path(output_file) + + # Import here to avoid circular dependency + try: + from monailabel.datastore.utils.convert_htj2k import transcode_dicom_to_htj2k + except ImportError as e: + logger.error(f"HTJ2K compression requires convert_htj2k module: {e}") + return False + + # Calculate original files size for statistics + original_size_bytes = 0 + num_files = 0 + + # Get list of file paths + if isinstance(input_source, (list, tuple)): + file_paths = [Path(f) for f in input_source] + else: + input_dir = Path(input_source) + if not input_dir.is_dir(): + logger.error(f"Input path is not a directory: {input_dir}") + return False + file_paths = [f for f in input_dir.iterdir() if f.is_file() and not f.name.startswith(".")] + + # Calculate total input size + for filepath in file_paths: + try: + ds = pydicom.dcmread(filepath, stop_before_pixels=True) + original_size_bytes += filepath.stat().st_size + num_files += 1 + except Exception: + continue + + original_size_mb = original_size_bytes / (1024 * 1024) + + # Step 1: Create uncompressed enhanced DICOM in temp file + with tempfile.NamedTemporaryFile(suffix=".dcm", delete=False) as tmp: + temp_file = tmp.name + + try: + logger.info("Step 1/2: Creating enhanced multi-frame DICOM (uncompressed)...") + success = convert_to_enhanced_dicom( + input_source=input_source, + output_file=temp_file, + preserve_series_uid=preserve_series_uid, + show_stats=False, # Suppress intermediate statistics + ) + + if not success: + logger.error("Failed to convert to enhanced multi-frame DICOM") + return False + + # Get intermediate uncompressed size + temp_size_bytes = Path(temp_file).stat().st_size + temp_size_mb = temp_size_bytes / (1024 * 1024) + + # Step 2: Compress with HTJ2K + logger.info("Step 2/2: Compressing with HTJ2K...") + + # Extract HTJ2K parameters + num_resolutions = compression_kwargs.get("num_resolutions", 6) + code_block_size = compression_kwargs.get("code_block_size", (64, 64)) + progression_order = compression_kwargs.get("progression_order", "RPCL") + + # Create output directory if needed + output_file.parent.mkdir(parents=True, exist_ok=True) + + # Compress the enhanced DICOM + # transcode_dicom_to_htj2k expects a file_loader iterable + file_loader = [([temp_file], [str(output_file)])] + transcode_dicom_to_htj2k( + file_loader=file_loader, + num_resolutions=num_resolutions, + code_block_size=code_block_size, + progression_order=progression_order, + ) + + # Check if output file was created successfully + if output_file.exists(): + output_size_bytes = output_file.stat().st_size + output_size_mb = output_size_bytes / (1024 * 1024) + + # Get image metadata from the output file + try: + output_ds = pydicom.dcmread(output_file, stop_before_pixels=True) + num_frames = getattr(output_ds, "NumberOfFrames", num_files) + rows = getattr(output_ds, "Rows", "N/A") + columns = getattr(output_ds, "Columns", "N/A") + except Exception: + num_frames = num_files + rows = "N/A" + columns = "N/A" + + # Calculate compression statistics + if original_size_bytes > 0: + overall_compression_ratio = original_size_bytes / output_size_bytes + size_reduction_pct = ((original_size_bytes - output_size_bytes) / original_size_bytes) * 100 + htj2k_compression_ratio = temp_size_bytes / output_size_bytes + else: + overall_compression_ratio = 0.0 + size_reduction_pct = 0.0 + htj2k_compression_ratio = 0.0 + + # Display comprehensive statistics at the end + logger.info(f"") + logger.info(f"✓ Successfully created HTJ2K compressed enhanced DICOM: {output_file}") + logger.info(f"") + logger.info(f" Conversion Statistics:") + logger.info(f" Original files: {num_files} files, {original_size_mb:.2f} MB") + logger.info(f" Uncompressed enhanced: 1 file, {temp_size_mb:.2f} MB") + logger.info(f" HTJ2K compressed: 1 file, {output_size_mb:.2f} MB") + logger.info(f"") + if original_size_bytes > 0: + logger.info(f" Compression Performance:") + logger.info(f" Overall size reduction: {size_reduction_pct:.1f}% smaller") + logger.info(f" Overall compression: {overall_compression_ratio:.2f}x") + logger.info(f" HTJ2K compression: {htj2k_compression_ratio:.2f}x") + logger.info(f"") + logger.info(f" Image Information:") + logger.info(f" Frames: {num_frames}") + if rows != "N/A" and columns != "N/A": + logger.info(f" Dimensions: {rows}x{columns}") + logger.info(f"") + + return True + else: + logger.error(f"HTJ2K compression failed - output file not created") + return False + + finally: + # Clean up temp file + if os.path.exists(temp_file): + os.unlink(temp_file) + + +if __name__ == "__main__": + # Example CLI usage + import argparse + + parser = argparse.ArgumentParser(description="Convert legacy DICOM series to enhanced multi-frame format") + parser.add_argument("input", type=str, help="Input directory containing legacy DICOM files") + parser.add_argument( + "-o", "--output", type=str, help="Output file path for enhanced DICOM (required unless --validate-only)" + ) + parser.add_argument("--validate-only", action="store_true", help="Only validate series without creating output") + + parser.add_argument( + "--batch", + action="store_true", + help="Batch mode: group files by SeriesInstanceUID and convert each series separately", + ) + + parser.add_argument("--htj2k", action="store_true", help="Compress the enhanced DICOM file with HTJ2K") + + parser.add_argument( + "--num-resolutions", type=int, default=6, help="Number of wavelet decomposition levels (default: 6)" + ) + + parser.add_argument("--code-block-size", type=tuple, default=(64, 64), help="Code block size (default: (64, 64))") + + parser.add_argument("--progression-order", type=str, default="RPCL", help="Progression order (default: RPCL)") + + parser.add_argument("--preserve-series-uid", action="store_true", help="Preserve the original SeriesInstanceUID") + + parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging") + + args = parser.parse_args() + + # Setup logging + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="[%(asctime)s] [%(levelname)s] %(message)s") + + # Run conversion + try: + if args.batch: + # Batch mode: group by series and convert each + if not args.output: + parser.error("--output is required for batch mode (use as output directory)") + + # Scan input directory for DICOM files + input_dir = Path(args.input) + input_files = [] + for filepath in input_dir.rglob("*"): + if filepath.is_file() and not filepath.name.startswith("."): + try: + # Quick check if it's a DICOM file + pydicom.dcmread(filepath, stop_before_pixels=True) + input_files.append(str(filepath)) + except: + pass # Skip non-DICOM files + + # Create file_loader with single batch + file_loader = [(input_files, args.output)] + + stats = batch_convert_by_series( + file_loader=file_loader, + preserve_series_uid=args.preserve_series_uid, + compress_htj2k=args.htj2k, + num_resolutions=args.num_resolutions, + code_block_size=args.code_block_size, + progression_order=args.progression_order, + ) + exit(0 if stats["failed"] == 0 else 1) + else: + # Single series mode + if not args.validate_only and not args.output: + parser.error("--output is required unless --validate-only is specified") + + if args.htj2k: + success = convert_and_convert_to_htj2k( + input_source=args.input, + output_file=args.output, + preserve_series_uid=args.preserve_series_uid, + num_resolutions=args.num_resolutions, + code_block_size=args.code_block_size, + progression_order=args.progression_order, + ) + else: + success = convert_to_enhanced_dicom( + input_source=args.input, + output_file=args.output or "dummy.dcm", + validate_only=args.validate_only, + preserve_series_uid=args.preserve_series_uid, + ) + exit(0 if success else 1) + + except Exception as e: + logger.error(f"Error: {e}", exc_info=args.verbose) + exit(1) diff --git a/monailabel/endpoints/datastore.py b/monailabel/endpoints/datastore.py index 119f5f941..fdd63bb6e 100644 --- a/monailabel/endpoints/datastore.py +++ b/monailabel/endpoints/datastore.py @@ -133,8 +133,10 @@ def remove_label(id: str, tag: str, user: Optional[str] = None): def download_image(image: str, check_only=False, check_sum=None): instance: MONAILabelApp = app_instance() image = instance.datastore().get_image_uri(image) + if not os.path.isfile(image): - raise HTTPException(status_code=404, detail="Image NOT Found") + logger.error(f"Image NOT Found or is a directory: {image}") + raise HTTPException(status_code=404, detail="Image NOT Found or is a directory") if check_only: if check_sum: @@ -151,7 +153,8 @@ def download_label(label: str, tag: str, check_only=False): instance: MONAILabelApp = app_instance() label = instance.datastore().get_label_uri(label, tag) if not os.path.isfile(label): - raise HTTPException(status_code=404, detail="Label NOT Found") + logger.error(f"Label NOT Found or is a directory: {label}") + raise HTTPException(status_code=404, detail="Label NOT Found or is a directory") if check_only: return {} diff --git a/monailabel/transform/reader.py b/monailabel/transform/reader.py new file mode 100644 index 000000000..ea325d70f --- /dev/null +++ b/monailabel/transform/reader.py @@ -0,0 +1,1115 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import os +import threading +import warnings +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import numpy as np +from monai.config import PathLike +from monai.data import ImageReader +from monai.data.image_reader import _copy_compatible_dict, _stack_images +from monai.data.utils import orientation_ras_lps +from monai.utils import MetaKeys, SpaceKeys, TraceKeys, ensure_tuple, optional_import, require_pkg +from packaging import version +from torch.utils.data._utils.collate import np_str_obj_array_pattern + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + import pydicom + + has_pydicom = True + import cupy as cp + + has_cp = True + from nvidia import nvimgcodec as nvimgcodec + + has_nvimgcodec = True +else: + pydicom, has_pydicom = optional_import("pydicom") + cp, has_cp = optional_import("cupy") + nvimgcodec, has_nvimgcodec = optional_import("nvidia.nvimgcodec") + +logger = logging.getLogger(__name__) + +__all__ = ["NvDicomReader"] + +# Thread-local storage for nvimgcodec decoder +# Each thread gets its own decoder instance for thread safety +_thread_local = threading.local() + + +def _get_nvimgcodec_decoder(): + """Get or create a thread-local nvimgcodec decoder singleton.""" + if not has_nvimgcodec: + raise RuntimeError("nvimgcodec is not available. Cannot create decoder.") + + if not hasattr(_thread_local, "decoder") or _thread_local.decoder is None: + _thread_local.decoder = nvimgcodec.Decoder() + logger.debug(f"Initialized thread-local nvimgcodec.Decoder for thread {threading.current_thread().name}") + + return _thread_local.decoder + + +@require_pkg(pkg_name="pydicom") +class NvDicomReader(ImageReader): + """ + DICOM reader with proper spatial slice ordering. + + This reader properly handles DICOM slice ordering using ImagePositionPatient + and ImageOrientationPatient tags, ensuring correct 3D volume construction + for any orientation (axial, sagittal, coronal, or oblique). + + When reading a directory containing multiple series, only the first series + is read by default (similar to ITKReader behavior). + + Args: + channel_dim: the channel dimension of the input image, default is None. + This is used to set original_channel_dim in the metadata. + series_name: the SeriesInstanceUID to read when directory contains multiple series. + If empty (default), reads the first series found. + series_meta: whether to load series metadata (currently unused). + affine_lps_to_ras: whether to convert the affine matrix from "LPS" to "RAS". + Defaults to ``True``. Set to ``True`` to be consistent with ``NibabelReader``. + depth_last: whether to place depth dimension last in the returned data array. + If ``True`` (default), returns shape (width, height, depth) similar to ITK's layout. + If ``False``, returns shape (depth, height, width) following NumPy convention. + This option does not affect the metadata. + preserve_dtype: whether to preserve the original DICOM pixel data type after applying rescale. + If ``True`` (default), converts back to original dtype (matching ITK behavior). + If ``False``, outputs float32 for all data after rescaling. + prefer_gpu_output: If True, prefer GPU output over CPU output if the underlying codec supports it. Otherwise, convert to CPU regardless. + Default is True. + use_nvimgcodec: If True, use nvImageCodec to decode the pixel data. Default is True. nvImageCodec is required for this option. + nvImageCodec supports JPEG2000, HTJ2K, and JPEG transfer syntaxes. + kwargs: additional args for `pydicom.dcmread` API. + + Example: + >>> # Read first series from directory (default: depth last, ITK-style) + >>> reader = NvDicomReader() + >>> img = reader.read("path/to/dicom/dir") + >>> volume, metadata = reader.get_data(img) + >>> volume.shape # (512, 512, 173) = (width, height, depth) + >>> + >>> # Read with NumPy-style layout (depth first) + >>> reader = NvDicomReader(depth_last=False) + >>> img = reader.read("path/to/dicom/dir") + >>> volume, metadata = reader.get_data(img) + >>> volume.shape # (173, 512, 512) = (depth, height, width) + >>> + >>> # Output float32 instead of preserving original dtype + >>> reader = NvDicomReader(preserve_dtype=False) + >>> img = reader.read("path/to/dicom/dir") + >>> volume, metadata = reader.get_data(img) + >>> volume.dtype # float32 (instead of int32) + >>> + >>> # Load to GPU memory with nvImageCodec acceleration + >>> reader = NvDicomReader(prefer_gpu_output=True, use_nvimgcodec=True) + >>> img = reader.read("path/to/dicom/dir") + >>> volume, metadata = reader.get_data(img) + >>> type(volume).__module__ # 'cupy' (GPU array) + >>> + >>> # Read specific series + >>> reader = NvDicomReader(series_name="1.2.3.4.5.6.7") + >>> img = reader.read("path/to/dicom/dir") + """ + + def __init__( + self, + channel_dim: str | int | None = None, + series_name: str = "", + series_meta: bool = False, + affine_lps_to_ras: bool = True, + depth_last: bool = True, + preserve_dtype: bool = True, + prefer_gpu_output: bool = True, + use_nvimgcodec: bool = True, + allow_fallback_decode: bool = False, + **kwargs, + ): + super().__init__() + self.kwargs = kwargs + self.channel_dim = float("nan") if channel_dim == "no_channel" else channel_dim + self.series_name = series_name + self.series_meta = series_meta + self.affine_lps_to_ras = affine_lps_to_ras + self.depth_last = depth_last + self.preserve_dtype = preserve_dtype + self.use_nvimgcodec = use_nvimgcodec + self.prefer_gpu_output = prefer_gpu_output + self.allow_fallback_decode = allow_fallback_decode + # Initialize decode params for nvImageCodec if needed + if self.use_nvimgcodec: + if not has_nvimgcodec: + warnings.warn("NvDicomReader: nvImageCodec not installed, will use pydicom for decoding.") + self.use_nvimgcodec = False + else: + self.decode_params = nvimgcodec.DecodeParams( + allow_any_depth=True, color_spec=nvimgcodec.ColorSpec.UNCHANGED + ) + + def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: + """ + Verify whether the specified file or files format is supported by NvDicom reader. + + Args: + filename: file name or a list of file names to read. + if a list of files, verify all the suffixes. + + Returns: + bool: True if pydicom and nvimgcodec are available and all paths are valid DICOM files or directories containing DICOM files. + """ + logger.info("verify_suffix: has_pydicom=%s has_nvimgcodec=%s", has_pydicom, has_nvimgcodec) + if not (has_pydicom and has_nvimgcodec): + logger.info( + "verify_suffix: has_pydicom=%s has_nvimgcodec=%s -> returning False", has_pydicom, has_nvimgcodec + ) + return False + + def _is_dcm_file(path): + return str(path).lower().endswith(".dcm") and os.path.isfile(str(path)) + + def _dir_contains_dcm(path): + if not os.path.isdir(str(path)): + return False + try: + for f in os.listdir(str(path)): + if f.lower().endswith(".dcm") and os.path.isfile(os.path.join(str(path), f)): + return True + except Exception: + return False + return False + + paths = ensure_tuple(filename) + if len(paths) < 1: + logger.info("verify_suffix: No paths provided.") + return False + + for fpath in paths: + if _is_dcm_file(fpath): + logger.info(f"verify_suffix: Path '{fpath}' is a DICOM file.") + continue + elif _dir_contains_dcm(fpath): + logger.info(f"verify_suffix: Path '{fpath}' is a directory containing at least one DICOM file.") + continue + else: + logger.info( + f"verify_suffix: Path '{fpath}' is neither a DICOM file nor a directory containing DICOM files." + ) + return False + return True + + def _apply_rescale_and_dtype(self, pixel_data, ds, original_dtype): + """ + Apply DICOM rescale slope/intercept and handle dtype preservation. + + Args: + pixel_data: numpy or cupy array of pixel data + ds: pydicom dataset containing RescaleSlope/RescaleIntercept tags + original_dtype: original dtype before any processing + + Returns: + Processed pixel data array (potentially rescaled and dtype converted) + """ + # Detect array library (numpy or cupy) + xp = cp if hasattr(pixel_data, "__cuda_array_interface__") else np + + # Check if rescaling is needed + has_rescale = hasattr(ds, "RescaleSlope") and hasattr(ds, "RescaleIntercept") + + if has_rescale: + slope = float(ds.RescaleSlope) + intercept = float(ds.RescaleIntercept) + slope = xp.asarray(slope, dtype=xp.float32) + intercept = xp.asarray(intercept, dtype=xp.float32) + pixel_data = pixel_data.astype(xp.float32) * slope + intercept + + # Convert back to original dtype if requested (matching ITK behavior) + if self.preserve_dtype: + # Determine target dtype based on original and rescale + # ITK converts to a dtype that can hold the rescaled values + # Handle both numpy and cupy dtypes + orig_dtype_str = str(original_dtype) + if "uint16" in orig_dtype_str: + # uint16 with rescale typically goes to int32 in ITK + target_dtype = xp.int32 + elif "int16" in orig_dtype_str: + target_dtype = xp.int32 + elif "uint8" in orig_dtype_str: + target_dtype = xp.int32 + else: + # Preserve original dtype for other types + target_dtype = original_dtype + pixel_data = pixel_data.astype(target_dtype) + + return pixel_data + + def _is_nvimgcodec_supported_syntax(self, img): + """ + Check if the DICOM transfer syntax is supported by nvImageCodec. + + Args: + img: a Pydicom dataset object. + + Returns: + bool: True if transfer syntax is supported by nvImageCodec, False otherwise. + """ + if not has_nvimgcodec: + return False + + # Check if we have a transfer syntax that nvImageCodec can handle + file_meta = getattr(img, "file_meta", None) + if file_meta is None: + return False + + transfer_syntax = getattr(file_meta, "TransferSyntaxUID", None) + if transfer_syntax is None: + return False + + # Define supported transfer syntaxes for nvImageCodec + jpeg2000_syntaxes = [ + "1.2.840.10008.1.2.4.90", # JPEG 2000 Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.91", # JPEG 2000 Image Compression + ] + + htj2k_syntaxes = [ + "1.2.840.10008.1.2.4.201", # High-Throughput JPEG 2000 Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.202", # High-Throughput JPEG 2000 with RPCL Options Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.203", # High-Throughput JPEG 2000 Image Compression + ] + + # JPEG transfer syntaxes (lossy) + jpeg_lossy_syntaxes = [ + "1.2.840.10008.1.2.4.50", # JPEG Baseline (Process 1) + "1.2.840.10008.1.2.4.51", # JPEG Extended (Process 2 & 4) + ] + + jpeg_lossless_syntaxes = [ + "1.2.840.10008.1.2.4.57", # JPEG Lossless, Non-Hierarchical (Process 14) + "1.2.840.10008.1.2.4.70", # JPEG Lossless, Non-Hierarchical, First-Order Prediction + ] + + return str(transfer_syntax) in jpeg2000_syntaxes + htj2k_syntaxes + jpeg_lossy_syntaxes + jpeg_lossless_syntaxes + + def _nvimgcodec_decode(self, img): + """ + Decode pixel data using nvImageCodec for supported transfer syntaxes. + + Args: + img: a Pydicom dataset object. + + Returns: + numpy or cupy array: Decoded pixel data. + + Raises: + ValueError: If pixel data is missing or decoding fails. + """ + logger.info(f"NvDicomReader: Starting nvImageCodec decoding") + + # Get raw pixel data + if not hasattr(img, "PixelData") or img.PixelData is None: + raise ValueError(f"dicom data: does not have a PixelData member.") + + pixel_data = img.PixelData + + # Decode the pixel data + data_sequence = [fragment for fragment in pydicom.encaps.generate_frames(pixel_data)] + logger.info(f"NvDicomReader: Decoding {len(data_sequence)} fragment(s) with nvImageCodec") + decoder = _get_nvimgcodec_decoder() + decoder_output = decoder.decode(data_sequence, params=self.decode_params) + if decoder_output is None: + raise ValueError(f"nvImageCodec failed to decode") + + # Not all fragments are images, so we need to filter out None images + decoded_data = [img for img in decoder_output if img is not None] + if len(decoded_data) == 0: + raise ValueError(f"nvImageCodec failed to decode or no valid images were found in the decoded data") + + buffer_kind_enum = decoded_data[0].buffer_kind + + # Concatenate all images into a volume if number_of_frames > 1 and multiple images are present + number_of_frames = getattr(img, "NumberOfFrames", 1) + if number_of_frames > 1 and len(decoded_data) > 1: + if number_of_frames != len(decoded_data): + raise ValueError( + f"Number of frames in the image ({number_of_frames}) does not match the number of decoded images ({len(decoded_data)})." + ) + if buffer_kind_enum == nvimgcodec.ImageBufferKind.STRIDED_DEVICE: + decoded_array = cp.concatenate([cp.array(d.gpu()) for d in decoded_data], axis=0) + elif buffer_kind_enum == nvimgcodec.ImageBufferKind.STRIDED_HOST: + # Use .cpu() to get data from either GPU or CPU buffer + decoded_array = np.concatenate([np.array(d.cpu()) for d in decoded_data], axis=0) + else: + raise ValueError(f"Unknown buffer kind: {buffer_kind_enum}") + else: + if buffer_kind_enum == nvimgcodec.ImageBufferKind.STRIDED_DEVICE: + decoded_array = cp.array(decoded_data[0].cuda()) + elif buffer_kind_enum == nvimgcodec.ImageBufferKind.STRIDED_HOST: + # Use .cpu() to get data from either GPU or CPU buffer + decoded_array = np.array(decoded_data[0].cpu()) + else: + raise ValueError(f"Unknown buffer kind: {buffer_kind_enum}") + + # Reshape based on DICOM parameters + rows = getattr(img, "Rows", None) + columns = getattr(img, "Columns", None) + samples_per_pixel = getattr(img, "SamplesPerPixel", 1) + number_of_frames = getattr(img, "NumberOfFrames", 1) + + if rows and columns: + if number_of_frames > 1: + expected_shape = (number_of_frames, rows, columns) + if samples_per_pixel > 1: + expected_shape = expected_shape + (samples_per_pixel,) + else: + expected_shape = (rows, columns) + if samples_per_pixel > 1: + expected_shape = expected_shape + (samples_per_pixel,) + + # Reshape if necessary + if decoded_array.size == np.prod(expected_shape): + decoded_array = decoded_array.reshape(expected_shape) + + return decoded_array + + def read(self, data: Sequence[PathLike] | PathLike, **kwargs): + """ + Read image data from specified file or files, it can read a list of images + and stack them together as multi-channel data in `get_data()`. + If passing directory path instead of file path, will treat it as DICOM images series and read. + Note that the returned object is ITK image object or list of ITK image objects. + + Args: + data: file name or a list of file names to read, + kwargs: additional args for `itk.imread` API, will override `self.kwargs` for existing keys. + More details about available args: + https://github.com/InsightSoftwareConsortium/ITK/blob/master/Wrapping/Generators/Python/itk/support/extras.py + + """ + from pathlib import Path + + img_ = [] + + filenames: Sequence[PathLike] = ensure_tuple(data) + # Store filenames for later use in get_data (needed for nvImageCodec/GPU loading) + self.filenames: list = [] + kwargs_ = self.kwargs.copy() + kwargs_.update(kwargs) + for name in filenames: + name = f"{name}" + if Path(name).is_dir(): + # read DICOM series + # Use pydicom to read a DICOM series from the directory `name`. + logger.info(f"NvDicomReader: Reading DICOM series from directory: {name}") + + # Collect all DICOM files in the directory + dicom_files = [os.path.join(name, f) for f in os.listdir(name) if os.path.isfile(os.path.join(name, f))] + if not dicom_files: + raise FileNotFoundError(f"No files found in: {name}.") + + # Group files by SeriesInstanceUID and collect metadata + series_dict = {} + series_metadata = {} + logger.info(f"NvDicomReader: Parsing {len(dicom_files)} DICOM files with pydicom") + for fp in dicom_files: + try: + ds = pydicom.dcmread(fp, stop_before_pixels=True) + if hasattr(ds, "SeriesInstanceUID"): + series_uid = ds.SeriesInstanceUID + if self.series_name and not series_uid.startswith(self.series_name): + continue + if series_uid not in series_dict: + series_dict[series_uid] = [] + # Store series metadata from first file + series_metadata[series_uid] = { + "SeriesDate": getattr(ds, "SeriesDate", ""), + "SeriesTime": getattr(ds, "SeriesTime", ""), + "SeriesNumber": getattr(ds, "SeriesNumber", 0), + "SeriesDescription": getattr(ds, "SeriesDescription", ""), + } + series_dict[series_uid].append((fp, ds)) + except Exception as e: + warnings.warn(f"Skipping file {fp}: {e}") + + if self.series_name: + if not series_dict: + raise FileNotFoundError( + f"No valid DICOM series found in {name} matching series name {self.series_name}." + ) + elif not series_dict: + raise FileNotFoundError(f"No valid DICOM series found in {name}.") + + # Sort series by SeriesDate (and SeriesTime as tiebreaker) + # This matches ITKReader's behavior with AddSeriesRestriction("0008|0021") + def series_sort_key(series_uid): + meta = series_metadata[series_uid] + # Format: (SeriesDate, SeriesTime, SeriesNumber) + # Empty strings sort first, so series without dates come first + return (meta["SeriesDate"], meta["SeriesTime"], meta["SeriesNumber"]) + + sorted_series_uids = sorted(series_dict.keys(), key=series_sort_key) + + # Determine which series to use + if len(sorted_series_uids) > 1: + logger.warning(f"NvDicomReader: Directory {name} contains {len(sorted_series_uids)} DICOM series") + + series_identifier = sorted_series_uids[0] if not self.series_name else self.series_name + logger.info(f"NvDicomReader: Selected series: {series_identifier}") + + if series_identifier not in series_dict: + raise ValueError( + f"Series '{series_identifier}' not found in directory. Available series: {sorted_series_uids}" + ) + + # Get files for the selected series + series_files = series_dict[series_identifier] + + # Prepare slices with position information for sorting + slices = [] + slices_without_position = [] + for fp, ds in series_files: + if hasattr(ds, "ImagePositionPatient"): + pos = np.array(ds.ImagePositionPatient) + slices.append((pos, fp, ds)) + else: + # Handle slices without ImagePositionPatient (e.g., localizers, single-slice images) + slices_without_position.append((fp, ds)) + + if not slices and not slices_without_position: + raise FileNotFoundError(f"No readable DICOM slices found in series {series_identifier}.") + + # Sort by spatial position using slice normal projection + # This works for ANY orientation (axial, sagittal, coronal, oblique) + if slices: + # We have slices with ImagePositionPatient - sort spatially + first_ds = slices[0][2] + if hasattr(first_ds, "ImageOrientationPatient"): + iop = np.array(first_ds.ImageOrientationPatient) + row_direction = iop[:3] + col_direction = iop[3:] + slice_normal = np.cross(row_direction, col_direction) + + # Project each position onto slice normal and sort by distance + slices_with_distance = [] + for pos, fp, ds in slices: + distance = np.dot(pos, slice_normal) + slices_with_distance.append((distance, fp, ds)) + slices_with_distance.sort(key=lambda s: s[0]) + slices = slices_with_distance + else: + # Fallback to Z-coordinate if no orientation info + slices_with_z = [(pos[2], fp, ds) for pos, fp, ds in slices] + slices_with_z.sort(key=lambda s: s[0]) + slices = slices_with_z + + # Return sorted list of file paths (not datasets without pixel data) + # We'll read the full datasets with pixel data in get_data() + sorted_filepaths = [fp for _, fp, _ in slices] + else: + # No ImagePositionPatient - sort by InstanceNumber or keep original order + slices_no_pos = [] + for fp, ds in slices_without_position: + inst_num = ds.InstanceNumber if hasattr(ds, "InstanceNumber") else 0 + slices_no_pos.append((inst_num, fp, ds)) + slices_no_pos.sort(key=lambda s: s[0]) + sorted_filepaths = [fp for _, fp, _ in slices_no_pos] + + # Read all DICOM files for the series and store as a list of Datasets + # This allows _process_dicom_series() to handle the series as a whole + logger.info(f"NvDicomReader: Series contains {len(sorted_filepaths)} slices") + series_datasets = [] + for fpath in sorted_filepaths: + ds = pydicom.dcmread(fpath, **kwargs_) + series_datasets.append(ds) + + # Append the list of datasets as a single series + img_.append(series_datasets) + self.filenames.extend(sorted_filepaths) + else: + # Single file + logger.info(f"NvDicomReader: Parsing single DICOM file with pydicom: {name}") + ds = pydicom.dcmread(name, **kwargs_) + img_.append(ds) + self.filenames.append(name) + + if len(filenames) == 1: + return img_[0] + return img_ + + def get_data(self, img) -> tuple[np.ndarray, dict]: + """ + Extract data array and metadata from loaded DICOM image(s). + + This function constructs 3D volumes from DICOM series by: + 1. Slices are already sorted by spatial position in read() + 2. Stacking slices into a 3D array + 3. Applying rescale slope/intercept if present + 4. Computing affine matrix for spatial transformations + + Args: + img: a pydicom dataset object or a list of pydicom dataset objects. + + Returns: + tuple: (numpy array of image data, metadata dict) + - Array shape: (depth, height, width) for 3D volumes + - Metadata contains: affine, spacing, original_affine, spatial_shape + """ + img_array: list[np.ndarray] = [] + compatible_meta: dict = {} + + # Handle single dataset or list of datasets + if isinstance(img, pydicom.Dataset): + datasets = [img] + elif isinstance(img, list): + # Check if this is a list of Dataset objects from a DICOM series + if img and isinstance(img[0], pydicom.Dataset): + # This is a DICOM series - wrap it so it's processed as one unit + datasets = [img] + else: + # This is a list of something else (shouldn't happen normally) + datasets = img + else: + datasets = ensure_tuple(img) + + for idx, ds_or_list in enumerate(datasets): + # Check if it's a series (list of datasets) or single dataset + if isinstance(ds_or_list, list): + # List of datasets - process as series + data_array, metadata = self._process_dicom_series(ds_or_list) + elif isinstance(ds_or_list, pydicom.Dataset): + # Single multi-frame DICOM - process as a series with one dataset + # This ensures proper depth_last handling and metadata calculation + is_multiframe = hasattr(ds_or_list, "NumberOfFrames") and ds_or_list.NumberOfFrames > 1 + if is_multiframe: + # Process as a series to get proper spacing, depth_last handling, etc. + data_array, metadata = self._process_dicom_series([ds_or_list]) + else: + # Single-frame DICOM - process directly + data_array = self._get_array_data(ds_or_list) + metadata = self._get_meta_dict(ds_or_list) + metadata[MetaKeys.SPATIAL_SHAPE] = np.asarray(data_array.shape) + + # Calculate spacing for single-frame images + pixel_spacing = ds_or_list.PixelSpacing if hasattr(ds_or_list, "PixelSpacing") else [1.0, 1.0] + slice_spacing = float(ds_or_list.SliceThickness) if hasattr(ds_or_list, "SliceThickness") else 1.0 + metadata["spacing"] = np.array([float(pixel_spacing[1]), float(pixel_spacing[0]), slice_spacing]) + + img_array.append(data_array) + metadata[MetaKeys.ORIGINAL_AFFINE] = self._get_affine(metadata, self.affine_lps_to_ras) + metadata[MetaKeys.AFFINE] = metadata[MetaKeys.ORIGINAL_AFFINE].copy() + metadata[MetaKeys.SPACE] = SpaceKeys.RAS if self.affine_lps_to_ras else SpaceKeys.LPS + + if self.channel_dim is None: + metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = ( + float("nan") if len(data_array.shape) == len(metadata[MetaKeys.SPATIAL_SHAPE]) else -1 + ) + else: + metadata[MetaKeys.ORIGINAL_CHANNEL_DIM] = self.channel_dim + + _copy_compatible_dict(metadata, compatible_meta) + + return _stack_images(img_array, compatible_meta), compatible_meta + + def _process_dicom_series(self, datasets: list) -> tuple[np.ndarray, dict]: + """ + Process a list of sorted DICOM Dataset objects into a 3D volume. + + This method implements batch decoding optimization: when all files use + nvImageCodec-supported transfer syntaxes, all frames are decoded in a + single nvImageCodec call for better performance. Falls back to + frame-by-frame decoding if batch decode fails or is not applicable. + + Args: + datasets: list of pydicom Dataset objects (already sorted by spatial position) + + Returns: + tuple: (3D numpy array, metadata dict) + """ + if not datasets: + raise ValueError("Empty dataset list") + + first_ds = datasets[0] + needs_rescale = hasattr(first_ds, "RescaleSlope") and hasattr(first_ds, "RescaleIntercept") + rows = first_ds.Rows + cols = first_ds.Columns + + # For multi-frame DICOMs, depth is the total number of frames, not the number of files + # For single-frame DICOMs, depth is the number of files + depth = 0 + for ds in datasets: + num_frames = getattr(ds, "NumberOfFrames", 1) + depth += num_frames + + # Check if we can use nvImageCodec on the whole series + can_use_nvimgcodec = self.use_nvimgcodec and all(self._is_nvimgcodec_supported_syntax(ds) for ds in datasets) + + batch_decode_success = False + original_dtype = None + + if can_use_nvimgcodec: + logger.info(f"NvDicomReader: Using nvImageCodec batch decode for {depth} slices") + try: + # Batch decode all frames in a single nvImageCodec call + # Collect all compressed frames from all DICOM files + all_frames = [] + for ds in datasets: + if not hasattr(ds, "PixelData") or ds.PixelData is None: + raise ValueError("DICOM data does not have pixel data") + pixel_data = ds.PixelData + # Extract compressed frame(s) from this DICOM file + frames = [fragment for fragment in pydicom.encaps.generate_frames(pixel_data)] + all_frames.extend(frames) + + # Decode all frames at once + decoder = _get_nvimgcodec_decoder() + decoded_data = decoder.decode(all_frames, params=self.decode_params) + + if not decoded_data or any(d is None for d in decoded_data): + raise ValueError("nvImageCodec batch decode failed") + + # Determine buffer location (GPU or CPU) + buffer_kind_enum = decoded_data[0].buffer_kind + + # Convert all decoded frames to numpy/cupy arrays + if buffer_kind_enum == nvimgcodec.ImageBufferKind.STRIDED_DEVICE: + xp = cp + decoded_arrays = [cp.array(d.cuda()) for d in decoded_data] + elif buffer_kind_enum == nvimgcodec.ImageBufferKind.STRIDED_HOST: + xp = np + decoded_arrays = [np.array(d.cpu()) for d in decoded_data] + else: + raise ValueError(f"Unknown buffer kind: {buffer_kind_enum}") + + original_dtype = decoded_arrays[0].dtype + dtype_vol = xp.float32 if needs_rescale else original_dtype + + # Build 3D volume (use float32 for rescaling to avoid overflow) + # Shape depends on depth_last + if self.depth_last: + volume = xp.zeros((cols, rows, depth), dtype=dtype_vol) + else: + volume = xp.zeros((depth, rows, cols), dtype=dtype_vol) + + for frame_idx, frame_array in enumerate(decoded_arrays): + # Reshape if needed + if frame_array.shape != (rows, cols): + frame_array = frame_array.reshape(rows, cols) + + if self.depth_last: + volume[:, :, frame_idx] = frame_array.T + else: + volume[frame_idx, :, :] = frame_array + + batch_decode_success = True + + except Exception as e: + if not self.allow_fallback_decode: + raise ValueError(f"nvImageCodec batch decoding failed: {e}") + warnings.warn(f"nvImageCodec batch decoding failed: {e}. Falling back to frame-by-frame.") + batch_decode_success = False + + if not batch_decode_success or not can_use_nvimgcodec: + # Fallback: use pydicom pixel_array for each frame + logger.info(f"NvDicomReader: Using pydicom pixel_array decode for {depth} slices") + first_pixel_array = first_ds.pixel_array + original_dtype = first_pixel_array.dtype + + # Build 3D volume (use float32 for rescaling to avoid overflow if needed) + xp = cp if hasattr(first_pixel_array, "__cuda_array_interface__") else np + dtype_vol = xp.float32 if needs_rescale else original_dtype + + # Shape depends on depth_last + if self.depth_last: + volume = xp.zeros((cols, rows, depth), dtype=dtype_vol) + else: + volume = xp.zeros((depth, rows, cols), dtype=dtype_vol) + + # Handle both single-frame series and multi-frame DICOMs + frame_idx = 0 + if len(datasets) == 1 and getattr(datasets[0], "NumberOfFrames", 1) > 1: + # Multi-frame DICOM: all frames in a single dataset + ds = datasets[0] + pixel_array = ds.pixel_array + # Ensure correct array type + if hasattr(pixel_array, "__cuda_array_interface__"): + pixel_array = cp.asarray(pixel_array) + else: + pixel_array = np.asarray(pixel_array) + num_frames = getattr(ds, "NumberOfFrames", 1) + if not self.depth_last: + # Depth-first: copy whole volume at once + volume[:, :, :] = pixel_array + else: + # Depth-last: assign using transpose for the whole volume + volume[:, :, :num_frames] = pixel_array.transpose(2, 1, 0) + else: + # Single-frame DICOMs: each dataset is a single slice + for frame_idx, ds in enumerate(datasets): + pixel_array = ds.pixel_array + if hasattr(pixel_array, "__cuda_array_interface__"): + pixel_array = cp.asarray(pixel_array) + else: + pixel_array = np.asarray(pixel_array) + if self.depth_last: + volume[:, :, frame_idx] = pixel_array.T + else: + volume[frame_idx, :, :] = pixel_array + + # Ensure xp is defined for subsequent operations + xp = cp if hasattr(volume, "__cuda_array_interface__") else np + + # Ensure original_dtype is set + if original_dtype is None: + # Get dtype from first pixel array if not already set + original_dtype = first_ds.pixel_array.dtype + + # Apply rescaling and dtype conversion using common helper + volume = self._apply_rescale_and_dtype(volume, first_ds, original_dtype) + + # Calculate spacing + pixel_spacing = first_ds.PixelSpacing if hasattr(first_ds, "PixelSpacing") else [1.0, 1.0] + + # Calculate slice spacing + if depth > 1: + # For multi-frame DICOM, calculate spacing from per-frame positions + is_multiframe = len(datasets) == 1 and hasattr(first_ds, "NumberOfFrames") and first_ds.NumberOfFrames > 1 + + if is_multiframe and hasattr(first_ds, "PerFrameFunctionalGroupsSequence"): + # Multi-frame DICOM: extract positions from PerFrameFunctionalGroupsSequence + average_distance = 0.0 + positions = [] + + try: + # Extract all frame positions + for frame_idx, frame in enumerate(first_ds.PerFrameFunctionalGroupsSequence): + # Try to get PlanePositionSequence + plane_pos_seq = None + if hasattr(frame, "PlanePositionSequence"): + plane_pos_seq = frame.PlanePositionSequence + elif hasattr(frame, "get"): + plane_pos_seq = frame.get("PlanePositionSequence") + + if plane_pos_seq and len(plane_pos_seq) > 0: + plane_pos_item = plane_pos_seq[0] + if hasattr(plane_pos_item, "ImagePositionPatient"): + ipp = plane_pos_item.ImagePositionPatient + z_pos = float(ipp[2]) + positions.append(z_pos) + + # Calculate average distance between consecutive positions + if len(positions) > 1: + for i in range(1, len(positions)): + average_distance += abs(positions[i] - positions[i - 1]) + slice_spacing = average_distance / (len(positions) - 1) + else: + logger.warning( + f"NvDicomReader: Only found {len(positions)} positions, cannot calculate spacing" + ) + slice_spacing = 1.0 + + except Exception as e: + logger.warning(f"NvDicomReader: Failed to calculate spacing from per-frame positions: {e}") + # Fallback to SliceThickness or default + if hasattr(first_ds, "SliceThickness"): + slice_spacing = float(first_ds.SliceThickness) + else: + slice_spacing = 1.0 + + elif len(datasets) > 1 and hasattr(first_ds, "ImagePositionPatient"): + # Multiple single-frame DICOMs: calculate from dataset positions + average_distance = 0.0 + prev_pos = np.array(datasets[0].ImagePositionPatient)[2] + for i in range(1, len(datasets)): + if hasattr(datasets[i], "ImagePositionPatient"): + curr_pos = np.array(datasets[i].ImagePositionPatient)[2] + average_distance += abs(curr_pos - prev_pos) + prev_pos = curr_pos + slice_spacing = average_distance / (len(datasets) - 1) + logger.info( + f"NvDicomReader: Calculated slice spacing from {len(datasets)} datasets: {slice_spacing:.4f}" + ) + + elif hasattr(first_ds, "SliceThickness"): + # Fallback to SliceThickness tag if positions unavailable + slice_spacing = float(first_ds.SliceThickness) + logger.info(f"NvDicomReader: Using SliceThickness: {slice_spacing}") + else: + slice_spacing = 1.0 + logger.warning(f"NvDicomReader: No position data available, using default spacing: 1.0") + else: + slice_spacing = 1.0 + + # Build metadata + metadata = self._get_meta_dict(first_ds) + + metadata["spacing"] = np.array([float(pixel_spacing[1]), float(pixel_spacing[0]), slice_spacing]) + # Metadata should always use numpy arrays, even if data is on GPU + metadata[MetaKeys.SPATIAL_SHAPE] = np.asarray(volume.shape) + + # Store last position for affine calculation + last_ds = datasets[-1] + + # For multi-frame DICOM, try to get the last frame's position from PerFrameFunctionalGroupsSequence + is_multiframe = hasattr(last_ds, "NumberOfFrames") and last_ds.NumberOfFrames > 1 + if is_multiframe and hasattr(last_ds, "PerFrameFunctionalGroupsSequence"): + try: + last_frame_idx = last_ds.NumberOfFrames - 1 + last_frame = last_ds.PerFrameFunctionalGroupsSequence[last_frame_idx] + if hasattr(last_frame, "PlanePositionSequence") and len(last_frame.PlanePositionSequence) > 0: + last_ipp = last_frame.PlanePositionSequence[0].ImagePositionPatient + metadata["lastImagePositionPatient"] = np.array(last_ipp) + logger.info(f"[DICOM Reader] Multi-frame: extracted last frame IPP: {last_ipp}") + except Exception as e: + logger.warning(f"NvDicomReader: Failed to extract last frame position: {e}") + elif hasattr(last_ds, "ImagePositionPatient"): + metadata["lastImagePositionPatient"] = np.array(last_ds.ImagePositionPatient) + + # Log extracted DICOM metadata for debugging + logger.info(f"[DICOM Reader] Extracted metadata for {len(datasets)} slices") + logger.info(f"[DICOM Reader] Volume shape: {volume.shape}") + logger.info(f"[DICOM Reader] ImagePositionPatient (first): {metadata.get('ImagePositionPatient')}") + logger.info(f"[DICOM Reader] ImagePositionPatient (last): {metadata.get('lastImagePositionPatient')}") + logger.info(f"[DICOM Reader] ImageOrientationPatient: {metadata.get('ImageOrientationPatient')}") + logger.info(f"[DICOM Reader] Spacing: {metadata.get('spacing')}") + logger.info(f"[DICOM Reader] Is multi-frame: {is_multiframe}") + + return volume, metadata + + def _get_array_data(self, ds): + """ + Get pixel array from a single DICOM dataset. + + Args: + ds: pydicom dataset object + + Returns: + numpy or cupy array of pixel data + """ + # Get pixel array using nvImageCodec or GPU loading if enabled and filename available + if self.use_nvimgcodec and self._is_nvimgcodec_supported_syntax(ds): + try: + pixel_array = self._nvimgcodec_decode(ds) + original_dtype = pixel_array.dtype + logger.info(f"NvDicomReader: Successfully decoded with nvImageCodec") + except Exception as e: + logger.warning(f"NvDicomReader: nvImageCodec decoding failed: {e}, falling back to pydicom") + pixel_array = ds.pixel_array + original_dtype = pixel_array.dtype + else: + logger.info(f"NvDicomReader: Using pydicom pixel_array decode") + pixel_array = ds.pixel_array + original_dtype = pixel_array.dtype + + # Apply rescaling and dtype conversion using common helper + pixel_array = self._apply_rescale_and_dtype(pixel_array, ds, original_dtype) + + return pixel_array + + def _get_meta_dict(self, ds) -> dict: + """Extract metadata from DICOM dataset, storing all tags like ITKReader does.""" + metadata = {} + + # Store all DICOM tags in ITK format (GGGG|EEEE) + for elem in ds: + # Skip pixel data and large binary data + if elem.tag in [ + (0x7FE0, 0x0010), # Pixel Data + (0x7FE0, 0x0008), # Float Pixel Data + (0x7FE0, 0x0009), + ]: # Double Float Pixel Data + continue + + # Format tag as 'GGGG|EEEE' (matching ITK format) + tag_str = f"{elem.tag.group:04x}|{elem.tag.element:04x}" + + # Store the value, converting to appropriate Python types + if elem.VR == "SQ": # Sequence - skip for now (can be very large) + continue + try: + # Convert value to appropriate Python type + value = elem.value + + # Handle pydicom special types + value_type_name = type(value).__name__ + if value_type_name == "MultiValue": + # MultiValue: convert to list + value = list(value) + elif value_type_name == "PersonName": + # PersonName: convert to string + value = str(value) + elif hasattr(value, "tolist"): + # NumPy arrays: convert to list or scalar + value = value.tolist() if value.size > 1 else value.item() + elif isinstance(value, bytes): + # Bytes: decode to string + try: + value = value.decode("utf-8", errors="ignore") + except: + value = str(value) + + metadata[tag_str] = value + except Exception: + # Some values might not be decodable, skip them + pass + + # Also store essential spatial tags with readable names + # (for convenience and backward compatibility) + + # For multi-frame (Enhanced) DICOM, extract per-frame metadata from the first frame + is_multiframe = hasattr(ds, "NumberOfFrames") and ds.NumberOfFrames > 1 + if is_multiframe and hasattr(ds, "PerFrameFunctionalGroupsSequence"): + try: + first_frame = ds.PerFrameFunctionalGroupsSequence[0] + + # Helper function to safely access sequence items (handles both attribute and dict access) + def get_sequence_item(obj, seq_name, item_idx=0): + """Get item from a sequence, handling both attribute and dict access.""" + seq = None + # Try attribute access + if hasattr(obj, seq_name): + seq = getattr(obj, seq_name, None) + # Try dict-style access + elif hasattr(obj, "get"): + seq = obj.get(seq_name) + elif hasattr(obj, "__getitem__"): + try: + seq = obj[seq_name] + except (KeyError, TypeError): + pass + + if seq and len(seq) > item_idx: + return seq[item_idx] + return None + + # Extract ImageOrientationPatient from per-frame sequence + plane_orient_item = get_sequence_item(first_frame, "PlaneOrientationSequence") + if plane_orient_item and hasattr(plane_orient_item, "ImageOrientationPatient"): + iop = plane_orient_item.ImageOrientationPatient + metadata["ImageOrientationPatient"] = list(iop) + + # Extract ImagePositionPatient from per-frame sequence + plane_pos_item = get_sequence_item(first_frame, "PlanePositionSequence") + if plane_pos_item and hasattr(plane_pos_item, "ImagePositionPatient"): + ipp = plane_pos_item.ImagePositionPatient + metadata["ImagePositionPatient"] = list(ipp) + else: + logger.warning(f"NvDicomReader: PlanePositionSequence not found or empty") + + # Extract PixelSpacing from per-frame sequence + pixel_measures_item = get_sequence_item(first_frame, "PixelMeasuresSequence") + if pixel_measures_item and hasattr(pixel_measures_item, "PixelSpacing"): + ps = pixel_measures_item.PixelSpacing + metadata["PixelSpacing"] = list(ps) + + # Also check SliceThickness from PixelMeasuresSequence + if pixel_measures_item and hasattr(pixel_measures_item, "SliceThickness"): + st = pixel_measures_item.SliceThickness + metadata["SliceThickness"] = float(st) + + except Exception as e: + logger.warning(f"NvDicomReader: Failed to extract per-frame metadata: {e}, falling back to top-level") + import traceback + + logger.warning(f"NvDicomReader: Traceback: {traceback.format_exc()}") + + # Fall back to top-level attributes if not extracted from per-frame sequence + if hasattr(ds, "ImageOrientationPatient") and "ImageOrientationPatient" not in metadata: + metadata["ImageOrientationPatient"] = list(ds.ImageOrientationPatient) + if hasattr(ds, "ImagePositionPatient") and "ImagePositionPatient" not in metadata: + metadata["ImagePositionPatient"] = list(ds.ImagePositionPatient) + if hasattr(ds, "PixelSpacing") and "PixelSpacing" not in metadata: + metadata["PixelSpacing"] = list(ds.PixelSpacing) + + return metadata + + def _get_affine(self, metadata: dict, lps_to_ras: bool = True) -> np.ndarray: + """ + Construct affine matrix from DICOM metadata. + + Args: + metadata: metadata dictionary + lps_to_ras: whether to convert from LPS to RAS + + Returns: + 4x4 affine matrix + """ + affine = np.eye(4) + + if "ImageOrientationPatient" not in metadata or "ImagePositionPatient" not in metadata: + # No explicit orientation info - use identity but still apply LPS->RAS if requested + # DICOM default coordinate system is LPS + if lps_to_ras: + affine = orientation_ras_lps(affine) + return affine + + iop = metadata["ImageOrientationPatient"] + ipp = metadata["ImagePositionPatient"] + spacing = metadata.get("spacing", np.array([1.0, 1.0, 1.0])) + + # Extract direction cosines + row_cosine = np.array(iop[:3]) + col_cosine = np.array(iop[3:]) + + # Build affine matrix + # Column 0: row direction * row spacing + affine[:3, 0] = row_cosine * spacing[0] + # Column 1: col direction * col spacing + affine[:3, 1] = col_cosine * spacing[1] + + # Calculate slice direction + # Determine the depth dimension (handle depth_last) + spatial_shape = metadata[MetaKeys.SPATIAL_SHAPE] + if len(spatial_shape) == 3: + # Find which dimension is the depth (smallest for typical medical images) + # When depth_last=True: shape is (W, H, D), depth is at index 2 + # When depth_last=False: shape is (D, H, W), depth is at index 0 + depth_idx = np.argmin(spatial_shape) + n_slices = spatial_shape[depth_idx] + + if n_slices > 1 and "lastImagePositionPatient" in metadata: + # Multi-slice: calculate from first and last positions + last_ipp = metadata["lastImagePositionPatient"] + slice_vec = (last_ipp - np.array(ipp)) / (n_slices - 1) + affine[:3, 2] = slice_vec + else: + # Single slice or no last position: use cross product + slice_normal = np.cross(row_cosine, col_cosine) + affine[:3, 2] = slice_normal * spacing[2] + else: + # 2D image - use cross product + slice_normal = np.cross(row_cosine, col_cosine) + affine[:3, 2] = slice_normal * spacing[2] + + # Translation + affine[:3, 3] = ipp + + # Log affine construction details + logger.info(f"[DICOM Reader] Affine matrix construction:") + logger.info(f"[DICOM Reader] Origin (IPP): {ipp}") + logger.info(f"[DICOM Reader] Spacing: {spacing}") + logger.info(f"[DICOM Reader] Spatial shape: {spatial_shape}") + if len(spatial_shape) == 3 and "lastImagePositionPatient" in metadata: + logger.info(f"[DICOM Reader] Last IPP: {metadata['lastImagePositionPatient']}") + logger.info(f"[DICOM Reader] Slice vector: {affine[:3, 2]}") + logger.info(f"[DICOM Reader] Affine (before LPS->RAS):\n{affine}") + + # Convert LPS to RAS if requested + if lps_to_ras: + affine = orientation_ras_lps(affine) + logger.info(f"[DICOM Reader] Affine (after LPS->RAS):\n{affine}") + + return affine diff --git a/monailabel/transform/writer.py b/monailabel/transform/writer.py index 402e1d17d..7c4e675cc 100644 --- a/monailabel/transform/writer.py +++ b/monailabel/transform/writer.py @@ -141,6 +141,15 @@ def write_seg_nrrd( ] ) + # Log NRRD geometry being written + logger.info(f"[NRRD Writer] Writing segmentation to: {output_file}") + logger.info(f"[NRRD Writer] Image shape: {image_np.shape}") + logger.info(f"[NRRD Writer] Affine matrix:\n{affine}") + logger.info(f"[NRRD Writer] Space origin: {origin}") + logger.info(f"[NRRD Writer] Space directions:\n{space_directions}") + logger.info(f"[NRRD Writer] Space: {space}") + logger.info(f"[NRRD Writer] Index order: {index_order}") + header.update( { "kinds": kinds, diff --git a/plugins/ohifv3/extensions/monai-label/src/components/MonaiLabelPanel.tsx b/plugins/ohifv3/extensions/monai-label/src/components/MonaiLabelPanel.tsx index 4ab37b53a..940284bf1 100644 --- a/plugins/ohifv3/extensions/monai-label/src/components/MonaiLabelPanel.tsx +++ b/plugins/ohifv3/extensions/monai-label/src/components/MonaiLabelPanel.tsx @@ -44,6 +44,8 @@ export default class MonaiLabelPanel extends Component { classprompts: any; }; serverURI = 'http://127.0.0.1:8000'; + private _currentSeriesUID: string | null = null; + private _unsubscribeFromViewportGrid: any = null; constructor(props) { super(props); @@ -183,48 +185,18 @@ export default class MonaiLabelPanel extends Component { } const labelsOrdered = [...new Set(all_labels)].sort(); - const segmentations = [ - { - segmentationId: '1', - representation: { - type: Enums.SegmentationRepresentations.Labelmap, - }, - config: { - label: 'Segmentations', - segments: labelsOrdered.reduce((acc, label, index) => { + + // Prepare initial segmentation configuration - will be created per-series on inference + const initialSegs = labelsOrdered.reduce((acc, label, index) => { acc[index + 1] = { segmentIndex: index + 1, label: label, - active: index === 0, // First segment is active + active: index === 0, locked: false, color: this.segmentColor(label), }; return acc; - }, {}), - }, - }, - ]; - - const initialSegs = segmentations[0].config.segments; - const volumeLoadObject = cache.getVolume('1'); - if (!volumeLoadObject) { - this.props.commandsManager.runCommand('loadSegmentationsForViewport', { - segmentations, - }); - - // Wait for Above Segmentations to be added/available - setTimeout(() => { - const { viewport } = this.getActiveViewportInfo(); - for (const segmentIndex of Object.keys(initialSegs)) { - cornerstoneTools.segmentation.config.color.setSegmentIndexColor( - viewport.viewportId, - '1', - initialSegs[segmentIndex].segmentIndex, - initialSegs[segmentIndex].color - ); - } - }, 1000); - } + }, {}); const info = { models: models, @@ -260,6 +232,63 @@ export default class MonaiLabelPanel extends Component { } } this.setState({ action: name }); + + // Check if we switched series and need to reapply origin correction + this.checkAndApplyOriginCorrectionOnSeriesSwitch(); + }; + + // Check if series has changed and apply origin correction to existing segmentation + checkAndApplyOriginCorrectionOnSeriesSwitch = () => { + try { + const currentViewportInfo = this.getActiveViewportInfo(); + const currentSeriesUID = currentViewportInfo?.displaySet?.SeriesInstanceUID; + + // If series changed + if (currentSeriesUID && currentSeriesUID !== this._currentSeriesUID) { + this._currentSeriesUID = currentSeriesUID; + const segmentationId = `seg-${currentSeriesUID}`; + + // Check if this series already has a segmentation + const { segmentationService } = this.props.servicesManager.services; + try { + const volumeLoadObject = segmentationService.getLabelmapVolume(segmentationId); + if (volumeLoadObject) { + // Segmentation exists, apply origin correction + this.applyOriginCorrection(volumeLoadObject); + } + } catch (e) { + // No segmentation for this series yet, which is fine + } + } + } catch (e) { + // Ignore errors (e.g., viewport not ready) + } + }; + + // Apply origin correction - match segmentation origin to image volume origin + applyOriginCorrection = (volumeLoadObject) => { + const { displaySet } = this.getActiveViewportInfo(); + const imageVolumeId = displaySet.displaySetInstanceUID; + let imageVolume = cache.getVolume(imageVolumeId); + if (!imageVolume) { + imageVolume = cache.getVolume('cornerstoneStreamingImageVolume:' + imageVolumeId); + } + + if (imageVolume && displaySet.isMultiFrame) { + // Simply copy the image volume's origin to the segmentation + // This way the segmentation matches whatever origin OHIF has set for the image + volumeLoadObject.origin = [...imageVolume.origin]; + + if (volumeLoadObject.imageData) { + volumeLoadObject.imageData.setOrigin(volumeLoadObject.origin); + } + + // Trigger render to show the corrected segmentation + const renderingEngine = this.props.servicesManager.services.cornerstoneViewportService.getRenderingEngine(); + if (renderingEngine) { + renderingEngine.render(); + } + } }; updateView = async ( @@ -314,16 +343,62 @@ export default class MonaiLabelPanel extends Component { console.log('Index Remap', labels, modelToSegMapping); const data = new Uint8Array(ret.image); + // Use series-specific segmentation ID to ensure each series has its own segmentation + const currentViewportInfo = this.getActiveViewportInfo(); + const currentSeriesUID = currentViewportInfo?.displaySet?.SeriesInstanceUID; + const segmentationId = `seg-${currentSeriesUID || 'default'}`; + + // Track current series + this._currentSeriesUID = currentSeriesUID; + const { segmentationService } = this.props.servicesManager.services; - const volumeLoadObject = segmentationService.getLabelmapVolume('1'); + let volumeLoadObject = null; + + try { + volumeLoadObject = segmentationService.getLabelmapVolume(segmentationId); + } catch (e) { + // Segmentation doesn't exist yet - create it + const initialSegs = this.state.info?.initialSegs; + if (initialSegs) { + const segmentations = [{ + segmentationId: segmentationId, + representation: { + type: Enums.SegmentationRepresentations.Labelmap + }, + config: { + label: 'Segmentations', + segments: initialSegs + } + }]; + + this.props.commandsManager.runCommand('loadSegmentationsForViewport', { + segmentations + }); + + // Wait a bit for segmentation to be created, then try again + setTimeout(() => { + try { + const vol = segmentationService.getLabelmapVolume(segmentationId); + if (vol) { + this.updateView(response, model_id, labels, override, label_class_unknown, sidx); + } + } catch (err) { + console.error('Failed to create segmentation volume:', err); + } + }, 500); + return; + } + } + if (volumeLoadObject) { - // console.log('Volume Object is In Cache....'); let convertedData = data; + + // Convert label indices for (let i = 0; i < convertedData.length; i++) { const midx = convertedData[i]; - const sidx = modelToSegMapping[midx]; - if (midx && sidx) { - convertedData[i] = sidx; + const sidx_mapped = modelToSegMapping[midx]; + if (midx && sidx_mapped) { + convertedData[i] = sidx_mapped; } else if (override && label_class_unknown && labels.length === 1) { convertedData[i] = midx ? labelNames[labels[0]] : 0; } else if (labels.length > 0) { @@ -331,20 +406,15 @@ export default class MonaiLabelPanel extends Component { } } + // Handle override mode (partial update) if (override === true) { - const { segmentationService } = this.props.servicesManager.services; - const volumeLoadObject = segmentationService.getLabelmapVolume('1'); const { voxelManager } = volumeLoadObject; const scalarData = voxelManager?.getCompleteScalarDataArray(); - - // console.log('Current ScalarData: ', scalarData); const currentSegArray = new Uint8Array(scalarData.length); currentSegArray.set(scalarData); - // get unique values to determine which organs to update, keep rest const updateTargets = new Set(convertedData); - const numImageFrames = - this.getActiveViewportInfo().displaySet.numImageFrames; + const numImageFrames = this.getActiveViewportInfo().displaySet.numImageFrames; const sliceLength = scalarData.length / numImageFrames; const sliceBegin = sliceLength * sidx; const sliceEnd = sliceBegin + sliceLength; @@ -353,24 +423,36 @@ export default class MonaiLabelPanel extends Component { if (sidx >= 0 && (i < sliceBegin || i >= sliceEnd)) { continue; } - - if ( - convertedData[i] !== 255 && - updateTargets.has(currentSegArray[i]) - ) { + if (convertedData[i] !== 255 && updateTargets.has(currentSegArray[i])) { currentSegArray[i] = convertedData[i]; } } convertedData = currentSegArray; } - const { voxelManager } = volumeLoadObject; - voxelManager?.setCompleteScalarDataArray(convertedData); + + // Apply origin correction for multi-frame volumes + this.applyOriginCorrection(volumeLoadObject); + + // Apply segment colors + const { viewport } = this.getActiveViewportInfo(); + for (const label of labels) { + const segmentIndex = labelNames[label]; + if (segmentIndex) { + const color = this.segmentColor(label); + cornerstoneTools.segmentation.config.color.setSegmentIndexColor( + viewport.viewportId, + segmentationId, + segmentIndex, + color + ); + } + } + + // Set the voxel data + volumeLoadObject.voxelManager.setCompleteScalarDataArray(convertedData); triggerEvent(eventTarget, Enums.Events.SEGMENTATION_DATA_MODIFIED, { - segmentationId: '1', + segmentationId: segmentationId }); - console.log("updated the segmentation's scalar data"); - } else { - console.log('TODO:: Volume Object is NOT In Cache....'); } }; @@ -396,9 +478,33 @@ export default class MonaiLabelPanel extends Component { } console.log('(Component Mounted) Ready to Connect to MONAI Server...'); + + // Subscribe to viewport grid state changes to detect series switches + const { viewportGridService } = this.props.servicesManager.services; + + // Listen to any state change in the viewport grid + const handleViewportChange = () => { + // Multiple attempts with delays to catch the viewport at the right time + setTimeout(() => this.checkAndApplyOriginCorrectionOnSeriesSwitch(), 50); + setTimeout(() => this.checkAndApplyOriginCorrectionOnSeriesSwitch(), 200); + setTimeout(() => this.checkAndApplyOriginCorrectionOnSeriesSwitch(), 500); + }; + + this._unsubscribeFromViewportGrid = viewportGridService.subscribe( + viewportGridService.EVENTS.ACTIVE_VIEWPORT_ID_CHANGED, + handleViewportChange + ); + // await this.onInfo(); } + componentWillUnmount() { + if (this._unsubscribeFromViewportGrid) { + this._unsubscribeFromViewportGrid(); + this._unsubscribeFromViewportGrid = null; + } + } + onOptionsConfig = () => { return this.state.options; }; diff --git a/requirements.txt b/requirements.txt index 9a2873647..a14e3e325 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,8 +24,8 @@ expiringdict==1.2.2 expiring_dict==1.1.0 cachetools==5.3.3 watchdog==4.0.0 -pydicom==2.4.4 -pydicom-seg==0.4.1 +pydicom==3.0.1 +highdicom==0.26.1 pynetdicom==2.0.2 pynrrd==1.0.0 numpymaxflow==0.0.7 @@ -52,6 +52,14 @@ SAM-2 @ git+https://github.com/facebookresearch/sam2.git@c2ec8e14a185632b0a5d8b1 # scipy and scikit-learn latest packages are missing on python 3.8 # sudo apt-get install openslide-tools -y +# Optional dependencies: +# - nvidia-nvimgcodec-cu{XX}[all] (replace {XX} with your CUDA major version, e.g., cu13 for CUDA 13.x) +# Required for HTJ2K DICOM support and accelerated DICOM decoding +# Installation guide: https://docs.nvidia.com/cuda/nvimagecodec/installation.html +# - dcmqi (provides itkimage2segimage command-line tool for legacy DICOM SEG conversion) +# Install with: pip install dcmqi +# More info: https://github.com/QIICR/dcmqi + # How to auto update versions? # pip install pur # pur -r requirements.txt diff --git a/sample-apps/radiology/lib/infers/deepedit.py b/sample-apps/radiology/lib/infers/deepedit.py index afc755c98..891d842a8 100644 --- a/sample-apps/radiology/lib/infers/deepedit.py +++ b/sample-apps/radiology/lib/infers/deepedit.py @@ -35,6 +35,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferType from monailabel.tasks.infer.basic_infer import BasicInferTask from monailabel.transform.post import Restored +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -79,7 +80,7 @@ def __init__( def pre_transforms(self, data=None): t = [ - LoadImaged(keys="image", reader="ITKReader", image_only=False), + LoadImaged(keys="image", reader=["ITKReader", NvDicomReader()], image_only=False), EnsureChannelFirstd(keys="image"), Orientationd(keys="image", axcodes="RAS"), ScaleIntensityRanged(keys="image", a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), diff --git a/sample-apps/radiology/lib/infers/deepgrow.py b/sample-apps/radiology/lib/infers/deepgrow.py index 43f74af11..6fad1bb7c 100644 --- a/sample-apps/radiology/lib/infers/deepgrow.py +++ b/sample-apps/radiology/lib/infers/deepgrow.py @@ -36,6 +36,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferType from monailabel.tasks.infer.basic_infer import BasicInferTask +from monailabel.transform.reader import NvDicomReader class Deepgrow(BasicInferTask): @@ -72,7 +73,7 @@ def __init__( def pre_transforms(self, data=None) -> Sequence[Callable]: t = [ - LoadImaged(keys="image", image_only=False), + LoadImaged(keys="image", reader=["ITKReader", NvDicomReader()], image_only=False), Transposed(keys="image", indices=[2, 0, 1]), Spacingd(keys="image", pixdim=[1.0] * self.dimension, mode="bilinear"), AddGuidanceFromPointsd(ref_image="image", guidance="guidance", spatial_dims=self.dimension), diff --git a/sample-apps/radiology/lib/infers/deepgrow_pipeline.py b/sample-apps/radiology/lib/infers/deepgrow_pipeline.py index 871f865e1..0eefdd321 100644 --- a/sample-apps/radiology/lib/infers/deepgrow_pipeline.py +++ b/sample-apps/radiology/lib/infers/deepgrow_pipeline.py @@ -39,6 +39,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferTask, InferType from monailabel.tasks.infer.basic_infer import BasicInferTask from monailabel.transform.post import BoundingBoxd, LargestCCd +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -82,7 +83,7 @@ def __init__( def pre_transforms(self, data=None) -> Sequence[Callable]: t = [ - LoadImaged(keys="image", image_only=False), + LoadImaged(keys="image", reader=["ITKReader", NvDicomReader()], image_only=False), Transposed(keys="image", indices=[2, 0, 1]), Spacingd(keys="image", pixdim=[1.0, 1.0, 1.0], mode="bilinear"), AddGuidanceFromPointsd(ref_image="image", guidance="guidance", spatial_dims=3), diff --git a/sample-apps/radiology/lib/infers/localization_spine.py b/sample-apps/radiology/lib/infers/localization_spine.py index 347d1536e..5680c200a 100644 --- a/sample-apps/radiology/lib/infers/localization_spine.py +++ b/sample-apps/radiology/lib/infers/localization_spine.py @@ -29,6 +29,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferType from monailabel.tasks.infer.basic_infer import BasicInferTask from monailabel.transform.post import Restored +from monailabel.transform.reader import NvDicomReader class LocalizationSpine(BasicInferTask): @@ -61,7 +62,7 @@ def __init__( def pre_transforms(self, data=None) -> Sequence[Callable]: return [ - LoadImaged(keys="image", reader="ITKReader"), + LoadImaged(keys="image", reader=["ITKReader", NvDicomReader()]), EnsureTyped(keys="image", device=data.get("device") if data else None), EnsureChannelFirstd(keys="image"), CacheObjectd(keys="image"), diff --git a/sample-apps/radiology/lib/infers/localization_vertebra.py b/sample-apps/radiology/lib/infers/localization_vertebra.py index fec4cc5a9..f5026aecf 100644 --- a/sample-apps/radiology/lib/infers/localization_vertebra.py +++ b/sample-apps/radiology/lib/infers/localization_vertebra.py @@ -31,6 +31,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferType from monailabel.tasks.infer.basic_infer import BasicInferTask from monailabel.transform.post import Restored +from monailabel.transform.reader import NvDicomReader class LocalizationVertebra(BasicInferTask): @@ -64,7 +65,7 @@ def __init__( def pre_transforms(self, data=None) -> Sequence[Callable]: if data and isinstance(data.get("image"), str): t = [ - LoadImaged(keys="image", reader="ITKReader"), + LoadImaged(keys="image", reader=["ITKReader", NvDicomReader()]), EnsureTyped(keys="image", device=data.get("device") if data else None), EnsureChannelFirstd(keys="image"), CacheObjectd(keys="image"), diff --git a/sample-apps/radiology/lib/infers/segmentation.py b/sample-apps/radiology/lib/infers/segmentation.py index b10c9f499..2e796087b 100644 --- a/sample-apps/radiology/lib/infers/segmentation.py +++ b/sample-apps/radiology/lib/infers/segmentation.py @@ -30,6 +30,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferType from monailabel.tasks.infer.basic_infer import BasicInferTask from monailabel.transform.post import Restored +from monailabel.transform.reader import NvDicomReader class Segmentation(BasicInferTask): @@ -62,7 +63,7 @@ def __init__( def pre_transforms(self, data=None) -> Sequence[Callable]: t = [ - LoadImaged(keys="image"), + LoadImaged(keys="image", reader=["ITKReader", NvDicomReader()]), EnsureTyped(keys="image", device=data.get("device") if data else None), EnsureChannelFirstd(keys="image"), Orientationd(keys="image", axcodes="RAS"), diff --git a/sample-apps/radiology/lib/infers/segmentation_spleen.py b/sample-apps/radiology/lib/infers/segmentation_spleen.py index 1e4c4102a..2a1cb043f 100644 --- a/sample-apps/radiology/lib/infers/segmentation_spleen.py +++ b/sample-apps/radiology/lib/infers/segmentation_spleen.py @@ -28,6 +28,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferType from monailabel.tasks.infer.basic_infer import BasicInferTask from monailabel.transform.post import Restored +from monailabel.transform.reader import NvDicomReader class SegmentationSpleen(BasicInferTask): @@ -60,7 +61,7 @@ def __init__( def pre_transforms(self, data=None) -> Sequence[Callable]: return [ - LoadImaged(keys="image"), + LoadImaged(keys="image", reader=["ITKReader", NvDicomReader()]), EnsureTyped(keys="image", device=data.get("device") if data else None), EnsureChannelFirstd(keys="image"), Orientationd(keys="image", axcodes="RAS"), diff --git a/sample-apps/radiology/lib/infers/segmentation_vertebra.py b/sample-apps/radiology/lib/infers/segmentation_vertebra.py index 142adba33..d0fd60a34 100644 --- a/sample-apps/radiology/lib/infers/segmentation_vertebra.py +++ b/sample-apps/radiology/lib/infers/segmentation_vertebra.py @@ -38,6 +38,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferType from monailabel.tasks.infer.basic_infer import BasicInferTask from monailabel.transform.post import Restored +from monailabel.transform.reader import NvDicomReader class SegmentationVertebra(BasicInferTask): @@ -75,7 +76,7 @@ def pre_transforms(self, data=None) -> Sequence[Callable]: add_cache = True t.extend( [ - LoadImaged(keys="image", reader="ITKReader"), + LoadImaged(keys="image", reader=["ITKReader", NvDicomReader()]), EnsureTyped(keys="image", device=data.get("device") if data else None), EnsureChannelFirstd(keys="image"), GetOriginalInformation(keys="image"), diff --git a/sample-apps/radiology/lib/infers/sw_fastedit.py b/sample-apps/radiology/lib/infers/sw_fastedit.py index fbfea3b0d..9430ee957 100644 --- a/sample-apps/radiology/lib/infers/sw_fastedit.py +++ b/sample-apps/radiology/lib/infers/sw_fastedit.py @@ -40,6 +40,7 @@ from monailabel.interfaces.tasks.infer_v2 import InferType from monailabel.tasks.infer.basic_infer import BasicInferTask, CallBackTypes +from monailabel.transform.reader import NvDicomReader # monai_version = pkg_resources.get_distribution("monai").version # if not pkg_resources.parse_version(monai_version) >= pkg_resources.parse_version("1.3.0"): @@ -119,7 +120,7 @@ def pre_transforms(self, data=None) -> Sequence[Callable]: t = [] t_val_1 = [ - LoadImaged(keys=input_keys, reader="ITKReader", image_only=False), + LoadImaged(keys=input_keys, reader=["ITKReader", NvDicomReader()], image_only=False), EnsureChannelFirstd(keys=input_keys), ScaleIntensityRangePercentilesd( keys="image", lower=0.05, upper=99.95, b_min=0.0, b_max=1.0, clip=True, relative=False diff --git a/sample-apps/radiology/lib/trainers/deepedit.py b/sample-apps/radiology/lib/trainers/deepedit.py index 3e8887fab..228279351 100644 --- a/sample-apps/radiology/lib/trainers/deepedit.py +++ b/sample-apps/radiology/lib/trainers/deepedit.py @@ -43,6 +43,7 @@ from monailabel.deepedit.handlers import TensorBoardImageHandler from monailabel.tasks.train.basic_train import BasicTrainTask, Context +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -100,7 +101,7 @@ def get_click_transforms(self, context: Context): def train_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label"), reader="ITKReader", image_only=False), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()], image_only=False), EnsureChannelFirstd(keys=("image", "label")), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), Orientationd(keys=["image", "label"], axcodes="RAS"), @@ -134,7 +135,7 @@ def train_post_transforms(self, context: Context): def val_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label"), reader="ITKReader"), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), EnsureChannelFirstd(keys=("image", "label")), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), Orientationd(keys=["image", "label"], axcodes="RAS"), diff --git a/sample-apps/radiology/lib/trainers/deepgrow.py b/sample-apps/radiology/lib/trainers/deepgrow.py index 99de1c884..7a05fa1dd 100644 --- a/sample-apps/radiology/lib/trainers/deepgrow.py +++ b/sample-apps/radiology/lib/trainers/deepgrow.py @@ -43,6 +43,7 @@ from monailabel.interfaces.datastore import Datastore from monailabel.tasks.train.basic_train import BasicTrainTask, Context +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -115,7 +116,7 @@ def get_click_transforms(self, context: Context): def train_pre_transforms(self, context: Context): # Dataset preparation t: List[Any] = [ - LoadImaged(keys=("image", "label"), image_only=False), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()], image_only=False), EnsureChannelFirstd(keys=("image", "label")), SpatialCropForegroundd(keys=("image", "label"), source_key="label", spatial_size=self.roi_size), Resized(keys=("image", "label"), spatial_size=self.model_size, mode=("area", "nearest")), diff --git a/sample-apps/radiology/lib/trainers/localization_spine.py b/sample-apps/radiology/lib/trainers/localization_spine.py index cd42658a1..eb17f8250 100644 --- a/sample-apps/radiology/lib/trainers/localization_spine.py +++ b/sample-apps/radiology/lib/trainers/localization_spine.py @@ -32,6 +32,7 @@ from monailabel.tasks.train.basic_train import BasicTrainTask, Context from monailabel.tasks.train.utils import region_wise_metrics +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -71,7 +72,7 @@ def train_data_loader(self, context, num_workers=0, shuffle=False): def train_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label"), reader="ITKReader"), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels EnsureChannelFirstd(keys=("image", "label")), EnsureTyped(keys=("image", "label"), device=context.device), @@ -101,7 +102,7 @@ def train_post_transforms(self, context: Context): def val_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label"), reader="ITKReader"), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels EnsureTyped(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), diff --git a/sample-apps/radiology/lib/trainers/localization_vertebra.py b/sample-apps/radiology/lib/trainers/localization_vertebra.py index 726215197..94528a075 100644 --- a/sample-apps/radiology/lib/trainers/localization_vertebra.py +++ b/sample-apps/radiology/lib/trainers/localization_vertebra.py @@ -33,6 +33,7 @@ from monailabel.tasks.train.basic_train import BasicTrainTask, Context from monailabel.tasks.train.utils import region_wise_metrics +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -71,7 +72,7 @@ def train_data_loader(self, context, num_workers=0, shuffle=False): def train_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label"), reader="ITKReader"), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels EnsureChannelFirstd(keys=("image", "label")), EnsureTyped(keys=("image", "label"), device=context.device), @@ -107,7 +108,7 @@ def train_post_transforms(self, context: Context): def val_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label"), reader="ITKReader"), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels EnsureTyped(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), diff --git a/sample-apps/radiology/lib/trainers/segmentation.py b/sample-apps/radiology/lib/trainers/segmentation.py index 07cbc6b7d..1ea8ebdd7 100644 --- a/sample-apps/radiology/lib/trainers/segmentation.py +++ b/sample-apps/radiology/lib/trainers/segmentation.py @@ -34,6 +34,7 @@ from monailabel.tasks.train.basic_train import BasicTrainTask, Context from monailabel.tasks.train.utils import region_wise_metrics +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -72,7 +73,7 @@ def train_data_loader(self, context, num_workers=0, shuffle=False): def train_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label")), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels EnsureChannelFirstd(keys=("image", "label")), EnsureTyped(keys=("image", "label"), device=context.device), @@ -108,7 +109,7 @@ def train_post_transforms(self, context: Context): def val_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label")), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels EnsureTyped(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), diff --git a/sample-apps/radiology/lib/trainers/segmentation_spleen.py b/sample-apps/radiology/lib/trainers/segmentation_spleen.py index 1dc0df6cf..0f63499ce 100644 --- a/sample-apps/radiology/lib/trainers/segmentation_spleen.py +++ b/sample-apps/radiology/lib/trainers/segmentation_spleen.py @@ -31,6 +31,7 @@ ) from monailabel.tasks.train.basic_train import BasicTrainTask, Context +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -61,7 +62,7 @@ def loss_function(self, context: Context): def train_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label")), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels EnsureChannelFirstd(keys=("image", "label")), EnsureTyped(keys=("image", "label"), device=context.device), @@ -90,7 +91,7 @@ def train_post_transforms(self, context: Context): def val_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label")), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), NormalizeLabelsInDatasetd(keys="label", label_names=self._labels), # Specially for missing labels EnsureTyped(keys=("image", "label")), EnsureChannelFirstd(keys=("image", "label")), diff --git a/sample-apps/radiology/lib/trainers/segmentation_vertebra.py b/sample-apps/radiology/lib/trainers/segmentation_vertebra.py index 20f69bd7d..8601668bf 100644 --- a/sample-apps/radiology/lib/trainers/segmentation_vertebra.py +++ b/sample-apps/radiology/lib/trainers/segmentation_vertebra.py @@ -38,6 +38,7 @@ from monailabel.tasks.train.basic_train import BasicTrainTask, Context from monailabel.tasks.train.utils import region_wise_metrics +from monailabel.transform.reader import NvDicomReader logger = logging.getLogger(__name__) @@ -76,7 +77,7 @@ def train_data_loader(self, context, num_workers=0, shuffle=False): def train_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label"), reader="ITKReader"), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), EnsureChannelFirstd(keys=("image", "label")), # NormalizeIntensityd(keys="image", divisor=2048.0), ScaleIntensityRanged(keys="image", a_min=-1000, a_max=1900, b_min=0.0, b_max=1.0, clip=True), @@ -107,7 +108,7 @@ def train_post_transforms(self, context: Context): def val_pre_transforms(self, context: Context): return [ - LoadImaged(keys=("image", "label"), reader="ITKReader"), + LoadImaged(keys=("image", "label"), reader=["ITKReader", NvDicomReader()]), EnsureChannelFirstd(keys=("image", "label")), # NormalizeIntensityd(keys="image", divisor=2048.0), ScaleIntensityRanged(keys="image", a_min=-1000, a_max=1900, b_min=0.0, b_max=1.0, clip=True), diff --git a/setup.cfg b/setup.cfg index 83b3d77e0..bc795898c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -50,8 +50,8 @@ install_requires = expiring_dict>=1.1.0 cachetools>=5.3.3 watchdog>=4.0.0 - pydicom>=2.4.4 - pydicom-seg>=0.4.1 + pydicom>=3.0.1 + highdicom>=0.26.1 pynetdicom>=2.0.2 pynrrd>=1.0.0 numpymaxflow>=0.0.7 diff --git a/tests/integration/radiology_serverless/__init__.py b/tests/integration/radiology_serverless/__init__.py new file mode 100644 index 000000000..1e97f8940 --- /dev/null +++ b/tests/integration/radiology_serverless/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/integration/radiology_serverless/test_dicom_segmentation.py b/tests/integration/radiology_serverless/test_dicom_segmentation.py new file mode 100644 index 000000000..b0c30eca3 --- /dev/null +++ b/tests/integration/radiology_serverless/test_dicom_segmentation.py @@ -0,0 +1,509 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import tempfile +import time +import unittest +from pathlib import Path + +import numpy as np +import torch + +from monailabel.config import settings +from monailabel.interfaces.app import MONAILabelApp +from monailabel.interfaces.utils.app import app_instance + +logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", +) +logger = logging.getLogger(__name__) + + +class TestDicomSegmentation(unittest.TestCase): + """ + Test direct MONAI Label inference on DICOM series without server. + + This test demonstrates serverless usage of MONAILabel for DICOM segmentation, + loading DICOM series from test data directories and running inference directly + through the app instance. + """ + + app = None + base_dir = os.path.realpath(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) + data_dir = os.path.join(base_dir, "tests", "data") + + app_dir = os.path.join(base_dir, "sample-apps", "radiology") + studies = os.path.join(data_dir, "dataset", "local", "spleen") + + # DICOM test data directories + dicomweb_dir = os.path.join(data_dir, "dataset", "dicomweb") + dicomweb_htj2k_dir = os.path.join(data_dir, "dataset", "dicomweb_htj2k") + + # Specific DICOM series for testing + dicomweb_series = os.path.join( + data_dir, + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686521.1629744176.620266", + ) + dicomweb_htj2k_series = os.path.join( + data_dir, + "dataset", + "dicomweb_htj2k", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686521.1629744176.620266", + ) + + dicomweb_htj2k_multiframe_series = os.path.join( + data_dir, + "dataset", + "dicomweb_htj2k_multiframe", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686521.1629744176.620251", + ) + + @classmethod + def setUpClass(cls) -> None: + """Initialize MONAI Label app for direct usage without server.""" + settings.MONAI_LABEL_APP_DIR = cls.app_dir + settings.MONAI_LABEL_STUDIES = cls.studies + settings.MONAI_LABEL_DATASTORE_AUTO_RELOAD = False + + if torch.cuda.is_available(): + logger.info(f"Initializing MONAI Label app from: {cls.app_dir}") + logger.info(f"Studies directory: {cls.studies}") + + cls.app: MONAILabelApp = app_instance( + app_dir=cls.app_dir, + studies=cls.studies, + conf={ + "preload": "true", + "models": "segmentation_spleen", + }, + ) + + logger.info("App initialized successfully") + + @classmethod + def tearDownClass(cls) -> None: + """Clean up after tests.""" + pass + + def _run_inference(self, image_path: str, model_name: str = "segmentation_spleen") -> tuple: + """ + Run segmentation inference on an image (DICOM series directory or NIfTI file). + + Args: + image_path: Path to DICOM series directory or NIfTI file + model_name: Name of the segmentation model to use + + Returns: + Tuple of (label_data, label_json, inference_time) + """ + logger.info(f"Running inference on: {image_path}") + logger.info(f"Model: {model_name}") + + # Prepare inference request + request = { + "model": model_name, + "image": image_path, # Can be DICOM directory or NIfTI file + "device": "cuda" if torch.cuda.is_available() else "cpu", + "result_extension": ".nii.gz", # Force NIfTI output format + "result_dtype": "uint8", # Set output data type + } + + # Get the inference task directly + task = self.app._infers[model_name] + + # Run inference + inference_start = time.time() + label_data, label_json = task(request) + inference_time = time.time() - inference_start + + logger.info(f"Inference completed in {inference_time:.3f} seconds") + + return label_data, label_json, inference_time + + def _load_segmentation_array(self, label_data): + """ + Load segmentation data as numpy array. + + Args: + label_data: File path (str) or numpy array + + Returns: + numpy array of segmentation + """ + if isinstance(label_data, str): + import nibabel as nib + + nii = nib.load(label_data) + return nii.get_fdata() + elif isinstance(label_data, np.ndarray): + return label_data + else: + raise ValueError(f"Unexpected label data type: {type(label_data)}") + + def _validate_segmentation_output(self, label_data, label_json): + """ + Validate that the segmentation output is correct. + + Args: + label_data: The segmentation result (file path or numpy array) + label_json: Metadata about the segmentation + """ + self.assertIsNotNone(label_data, "Label data should not be None") + self.assertIsNotNone(label_json, "Label JSON should not be None") + + # Check if it's a file path or numpy array + if isinstance(label_data, str): + self.assertTrue(os.path.exists(label_data), f"Output file should exist: {label_data}") + logger.info(f"Segmentation saved to: {label_data}") + + # Try to load and verify the file + try: + array = self._load_segmentation_array(label_data) + self.assertGreater(array.size, 0, "Segmentation array should not be empty") + logger.info(f"Segmentation shape: {array.shape}, dtype: {array.dtype}") + logger.info(f"Unique labels: {np.unique(array)}") + except Exception as e: + logger.warning(f"Could not load segmentation file: {e}") + + elif isinstance(label_data, np.ndarray): + self.assertGreater(label_data.size, 0, "Segmentation array should not be empty") + logger.info(f"Segmentation shape: {label_data.shape}, dtype: {label_data.dtype}") + logger.info(f"Unique labels: {np.unique(label_data)}") + else: + self.fail(f"Unexpected label data type: {type(label_data)}") + + # Validate metadata + self.assertIsInstance(label_json, dict, "Label JSON should be a dictionary") + logger.info(f"Label metadata keys: {list(label_json.keys())}") + + def _compare_segmentations( + self, label_data_1, label_data_2, name_1="Reference", name_2="Comparison", tolerance=0.05 + ): + """ + Compare two segmentation outputs to verify they are similar. + + Args: + label_data_1: First segmentation (file path or array) + label_data_2: Second segmentation (file path or array) + name_1: Name for first segmentation (for logging) + name_2: Name for second segmentation (for logging) + tolerance: Maximum allowed dice coefficient difference (0.0-1.0) + + Returns: + dict with comparison metrics + """ + # Load arrays + array_1 = self._load_segmentation_array(label_data_1) + array_2 = self._load_segmentation_array(label_data_2) + + # Check shapes match + self.assertEqual( + array_1.shape, array_2.shape, f"Segmentation shapes should match: {array_1.shape} vs {array_2.shape}" + ) + + # Calculate dice coefficient for each label + unique_labels = np.union1d(np.unique(array_1), np.unique(array_2)) + unique_labels = unique_labels[unique_labels != 0] # Exclude background + + dice_scores = {} + for label in unique_labels: + mask_1 = (array_1 == label).astype(np.float32) + mask_2 = (array_2 == label).astype(np.float32) + + intersection = np.sum(mask_1 * mask_2) + sum_masks = np.sum(mask_1) + np.sum(mask_2) + + if sum_masks > 0: + dice = (2.0 * intersection) / sum_masks + dice_scores[int(label)] = dice + else: + dice_scores[int(label)] = 0.0 + + # Calculate overall metrics + exact_match = np.array_equal(array_1, array_2) + pixel_accuracy = np.mean(array_1 == array_2) + + comparison_result = { + "exact_match": exact_match, + "pixel_accuracy": pixel_accuracy, + "dice_scores": dice_scores, + "avg_dice": np.mean(list(dice_scores.values())) if dice_scores else 0.0, + } + + # Log results + logger.info(f"\nComparing {name_1} vs {name_2}:") + logger.info(f" Exact match: {exact_match}") + logger.info(f" Pixel accuracy: {pixel_accuracy:.4f}") + logger.info(f" Dice scores by label: {dice_scores}") + logger.info(f" Average Dice: {comparison_result['avg_dice']:.4f}") + + # Assert high similarity + self.assertGreater( + comparison_result["avg_dice"], + 1.0 - tolerance, + f"Segmentations should be similar (Dice > {1.0 - tolerance:.2f}). " + f"Got {comparison_result['avg_dice']:.4f}", + ) + + return comparison_result + + def test_01_app_initialized(self): + """Test that the app is properly initialized.""" + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + self.assertIsNotNone(self.app, "App should be initialized") + self.assertIn("segmentation_spleen", self.app._infers, "segmentation_spleen model should be available") + logger.info(f"Available models: {list(self.app._infers.keys())}") + + def test_02_dicom_inference_dicomweb(self): + """Test inference on DICOM series from dicomweb directory.""" + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + if not self.app: + self.skipTest("App not initialized") + + # Use specific DICOM series + if not os.path.exists(self.dicomweb_series): + self.skipTest(f"DICOM series not found: {self.dicomweb_series}") + + logger.info(f"Testing on DICOM series: {self.dicomweb_series}") + + # Run inference + label_data, label_json, inference_time = self._run_inference(self.dicomweb_series) + + # Validate output + self._validate_segmentation_output(label_data, label_json) + + # Performance check + self.assertLess(inference_time, 60.0, "Inference should complete within 60 seconds") + logger.info(f"✓ DICOM inference test passed (dicomweb) in {inference_time:.3f}s") + + def test_03_dicom_inference_dicomweb_htj2k(self): + """Test inference on DICOM series from dicomweb_htj2k directory (HTJ2K compressed).""" + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + if not self.app: + self.skipTest("App not initialized") + + # Use specific HTJ2K DICOM series + if not os.path.exists(self.dicomweb_htj2k_series): + self.skipTest(f"HTJ2K DICOM series not found: {self.dicomweb_htj2k_series}") + + logger.info(f"Testing on HTJ2K compressed DICOM series: {self.dicomweb_htj2k_series}") + + # Run inference + label_data, label_json, inference_time = self._run_inference(self.dicomweb_htj2k_series) + + # Validate output + self._validate_segmentation_output(label_data, label_json) + + # Performance check + self.assertLess(inference_time, 60.0, "Inference should complete within 60 seconds") + logger.info(f"✓ DICOM inference test passed (HTJ2K) in {inference_time:.3f}s") + + def test_04_compare_all_formats(self): + """ + Compare segmentation outputs across all DICOM format variations. + + This is the KEY test that validates: + - Standard DICOM (uncompressed, single-frame) + - HTJ2K compressed DICOM (single-frame) + - Multi-frame HTJ2K DICOM + + All produce IDENTICAL or highly similar segmentation results. + """ + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + if not self.app: + self.skipTest("App not initialized") + + logger.info(f"\n{'='*60}") + logger.info("Comparing Segmentation Outputs Across All Formats") + logger.info(f"{'='*60}") + + # Test all series types + test_series = [ + ("Standard DICOM", self.dicomweb_series), + ("HTJ2K DICOM", self.dicomweb_htj2k_series), + ("Multi-frame HTJ2K", self.dicomweb_htj2k_multiframe_series), + ] + + # Run inference on all available formats + results = {} + for series_name, series_path in test_series: + if not os.path.exists(series_path): + logger.warning(f"Skipping {series_name}: not found") + continue + + logger.info(f"\nRunning {series_name}...") + try: + label_data, label_json, inference_time = self._run_inference(series_path) + self._validate_segmentation_output(label_data, label_json) + + results[series_name] = {"label_data": label_data, "label_json": label_json, "time": inference_time} + logger.info(f" ✓ {series_name} completed in {inference_time:.3f}s") + except Exception as e: + logger.error(f" ✗ {series_name} failed: {e}", exc_info=True) + + # Require at least 2 formats to compare + self.assertGreaterEqual(len(results), 2, "Need at least 2 formats to compare. Check test data availability.") + + # Compare all pairs + logger.info(f"\n{'='*60}") + logger.info("Cross-Format Comparison:") + logger.info(f"{'='*60}") + + format_names = list(results.keys()) + comparison_results = [] + + for i in range(len(format_names)): + for j in range(i + 1, len(format_names)): + name1 = format_names[i] + name2 = format_names[j] + + logger.info(f"\nComparing: {name1} vs {name2}") + try: + comparison = self._compare_segmentations( + results[name1]["label_data"], + results[name2]["label_data"], + name_1=name1, + name_2=name2, + tolerance=0.05, # Allow 5% dice variation + ) + comparison_results.append( + { + "pair": f"{name1} vs {name2}", + "dice": comparison["avg_dice"], + "pixel_accuracy": comparison["pixel_accuracy"], + } + ) + except Exception as e: + logger.error(f"Comparison failed: {e}", exc_info=True) + raise + + # Summary + logger.info(f"\n{'='*60}") + logger.info("Comparison Summary:") + for comp in comparison_results: + logger.info(f" {comp['pair']}: Dice={comp['dice']:.4f}, Accuracy={comp['pixel_accuracy']:.4f}") + logger.info(f"{'='*60}") + + # All comparisons should show high similarity + self.assertTrue(len(comparison_results) > 0, "Should have at least one comparison") + avg_dice = np.mean([c["dice"] for c in comparison_results]) + logger.info(f"\nOverall average Dice across all comparisons: {avg_dice:.4f}") + self.assertGreater(avg_dice, 0.95, "All formats should produce highly similar segmentations (avg Dice > 0.95)") + + def test_05_compare_dicom_vs_nifti(self): + """ + Compare inference results between DICOM series and pre-converted NIfTI files. + + Validates that the DICOM reader produces identical results to pre-converted NIfTI. + """ + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + if not self.app: + self.skipTest("App not initialized") + + # Use specific DICOM series and its NIfTI equivalent + dicom_dir = self.dicomweb_series + nifti_file = f"{dicom_dir}.nii.gz" + + if not os.path.exists(dicom_dir): + self.skipTest(f"DICOM series not found: {dicom_dir}") + + if not os.path.exists(nifti_file): + self.skipTest(f"Corresponding NIfTI file not found: {nifti_file}") + + logger.info(f"\n{'='*60}") + logger.info("Comparing DICOM vs NIfTI Segmentation") + logger.info(f"{'='*60}") + logger.info(f" DICOM: {dicom_dir}") + logger.info(f" NIfTI: {nifti_file}") + + # Run inference on DICOM + logger.info("\n--- Running inference on DICOM series ---") + dicom_label, dicom_json, dicom_time = self._run_inference(dicom_dir) + self._validate_segmentation_output(dicom_label, dicom_json) + + # Run inference on NIfTI + logger.info("\n--- Running inference on NIfTI file ---") + nifti_label, nifti_json, nifti_time = self._run_inference(nifti_file) + self._validate_segmentation_output(nifti_label, nifti_json) + + # Compare the segmentation outputs + comparison = self._compare_segmentations( + dicom_label, + nifti_label, + name_1="DICOM", + name_2="NIfTI", + tolerance=0.01, # Stricter tolerance - should be nearly identical + ) + + logger.info(f"\n{'='*60}") + logger.info("Comparison Summary:") + logger.info(f" DICOM inference time: {dicom_time:.3f}s") + logger.info(f" NIfTI inference time: {nifti_time:.3f}s") + logger.info(f" Dice coefficient: {comparison['avg_dice']:.4f}") + logger.info(f" Pixel accuracy: {comparison['pixel_accuracy']:.4f}") + logger.info(f" Exact match: {comparison['exact_match']}") + logger.info(f"{'='*60}") + + # Should be nearly identical (Dice > 0.99) + self.assertGreater(comparison["avg_dice"], 0.99, "DICOM and NIfTI segmentations should be nearly identical") + + def test_06_multiframe_htj2k_inference(self): + """ + Test basic inference on multi-frame HTJ2K compressed DICOM series. + + Note: Comprehensive cross-format comparison is done in test_04. + This test ensures multi-frame HTJ2K inference works standalone. + """ + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + if not self.app: + self.skipTest("App not initialized") + + if not os.path.exists(self.dicomweb_htj2k_multiframe_series): + self.skipTest(f"Multi-frame HTJ2K series not found: {self.dicomweb_htj2k_multiframe_series}") + + logger.info(f"\n{'='*60}") + logger.info("Testing Multi-Frame HTJ2K DICOM Inference") + logger.info(f"{'='*60}") + logger.info(f"Series path: {self.dicomweb_htj2k_multiframe_series}") + + # Run inference + label_data, label_json, inference_time = self._run_inference(self.dicomweb_htj2k_multiframe_series) + + # Validate output + self._validate_segmentation_output(label_data, label_json) + + # Performance check + self.assertLess(inference_time, 60.0, "Inference should complete within 60 seconds") + + logger.info(f"✓ Multi-frame HTJ2K inference test passed in {inference_time:.3f}s") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/setup.py b/tests/setup.py index e33aeaf08..6f0693876 100644 --- a/tests/setup.py +++ b/tests/setup.py @@ -9,14 +9,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os import shutil +import tempfile +from pathlib import Path from monai.apps import download_url, extractall TEST_DIR = os.path.realpath(os.path.dirname(__file__)) TEST_DATA = os.path.join(TEST_DIR, "data") +logger = logging.getLogger(__name__) + def run_main(): downloaded_dataset_file = os.path.join(TEST_DIR, "downloads", "dataset.zip") @@ -50,11 +55,75 @@ def run_main(): os.makedirs(os.path.join(TEST_DATA, "detection")) extractall(filepath=downloaded_detection_file, output_dir=os.path.join(TEST_DATA, "detection")) - downloaded_dicom_file = os.path.join(TEST_DIR, "downloads", "dicom.zip") - dicom_url = "https://github.com/Project-MONAI/MONAILabel/releases/download/data/dicom.zip" - if not os.path.exists(downloaded_dicom_file): - download_url(url=dicom_url, filepath=downloaded_dicom_file) + # Create HTJ2K-encoded versions of dicomweb test data if nvimgcodec is available + try: + import sys + + sys.path.insert(0, TEST_DIR) + from monailabel.datastore.utils.convert_htj2k import ( + DicomFileLoader, + convert_single_frame_dicom_series_to_multiframe, + transcode_dicom_to_htj2k, + ) + + # Create regular HTJ2K files (preserving file structure) + logger.info("Creating HTJ2K test data (single-frame per file)...") + source_base_dir = Path(TEST_DATA) / "dataset" / "dicomweb" + htj2k_base_dir = Path(TEST_DATA) / "dataset" / "dicomweb_htj2k" + + if source_base_dir.exists() and not (htj2k_base_dir.exists() and any(htj2k_base_dir.rglob("*.dcm"))): + series_dirs = [d for d in source_base_dir.rglob("*") if d.is_dir() and any(d.glob("*.dcm"))] + for series_dir in series_dirs: + rel_path = series_dir.relative_to(source_base_dir) + output_series_dir = htj2k_base_dir / rel_path + if not (output_series_dir.exists() and any(output_series_dir.glob("*.dcm"))): + logger.info(f" Processing series: {rel_path}") + # Create file_loader using DicomFileLoader + file_loader = DicomFileLoader( + input_dir=str(series_dir), output_dir=str(output_series_dir), batch_size=256 + ) + transcode_dicom_to_htj2k( + file_loader=file_loader, + num_resolutions=6, + code_block_size=(64, 64), + add_basic_offset_table=False, + ) + logger.info(f"✓ HTJ2K test data created at: {htj2k_base_dir}") + else: + logger.info("HTJ2K test data already exists, skipping.") + + # Create multi-frame HTJ2K files (one file per series) + logger.info("Creating multi-frame HTJ2K test data...") + htj2k_multiframe_dir = Path(TEST_DATA) / "dataset" / "dicomweb_htj2k_multiframe" + + if source_base_dir.exists() and not ( + htj2k_multiframe_dir.exists() and any(htj2k_multiframe_dir.rglob("*.dcm")) + ): + convert_single_frame_dicom_series_to_multiframe( + input_dir=str(source_base_dir), + output_dir=str(htj2k_multiframe_dir), + convert_to_htj2k=True, + num_resolutions=6, + code_block_size=(64, 64), + ) + logger.info(f"✓ Multi-frame HTJ2K test data created at: {htj2k_multiframe_dir}") + else: + logger.info("Multi-frame HTJ2K test data already exists, skipping.") + + except ImportError as e: + if "nvidia" in str(e).lower() or "nvimgcodec" in str(e).lower(): + logger.info("Note: nvidia-nvimgcodec not installed. HTJ2K test data will not be created.") + logger.info("To enable HTJ2K support, install the package matching your CUDA version:") + logger.info(" pip install nvidia-nvimgcodec-cu{XX}[all]") + logger.info(" (Replace {XX} with your CUDA major version, e.g., cu13 for CUDA 13.x)") + logger.info("Installation guide: https://docs.nvidia.com/cuda/nvimagecodec/installation.html") + else: + logger.warning(f"Could not import HTJ2K creation module: {e}") + except Exception as e: + logger.warning(f"HTJ2K test data creation failed: {e}") + logger.info("You can manually run: python tests/prepare_htj2k_test_data.py") if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") run_main() diff --git a/tests/unit/datastore/test_convert.py b/tests/unit/datastore/test_convert.py index 9c190f162..7111ca1ba 100644 --- a/tests/unit/datastore/test_convert.py +++ b/tests/unit/datastore/test_convert.py @@ -10,13 +10,38 @@ # limitations under the License. import os +import subprocess import tempfile import unittest +from pathlib import Path import numpy as np +import pydicom from monai.transforms import LoadImage -from monailabel.datastore.utils.convert import binary_to_image, dicom_to_nifti, nifti_to_dicom_seg +from monailabel.datastore.utils.convert import ( + binary_to_image, + dicom_to_nifti, + nifti_to_dicom_seg, +) + +# Check if nvimgcodec is available +try: + from nvidia import nvimgcodec + + HAS_NVIMGCODEC = True +except ImportError: + HAS_NVIMGCODEC = False + nvimgcodec = None + +# HTJ2K Transfer Syntax UIDs +HTJ2K_TRANSFER_SYNTAXES = frozenset( + [ + "1.2.840.10008.1.2.4.201", # High-Throughput JPEG 2000 Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.202", # High-Throughput JPEG 2000 with RPCL Options Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.203", # High-Throughput JPEG 2000 Image Compression + ] +) class TestConvert(unittest.TestCase): @@ -48,25 +73,290 @@ def test_binary_to_image(self): assert result.endswith(".nii.gz") os.unlink(result) - def test_nifti_to_dicom_seg(self): - image = os.path.join(self.dicom_dataset, "1.2.826.0.1.3680043.8.274.1.1.8323329.686549.1629744177.996087") - label = os.path.join( + def test_nifti_to_dicom_seg_highdicom(self): + """Test NIfTI to DICOM SEG conversion using highdicom (use_itk=False).""" + series_dir = os.path.join(self.dicom_dataset, "1.2.826.0.1.3680043.8.274.1.1.8323329.686549.1629744177.996087") + label_file = os.path.join( self.dicom_dataset, "labels", "final", "1.2.826.0.1.3680043.8.274.1.1.8323329.686549.1629744177.996087.nii.gz", ) - result = nifti_to_dicom_seg(image, label, None, use_itk=False) - assert os.path.exists(result) - assert result.endswith(".dcm") + # Convert using highdicom (use_itk=False) + result = nifti_to_dicom_seg(series_dir, label_file, None, use_itk=False) + + # Verify output + self.assertTrue(os.path.exists(result), "DICOM SEG file should be created") + self.assertTrue(result.endswith(".dcm"), "Output should be a DICOM file") + + # Verify it's a valid DICOM file + ds = pydicom.dcmread(result) + self.assertEqual(ds.Modality, "SEG", "Should be a DICOM Segmentation object") + + # Verify segment count + input_label = LoadImage(image_only=True)(label_file) + num_labels = len(np.unique(input_label)) - 1 # Exclude background (0) + if hasattr(ds, "SegmentSequence"): + num_segments = len(ds.SegmentSequence) + print(f" Segments in DICOM SEG: {num_segments}, Unique labels in input: {num_labels}") + + # Clean up + os.unlink(result) + + print(f"✓ NIfTI → DICOM SEG conversion successful (highdicom)") + + def test_nifti_to_dicom_seg_itk(self): + """Test NIfTI to DICOM SEG conversion using ITK (use_itk=True).""" + series_dir = os.path.join(self.dicom_dataset, "1.2.826.0.1.3680043.8.274.1.1.8323329.686549.1629744177.996087") + label_file = os.path.join( + self.dicom_dataset, + "labels", + "final", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686549.1629744177.996087.nii.gz", + ) + + # Check if ITK/dcmqi is available + import shutil + + itk_available = shutil.which("itkimage2segimage") is not None + + if not itk_available: + self.skipTest( + "itkimage2segimage command-line tool not found. " + "Install dcmqi: pip install dcmqi (https://github.com/QIICR/dcmqi)" + ) + + # Convert using ITK (use_itk=True) + result = nifti_to_dicom_seg(series_dir, label_file, None, use_itk=True) + + # Verify output + self.assertTrue(os.path.exists(result), "DICOM SEG file should be created") + self.assertTrue(result.endswith(".dcm"), "Output should be a DICOM file") + + # Verify it's a valid DICOM file + ds = pydicom.dcmread(result) + self.assertEqual(ds.Modality, "SEG", "Should be a DICOM Segmentation object") + + # Verify segment count + input_label = LoadImage(image_only=True)(label_file) + num_labels = len(np.unique(input_label)) - 1 # Exclude background (0) + if hasattr(ds, "SegmentSequence"): + num_segments = len(ds.SegmentSequence) + print(f" Segments in DICOM SEG: {num_segments}, Unique labels in input: {num_labels}") + + # Clean up + os.unlink(result) + + print(f"✓ NIfTI → DICOM SEG conversion successful (ITK)") + + def test_dicom_series_to_nifti_original(self): + """Test DICOM to NIfTI conversion with original DICOM files (Explicit VR Little Endian).""" + # Use a specific series from dicomweb + dicom_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + # Find DICOM files in this series + dcm_files = list(Path(dicom_dir).glob("*.dcm")) + self.assertTrue(len(dcm_files) > 0, f"No DICOM files found in {dicom_dir}") + + # Reference NIfTI file (in parent directory with same name as series) + series_uid = os.path.basename(dicom_dir) + reference_nifti = os.path.join(os.path.dirname(dicom_dir), f"{series_uid}.nii.gz") + + # Convert DICOM series to NIfTI + result = dicom_to_nifti(dicom_dir) + + # Verify the result + self.assertTrue(os.path.exists(result), "NIfTI file should be created") + self.assertTrue(result.endswith(".nii.gz"), "Output should be a compressed NIfTI file") + + # Load and verify the NIfTI data + nifti_data, nifti_meta = LoadImage(image_only=False)(result) + + # Verify it's a 3D volume with expected dimensions (512x512x77) + self.assertEqual(len(nifti_data.shape), 3, "Should be a 3D volume") + self.assertEqual(nifti_data.shape[0], 512, "Should have 512 rows") + self.assertEqual(nifti_data.shape[1], 512, "Should have 512 columns") + self.assertEqual(nifti_data.shape[2], 77, "Should have 77 slices") + + # Verify metadata includes affine transformation + self.assertIn("affine", nifti_meta, "Metadata should include affine transformation") + + # Compare with reference NIfTI + ref_data, ref_meta = LoadImage(image_only=False)(reference_nifti) + self.assertEqual(nifti_data.shape, ref_data.shape, "Shape should match reference NIfTI") + # Check if pixel values are similar (allowing for minor differences in conversion) + np.testing.assert_allclose( + nifti_data, ref_data, rtol=1e-5, atol=1e-5, err_msg="Pixel values should match reference NIfTI" + ) + print(f" ✓ Matches reference NIfTI") + + # Clean up os.unlink(result) - def test_itk_image_to_dicom_seg(self): - pass + print(f"✓ Original DICOM → NIfTI conversion successful") + print(f" Input: {len(dcm_files)} DICOM files") + print(f" Output shape: {nifti_data.shape}") + + def test_dicom_series_to_nifti_htj2k(self): + """Test DICOM to NIfTI conversion with HTJ2K-encoded DICOM files.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Use a specific HTJ2K series from dicomweb_htj2k + htj2k_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb_htj2k", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + # Find HTJ2K files in this series + htj2k_files = list(Path(htj2k_dir).glob("*.dcm")) + + # If no HTJ2K files found but nvimgcodec is available, create them + if len(htj2k_files) == 0: + print("\nHTJ2K test data not found. Creating HTJ2K-encoded DICOM files...") + import sys + + sys.path.insert(0, os.path.join(self.base_dir)) + from prepare_htj2k_test_data import create_htj2k_data + + create_htj2k_data(os.path.join(self.base_dir, "data")) + # Re-check for files + htj2k_files = list(Path(htj2k_dir).glob("*.dcm")) + + if len(htj2k_files) == 0: + self.skipTest(f"No HTJ2K DICOM files found in {htj2k_dir}") + + # Reference NIfTI file (from original dicomweb directory) + series_uid = os.path.basename(htj2k_dir) + # Go up from dicomweb_htj2k to dataset, then to dicomweb + reference_nifti = os.path.join( + self.base_dir, "data", "dataset", "dicomweb", "e7567e0a064f0c334226a0658de23afd", f"{series_uid}.nii.gz" + ) + + # Convert HTJ2K DICOM series to NIfTI + result = dicom_to_nifti(htj2k_dir) + + # Verify the result + self.assertTrue(os.path.exists(result), "NIfTI file should be created") + self.assertTrue(result.endswith(".nii.gz"), "Output should be a compressed NIfTI file") + + # Load and verify the NIfTI data + nifti_data, nifti_meta = LoadImage(image_only=False)(result) + + # Verify it's a 3D volume with expected dimensions (512x512x77) + self.assertEqual(len(nifti_data.shape), 3, "Should be a 3D volume") + self.assertEqual(nifti_data.shape[0], 512, "Should have 512 rows") + self.assertEqual(nifti_data.shape[1], 512, "Should have 512 columns") + self.assertEqual(nifti_data.shape[2], 77, "Should have 77 slices") + + # Verify metadata includes affine transformation + self.assertIn("affine", nifti_meta, "Metadata should include affine transformation") + + # Compare with reference NIfTI + ref_data, ref_meta = LoadImage(image_only=False)(reference_nifti) + self.assertEqual(nifti_data.shape, ref_data.shape, "Shape should match reference NIfTI") + # HTJ2K is lossless, so pixel values should be identical + np.testing.assert_allclose( + nifti_data, ref_data, rtol=1e-5, atol=1e-5, err_msg="Pixel values should match reference NIfTI" + ) + print(f" ✓ Matches reference NIfTI (lossless HTJ2K compression verified)") + + # Clean up + os.unlink(result) + + print(f"✓ HTJ2K DICOM → NIfTI conversion successful") + print(f" Input: {len(htj2k_files)} HTJ2K DICOM files") + print(f" Output shape: {nifti_data.shape}") + + def test_dicom_to_nifti_consistency(self): + """Test that original and HTJ2K DICOM files produce identical NIfTI outputs.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Use specific series directories for both original and HTJ2K + dicom_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + htj2k_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb_htj2k", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + # Check if HTJ2K files exist, create if needed + htj2k_files = list(Path(htj2k_dir).glob("*.dcm")) + if len(htj2k_files) == 0: + print("\nHTJ2K test data not found. Creating HTJ2K-encoded DICOM files...") + import sys + + sys.path.insert(0, os.path.join(self.base_dir)) + from prepare_htj2k_test_data import create_htj2k_data + + create_htj2k_data(os.path.join(self.base_dir, "data")) + # Re-check for files + htj2k_files = list(Path(htj2k_dir).glob("*.dcm")) + + # If still no HTJ2K files, skip the test (encoding may have failed) + if len(htj2k_files) == 0: + self.skipTest( + f"No HTJ2K DICOM files found in {htj2k_dir}. HTJ2K encoding may not be supported for these files." + ) + + # Convert both versions + result_original = dicom_to_nifti(dicom_dir) + result_htj2k = dicom_to_nifti(htj2k_dir) + + try: + # Load both NIfTI files + data_original = LoadImage(image_only=True)(result_original) + data_htj2k = LoadImage(image_only=True)(result_htj2k) + + # Verify shapes match + self.assertEqual(data_original.shape, data_htj2k.shape, "Original and HTJ2K should produce same shape") + + # Verify data types match + self.assertEqual(data_original.dtype, data_htj2k.dtype, "Original and HTJ2K should produce same data type") + + # Verify pixel values are identical (HTJ2K is lossless) + np.testing.assert_array_equal( + data_original, data_htj2k, err_msg="Original and HTJ2K should produce identical pixel values (lossless)" + ) + + print(f"✓ Original and HTJ2K produce identical NIfTI outputs") + print(f" Shape: {data_original.shape}") + print(f" Data type: {data_original.dtype}") + print(f" Pixel values: Identical (lossless compression verified)") - def test_itk_dicom_seg_to_image(self): - pass + finally: + # Clean up + if os.path.exists(result_original): + os.unlink(result_original) + if os.path.exists(result_htj2k): + os.unlink(result_htj2k) if __name__ == "__main__": diff --git a/tests/unit/datastore/test_convert_htj2k.py b/tests/unit/datastore/test_convert_htj2k.py new file mode 100644 index 000000000..fc820144e --- /dev/null +++ b/tests/unit/datastore/test_convert_htj2k.py @@ -0,0 +1,2115 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import pydicom +from monai.transforms import LoadImage + +from monailabel.datastore.utils.convert import dicom_to_nifti +from monailabel.datastore.utils.convert_htj2k import ( + DicomFileLoader, + convert_single_frame_dicom_series_to_multiframe, + transcode_dicom_to_htj2k, +) + +# Check if nvimgcodec is available +try: + from nvidia import nvimgcodec + + HAS_NVIMGCODEC = True +except ImportError: + HAS_NVIMGCODEC = False + nvimgcodec = None + +# HTJ2K Transfer Syntax UIDs +HTJ2K_TRANSFER_SYNTAXES = frozenset( + [ + "1.2.840.10008.1.2.4.201", # High-Throughput JPEG 2000 Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.202", # High-Throughput JPEG 2000 with RPCL Options Image Compression (Lossless Only) + "1.2.840.10008.1.2.4.203", # High-Throughput JPEG 2000 Image Compression + ] +) + + +class TestConvertHTJ2K(unittest.TestCase): + base_dir = os.path.realpath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + dicom_dataset = os.path.join(base_dir, "data", "dataset", "dicomweb", "e7567e0a064f0c334226a0658de23afd") + + def test_transcode_multiframe_jpeg_ybr_to_htj2k(self): + """Test transcoding multi-frame JPEG with YCbCr color space to HTJ2K.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Use pydicom's built-in YBR color multi-frame JPEG example + import pydicom.data + + try: + source_file = pydicom.data.get_testdata_file("examples_ybr_color.dcm") + except Exception as e: + self.skipTest(f"Could not load pydicom test data: {e}") + + print(f"\nSource file: {source_file}") + + # Create temporary directories + input_dir = tempfile.mkdtemp(prefix="htj2k_multiframe_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_multiframe_output_") + + try: + # Copy file to input directory + import shutil + + test_filename = "multiframe_ybr.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original DICOM + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + original_transfer_syntax = str(ds_original.file_meta.TransferSyntaxUID) + num_frames = int(ds_original.NumberOfFrames) if hasattr(ds_original, "NumberOfFrames") else 1 + + print(f"\nOriginal file:") + print(f" Transfer Syntax: {original_transfer_syntax}") + print(f" Transfer Syntax Name: {ds_original.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_original.PhotometricInterpretation}") + print(f" Number of Frames: {num_frames}") + print(f" Dimensions: {ds_original.Rows} x {ds_original.Columns}") + print(f" Samples Per Pixel: {ds_original.SamplesPerPixel}") + print(f" Pixel shape: {original_pixels.shape}") + print(f" File size: {os.path.getsize(source_file):,} bytes") + + # Perform transcoding + print(f"\nTranscoding multi-frame YBR JPEG to HTJ2K...") + import time + + start_time = time.time() + + # Override default skip list to force transcoding of JPEG files + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader, skip_transfer_syntaxes=None # Override default to test JPEG transcoding + ) + + elapsed_time = time.time() - start_time + print(f"Transcoding completed in {elapsed_time:.2f} seconds") + + # Find transcoded file + transcoded_file = os.path.join(output_dir, test_filename) + self.assertTrue(os.path.exists(transcoded_file), f"Transcoded file should exist: {transcoded_file}") + + # Read transcoded DICOM + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_pixels = ds_transcoded.pixel_array + transcoded_transfer_syntax = str(ds_transcoded.file_meta.TransferSyntaxUID) + + print(f"\nTranscoded file:") + print(f" Transfer Syntax: {transcoded_transfer_syntax}") + print(f" PhotometricInterpretation: {ds_transcoded.PhotometricInterpretation}") + print(f" Pixel shape: {transcoded_pixels.shape}") + print(f" File size: {os.path.getsize(transcoded_file):,} bytes") + + # Verify transfer syntax is HTJ2K + self.assertIn( + transcoded_transfer_syntax, + HTJ2K_TRANSFER_SYNTAXES, + f"Transfer syntax should be HTJ2K, got {transcoded_transfer_syntax}", + ) + print(f"✓ Transfer syntax is HTJ2K: {transcoded_transfer_syntax}") + + # Verify PhotometricInterpretation was updated to RGB + self.assertEqual( + ds_transcoded.PhotometricInterpretation, + "RGB", + "PhotometricInterpretation should be updated to RGB after YCbCr conversion", + ) + print( + f"✓ PhotometricInterpretation updated: {ds_original.PhotometricInterpretation} -> {ds_transcoded.PhotometricInterpretation}" + ) + + # Verify shapes match + self.assertEqual(original_pixels.shape, transcoded_pixels.shape, "Pixel array shapes should match") + print(f"✓ Shapes match: {original_pixels.shape}") + + # Verify pixel values are close (allowing small differences due to color space conversions) + # Use allclose with tolerance since YCbCr->RGB conversion may have rounding differences + # between pydicom and nvimgcodec (atol=5 allows for typical conversion differences) + max_diff = np.abs(original_pixels.astype(np.float32) - transcoded_pixels.astype(np.float32)).max() + mean_diff = np.abs(original_pixels.astype(np.float32) - transcoded_pixels.astype(np.float32)).mean() + print(f" Pixel differences: max={max_diff}, mean={mean_diff:.3f}") + + if not np.allclose(original_pixels, transcoded_pixels, atol=5, rtol=0): + print(f"✗ Pixel values differ beyond tolerance") + self.fail(f"Pixel values should be close (atol=5), but max diff is {max_diff}") + + print(f"✓ Pixel values match within tolerance (atol=5, max_diff={max_diff})") + + # Verify metadata is preserved + self.assertEqual(ds_original.Rows, ds_transcoded.Rows, "Rows should be preserved") + self.assertEqual(ds_original.Columns, ds_transcoded.Columns, "Columns should be preserved") + self.assertEqual( + ds_original.NumberOfFrames, ds_transcoded.NumberOfFrames, "NumberOfFrames should be preserved" + ) + print(f"✓ Metadata preserved: {num_frames} frames, {ds_original.Rows}x{ds_original.Columns}") + + # Compare file sizes + size_ratio = os.path.getsize(transcoded_file) / os.path.getsize(source_file) + print(f"\nCompression ratio: {size_ratio:.2%}") + + print(f"\n✓ Multi-frame YBR JPEG to HTJ2K transcoding test passed!") + + finally: + # Clean up temporary directories + import shutil + + for temp_dir in [input_dir, output_dir]: + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + def test_transcode_ct_example_to_htj2k(self): + """Test transcoding uncompressed CT grayscale image to HTJ2K.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + source_file = str(examples.get_path("ct")) + print(f"\nSource: {source_file}") + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_ct_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_ct_output_") + + try: + test_filename = "ct_small.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + + print(f"Original: {ds_original.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_original.PhotometricInterpretation}") + print(f" Shape: {original_pixels.shape}") + + # Transcode + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k(file_loader=file_loader) + + # Read transcoded + transcoded_file = os.path.join(output_dir, test_filename) + self.assertTrue(os.path.exists(transcoded_file)) + + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_pixels = ds_transcoded.pixel_array + + print(f"Transcoded: {ds_transcoded.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_transcoded.PhotometricInterpretation}") + + # Verify HTJ2K + self.assertIn(str(ds_transcoded.file_meta.TransferSyntaxUID), HTJ2K_TRANSFER_SYNTAXES) + + # Verify lossless (grayscale should be exact) + np.testing.assert_array_equal(original_pixels, transcoded_pixels) + print("✓ CT grayscale lossless transcoding verified") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_transcode_mr_example_to_htj2k(self): + """Test transcoding uncompressed MR grayscale image to HTJ2K.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + source_file = str(examples.get_path("mr")) + print(f"\nSource: {source_file}") + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_mr_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_mr_output_") + + try: + test_filename = "mr_small.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + + print(f"Original: {ds_original.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_original.PhotometricInterpretation}") + print(f" Shape: {original_pixels.shape}") + + # Transcode + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k(file_loader=file_loader) + + # Read transcoded + transcoded_file = os.path.join(output_dir, test_filename) + self.assertTrue(os.path.exists(transcoded_file)) + + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_pixels = ds_transcoded.pixel_array + + print(f"Transcoded: {ds_transcoded.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_transcoded.PhotometricInterpretation}") + + # Verify HTJ2K + self.assertIn(str(ds_transcoded.file_meta.TransferSyntaxUID), HTJ2K_TRANSFER_SYNTAXES) + + # Verify lossless (grayscale should be exact) + np.testing.assert_array_equal(original_pixels, transcoded_pixels) + print("✓ MR grayscale lossless transcoding verified") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_transcode_rgb_color_example_to_htj2k(self): + """Test transcoding uncompressed RGB color image to HTJ2K.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + source_file = str(examples.get_path("rgb_color")) + print(f"\nSource: {source_file}") + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_rgb_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_rgb_output_") + + try: + test_filename = "rgb_color.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + + print(f"Original: {ds_original.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_original.PhotometricInterpretation}") + print(f" Shape: {original_pixels.shape}") + + # Transcode + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k(file_loader=file_loader) + + # Read transcoded + transcoded_file = os.path.join(output_dir, test_filename) + self.assertTrue(os.path.exists(transcoded_file)) + + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_pixels = ds_transcoded.pixel_array + + print(f"Transcoded: {ds_transcoded.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_transcoded.PhotometricInterpretation}") + + # Verify HTJ2K + self.assertIn(str(ds_transcoded.file_meta.TransferSyntaxUID), HTJ2K_TRANSFER_SYNTAXES) + + # Verify PhotometricInterpretation stays RGB + self.assertEqual(ds_transcoded.PhotometricInterpretation, "RGB") + + # Verify lossless (RGB uncompressed should be exact) + np.testing.assert_array_equal(original_pixels, transcoded_pixels) + print("✓ RGB color lossless transcoding verified") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_transcode_jpeg2k_example_to_htj2k(self): + """Test transcoding JPEG 2000 (YBR_RCT) color image to HTJ2K.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + source_file = str(examples.get_path("jpeg2k")) + print(f"\nSource: {source_file}") + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_jpeg2k_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_jpeg2k_output_") + + try: + test_filename = "jpeg2k.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + + print(f"Original: {ds_original.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_original.PhotometricInterpretation}") + print(f" Shape: {original_pixels.shape}") + + # Transcode + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k(file_loader=file_loader) + + # Read transcoded + transcoded_file = os.path.join(output_dir, test_filename) + self.assertTrue(os.path.exists(transcoded_file)) + + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_pixels = ds_transcoded.pixel_array + + print(f"Transcoded: {ds_transcoded.file_meta.TransferSyntaxUID.name}") + print(f" PhotometricInterpretation: {ds_transcoded.PhotometricInterpretation}") + + # Verify HTJ2K + self.assertIn(str(ds_transcoded.file_meta.TransferSyntaxUID), HTJ2K_TRANSFER_SYNTAXES) + + # Verify PhotometricInterpretation updated to RGB (from YBR_RCT) + self.assertEqual(ds_transcoded.PhotometricInterpretation, "RGB") + print(f"✓ PhotometricInterpretation updated: {ds_original.PhotometricInterpretation} -> RGB") + + # Verify pixels match within tolerance (color space conversion may have small differences) + max_diff = np.abs(original_pixels.astype(np.float32) - transcoded_pixels.astype(np.float32)).max() + mean_diff = np.abs(original_pixels.astype(np.float32) - transcoded_pixels.astype(np.float32)).mean() + print(f" Pixel differences: max={max_diff}, mean={mean_diff:.3f}") + + # YBR_RCT is reversible, so differences should be minimal + self.assertTrue(np.allclose(original_pixels, transcoded_pixels, atol=5, rtol=0)) + print(f"✓ JPEG2K (YBR_RCT) to HTJ2K transcoding verified (max_diff={max_diff})") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_transcode_dicom_to_htj2k_batch(self): + """Test batch transcoding of entire DICOM series to HTJ2K.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Use a specific series from dicomweb + dicom_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + # Find DICOM files in source directory + source_files = sorted(list(Path(dicom_dir).glob("*.dcm"))) + if not source_files: + source_files = sorted([f for f in Path(dicom_dir).iterdir() if f.is_file()]) + + self.assertGreater(len(source_files), 0, f"No DICOM files found in {dicom_dir}") + print(f"\nSource directory: {dicom_dir}") + print(f"Source files: {len(source_files)}") + + # Create a temporary directory for transcoded output + output_dir = tempfile.mkdtemp(prefix="htj2k_test_") + + try: + # Perform batch transcoding + print("\nTranscoding DICOM series to HTJ2K...") + file_loader = DicomFileLoader(dicom_dir, output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader, + ) + + # Find transcoded files + transcoded_files = sorted(list(Path(output_dir).glob("*.dcm"))) + if not transcoded_files: + transcoded_files = sorted([f for f in Path(output_dir).iterdir() if f.is_file()]) + + print(f"\nTranscoded files: {len(transcoded_files)}") + + # Verify file count matches + self.assertEqual( + len(transcoded_files), + len(source_files), + f"Number of transcoded files ({len(transcoded_files)}) should match source files ({len(source_files)})", + ) + print(f"✓ File count matches: {len(transcoded_files)} files") + + # Verify filenames match (directory structure) + source_names = sorted([f.name for f in source_files]) + transcoded_names = sorted([f.name for f in transcoded_files]) + self.assertEqual( + source_names, transcoded_names, "Filenames should match between source and transcoded directories" + ) + print(f"✓ Directory structure preserved: all filenames match") + + # Verify each file has been correctly transcoded + print("\nVerifying lossless transcoding...") + verified_count = 0 + + for source_file, transcoded_file in zip(source_files, transcoded_files): + # Read original DICOM + ds_original = pydicom.dcmread(str(source_file)) + original_pixels = ds_original.pixel_array + + # Read transcoded DICOM + ds_transcoded = pydicom.dcmread(str(transcoded_file)) + + # Verify transfer syntax is HTJ2K + transfer_syntax = str(ds_transcoded.file_meta.TransferSyntaxUID) + self.assertIn( + transfer_syntax, HTJ2K_TRANSFER_SYNTAXES, f"Transfer syntax should be HTJ2K, got {transfer_syntax}" + ) + + # Decode transcoded pixels + transcoded_pixels = ds_transcoded.pixel_array + + # Verify pixel values are identical (lossless) + np.testing.assert_array_equal( + original_pixels, + transcoded_pixels, + err_msg=f"Pixel values should be identical (lossless) for {source_file.name}", + ) + + # Verify metadata is preserved + self.assertEqual(ds_original.Rows, ds_transcoded.Rows, "Image dimensions (Rows) should be preserved") + self.assertEqual( + ds_original.Columns, ds_transcoded.Columns, "Image dimensions (Columns) should be preserved" + ) + self.assertEqual( + ds_original.BitsAllocated, ds_transcoded.BitsAllocated, "BitsAllocated should be preserved" + ) + self.assertEqual(ds_original.BitsStored, ds_transcoded.BitsStored, "BitsStored should be preserved") + + verified_count += 1 + + print(f"✓ All {verified_count} files verified: pixel values are identical (lossless)") + print(f"✓ Transfer syntax verified: HTJ2K (1.2.840.10008.1.2.4.20*)") + print(f"✓ Metadata preserved: dimensions, bit depth, etc.") + + # Verify that transcoded files are actually compressed + # HTJ2K files should typically be smaller or similar size for lossless + source_size = sum(f.stat().st_size for f in source_files) + transcoded_size = sum(f.stat().st_size for f in transcoded_files) + print(f"\nFile size comparison:") + print(f" Original: {source_size:,} bytes") + print(f" Transcoded: {transcoded_size:,} bytes") + print(f" Ratio: {transcoded_size/source_size:.2%}") + + print(f"\n✓ Batch HTJ2K transcoding test passed!") + + finally: + # Clean up temporary directory + import shutil + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + print(f"\n✓ Cleaned up temporary directory: {output_dir}") + + def test_transcode_mixed_directory(self): + """Test transcoding a directory with both uncompressed and HTJ2K images.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Use uncompressed DICOM series + uncompressed_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + # Find uncompressed DICOM files + uncompressed_files = sorted(list(Path(uncompressed_dir).glob("*.dcm"))) + if not uncompressed_files: + uncompressed_files = sorted([f for f in Path(uncompressed_dir).iterdir() if f.is_file()]) + + self.assertGreater(len(uncompressed_files), 10, f"Need at least 10 DICOM files in {uncompressed_dir}") + + # Create a mixed directory with some uncompressed and some HTJ2K files + import shutil + + mixed_dir = tempfile.mkdtemp(prefix="htj2k_mixed_") + output_dir = tempfile.mkdtemp(prefix="htj2k_output_") + htj2k_intermediate = tempfile.mkdtemp(prefix="htj2k_intermediate_") + + try: + print(f"\nCreating mixed directory with uncompressed and HTJ2K files...") + + # First, transcode half of the files to HTJ2K + mid_point = len(uncompressed_files) // 2 + + # Copy first half as uncompressed + uncompressed_subset = uncompressed_files[:mid_point] + for f in uncompressed_subset: + shutil.copy2(str(f), os.path.join(mixed_dir, f.name)) + + print(f" Copied {len(uncompressed_subset)} uncompressed files") + + # Transcode second half to HTJ2K + htj2k_source_dir = tempfile.mkdtemp(prefix="htj2k_source_", dir=htj2k_intermediate) + for f in uncompressed_files[mid_point:]: + shutil.copy2(str(f), os.path.join(htj2k_source_dir, f.name)) + + # Transcode this subset to HTJ2K + htj2k_output_dir = tempfile.mkdtemp(prefix="htj2k_subset_output_") + file_loader_subset = DicomFileLoader(htj2k_source_dir, htj2k_output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader_subset, + ) + + # Copy the transcoded HTJ2K files to mixed directory + htj2k_files_to_copy = list(Path(htj2k_output_dir).glob("*.dcm")) + if not htj2k_files_to_copy: + htj2k_files_to_copy = [f for f in Path(htj2k_output_dir).iterdir() if f.is_file()] + + for f in htj2k_files_to_copy: + shutil.copy2(str(f), os.path.join(mixed_dir, f.name)) + + print(f" Copied {len(htj2k_files_to_copy)} HTJ2K files") + + # Now we have a mixed directory + mixed_files = sorted(list(Path(mixed_dir).iterdir())) + self.assertEqual(len(mixed_files), len(uncompressed_files), "Mixed directory should have all files") + + print(f"\nMixed directory created with {len(mixed_files)} files:") + print(f" - {len(uncompressed_subset)} uncompressed") + print(f" - {len(htj2k_files_to_copy)} HTJ2K") + + # Verify the transfer syntaxes before transcoding + uncompressed_count_before = 0 + htj2k_count_before = 0 + for f in mixed_files: + ds = pydicom.dcmread(str(f)) + ts = str(ds.file_meta.TransferSyntaxUID) + if ts in HTJ2K_TRANSFER_SYNTAXES: + htj2k_count_before += 1 + else: + uncompressed_count_before += 1 + + print(f"\nBefore transcoding:") + print(f" - Uncompressed: {uncompressed_count_before}") + print(f" - HTJ2K: {htj2k_count_before}") + + # Store original pixel data from HTJ2K files for comparison + htj2k_original_data = {} + for f in mixed_files: + ds = pydicom.dcmread(str(f)) + ts = str(ds.file_meta.TransferSyntaxUID) + if ts in HTJ2K_TRANSFER_SYNTAXES: + htj2k_original_data[f.name] = { + "pixels": ds.pixel_array.copy(), + "mtime": f.stat().st_mtime, + } + + # Now transcode the mixed directory + print(f"\nTranscoding mixed directory...") + file_loader = DicomFileLoader(mixed_dir, output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader, + ) + + # Verify all files are in output + output_files = sorted(list(Path(output_dir).iterdir())) + self.assertEqual(len(output_files), len(mixed_files), "Output should have same number of files as input") + print(f"\n✓ File count matches: {len(output_files)} files") + + # Verify all filenames match + input_names = sorted([f.name for f in mixed_files]) + output_names = sorted([f.name for f in output_files]) + self.assertEqual(input_names, output_names, "All filenames should be preserved") + print(f"✓ Directory structure preserved: all filenames match") + + # Verify all output files are HTJ2K + all_htj2k = True + for f in output_files: + ds = pydicom.dcmread(str(f)) + ts = str(ds.file_meta.TransferSyntaxUID) + if ts not in HTJ2K_TRANSFER_SYNTAXES: + all_htj2k = False + print(f" ERROR: {f.name} has transfer syntax {ts}") + + self.assertTrue(all_htj2k, "All output files should be HTJ2K") + print(f"✓ All {len(output_files)} output files are HTJ2K") + + # Verify that HTJ2K files were copied (not re-transcoded) + print(f"\nVerifying HTJ2K files were copied correctly...") + for filename, original_data in htj2k_original_data.items(): + output_file = Path(output_dir) / filename + self.assertTrue(output_file.exists(), f"HTJ2K file {filename} should exist in output") + + # Read the output file + ds_output = pydicom.dcmread(str(output_file)) + output_pixels = ds_output.pixel_array + + # Verify pixel data is identical (proving it was copied, not re-transcoded) + np.testing.assert_array_equal( + original_data["pixels"], + output_pixels, + err_msg=f"HTJ2K file {filename} should have identical pixels after copy", + ) + + print(f"✓ All {len(htj2k_original_data)} HTJ2K files were copied correctly") + + # Verify that uncompressed files were transcoded and have correct pixel values + print(f"\nVerifying uncompressed files were transcoded correctly...") + transcoded_count = 0 + for input_file in mixed_files: + ds_input = pydicom.dcmread(str(input_file)) + ts_input = str(ds_input.file_meta.TransferSyntaxUID) + + if ts_input not in HTJ2K_TRANSFER_SYNTAXES: + # This was an uncompressed file, verify it was transcoded + output_file = Path(output_dir) / input_file.name + ds_output = pydicom.dcmread(str(output_file)) + + # Verify transfer syntax changed to HTJ2K + ts_output = str(ds_output.file_meta.TransferSyntaxUID) + self.assertIn( + ts_output, HTJ2K_TRANSFER_SYNTAXES, f"File {input_file.name} should be HTJ2K after transcoding" + ) + + # Verify lossless transcoding (pixel values identical) + np.testing.assert_array_equal( + ds_input.pixel_array, + ds_output.pixel_array, + err_msg=f"File {input_file.name} should have identical pixels after lossless transcoding", + ) + + transcoded_count += 1 + + print(f"✓ All {transcoded_count} uncompressed files were transcoded correctly (lossless)") + + print(f"\n✓ Mixed directory transcoding test passed!") + print(f" - HTJ2K files copied: {len(htj2k_original_data)}") + print(f" - Uncompressed files transcoded: {transcoded_count}") + print(f" - Total output files: {len(output_files)}") + + finally: + # Clean up all temporary directories + import shutil + + for temp_dir in [mixed_dir, output_dir, htj2k_intermediate]: + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + def test_transcode_dicom_to_htj2k_multiframe_metadata(self): + """Test that multi-frame HTJ2K files preserve correct DICOM metadata from original files.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Use a specific series from dicomweb + dicom_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + # Load original DICOM files and sort by Z-coordinate (same as transcode function does) + source_files = sorted(list(Path(dicom_dir).glob("*.dcm"))) + if not source_files: + source_files = sorted([f for f in Path(dicom_dir).iterdir() if f.is_file()]) + + print(f"\nLoading {len(source_files)} original DICOM files...") + original_datasets = [] + for source_file in source_files: + ds = pydicom.dcmread(str(source_file)) + z_pos = float(ds.ImagePositionPatient[2]) if hasattr(ds, "ImagePositionPatient") else 0 + original_datasets.append((z_pos, ds)) + + # Sort by Z position (same as convert_single_frame_dicom_series_to_multiframe does) + original_datasets.sort(key=lambda x: x[0]) + original_datasets = [ds for _, ds in original_datasets] + print(f"✓ Original files loaded and sorted by Z-coordinate") + + # Create temporary output directory + output_dir = tempfile.mkdtemp(prefix="htj2k_multiframe_metadata_") + + try: + # Transcode to multi-frame + result_dir = convert_single_frame_dicom_series_to_multiframe( + input_dir=dicom_dir, + output_dir=output_dir, + convert_to_htj2k=True, + ) + + # Find the multi-frame file + multiframe_files = list(Path(output_dir).rglob("*.dcm")) + self.assertEqual(len(multiframe_files), 1, "Should have one multi-frame file") + + # Load the multi-frame file + ds_multiframe = pydicom.dcmread(str(multiframe_files[0])) + + print(f"\nVerifying multi-frame metadata against original files...") + + # Check NumberOfFrames matches source file count + self.assertTrue(hasattr(ds_multiframe, "NumberOfFrames"), "Should have NumberOfFrames") + num_frames = int(ds_multiframe.NumberOfFrames) + self.assertEqual(num_frames, len(original_datasets), "NumberOfFrames should match source file count") + print(f"✓ NumberOfFrames: {num_frames} (matches source)") + + # Check FrameIncrementPointer (required for multi-frame) + self.assertTrue(hasattr(ds_multiframe, "FrameIncrementPointer"), "Should have FrameIncrementPointer") + self.assertEqual(ds_multiframe.FrameIncrementPointer, 0x00200032, "Should point to ImagePositionPatient") + print(f"✓ FrameIncrementPointer: {hex(ds_multiframe.FrameIncrementPointer)} (ImagePositionPatient)") + + # Verify top-level metadata matches first frame + first_original = original_datasets[0] + + # Check ImagePositionPatient is NOT there at top level DICOM file + self.assertFalse( + hasattr(ds_multiframe, "ImagePositionPatient"), "Should not have ImagePositionPatient at top level" + ) + + # Check PixelSpacing + self.assertTrue(hasattr(ds_multiframe, "PixelSpacing"), "Should have PixelSpacing") + np.testing.assert_array_almost_equal( + np.array([float(x) for x in ds_multiframe.PixelSpacing]), + np.array([float(x) for x in first_original.PixelSpacing]), + decimal=6, + err_msg="PixelSpacing should match original", + ) + print(f"✓ PixelSpacing matches original: {ds_multiframe.PixelSpacing}") + + # Check SliceThickness + if hasattr(first_original, "SliceThickness"): + self.assertTrue(hasattr(ds_multiframe, "SliceThickness"), "Should have SliceThickness") + self.assertAlmostEqual( + float(ds_multiframe.SliceThickness), + float(first_original.SliceThickness), + places=6, + msg="SliceThickness should match original", + ) + print(f"✓ SliceThickness matches original: {ds_multiframe.SliceThickness}") + + # Check SOPClassUID conversion to Enhanced/Multi-frame + self.assertTrue(hasattr(ds_multiframe, "SOPClassUID"), "Should have SOPClassUID") + self.assertTrue(hasattr(first_original, "SOPClassUID"), "Original should have SOPClassUID") + + # Map of single-frame to enhanced/multi-frame SOPClassUIDs + sopclass_map = { + "1.2.840.10008.5.1.4.1.1.2": "1.2.840.10008.5.1.4.1.1.2.1", # CT -> Enhanced CT + "1.2.840.10008.5.1.4.1.1.4": "1.2.840.10008.5.1.4.1.1.4.1", # MR -> Enhanced MR + "1.2.840.10008.5.1.4.1.1.6.1": "1.2.840.10008.5.1.4.1.1.3.1", # US -> Ultrasound Multi-frame + } + + original_sopclass = str(first_original.SOPClassUID) + multiframe_sopclass = str(ds_multiframe.SOPClassUID) + + if original_sopclass in sopclass_map: + expected_sopclass = sopclass_map[original_sopclass] + self.assertEqual( + multiframe_sopclass, + expected_sopclass, + f"SOPClassUID should be converted from {original_sopclass} to {expected_sopclass}", + ) + print(f"✓ SOPClassUID converted: {original_sopclass} -> {multiframe_sopclass}") + else: + # If not in map, should remain unchanged + self.assertEqual( + multiframe_sopclass, + original_sopclass, + "SOPClassUID should remain unchanged if not in conversion map", + ) + print(f"✓ SOPClassUID unchanged: {multiframe_sopclass}") + + # Check for PerFrameFunctionalGroupsSequence + self.assertTrue( + hasattr(ds_multiframe, "PerFrameFunctionalGroupsSequence"), + "Should have PerFrameFunctionalGroupsSequence", + ) + per_frame_seq = ds_multiframe.PerFrameFunctionalGroupsSequence + self.assertEqual( + len(per_frame_seq), num_frames, f"PerFrameFunctionalGroupsSequence should have {num_frames} items" + ) + print(f"✓ PerFrameFunctionalGroupsSequence: {len(per_frame_seq)} frames") + + # Verify each frame's metadata matches corresponding original file + print(f"\nVerifying per-frame metadata...") + mismatches = [] + for frame_idx in range(num_frames): + frame_item = per_frame_seq[frame_idx] + original_ds = original_datasets[frame_idx] + + # Check PlanePositionSequence + self.assertTrue( + hasattr(frame_item, "PlanePositionSequence"), f"Frame {frame_idx} should have PlanePositionSequence" + ) + plane_pos = frame_item.PlanePositionSequence[0] + self.assertTrue( + hasattr(plane_pos, "ImagePositionPatient"), + f"Frame {frame_idx} should have ImagePositionPatient in PlanePositionSequence", + ) + + # Verify ImagePositionPatient matches original + multiframe_ipp = np.array([float(x) for x in plane_pos.ImagePositionPatient]) + original_ipp = np.array([float(x) for x in original_ds.ImagePositionPatient]) + + try: + np.testing.assert_array_almost_equal( + multiframe_ipp, + original_ipp, + decimal=6, + err_msg=f"Frame {frame_idx} ImagePositionPatient should match original", + ) + except AssertionError as e: + mismatches.append(f"Frame {frame_idx}: {e}") + + # PlaneOrientationSequence should ONLY be in SharedFunctionalGroupsSequence, not per-frame + self.assertFalse( + hasattr(frame_item, "PlaneOrientationSequence"), + f"Frame {frame_idx} should not have PlaneOrientationSequence", + ) + + # Verify ImageOrientationPatient in SharedFunctionalGroupsSequence matches original + shared_fg = ds_multiframe.SharedFunctionalGroupsSequence[0] + self.assertTrue( + hasattr(shared_fg, "PlaneOrientationSequence"), + "SharedFunctionalGroupsSequence should have PlaneOrientationSequence", + ) + plane_orient = shared_fg.PlaneOrientationSequence[0] + multiframe_iop = np.array([float(x) for x in plane_orient.ImageOrientationPatient]) + original_iop = np.array([float(x) for x in original_datasets[0].ImageOrientationPatient]) # Use first frame + + try: + np.testing.assert_array_almost_equal( + multiframe_iop, + original_iop, + decimal=6, + err_msg="SharedFunctionalGroupsSequence ImageOrientationPatient should match original", # Remove frame_idx + ) + except AssertionError as e: + mismatches.append(f"Shared orientation: {e}") # Remove frame_idx reference + + # Report any mismatches + if mismatches: + self.fail(f"Per-frame metadata mismatches:\n" + "\n".join(mismatches)) + + print(f"✓ All {num_frames} frames have metadata matching original files") + + # Verify frame ordering (first and last frame positions) + first_frame_pos = per_frame_seq[0].PlanePositionSequence[0].ImagePositionPatient + last_frame_pos = per_frame_seq[-1].PlanePositionSequence[0].ImagePositionPatient + + first_original_pos = original_datasets[0].ImagePositionPatient + last_original_pos = original_datasets[-1].ImagePositionPatient + + print(f"\nFrame ordering verification:") + print(f" First frame Z: {first_frame_pos[2]} (original: {first_original_pos[2]})") + print(f" Last frame Z: {last_frame_pos[2]} (original: {last_original_pos[2]})") + + # Verify positions match originals + self.assertAlmostEqual( + float(first_frame_pos[2]), + float(first_original_pos[2]), + places=6, + msg="First frame Z should match first original", + ) + self.assertAlmostEqual( + float(last_frame_pos[2]), + float(last_original_pos[2]), + places=6, + msg="Last frame Z should match last original", + ) + print(f"✓ Frame ordering matches original files") + + print(f"\n✓ Multi-frame metadata test passed - all metadata preserved correctly!") + + finally: + # Clean up + import shutil + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + def test_transcode_dicom_to_htj2k_multiframe_lossless(self): + """Test that multi-frame HTJ2K transcoding is lossless.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Use a specific series from dicomweb + dicom_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + # Load original files + source_files = sorted(list(Path(dicom_dir).glob("*.dcm"))) + if not source_files: + source_files = sorted([f for f in Path(dicom_dir).iterdir() if f.is_file()]) + + print(f"\nLoading {len(source_files)} original DICOM files...") + + # Read original pixel data and sort by ImagePositionPatient Z-coordinate + original_frames = [] + for source_file in source_files: + ds = pydicom.dcmread(str(source_file)) + z_pos = float(ds.ImagePositionPatient[2]) if hasattr(ds, "ImagePositionPatient") else 0 + original_frames.append((z_pos, ds.pixel_array.copy())) + + # Sort by Z position (same as convert_single_frame_dicom_series_to_multiframe does) + original_frames.sort(key=lambda x: x[0]) + original_pixel_stack = np.stack([frame for _, frame in original_frames], axis=0) + + print(f"✓ Original pixel data loaded: {original_pixel_stack.shape}") + + # Create temporary output directory + output_dir = tempfile.mkdtemp(prefix="htj2k_multiframe_lossless_") + + try: + # Transcode to multi-frame HTJ2K + print(f"\nTranscoding to multi-frame HTJ2K...") + result_dir = convert_single_frame_dicom_series_to_multiframe( + input_dir=dicom_dir, + output_dir=output_dir, + convert_to_htj2k=True, + ) + + # Find the multi-frame file + multiframe_files = list(Path(output_dir).rglob("*.dcm")) + self.assertEqual(len(multiframe_files), 1, "Should have one multi-frame file") + + # Load the multi-frame file + ds_multiframe = pydicom.dcmread(str(multiframe_files[0])) + multiframe_pixels = ds_multiframe.pixel_array + + print(f"✓ Multi-frame pixel data loaded: {multiframe_pixels.shape}") + + # Verify shapes match + self.assertEqual( + multiframe_pixels.shape, + original_pixel_stack.shape, + "Multi-frame shape should match original stacked shape", + ) + + # Verify pixel values are identical (lossless) + print(f"\nVerifying lossless transcoding...") + np.testing.assert_array_equal( + original_pixel_stack, + multiframe_pixels, + err_msg="Multi-frame pixel values should be identical to original (lossless)", + ) + + print(f"✓ All {len(source_files)} frames are identical (lossless compression verified)") + + # Verify each frame individually + for frame_idx in range(len(source_files)): + np.testing.assert_array_equal( + original_pixel_stack[frame_idx], + multiframe_pixels[frame_idx], + err_msg=f"Frame {frame_idx} should be identical", + ) + + print(f"✓ Individual frame verification passed for all {len(source_files)} frames") + + print(f"\n✓ Lossless multi-frame HTJ2K transcoding test passed!") + + finally: + # Clean up + import shutil + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + def test_transcode_dicom_to_htj2k_multiframe_nifti_consistency(self): + """Test that multi-frame HTJ2K produces same NIfTI output as original series.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Use a specific series from dicomweb + dicom_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + print(f"\nConverting original DICOM series to NIfTI...") + nifti_from_original = dicom_to_nifti(dicom_dir) + + # Create temporary output directory for multi-frame + output_dir = tempfile.mkdtemp(prefix="htj2k_multiframe_nifti_") + + try: + # Transcode to multi-frame HTJ2K + print(f"\nTranscoding to multi-frame HTJ2K...") + result_dir = convert_single_frame_dicom_series_to_multiframe( + input_dir=dicom_dir, + output_dir=output_dir, + convert_to_htj2k=True, + ) + + # Find the multi-frame file + multiframe_files = list(Path(output_dir).rglob("*.dcm")) + self.assertEqual(len(multiframe_files), 1, "Should have one multi-frame file") + multiframe_dir = multiframe_files[0].parent + + # Convert multi-frame to NIfTI + print(f"\nConverting multi-frame HTJ2K to NIfTI...") + nifti_from_multiframe = dicom_to_nifti(str(multiframe_dir)) + + # Load both NIfTI files + data_original = LoadImage(image_only=True)(nifti_from_original) + data_multiframe = LoadImage(image_only=True)(nifti_from_multiframe) + + print(f"\nComparing NIfTI outputs...") + print(f" Original shape: {data_original.shape}") + print(f" Multi-frame shape: {data_multiframe.shape}") + + # Verify shapes match + self.assertEqual( + data_original.shape, data_multiframe.shape, "Original and multi-frame should produce same NIfTI shape" + ) + + # Verify data types match + self.assertEqual( + data_original.dtype, + data_multiframe.dtype, + "Original and multi-frame should produce same NIfTI data type", + ) + + # Verify pixel values are identical + np.testing.assert_array_equal( + data_original, + data_multiframe, + err_msg="Original and multi-frame should produce identical NIfTI pixel values", + ) + + print(f"✓ NIfTI outputs are identical") + print(f" Shape: {data_original.shape}") + print(f" Data type: {data_original.dtype}") + print(f" Pixel values: Identical") + + print(f"\n✓ Multi-frame HTJ2K NIfTI consistency test passed!") + + finally: + # Clean up + import shutil + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + if os.path.exists(nifti_from_original): + os.unlink(nifti_from_original) + if os.path.exists(nifti_from_multiframe): + os.unlink(nifti_from_multiframe) + + def test_default_progression_order(self): + """Test that the default progression order is RPCL.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + source_file = str(examples.get_path("ct")) + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_default_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_default_output_") + + try: + test_filename = "ct_small.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Transcode WITHOUT specifying progression_order (should default to RPCL) + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k(file_loader=file_loader) + + # Read transcoded + transcoded_file = os.path.join(output_dir, test_filename) + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_ts = str(ds_transcoded.file_meta.TransferSyntaxUID) + + # Default should be RPCL which uses transfer syntax 1.2.840.10008.1.2.4.202 + expected_ts = "1.2.840.10008.1.2.4.202" + self.assertEqual( + transcoded_ts, + expected_ts, + f"Default progression order should produce transfer syntax {expected_ts} (RPCL)", + ) + print(f"✓ Default progression order is RPCL (transfer syntax: {transcoded_ts})") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_progression_order_options(self): + """Test that all 5 progression orders work correctly with grayscale images.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + source_file = str(examples.get_path("ct")) + + # Test all 5 progression orders + progression_orders = [ + ("LRCP", "1.2.840.10008.1.2.4.201"), # HTJ2K (Lossless Only) - quality scalability + ("RLCP", "1.2.840.10008.1.2.4.201"), # HTJ2K (Lossless Only) - resolution scalability + ("RPCL", "1.2.840.10008.1.2.4.202"), # HTJ2K with RPCL Options - progressive by resolution + ("PCRL", "1.2.840.10008.1.2.4.201"), # HTJ2K (Lossless Only) - progressive by spatial area + ("CPRL", "1.2.840.10008.1.2.4.201"), # HTJ2K (Lossless Only) - component scalability + ] + + for prog_order, expected_ts in progression_orders: + with self.subTest(progression_order=prog_order): + print(f"\nTesting progression_order={prog_order}") + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix=f"htj2k_{prog_order.lower()}_input_") + output_dir = tempfile.mkdtemp(prefix=f"htj2k_{prog_order.lower()}_output_") + + try: + test_filename = "ct_small.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + + # Transcode with specific progression order + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k(file_loader=file_loader, progression_order=prog_order) + + # Read transcoded + transcoded_file = os.path.join(output_dir, test_filename) + self.assertTrue(os.path.exists(transcoded_file)) + + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_pixels = ds_transcoded.pixel_array + transcoded_ts = str(ds_transcoded.file_meta.TransferSyntaxUID) + + print(f" Transfer Syntax: {transcoded_ts}") + print(f" Expected: {expected_ts}") + + # Verify correct transfer syntax for progression order + self.assertEqual( + transcoded_ts, expected_ts, f"Transfer syntax should be {expected_ts} for {prog_order}" + ) + + # Verify lossless (grayscale should be exact) + np.testing.assert_array_equal(original_pixels, transcoded_pixels) + print(f"✓ {prog_order} progression order works correctly") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_progression_order_with_ybr_color(self): + """Test progression orders work correctly with YBR color space conversion.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import pydicom.data + + try: + source_file = pydicom.data.get_testdata_file("examples_ybr_color.dcm") + except Exception as e: + self.skipTest(f"Could not load pydicom test data: {e}") + + # Test a subset of progression orders with color images + # (testing all 5 would take too long, so we test RPCL, LRCP, and RLCP) + progression_orders = [ + ("RPCL", "1.2.840.10008.1.2.4.202"), # Default + ("LRCP", "1.2.840.10008.1.2.4.201"), # Quality scalability + ("RLCP", "1.2.840.10008.1.2.4.201"), # Resolution scalability + ] + + for prog_order, expected_ts in progression_orders: + with self.subTest(progression_order=prog_order): + print(f"\nTesting YBR color with progression_order={prog_order}") + + import shutil + + input_dir = tempfile.mkdtemp(prefix=f"htj2k_ybr_{prog_order.lower()}_input_") + output_dir = tempfile.mkdtemp(prefix=f"htj2k_ybr_{prog_order.lower()}_output_") + + try: + test_filename = "ybr_color.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + original_pi = ds_original.PhotometricInterpretation + + # Transcode with specific progression order + # Override default skip list to force transcoding of JPEG files + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader, + progression_order=prog_order, + skip_transfer_syntaxes=None, # Override default to test JPEG transcoding + ) + + # Read transcoded + transcoded_file = os.path.join(output_dir, test_filename) + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_pixels = ds_transcoded.pixel_array + transcoded_ts = str(ds_transcoded.file_meta.TransferSyntaxUID) + + print(f" Original PI: {original_pi}") + print(f" Transcoded PI: {ds_transcoded.PhotometricInterpretation}") + print(f" Transfer Syntax: {transcoded_ts}") + + # Verify transfer syntax matches progression order + self.assertEqual(transcoded_ts, expected_ts) + + # Verify PhotometricInterpretation was updated to RGB (from YBR) + self.assertEqual(ds_transcoded.PhotometricInterpretation, "RGB") + + # Verify pixels match within tolerance (color conversion) + max_diff = np.abs(original_pixels.astype(np.float32) - transcoded_pixels.astype(np.float32)).max() + self.assertTrue( + np.allclose(original_pixels, transcoded_pixels, atol=5, rtol=0), + f"Pixels should match within tolerance, max_diff={max_diff}", + ) + print(f"✓ {prog_order} works with YBR color conversion (max_diff={max_diff})") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_progression_order_with_rgb_color(self): + """Test progression orders work correctly with RGB color images.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + source_file = str(examples.get_path("rgb_color")) + + # Test a subset of progression orders with RGB images + progression_orders = [ + ("RPCL", "1.2.840.10008.1.2.4.202"), + ("PCRL", "1.2.840.10008.1.2.4.201"), # Position-Component (good for spatial access) + ("CPRL", "1.2.840.10008.1.2.4.201"), # Component-Position (good for component access) + ] + + for prog_order, expected_ts in progression_orders: + with self.subTest(progression_order=prog_order): + print(f"\nTesting RGB color with progression_order={prog_order}") + + input_dir = tempfile.mkdtemp(prefix=f"htj2k_rgb_{prog_order.lower()}_input_") + output_dir = tempfile.mkdtemp(prefix=f"htj2k_rgb_{prog_order.lower()}_output_") + + try: + test_filename = "rgb_color.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + + # Transcode with specific progression order + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k(file_loader=file_loader, progression_order=prog_order) + + # Read transcoded + transcoded_file = os.path.join(output_dir, test_filename) + ds_transcoded = pydicom.dcmread(transcoded_file) + transcoded_pixels = ds_transcoded.pixel_array + transcoded_ts = str(ds_transcoded.file_meta.TransferSyntaxUID) + + # Verify transfer syntax matches progression order + self.assertEqual(transcoded_ts, expected_ts) + + # Verify PhotometricInterpretation stays RGB + self.assertEqual(ds_transcoded.PhotometricInterpretation, "RGB") + + # Verify lossless (RGB uncompressed should be exact) + np.testing.assert_array_equal(original_pixels, transcoded_pixels) + print(f"✓ {prog_order} works with RGB color images (lossless)") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_invalid_progression_order(self): + """Test that invalid progression orders raise ValueError.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + source_file = str(examples.get_path("ct")) + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_invalid_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_invalid_output_") + + try: + test_filename = "ct_small.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Test various invalid progression orders (lowercase, mixed case, or completely invalid) + invalid_orders = ["invalid", "rpcl", "lrcp", "rlcp", "pcrl", "cprl", "ABCD", ""] + + for invalid_order in invalid_orders: + with self.subTest(invalid_progression_order=invalid_order): + print(f"\nTesting invalid progression_order={repr(invalid_order)}") + + with self.assertRaises(ValueError) as context: + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k(file_loader=file_loader, progression_order=invalid_order) + + # Verify error message is helpful and lists all valid options + error_msg = str(context.exception) + self.assertIn("progression_order", error_msg.lower()) + # Check that all valid progression orders are mentioned in the error + for valid_order in ["LRCP", "RLCP", "RPCL", "PCRL", "CPRL"]: + self.assertIn(valid_order, error_msg) + print(f"✓ Correctly raised ValueError: {error_msg}") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_progression_order_multiframe_conversion(self): + """Test progression orders work with multiframe conversion.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + # Use a specific series from dicomweb + dicom_dir = os.path.join( + self.base_dir, + "data", + "dataset", + "dicomweb", + "e7567e0a064f0c334226a0658de23afd", + "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721", + ) + + # Test all 5 progression orders + progression_orders = [ + ("LRCP", "1.2.840.10008.1.2.4.201"), + ("RLCP", "1.2.840.10008.1.2.4.201"), + ("RPCL", "1.2.840.10008.1.2.4.202"), + ("PCRL", "1.2.840.10008.1.2.4.201"), + ("CPRL", "1.2.840.10008.1.2.4.201"), + ] + + for prog_order, expected_ts in progression_orders: + with self.subTest(progression_order=prog_order): + print(f"\nTesting multiframe conversion with progression_order={prog_order}") + + output_dir = tempfile.mkdtemp(prefix=f"htj2k_multiframe_{prog_order.lower()}_") + + try: + # Convert to multiframe with specific progression order + result_dir = convert_single_frame_dicom_series_to_multiframe( + input_dir=dicom_dir, output_dir=output_dir, convert_to_htj2k=True, progression_order=prog_order + ) + + # Find the multi-frame file + multiframe_files = list(Path(output_dir).rglob("*.dcm")) + self.assertEqual(len(multiframe_files), 1, "Should have one multi-frame file") + + # Load and verify + ds_multiframe = pydicom.dcmread(str(multiframe_files[0])) + transcoded_ts = str(ds_multiframe.file_meta.TransferSyntaxUID) + + print(f" Transfer Syntax: {transcoded_ts}") + print(f" Expected: {expected_ts}") + + # Verify correct transfer syntax + self.assertEqual( + transcoded_ts, expected_ts, f"Transfer syntax should be {expected_ts} for {prog_order}" + ) + + print(f"✓ Multiframe conversion with {prog_order} works correctly") + + finally: + import shutil + + if os.path.exists(output_dir): + shutil.rmtree(output_dir) + + def test_skip_transfer_syntaxes_htj2k(self): + """Test skipping files with HTJ2K transfer syntax (copy instead of transcode) - using default skip list.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + import time + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_skip_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_skip_output_") + intermediate_dir = tempfile.mkdtemp(prefix="htj2k_intermediate_") + + try: + # First, create an HTJ2K file by transcoding a test file + import pydicom.examples as examples + + source_file = str(examples.get_path("ct")) + test_filename = "ct_htj2k.dcm" + + # Copy to intermediate directory + shutil.copy2(source_file, os.path.join(intermediate_dir, test_filename)) + + # Transcode to HTJ2K (disable default skip list to force transcoding) + print("\nStep 1: Creating HTJ2K file...") + htj2k_dir = tempfile.mkdtemp(prefix="htj2k_created_") + file_loader = DicomFileLoader(intermediate_dir, htj2k_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader, + progression_order="RPCL", + skip_transfer_syntaxes=None, # Override default to force transcoding + ) + + # Copy the HTJ2K file to input directory + htj2k_file = os.path.join(htj2k_dir, test_filename) + shutil.copy2(htj2k_file, os.path.join(input_dir, test_filename)) + + # Read the HTJ2K file to get its transfer syntax and checksum + ds_htj2k = pydicom.dcmread(htj2k_file) + htj2k_ts = str(ds_htj2k.file_meta.TransferSyntaxUID) + print(f"HTJ2K Transfer Syntax: {htj2k_ts}") + + # Calculate checksum of original HTJ2K file + import hashlib + + with open(os.path.join(input_dir, test_filename), "rb") as f: + original_checksum = hashlib.md5(f.read()).hexdigest() + original_size = os.path.getsize(os.path.join(input_dir, test_filename)) + original_mtime = os.path.getmtime(os.path.join(input_dir, test_filename)) + + print(f"\nStep 2: Testing default skip functionality (HTJ2K should be skipped by default)...") + print(f" Original file size: {original_size:,} bytes") + print(f" Original checksum: {original_checksum}") + + # Sleep briefly to ensure timestamps differ if file is modified + time.sleep(0.1) + + # Now transcode with DEFAULT skip_transfer_syntaxes (should skip HTJ2K by default) + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader + # Note: NOT passing skip_transfer_syntaxes, using default which includes HTJ2K + ) + + # Verify output file exists + output_file = os.path.join(output_dir, test_filename) + self.assertTrue(os.path.exists(output_file), "Output file should exist") + + # Calculate checksum of output file + with open(output_file, "rb") as f: + output_checksum = hashlib.md5(f.read()).hexdigest() + output_size = os.path.getsize(output_file) + output_mtime = os.path.getmtime(output_file) + + print(f"\nStep 3: Verifying file was copied, not transcoded...") + print(f" Output file size: {output_size:,} bytes") + print(f" Output checksum: {output_checksum}") + + # Verify file is identical (not re-encoded) + self.assertEqual(original_checksum, output_checksum, "File should be identical (copied, not transcoded)") + self.assertEqual(original_size, output_size, "File size should be identical") + + # Verify transfer syntax is still HTJ2K + ds_output = pydicom.dcmread(output_file) + self.assertEqual( + str(ds_output.file_meta.TransferSyntaxUID), htj2k_ts, "Transfer syntax should be preserved" + ) + + print(f"✓ File was copied without transcoding (HTJ2K skipped by default)") + print(f"✓ Transfer syntax preserved: {htj2k_ts}") + print(f"✓ Default behavior correctly skips HTJ2K files") + + finally: + # Clean up + for temp_dir in [input_dir, output_dir, intermediate_dir]: + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir, ignore_errors=True) + if "htj2k_dir" in locals() and os.path.exists(htj2k_dir): + shutil.rmtree(htj2k_dir, ignore_errors=True) + + def test_skip_transfer_syntaxes_mixed_batch(self): + """Test mixed batch with some files to skip and some to transcode - using default skip list.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import hashlib + import shutil + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_mixed_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_mixed_output_") + + try: + # Create two test files: one uncompressed CT, one uncompressed MR + import pydicom.examples as examples + + ct_source = str(examples.get_path("ct")) + mr_source = str(examples.get_path("mr")) + + ct_filename = "ct_uncompressed.dcm" + mr_filename = "mr_uncompressed.dcm" + + # Copy both to input + shutil.copy2(ct_source, os.path.join(input_dir, ct_filename)) + shutil.copy2(mr_source, os.path.join(input_dir, mr_filename)) + + # First pass: transcode CT to HTJ2K with LRCP, keep MR uncompressed + print("\nStep 1: Creating HTJ2K file with LRCP progression order...") + first_pass_dir = tempfile.mkdtemp(prefix="htj2k_first_pass_") + file_loader_first = DicomFileLoader(input_dir, first_pass_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader_first, + progression_order="LRCP", # This will create 1.2.840.10008.1.2.4.201 + skip_transfer_syntaxes=None, # Override default to force transcoding + ) + + # Now use the CT HTJ2K file and MR uncompressed for the mixed test + input_dir2 = tempfile.mkdtemp(prefix="htj2k_mixed2_input_") + + # Copy HTJ2K CT to new input + htj2k_ct_file = os.path.join(first_pass_dir, ct_filename) + shutil.copy2(htj2k_ct_file, os.path.join(input_dir2, ct_filename)) + + # Copy uncompressed MR to new input + shutil.copy2(mr_source, os.path.join(input_dir2, mr_filename)) + + # Get checksums before processing + with open(os.path.join(input_dir2, ct_filename), "rb") as f: + ct_original_checksum = hashlib.md5(f.read()).hexdigest() + + ds_ct = pydicom.dcmread(os.path.join(input_dir2, ct_filename)) + ct_ts = str(ds_ct.file_meta.TransferSyntaxUID) + + ds_mr_orig = pydicom.dcmread(os.path.join(input_dir2, mr_filename)) + mr_ts_orig = str(ds_mr_orig.file_meta.TransferSyntaxUID) + mr_pixels_orig = ds_mr_orig.pixel_array.copy() + + print(f"\nStep 2: Processing mixed batch with DEFAULT skip list...") + print(f" CT file: {ct_filename} (Transfer Syntax: {ct_ts}) - SKIP (by default)") + print(f" MR file: {mr_filename} (Transfer Syntax: {mr_ts_orig}) - TRANSCODE") + + # Transcode with DEFAULT skip list (HTJ2K files will be skipped by default) + file_loader = DicomFileLoader(input_dir2, output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader, + # Using default skip list which includes HTJ2K formats + progression_order="RPCL", # Will use 1.2.840.10008.1.2.4.202 for transcoded files + ) + + print("\nStep 3: Verifying results...") + + # Verify CT file was copied (not transcoded) + ct_output = os.path.join(output_dir, ct_filename) + self.assertTrue(os.path.exists(ct_output), "CT output should exist") + + with open(ct_output, "rb") as f: + ct_output_checksum = hashlib.md5(f.read()).hexdigest() + + self.assertEqual( + ct_original_checksum, ct_output_checksum, "CT file should be identical (copied, not transcoded)" + ) + + ds_ct_output = pydicom.dcmread(ct_output) + self.assertEqual( + str(ds_ct_output.file_meta.TransferSyntaxUID), ct_ts, "CT transfer syntax should be preserved (LRCP)" + ) + print(f"✓ CT file copied without transcoding (HTJ2K skipped by default: {ct_ts})") + + # Verify MR file was transcoded to RPCL HTJ2K + mr_output = os.path.join(output_dir, mr_filename) + self.assertTrue(os.path.exists(mr_output), "MR output should exist") + + ds_mr_output = pydicom.dcmread(mr_output) + mr_ts_output = str(ds_mr_output.file_meta.TransferSyntaxUID) + + self.assertEqual(mr_ts_output, "1.2.840.10008.1.2.4.202", "MR should be transcoded to RPCL HTJ2K") + + # Verify MR pixels are lossless + mr_pixels_output = ds_mr_output.pixel_array + np.testing.assert_array_equal(mr_pixels_orig, mr_pixels_output, err_msg="MR pixels should be lossless") + print(f"✓ MR file transcoded to HTJ2K RPCL ({mr_ts_output})") + print(f"✓ MR pixels are lossless") + print(f"✓ Mixed batch correctly handles default skip list") + + finally: + # Clean up + for temp_dir in [input_dir, output_dir]: + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir, ignore_errors=True) + if "first_pass_dir" in locals() and os.path.exists(first_pass_dir): + shutil.rmtree(first_pass_dir, ignore_errors=True) + if "input_dir2" in locals() and os.path.exists(input_dir2): + shutil.rmtree(input_dir2, ignore_errors=True) + + def test_skip_transfer_syntaxes_multiple(self): + """Test skipping multiple transfer syntax UIDs - using default skip list.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import hashlib + import shutil + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_multi_skip_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_multi_skip_output_") + + try: + import pydicom.examples as examples + + ct_source = str(examples.get_path("ct")) + + # Create two HTJ2K files with different progression orders + file1_name = "file_lrcp.dcm" + file2_name = "file_rpcl.dcm" + + # Create LRCP HTJ2K file + print("\nStep 1: Creating LRCP HTJ2K file...") + temp_dir1 = tempfile.mkdtemp(prefix="htj2k_temp1_") + shutil.copy2(ct_source, os.path.join(temp_dir1, file1_name)) + htj2k_dir1 = tempfile.mkdtemp(prefix="htj2k_lrcp_") + file_loader1 = DicomFileLoader(temp_dir1, htj2k_dir1) + transcode_dicom_to_htj2k( + file_loader=file_loader1, + progression_order="LRCP", + skip_transfer_syntaxes=None, # Override default to force transcoding + ) + shutil.copy2(os.path.join(htj2k_dir1, file1_name), os.path.join(input_dir, file1_name)) + + # Create RPCL HTJ2K file + print("Step 2: Creating RPCL HTJ2K file...") + temp_dir2 = tempfile.mkdtemp(prefix="htj2k_temp2_") + shutil.copy2(ct_source, os.path.join(temp_dir2, file2_name)) + htj2k_dir2 = tempfile.mkdtemp(prefix="htj2k_rpcl_") + file_loader2 = DicomFileLoader(temp_dir2, htj2k_dir2) + transcode_dicom_to_htj2k( + file_loader=file_loader2, + progression_order="RPCL", + skip_transfer_syntaxes=None, # Override default to force transcoding + ) + shutil.copy2(os.path.join(htj2k_dir2, file2_name), os.path.join(input_dir, file2_name)) + + # Get checksums + with open(os.path.join(input_dir, file1_name), "rb") as f: + checksum1 = hashlib.md5(f.read()).hexdigest() + with open(os.path.join(input_dir, file2_name), "rb") as f: + checksum2 = hashlib.md5(f.read()).hexdigest() + + ds1 = pydicom.dcmread(os.path.join(input_dir, file1_name)) + ds2 = pydicom.dcmread(os.path.join(input_dir, file2_name)) + + ts1 = str(ds1.file_meta.TransferSyntaxUID) + ts2 = str(ds2.file_meta.TransferSyntaxUID) + + print(f"\nStep 3: Processing with DEFAULT skip list (both HTJ2K formats should be skipped)...") + print(f" File 1: {file1_name} - {ts1}") + print(f" File 2: {file2_name} - {ts2}") + + # Use DEFAULT skip list (includes all HTJ2K transfer syntaxes) + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader + # Using default skip list which includes all HTJ2K formats + ) + + print("\nStep 4: Verifying both files were copied...") + + # Verify both files were copied + output1 = os.path.join(output_dir, file1_name) + output2 = os.path.join(output_dir, file2_name) + + self.assertTrue(os.path.exists(output1), "File 1 should exist") + self.assertTrue(os.path.exists(output2), "File 2 should exist") + + # Verify checksums match (files were copied, not transcoded) + with open(output1, "rb") as f: + output_checksum1 = hashlib.md5(f.read()).hexdigest() + with open(output2, "rb") as f: + output_checksum2 = hashlib.md5(f.read()).hexdigest() + + self.assertEqual(checksum1, output_checksum1, "File 1 should be identical") + self.assertEqual(checksum2, output_checksum2, "File 2 should be identical") + + # Verify transfer syntaxes preserved + ds_out1 = pydicom.dcmread(output1) + ds_out2 = pydicom.dcmread(output2) + + self.assertEqual(str(ds_out1.file_meta.TransferSyntaxUID), ts1) + self.assertEqual(str(ds_out2.file_meta.TransferSyntaxUID), ts2) + + print(f"✓ Both files copied without transcoding (HTJ2K skipped by default)") + print(f"✓ File 1 preserved: {ts1}") + print(f"✓ File 2 preserved: {ts2}") + print(f"✓ Default skip list correctly handles multiple HTJ2K formats") + + finally: + # Clean up + for temp_dir in [input_dir, output_dir]: + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir, ignore_errors=True) + for var in ["temp_dir1", "temp_dir2", "htj2k_dir1", "htj2k_dir2"]: + if var in locals() and os.path.exists(locals()[var]): + shutil.rmtree(locals()[var], ignore_errors=True) + + def test_skip_transfer_syntaxes_override_to_transcode_all(self): + """Test that explicitly overriding skip list to None/empty transcodes all files (overrides default).""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + import shutil + + import pydicom.examples as examples + + input_dir = tempfile.mkdtemp(prefix="htj2k_override_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_override_output_") + + try: + source_file = str(examples.get_path("ct")) + test_filename = "ct_test.dcm" + shutil.copy2(source_file, os.path.join(input_dir, test_filename)) + + # Read original + ds_original = pydicom.dcmread(source_file) + original_pixels = ds_original.pixel_array.copy() + original_ts = str(ds_original.file_meta.TransferSyntaxUID) + + print(f"\nOriginal transfer syntax: {original_ts}") + print(f"Testing override of default skip list to force transcoding...") + + # Transcode with None (override default skip list to force transcoding) + file_loader = DicomFileLoader(input_dir, output_dir) + transcode_dicom_to_htj2k( + file_loader=file_loader, skip_transfer_syntaxes=None # Override default to force transcoding + ) + + # Verify file was transcoded + output_file = os.path.join(output_dir, test_filename) + self.assertTrue(os.path.exists(output_file)) + + ds_output = pydicom.dcmread(output_file) + output_ts = str(ds_output.file_meta.TransferSyntaxUID) + output_pixels = ds_output.pixel_array + + print(f"Output transfer syntax: {output_ts}") + + # Should be transcoded to HTJ2K + self.assertIn(output_ts, HTJ2K_TRANSFER_SYNTAXES) + self.assertNotEqual(original_ts, output_ts, "Transfer syntax should have changed") + + # Pixels should still be lossless + np.testing.assert_array_equal(original_pixels, output_pixels) + + print("✓ Override with None successfully forces transcoding (bypasses default skip list)") + + finally: + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_transcode_with_pytorch_dataloader(self): + """Test transcoding using PyTorch DataLoader as file_loader.""" + if not HAS_NVIMGCODEC: + self.skipTest("nvimgcodec not available") + + try: + import torch + from torch.utils.data import DataLoader, Dataset + except ImportError: + self.skipTest("PyTorch not available") + + import shutil + + # Create temp directories + input_dir = tempfile.mkdtemp(prefix="htj2k_pytorch_input_") + output_dir = tempfile.mkdtemp(prefix="htj2k_pytorch_output_") + + try: + # Copy test files + source_dir = os.path.join( + self.dicom_dataset, "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721" + ) + + # Copy a subset of files for testing + test_files = [] + for i, filename in enumerate(sorted(os.listdir(source_dir))): + if i >= 5: # Only copy 5 files for this test + break + src = os.path.join(source_dir, filename) + dst = os.path.join(input_dir, filename) + shutil.copy2(src, dst) + test_files.append(filename) + + print(f"\nTesting with PyTorch DataLoader using {len(test_files)} files") + + # Create a custom Dataset that yields (input_paths, output_paths) tuples + class DicomFileDataset(Dataset): + """Custom Dataset that yields batches of (input_paths, output_paths).""" + + def __init__(self, input_dir, output_dir, files): + self.input_dir = input_dir + self.output_dir = output_dir + self.files = files + + def __len__(self): + return len(self.files) + + def __getitem__(self, idx): + """Return a single file path tuple.""" + filename = self.files[idx] + input_path = os.path.join(self.input_dir, filename) + output_path = os.path.join(self.output_dir, filename) + return input_path, output_path + + # Custom collate function to group paths into batches + def collate_paths(batch): + """Collate function that returns (batch_input_paths, batch_output_paths).""" + input_paths = [item[0] for item in batch] + output_paths = [item[1] for item in batch] + return input_paths, output_paths + + # Create Dataset and DataLoader + dataset = DicomFileDataset(input_dir, output_dir, test_files) + dataloader = DataLoader( + dataset, + batch_size=2, # Process 2 files per batch + shuffle=False, + collate_fn=collate_paths, + num_workers=0, # Use 0 for compatibility in tests + ) + + print(f"Created PyTorch DataLoader with batch_size=2") + print(f"Number of batches: {len(dataloader)}") + + # Read original files to verify later + original_data = {} + for filename in test_files: + filepath = os.path.join(input_dir, filename) + ds = pydicom.dcmread(filepath) + original_data[filename] = { + "pixels": ds.pixel_array.copy(), + "transfer_syntax": ds.file_meta.TransferSyntaxUID, + } + + # Run transcoding with PyTorch DataLoader + transcode_dicom_to_htj2k( + file_loader=dataloader, + num_resolutions=6, + code_block_size=(64, 64), + progression_order="RPCL", + max_batch_size=256, + add_basic_offset_table=True, + ) + + print(f"✓ Transcoding completed, output_dir: {output_dir}") + + # Verify all files were transcoded + output_files = os.listdir(output_dir) + self.assertEqual(len(output_files), len(test_files)) + print(f"✓ All {len(test_files)} files were processed") + + # Verify transcoding was correct + for filename in test_files: + output_path = os.path.join(output_dir, filename) + self.assertTrue(os.path.exists(output_path), f"Output file {filename} should exist") + + # Read transcoded file + ds_transcoded = pydicom.dcmread(output_path) + transcoded_pixels = ds_transcoded.pixel_array + + # Verify transfer syntax changed to HTJ2K + transcoded_ts = ds_transcoded.file_meta.TransferSyntaxUID + self.assertIn(str(transcoded_ts), HTJ2K_TRANSFER_SYNTAXES) + + # Verify pixels are identical (lossless) + original_pixels = original_data[filename]["pixels"] + np.testing.assert_array_equal( + original_pixels, transcoded_pixels, err_msg=f"Pixels should match for {filename}" + ) + + print(f"✓ All files transcoded to HTJ2K with lossless compression") + print(f"✓ PyTorch DataLoader test passed!") + + finally: + # Clean up + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + def test_convert_multiframe_handles_missing_pixeldata(self): + """Test that convert_single_frame_dicom_series_to_multiframe handles datasets without PixelData.""" + if not HAS_NVIMGCODEC: + self.skipTest( + "nvimgcodec not available. Install nvidia-nvimgcodec-cu{XX} matching your CUDA version (e.g., nvidia-nvimgcodec-cu13 for CUDA 13.x)" + ) + + # Create temporary directory with mixed DICOM files + input_dir = tempfile.mkdtemp(prefix="test_missing_pixeldata_") + output_dir = tempfile.mkdtemp(prefix="test_missing_pixeldata_output_") + + try: + # Create a series with some files having PixelData and some without + study_uid = pydicom.uid.generate_uid() + series_uid = pydicom.uid.generate_uid() + + print(f"\nCreating test series with mixed PixelData presence...") + + # Create 3 valid DICOM files with PixelData + valid_files = [] + for i in range(3): + ds = pydicom.Dataset() + ds.StudyInstanceUID = study_uid + ds.SeriesInstanceUID = series_uid + ds.SOPInstanceUID = pydicom.uid.generate_uid() + ds.SOPClassUID = "1.2.840.10008.5.1.4.1.1.2" # CT Image Storage + ds.InstanceNumber = i + 1 + ds.Modality = "CT" + ds.PatientName = "Test^Patient" + ds.PatientID = "12345" + + # Add spatial metadata + ds.ImagePositionPatient = [0.0, 0.0, float(i * 2.5)] + ds.ImageOrientationPatient = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0] + ds.PixelSpacing = [0.5, 0.5] + ds.SliceThickness = 2.5 + + # Add image data + ds.Rows = 64 + ds.Columns = 64 + ds.SamplesPerPixel = 1 + ds.PhotometricInterpretation = "MONOCHROME2" + ds.BitsAllocated = 16 + ds.BitsStored = 16 + ds.HighBit = 15 + ds.PixelRepresentation = 0 + + # Create pixel data + pixel_array = np.random.randint(0, 1000, (64, 64), dtype=np.uint16) + ds.PixelData = pixel_array.tobytes() + + # Save file with proper file meta + ds.file_meta = pydicom.dataset.FileMetaDataset() + ds.file_meta.FileMetaInformationVersion = b"\x00\x01" + ds.file_meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian + ds.file_meta.MediaStorageSOPClassUID = ds.SOPClassUID + ds.file_meta.MediaStorageSOPInstanceUID = ds.SOPInstanceUID + ds.file_meta.ImplementationClassUID = pydicom.uid.PYDICOM_IMPLEMENTATION_UID + + filepath = os.path.join(input_dir, f"valid_{i:03d}.dcm") + # Use save_as which properly writes DICOM Part 10 format with preamble + ds.save_as(filepath, enforce_file_format=True) + valid_files.append(filepath) + print(f" Created valid file: {os.path.basename(filepath)}") + + # Create 2 DICOM files WITHOUT PixelData (like SR or metadata-only) + for i in range(2): + ds = pydicom.Dataset() + ds.StudyInstanceUID = study_uid + ds.SeriesInstanceUID = series_uid + ds.SOPInstanceUID = pydicom.uid.generate_uid() + ds.SOPClassUID = "1.2.840.10008.5.1.4.1.1.2" # CT Image Storage + ds.InstanceNumber = i + 10 + ds.Modality = "CT" + ds.PatientName = "Test^Patient" + ds.PatientID = "12345" + + # Add spatial metadata but NO PixelData + ds.ImagePositionPatient = [0.0, 0.0, float((i + 10) * 2.5)] + ds.ImageOrientationPatient = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0] + + # Save file with proper file meta + ds.file_meta = pydicom.dataset.FileMetaDataset() + ds.file_meta.FileMetaInformationVersion = b"\x00\x01" + ds.file_meta.TransferSyntaxUID = pydicom.uid.ExplicitVRLittleEndian + ds.file_meta.MediaStorageSOPClassUID = ds.SOPClassUID + ds.file_meta.MediaStorageSOPInstanceUID = ds.SOPInstanceUID + ds.file_meta.ImplementationClassUID = pydicom.uid.PYDICOM_IMPLEMENTATION_UID + + filepath = os.path.join(input_dir, f"no_pixel_{i:03d}.dcm") + # Use save_as which properly writes DICOM Part 10 format with preamble + ds.save_as(filepath, enforce_file_format=True) + print(f" Created file without PixelData: {os.path.basename(filepath)}") + + print(f"✓ Created {len(valid_files)} valid files and 2 files without PixelData") + + # Convert to multiframe - should skip files without PixelData + result_dir = convert_single_frame_dicom_series_to_multiframe( + input_dir=input_dir, + output_dir=output_dir, + convert_to_htj2k=True, + ) + + # Verify multiframe file was created + multiframe_files = list(Path(result_dir).rglob("*.dcm")) + self.assertEqual(len(multiframe_files), 1, "Should create one multiframe file") + print(f"✓ Created multiframe file: {multiframe_files[0]}") + + # Load and verify the multiframe file + ds_multiframe = pydicom.dcmread(str(multiframe_files[0])) + + # Should have 3 frames (only the valid files) + self.assertTrue(hasattr(ds_multiframe, "NumberOfFrames"), "Should have NumberOfFrames") + num_frames = int(ds_multiframe.NumberOfFrames) + self.assertEqual(num_frames, 3, "Should have 3 frames (files without PixelData excluded)") + print(f"✓ NumberOfFrames: {num_frames} (correctly excluded files without PixelData)") + + # Verify PerFrameFunctionalGroupsSequence has correct number of items + self.assertTrue( + hasattr(ds_multiframe, "PerFrameFunctionalGroupsSequence"), + "Should have PerFrameFunctionalGroupsSequence", + ) + per_frame_seq = ds_multiframe.PerFrameFunctionalGroupsSequence + self.assertEqual(len(per_frame_seq), 3, "Should have 3 per-frame items") + print(f"✓ PerFrameFunctionalGroupsSequence has {len(per_frame_seq)} items") + + print(f"✓ Test passed: Files without PixelData were correctly skipped") + + finally: + # Clean up + shutil.rmtree(input_dir, ignore_errors=True) + shutil.rmtree(output_dir, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/datastore/test_convert_multiframe.py b/tests/unit/datastore/test_convert_multiframe.py new file mode 100644 index 000000000..add727415 --- /dev/null +++ b/tests/unit/datastore/test_convert_multiframe.py @@ -0,0 +1,256 @@ +""" +Unit tests for convert_multiframe module. + +Tests the conversion of legacy DICOM CT, MR, and PET series to enhanced multi-frame format. +""" + +import os +import tempfile +import unittest +from pathlib import Path + +import highdicom +import pydicom + +from monailabel.datastore.utils.convert_multiframe import ( + batch_convert_by_series, + convert_and_convert_to_htj2k, + convert_to_enhanced_dicom, + validate_dicom_series, +) + + +class TestConvertMultiframe(unittest.TestCase): + """Test DICOM series conversion to enhanced multi-frame format.""" + + @classmethod + def setUpClass(cls): + """Set up test data paths.""" + cls.base_dir = Path(__file__).parent.parent.parent.parent + cls.test_data_dir = cls.base_dir / "tests" / "data" / "dataset" + + # Find available test data directories + cls.dicomweb_dir = cls.test_data_dir / "dicomweb" + cls.dicomweb_htj2k_dir = cls.test_data_dir / "dicomweb_htj2k" + + def test_01_validate_series(self): + """Test validation of a DICOM series.""" + for root, dirs, files in os.walk(self.dicomweb_dir): + if files and any(f.endswith(".dcm") for f in files): + series_dir = Path(root) + print(f"Testing validation on: {series_dir}") + + # Validate the series + is_valid = validate_dicom_series(series_dir) + print(f"Validation result: {is_valid}") + + # We may get False if the series is not CT/MR/PT or has issues + # But the test passes if no exception is raised + self.assertIsInstance(is_valid, bool) + break + + def test_02_convert_series_full(self): + """Test full conversion to enhanced multi-frame format.""" + for root, dirs, files in os.walk(self.dicomweb_dir): + if files and any(f.endswith(".dcm") for f in files): + series_dir = Path(root) + + # Check if this is a CT/MR/PT series + first_file = next((f for f in files if f.endswith(".dcm")), None) + if first_file: + try: + ds = pydicom.dcmread(Path(root) / first_file, stop_before_pixels=True) + modality = getattr(ds, "Modality", None) + + if modality not in {"CT", "MR", "PT"}: + print(f"Skipping series with modality: {modality}") + continue + + print(f"Testing full conversion on {modality} series: {series_dir}") + + # Create temporary output file + with tempfile.NamedTemporaryFile(suffix=".dcm", delete=False) as tmp: + output_file = tmp.name + + try: + # Convert to enhanced format + result = convert_to_enhanced_dicom( + input_source=series_dir, + output_file=output_file, + ) + + if result: + # Verify the output file was created + self.assertTrue(os.path.exists(output_file)) + + # Load and verify the enhanced DICOM + enhanced_ds = pydicom.dcmread(output_file) + print(f"Enhanced DICOM created:") + print(f" Modality: {enhanced_ds.Modality}") + print(f" NumberOfFrames: {getattr(enhanced_ds, 'NumberOfFrames', 'N/A')}") + print(f" SOPClassUID: {enhanced_ds.SOPClassUID}") + + # Should have NumberOfFrames attribute + self.assertTrue(hasattr(enhanced_ds, "NumberOfFrames")) + self.assertGreater(enhanced_ds.NumberOfFrames, 0) + else: + print("Conversion returned False") + + finally: + # Clean up + if os.path.exists(output_file): + os.unlink(output_file) + + # Only test one series + break + + except Exception as e: + print(f"Error processing series: {e}") + continue + + def test_03_convert_and_compress_htj2k(self): + """Test conversion to enhanced multi-frame format with HTJ2K compression.""" + # Check if nvImageCodec is available for HTJ2K + try: + from nvidia import nvimgcodec + except ImportError: + self.skipTest("nvImageCodec is not installed (required for HTJ2K)") + + for root, dirs, files in os.walk(self.dicomweb_dir): + if files and any(f.endswith(".dcm") for f in files): + series_dir = Path(root) + + # Check if this is a CT/MR/PT series + first_file = next((f for f in files if f.endswith(".dcm")), None) + if first_file: + try: + ds = pydicom.dcmread(Path(root) / first_file, stop_before_pixels=True) + modality = getattr(ds, "Modality", None) + + if modality not in {"CT", "MR", "PT"}: + print(f"Skipping series with modality: {modality}") + continue + + print(f"Testing HTJ2K conversion on {modality} series: {series_dir}") + + # Create temporary output file + with tempfile.NamedTemporaryFile(suffix="_htj2k.dcm", delete=False) as tmp: + output_file = tmp.name + + try: + # Convert to enhanced format and compress with HTJ2K + result = convert_and_convert_to_htj2k( + input_source=series_dir, + output_file=output_file, + preserve_series_uid=True, + num_resolutions=6, + progression_order="RPCL", + ) + + if result: + # Verify the output file was created + self.assertTrue(os.path.exists(output_file)) + + # Load and verify the enhanced DICOM + enhanced_ds = pydicom.dcmread(output_file) + print(f"Enhanced HTJ2K DICOM created:") + print(f" Modality: {enhanced_ds.Modality}") + print(f" NumberOfFrames: {getattr(enhanced_ds, 'NumberOfFrames', 'N/A')}") + print(f" TransferSyntaxUID: {enhanced_ds.file_meta.TransferSyntaxUID}") + print(f" File size: {os.path.getsize(output_file) / (1024*1024):.2f} MB") + + # Should have NumberOfFrames attribute + self.assertTrue(hasattr(enhanced_ds, "NumberOfFrames")) + self.assertGreater(enhanced_ds.NumberOfFrames, 0) + + # Should be HTJ2K compressed + htj2k_syntaxes = { + "1.2.840.10008.1.2.4.201", # HTJ2K Lossless Only + "1.2.840.10008.1.2.4.202", # HTJ2K with RPCL + "1.2.840.10008.1.2.4.203", # HTJ2K + } + self.assertIn( + str(enhanced_ds.file_meta.TransferSyntaxUID), + htj2k_syntaxes, + "Output should be HTJ2K compressed", + ) + else: + print("Conversion returned False") + + finally: + # Clean up + if os.path.exists(output_file): + os.unlink(output_file) + + # Only test one series + break + + except Exception as e: + print(f"Error processing series: {e}") + continue + + def test_04_batch_convert_by_series(self): + """Test batch conversion that groups files by SeriesInstanceUID.""" + # Use the dicomweb directory which may contain multiple series + if not self.dicomweb_dir.exists(): + self.skipTest("Test DICOM data not found") + + # Create a temporary output directory + with tempfile.TemporaryDirectory() as temp_output: + output_dir = Path(temp_output) + + print(f"Testing batch conversion on: {self.dicomweb_dir}") + print(f"Output directory: {output_dir}") + + try: + # Scan for DICOM files + input_files = [] + for filepath in self.dicomweb_dir.rglob("*"): + if filepath.is_file() and not filepath.name.startswith("."): + try: + pydicom.dcmread(filepath, stop_before_pixels=True) + input_files.append(str(filepath)) + except: + pass # Skip non-DICOM files + + print(f"Found {len(input_files)} DICOM files") + + # Create file_loader + file_loader = [(input_files, str(output_dir))] + + # Run batch conversion + stats = batch_convert_by_series( + file_loader=file_loader, + preserve_series_uid=True, + compress_htj2k=False, + ) + + print(f"Batch conversion results:") + print(f" Total series input: {stats.get('total_series_input', stats.get('total_series', 0))}") + print(f" Total series output: {stats.get('total_series_output', 0)}") + print(f" Converted to multiframe: {stats.get('converted_to_multiframe', stats.get('converted', 0))}") + print(f" Failed: {stats['failed']}") + + # Verify results + total_series = stats.get("total_series_input", stats.get("total_series", 0)) + self.assertGreater(total_series, 0, "Should find at least one series") + self.assertIsInstance(stats["series_info"], list) + + # Check that output files were created for successful conversions + for series_info in stats["series_info"]: + if series_info["status"] == "success": + output_file = Path(series_info["output_file"]) + self.assertTrue(output_file.exists(), f"Output file should exist: {output_file}") + + # Verify it's a valid DICOM file + ds = pydicom.dcmread(output_file, stop_before_pixels=True) + self.assertTrue(hasattr(ds, "NumberOfFrames")) + print(f" ✓ Created: {output_file.name} ({ds.NumberOfFrames} frames)") + + except Exception as e: + print(f"Error during batch conversion: {e}") + raise + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/transform/test_reader.py b/tests/unit/transform/test_reader.py new file mode 100644 index 000000000..bd456d5c2 --- /dev/null +++ b/tests/unit/transform/test_reader.py @@ -0,0 +1,594 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from pathlib import Path + +import numpy as np +from monai.transforms import LoadImage + +# Check if required dependencies are available +try: + from nvidia import nvimgcodec + + HAS_NVIMGCODEC = True +except ImportError: + HAS_NVIMGCODEC = False + nvimgcodec = None + +try: + import pydicom + + HAS_PYDICOM = True +except ImportError: + HAS_PYDICOM = False + pydicom = None + +# Import the reader +try: + from monailabel.transform.reader import NvDicomReader + + HAS_NVDICOMREADER = True +except ImportError: + HAS_NVDICOMREADER = False + NvDicomReader = None + + +@unittest.skipIf(not HAS_NVDICOMREADER, "NvDicomReader not available") +@unittest.skipIf(not HAS_PYDICOM, "pydicom not available") +class TestNvDicomReader(unittest.TestCase): + """Test suite for NvDicomReader with HTJ2K encoded DICOM files.""" + + base_dir = os.path.realpath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + dicom_dataset = os.path.join(base_dir, "data", "dataset", "dicomweb", "e7567e0a064f0c334226a0658de23afd") + + # Test series for HTJ2K decoding + test_series_uid = "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721" + + def setUp(self): + """Set up test fixtures.""" + # Paths to test data + self.original_series_dir = os.path.join(self.dicom_dataset, self.test_series_uid) + self.htj2k_series_dir = os.path.join( + self.base_dir, "data", "dataset", "dicomweb_htj2k", "e7567e0a064f0c334226a0658de23afd", self.test_series_uid + ) + self.reference_nifti = os.path.join(self.dicom_dataset, f"{self.test_series_uid}.nii.gz") + + def _check_test_data(self, directory, desc="DICOM"): + """Check if test data exists.""" + if not os.path.exists(directory): + return False + dcm_files = list(Path(directory).glob("*.dcm")) + if len(dcm_files) == 0: + return False + return True + + def _get_reference_image(self): + """Load reference NIfTI image.""" + if not os.path.exists(self.reference_nifti): + self.fail(f"Reference NIfTI file not found: {self.reference_nifti}") + + loader = LoadImage(image_only=False) + img_array, meta = loader(self.reference_nifti) + # Reference NIfTI is in (W, H, D) order + return np.array(img_array), meta + + def test_nvdicomreader_original_series(self): + """Test NvDicomReader with original (non-HTJ2K) DICOM series.""" + # Check test data exists + if not self._check_test_data(self.original_series_dir, "original DICOM"): + self.skipTest(f"Original DICOM test data not found at {self.original_series_dir}") + + # Load with NvDicomReader (default depth_last=True matches NIfTI W,H,D layout) + reader = NvDicomReader() + img_obj = reader.read(self.original_series_dir) + volume, metadata = reader.get_data(img_obj) + + # Verify shape (should be W, H, D with depth_last=True, the default) + self.assertEqual(volume.shape, (512, 512, 77), f"Expected shape (512, 512, 77), got {volume.shape}") + + # Load reference NIfTI for comparison + reference, ref_meta = self._get_reference_image() + + # Compare with reference (allowing for small numerical differences) + np.testing.assert_allclose( + volume, reference, rtol=1e-5, atol=1e-3, err_msg="NvDicomReader output differs from reference NIfTI" + ) + + print(f"✓ NvDicomReader original DICOM series test passed") + + @unittest.skipIf(not HAS_NVIMGCODEC, "nvimgcodec not available for HTJ2K decoding") + def test_nvdicomreader_htj2k_series(self): + """Test NvDicomReader with HTJ2K-encoded DICOM series.""" + # Check HTJ2K test data exists + if not self._check_test_data(self.htj2k_series_dir, "HTJ2K DICOM"): + # Try to create HTJ2K data if nvimgcodec is available + print("\nHTJ2K test data not found. Attempting to create...") + import sys + + sys.path.insert(0, os.path.join(self.base_dir)) + try: + from prepare_htj2k_test_data import create_htj2k_data + + create_htj2k_data(os.path.join(self.base_dir, "data")) + except Exception as e: + self.skipTest(f"Could not create HTJ2K test data: {e}") + + # Re-check after creation attempt + if not self._check_test_data(self.htj2k_series_dir, "HTJ2K DICOM"): + self.skipTest(f"HTJ2K DICOM files not found at {self.htj2k_series_dir}") + + # Verify these are actually HTJ2K encoded + htj2k_files = list(Path(self.htj2k_series_dir).glob("*.dcm")) + first_dcm = pydicom.dcmread(str(htj2k_files[0])) + transfer_syntax = first_dcm.file_meta.TransferSyntaxUID + htj2k_syntaxes = [ + "1.2.840.10008.1.2.4.201", # HTJ2K Lossless + "1.2.840.10008.1.2.4.202", # HTJ2K with RPCL + "1.2.840.10008.1.2.4.203", # HTJ2K Lossy + ] + if str(transfer_syntax) not in htj2k_syntaxes: + self.skipTest(f"DICOM files are not HTJ2K encoded (Transfer Syntax: {transfer_syntax})") + + # Load with NvDicomReader (default depth_last=True matches NIfTI W,H,D layout) + reader = NvDicomReader(use_nvimgcodec=True, prefer_gpu_output=False) + img_obj = reader.read(self.htj2k_series_dir) + volume, metadata = reader.get_data(img_obj) + + # Verify shape (should be W, H, D with depth_last=True, the default) + self.assertEqual(volume.shape, (512, 512, 77), f"Expected shape (512, 512, 77), got {volume.shape}") + + # Load reference NIfTI for comparison + reference, ref_meta = self._get_reference_image() + + # Convert to numpy if cupy array (batch decode may return GPU arrays) + if hasattr(volume, "__cuda_array_interface__"): + import cupy as cp + + volume = cp.asnumpy(volume) + + # Compare with reference (HTJ2K is lossless, so should be identical) + np.testing.assert_allclose( + volume, reference, rtol=1e-5, atol=1e-3, err_msg="HTJ2K decoded volume differs from reference NIfTI" + ) + + print(f"✓ NvDicomReader HTJ2K DICOM series test passed") + + @unittest.skipIf(not HAS_NVIMGCODEC, "nvimgcodec not available for HTJ2K decoding") + def test_htj2k_vs_original_consistency(self): + """Test that HTJ2K decoding produces the same result as original DICOM.""" + # Check both datasets exist + if not self._check_test_data(self.original_series_dir, "original DICOM"): + self.skipTest(f"Original DICOM test data not found at {self.original_series_dir}") + + if not self._check_test_data(self.htj2k_series_dir, "HTJ2K DICOM"): + # Try to create HTJ2K data + print("\nHTJ2K test data not found. Attempting to create...") + import sys + + sys.path.insert(0, os.path.join(self.base_dir)) + try: + from prepare_htj2k_test_data import create_htj2k_data + + create_htj2k_data(os.path.join(self.base_dir, "data")) + except Exception as e: + self.skipTest(f"Could not create HTJ2K test data: {e}") + + # Re-check after creation attempt + if not self._check_test_data(self.htj2k_series_dir, "HTJ2K DICOM"): + self.skipTest(f"HTJ2K DICOM files not found at {self.htj2k_series_dir}") + + # Load original series (default depth_last=True for W,H,D layout) + reader_original = NvDicomReader(use_nvimgcodec=False) # Force pydicom for original + img_obj_orig = reader_original.read(self.original_series_dir) + volume_orig, metadata_orig = reader_original.get_data(img_obj_orig) + + # Load HTJ2K series with nvImageCodec (default depth_last=True for W,H,D layout) + reader_htj2k = NvDicomReader(use_nvimgcodec=True, prefer_gpu_output=False) + img_obj_htj2k = reader_htj2k.read(self.htj2k_series_dir) + volume_htj2k, metadata_htj2k = reader_htj2k.get_data(img_obj_htj2k) + + # Convert to numpy if cupy arrays + if hasattr(volume_orig, "__cuda_array_interface__"): + import cupy as cp + + volume_orig = cp.asnumpy(volume_orig) + if hasattr(volume_htj2k, "__cuda_array_interface__"): + import cupy as cp + + volume_htj2k = cp.asnumpy(volume_htj2k) + + # Verify shapes match + self.assertEqual(volume_orig.shape, volume_htj2k.shape, "Original and HTJ2K volumes should have the same shape") + + # Compare volumes (HTJ2K lossless should be identical) + np.testing.assert_allclose( + volume_orig, volume_htj2k, rtol=1e-5, atol=1e-3, err_msg="HTJ2K decoded volume differs from original DICOM" + ) + + # Verify metadata consistency + self.assertEqual( + metadata_orig["spacing"].tolist(), metadata_htj2k["spacing"].tolist(), "Spacing should be identical" + ) + + np.testing.assert_allclose( + metadata_orig["affine"], metadata_htj2k["affine"], rtol=1e-6, err_msg="Affine matrices should be identical" + ) + + print(f"✓ HTJ2K vs original consistency test passed") + + def test_nvdicomreader_metadata(self): + """Test that NvDicomReader extracts proper metadata.""" + if not self._check_test_data(self.original_series_dir): + self.skipTest(f"Original DICOM test data not found at {self.original_series_dir}") + + reader = NvDicomReader() # default depth_last=True + img_obj = reader.read(self.original_series_dir) + volume, metadata = reader.get_data(img_obj) + + # Check essential metadata fields + self.assertIn("affine", metadata, "Metadata should contain affine matrix") + self.assertIn("spacing", metadata, "Metadata should contain spacing") + self.assertIn("spatial_shape", metadata, "Metadata should contain spatial_shape") + + # Verify affine is 4x4 + self.assertEqual(metadata["affine"].shape, (4, 4), "Affine should be 4x4") + + # Verify spacing has 3 elements + self.assertEqual(len(metadata["spacing"]), 3, "Spacing should have 3 elements") + + # Verify spatial shape matches volume shape + np.testing.assert_array_equal( + metadata["spatial_shape"], volume.shape, err_msg="Spatial shape in metadata should match volume shape" + ) + + print(f"✓ NvDicomReader metadata test passed") + + def test_nvdicomreader_depth_last(self): + """Test NvDicomReader with depth_last option (ITK-style vs NumPy-style layout).""" + if not self._check_test_data(self.original_series_dir): + self.skipTest(f"Original DICOM test data not found at {self.original_series_dir}") + + # NumPy-style: depth_last=False -> (depth, height, width) + reader_numpy = NvDicomReader(depth_last=False) + img_obj_numpy = reader_numpy.read(self.original_series_dir) + volume_numpy, _ = reader_numpy.get_data(img_obj_numpy) + + # ITK-style (default): depth_last=True -> (width, height, depth) + reader_itk = NvDicomReader(depth_last=True) + img_obj_itk = reader_itk.read(self.original_series_dir) + volume_itk, _ = reader_itk.get_data(img_obj_itk) + + # Verify shapes are transposed correctly + self.assertEqual(volume_numpy.shape, (77, 512, 512)) + self.assertEqual(volume_itk.shape, (512, 512, 77)) + + # Verify data is the same (just transposed) + np.testing.assert_allclose( + volume_numpy.transpose(2, 1, 0), + volume_itk, + rtol=1e-6, + err_msg="depth_last should produce transposed volume", + ) + + print(f"✓ NvDicomReader depth_last test passed") + + +@unittest.skipIf(not HAS_NVIMGCODEC, "nvimgcodec not available") +@unittest.skipIf(not HAS_PYDICOM, "pydicom not available") +class TestNvDicomReaderHTJ2KPerformance(unittest.TestCase): + """Performance tests for HTJ2K decoding with NvDicomReader.""" + + base_dir = os.path.realpath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + dicom_dataset = os.path.join(base_dir, "data", "dataset", "dicomweb", "e7567e0a064f0c334226a0658de23afd") + test_series_uid = "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721" + + def setUp(self): + """Set up test fixtures.""" + self.htj2k_series_dir = os.path.join( + self.base_dir, "data", "dataset", "dicomweb_htj2k", "e7567e0a064f0c334226a0658de23afd", self.test_series_uid + ) + + def test_batch_decode_optimization(self): + """Test that batch decode is used for HTJ2K series.""" + # Skip if HTJ2K data not available + if not os.path.exists(self.htj2k_series_dir): + self.skipTest(f"HTJ2K test data not found at {self.htj2k_series_dir}") + + htj2k_files = list(Path(self.htj2k_series_dir).glob("*.dcm")) + if len(htj2k_files) == 0: + self.skipTest(f"No HTJ2K DICOM files found in {self.htj2k_series_dir}") + + # Verify HTJ2K encoding + first_dcm = pydicom.dcmread(str(htj2k_files[0])) + transfer_syntax = str(first_dcm.file_meta.TransferSyntaxUID) + htj2k_syntaxes = ["1.2.840.10008.1.2.4.201", "1.2.840.10008.1.2.4.202", "1.2.840.10008.1.2.4.203"] + if transfer_syntax not in htj2k_syntaxes: + self.skipTest(f"DICOM files are not HTJ2K encoded") + + # Load with batch decode enabled (default depth_last=True gives W,H,D layout) + reader = NvDicomReader(use_nvimgcodec=True, prefer_gpu_output=False) + img_obj = reader.read(self.htj2k_series_dir) + volume, metadata = reader.get_data(img_obj) + + # Verify successful decode + self.assertIsNotNone(volume, "Volume should be decoded successfully") + # With depth_last=True (default), shape is (W, H, D), so depth is at index 2 + self.assertEqual(volume.shape[2], len(htj2k_files), f"Volume should have {len(htj2k_files)} slices") + + print(f"✓ Batch decode optimization test passed ({len(htj2k_files)} slices)") + + +@unittest.skipIf(not HAS_NVDICOMREADER, "NvDicomReader not available") +@unittest.skipIf(not HAS_PYDICOM, "pydicom not available") +class TestNvDicomReaderMultiFrame(unittest.TestCase): + """Test suite for NvDicomReader with multi-frame DICOM files.""" + + base_dir = os.path.realpath(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) + + # Single-frame series paths + dicom_dataset = os.path.join(base_dir, "data", "dataset", "dicomweb", "e7567e0a064f0c334226a0658de23afd") + htj2k_single_base = os.path.join(base_dir, "data", "dataset", "dicomweb_htj2k", "e7567e0a064f0c334226a0658de23afd") + + # Multi-frame paths (organized by study UID directly) + htj2k_multiframe_base = os.path.join(base_dir, "data", "dataset", "dicomweb_htj2k_multiframe") + + # Test series UIDs + test_study_uid = "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656706" + test_series_uid = "1.2.826.0.1.3680043.8.274.1.1.8323329.686405.1629744173.656721" + + def setUp(self): + """Set up test fixtures.""" + self.original_series_dir = os.path.join(self.dicom_dataset, self.test_series_uid) + self.htj2k_series_dir = os.path.join(self.htj2k_single_base, self.test_series_uid) + self.multiframe_file = os.path.join( + self.htj2k_multiframe_base, self.test_study_uid, f"{self.test_series_uid}.dcm" + ) + + def _check_multiframe_data(self): + """Check if multi-frame test data exists.""" + if not os.path.exists(self.multiframe_file): + return False + return True + + def _check_single_frame_data(self): + """Check if single-frame test data exists.""" + if not os.path.exists(self.original_series_dir): + return False + dcm_files = list(Path(self.original_series_dir).glob("*.dcm")) + if len(dcm_files) == 0: + return False + return True + + @unittest.skipIf(not HAS_NVIMGCODEC, "nvimgcodec not available for HTJ2K decoding") + def test_multiframe_basic_read(self): + """Test that multi-frame DICOM can be read successfully.""" + if not self._check_multiframe_data(): + self.skipTest(f"Multi-frame DICOM not found at {self.multiframe_file}") + + # Read multi-frame DICOM + reader = NvDicomReader(use_nvimgcodec=True, prefer_gpu_output=False) + img_obj = reader.read(self.multiframe_file) + volume, metadata = reader.get_data(img_obj) + + # Convert to numpy if cupy array + if hasattr(volume, "__cuda_array_interface__"): + import cupy as cp + + volume = cp.asnumpy(volume) + + # Verify shape (should be W, H, D with depth_last=True) + self.assertEqual(len(volume.shape), 3, f"Volume should be 3D, got shape {volume.shape}") + self.assertEqual(volume.shape[2], 77, f"Expected 77 slices, got {volume.shape[2]}") + + # Verify metadata + self.assertIn("affine", metadata, "Metadata should contain affine matrix") + self.assertIn("spacing", metadata, "Metadata should contain spacing") + self.assertIn("ImagePositionPatient", metadata, "Metadata should contain ImagePositionPatient") + + print(f"✓ Multi-frame basic read test passed - shape: {volume.shape}") + + @unittest.skipIf(not HAS_NVIMGCODEC, "nvimgcodec not available for HTJ2K decoding") + def test_multiframe_vs_singleframe_consistency(self): + """Test that multi-frame DICOM produces identical results to single-frame series.""" + if not self._check_multiframe_data(): + self.skipTest(f"Multi-frame DICOM not found at {self.multiframe_file}") + + if not self._check_single_frame_data(): + self.skipTest(f"Single-frame series not found at {self.original_series_dir}") + + # Read single-frame series + reader_single = NvDicomReader(use_nvimgcodec=False, prefer_gpu_output=False) + img_obj_single = reader_single.read(self.original_series_dir) + volume_single, metadata_single = reader_single.get_data(img_obj_single) + + # Read multi-frame DICOM + reader_multi = NvDicomReader(use_nvimgcodec=True, prefer_gpu_output=False) + img_obj_multi = reader_multi.read(self.multiframe_file) + volume_multi, metadata_multi = reader_multi.get_data(img_obj_multi) + + # Convert to numpy if needed + if hasattr(volume_single, "__cuda_array_interface__"): + import cupy as cp + + volume_single = cp.asnumpy(volume_single) + if hasattr(volume_multi, "__cuda_array_interface__"): + import cupy as cp + + volume_multi = cp.asnumpy(volume_multi) + + # Verify shapes match + self.assertEqual( + volume_single.shape, + volume_multi.shape, + f"Single-frame and multi-frame volumes should have same shape. Single: {volume_single.shape}, Multi: {volume_multi.shape}", + ) + + # Compare pixel data (HTJ2K lossless should be identical) + np.testing.assert_allclose( + volume_single, + volume_multi, + rtol=1e-5, + atol=1e-3, + err_msg="Multi-frame DICOM pixel data differs from single-frame series", + ) + + # Compare spacing + np.testing.assert_allclose( + metadata_single["spacing"], metadata_multi["spacing"], rtol=1e-6, err_msg="Spacing should be identical" + ) + + # Compare affine matrices + np.testing.assert_allclose( + metadata_single["affine"], + metadata_multi["affine"], + rtol=1e-6, + atol=1e-3, + err_msg="Affine matrices should be identical", + ) + + print(f"✓ Multi-frame vs single-frame consistency test passed") + print(f" Shape: {volume_multi.shape}") + print(f" Spacing: {metadata_multi['spacing']}") + print(f" Affine origin: {metadata_multi['affine'][:3, 3]}") + + @unittest.skipIf(not HAS_NVIMGCODEC, "nvimgcodec not available") + def test_multiframe_per_frame_metadata(self): + """Test that per-frame metadata is correctly extracted from PerFrameFunctionalGroupsSequence.""" + if not self._check_multiframe_data(): + self.skipTest(f"Multi-frame DICOM not found at {self.multiframe_file}") + + # Read the DICOM file directly with pydicom to check PerFrameFunctionalGroupsSequence + ds = pydicom.dcmread(self.multiframe_file) + + # Verify it's actually multi-frame + self.assertTrue(hasattr(ds, "NumberOfFrames"), "Should have NumberOfFrames attribute") + self.assertGreater(ds.NumberOfFrames, 1, "Should have multiple frames") + + # Verify PerFrameFunctionalGroupsSequence exists + self.assertTrue( + hasattr(ds, "PerFrameFunctionalGroupsSequence"), + "Multi-frame DICOM should have PerFrameFunctionalGroupsSequence", + ) + + # Verify first frame has PlanePositionSequence + first_frame = ds.PerFrameFunctionalGroupsSequence[0] + self.assertTrue(hasattr(first_frame, "PlanePositionSequence"), "First frame should have PlanePositionSequence") + + first_pos = first_frame.PlanePositionSequence[0].ImagePositionPatient + self.assertEqual(len(first_pos), 3, "ImagePositionPatient should have 3 coordinates") + + # Now read with NvDicomReader and verify metadata is extracted + reader = NvDicomReader(use_nvimgcodec=True, prefer_gpu_output=False) + img_obj = reader.read(self.multiframe_file) + volume, metadata = reader.get_data(img_obj) + + # Verify ImagePositionPatient was extracted from per-frame metadata + self.assertIn("ImagePositionPatient", metadata, "Should have ImagePositionPatient in metadata") + + extracted_pos = metadata["ImagePositionPatient"] + self.assertEqual(len(extracted_pos), 3, "Extracted ImagePositionPatient should have 3 coordinates") + + # Verify it matches the first frame position + np.testing.assert_allclose( + extracted_pos, first_pos, rtol=1e-6, err_msg="Extracted ImagePositionPatient should match first frame" + ) + + print(f"✓ Multi-frame per-frame metadata test passed") + print(f" NumberOfFrames: {ds.NumberOfFrames}") + print(f" First frame ImagePositionPatient: {first_pos}") + + @unittest.skipIf(not HAS_NVIMGCODEC, "nvimgcodec not available") + def test_multiframe_affine_origin(self): + """Test that affine matrix origin is correctly extracted from multi-frame per-frame metadata.""" + if not self._check_multiframe_data(): + self.skipTest(f"Multi-frame DICOM not found at {self.multiframe_file}") + + # Read with pydicom to get expected origin + ds = pydicom.dcmread(self.multiframe_file) + first_frame = ds.PerFrameFunctionalGroupsSequence[0] + expected_origin = np.array(first_frame.PlanePositionSequence[0].ImagePositionPatient) + + # Read with NvDicomReader + reader = NvDicomReader(use_nvimgcodec=True, prefer_gpu_output=False, affine_lps_to_ras=True) + img_obj = reader.read(self.multiframe_file) + volume, metadata = reader.get_data(img_obj) + + # Extract origin from affine matrix (after LPS->RAS conversion) + # RAS affine has origin in last column, first 3 rows + affine_origin_ras = metadata["affine"][:3, 3] + + # Convert expected_origin from LPS to RAS for comparison + # LPS to RAS: negate X and Y + expected_origin_ras = expected_origin.copy() + expected_origin_ras[0] = -expected_origin_ras[0] + expected_origin_ras[1] = -expected_origin_ras[1] + + # Verify affine origin matches the first frame's ImagePositionPatient (in RAS) + np.testing.assert_allclose( + affine_origin_ras, + expected_origin_ras, + rtol=1e-6, + atol=1e-3, + err_msg=f"Affine origin should match first frame ImagePositionPatient. Got {affine_origin_ras}, expected {expected_origin_ras}", + ) + + print(f"✓ Multi-frame affine origin test passed") + print(f" ImagePositionPatient (LPS): {expected_origin}") + print(f" Affine origin (RAS): {affine_origin_ras}") + + @unittest.skipIf(not HAS_NVIMGCODEC, "nvimgcodec not available") + def test_multiframe_slice_spacing(self): + """Test that slice spacing is correctly calculated for multi-frame DICOMs.""" + if not self._check_multiframe_data(): + self.skipTest(f"Multi-frame DICOM not found at {self.multiframe_file}") + + # Read with pydicom to get first and last frame positions + ds = pydicom.dcmread(self.multiframe_file) + num_frames = ds.NumberOfFrames + + first_frame = ds.PerFrameFunctionalGroupsSequence[0] + last_frame = ds.PerFrameFunctionalGroupsSequence[num_frames - 1] + + first_pos = np.array(first_frame.PlanePositionSequence[0].ImagePositionPatient) + last_pos = np.array(last_frame.PlanePositionSequence[0].ImagePositionPatient) + + # Calculate expected slice spacing + # Distance between first and last divided by (number of slices - 1) + distance = np.linalg.norm(last_pos - first_pos) + expected_spacing = distance / (num_frames - 1) + + # Read with NvDicomReader + reader = NvDicomReader(use_nvimgcodec=True, prefer_gpu_output=False) + img_obj = reader.read(self.multiframe_file) + volume, metadata = reader.get_data(img_obj) + + # Get slice spacing (Z spacing, index 2) + slice_spacing = metadata["spacing"][2] + + # Verify it matches expected + self.assertAlmostEqual( + slice_spacing, + expected_spacing, + delta=0.1, + msg=f"Slice spacing should be ~{expected_spacing:.2f}mm, got {slice_spacing:.2f}mm", + ) + + print(f"✓ Multi-frame slice spacing test passed") + print(f" Number of frames: {num_frames}") + print(f" First position: {first_pos}") + print(f" Last position: {last_pos}") + print(f" Calculated spacing: {slice_spacing:.4f}mm (expected: {expected_spacing:.4f}mm)") + + +if __name__ == "__main__": + unittest.main()