Skip to content

Commit 119f000

Browse files
committed
Add batch transcode function to convert utils
Signed-off-by: Joaquin Anton Guirao <janton@nvidia.com>
1 parent 3b5bd1c commit 119f000

File tree

3 files changed

+713
-200
lines changed

3 files changed

+713
-200
lines changed

monailabel/datastore/utils/convert.py

Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,3 +597,272 @@ def dicom_seg_to_itk_image(label, output_ext=".seg.nrrd"):
597597

598598
logger.info(f"Result/Output File: {output_file}")
599599
return output_file
600+
601+
602+
def transcode_dicom_to_htj2k(
603+
input_dir: str,
604+
output_dir: str = None,
605+
num_resolutions: int = 6,
606+
code_block_size: tuple = (64, 64),
607+
verify: bool = False,
608+
) -> str:
609+
"""
610+
Transcode DICOM files to HTJ2K (High Throughput JPEG 2000) lossless compression.
611+
612+
HTJ2K is a faster variant of JPEG 2000 that provides better compression performance
613+
for medical imaging applications. This function uses nvidia-nvimgcodec for encoding
614+
with batch processing for improved performance. All transcoding is performed using
615+
lossless compression to preserve image quality.
616+
617+
The function operates in three phases:
618+
1. Load all DICOM files and prepare pixel arrays
619+
2. Batch encode all images to HTJ2K in parallel
620+
3. Save encoded data back to DICOM files
621+
622+
Args:
623+
input_dir: Path to directory containing DICOM files to transcode
624+
output_dir: Path to output directory for transcoded files. If None, creates temp directory
625+
num_resolutions: Number of resolution levels (default: 6)
626+
code_block_size: Code block size as (height, width) tuple (default: (64, 64))
627+
verify: If True, decode output to verify correctness (default: False)
628+
629+
Returns:
630+
Path to output directory containing transcoded DICOM files
631+
632+
Raises:
633+
ImportError: If nvidia-nvimgcodec or pydicom are not available
634+
ValueError: If input directory doesn't exist or contains no DICOM files
635+
636+
Example:
637+
>>> output_dir = transcode_dicom_to_htj2k("/path/to/dicoms")
638+
>>> # Transcoded files are now in output_dir with lossless HTJ2K compression
639+
640+
Note:
641+
Requires nvidia-nvimgcodec to be installed:
642+
pip install nvidia-nvimgcodec-cu{XX}[all]
643+
Replace {XX} with your CUDA version (e.g., cu13 for CUDA 13.x)
644+
"""
645+
import glob
646+
import shutil
647+
from pathlib import Path
648+
649+
# Check for nvidia-nvimgcodec
650+
try:
651+
from nvidia import nvimgcodec
652+
except ImportError:
653+
raise ImportError(
654+
"nvidia-nvimgcodec is required for HTJ2K transcoding. "
655+
"Install it with: pip install nvidia-nvimgcodec-cu{XX}[all] "
656+
"(replace {XX} with your CUDA version, e.g., cu13)"
657+
)
658+
659+
# Validate input
660+
if not os.path.exists(input_dir):
661+
raise ValueError(f"Input directory does not exist: {input_dir}")
662+
663+
if not os.path.isdir(input_dir):
664+
raise ValueError(f"Input path is not a directory: {input_dir}")
665+
666+
# Get all DICOM files
667+
dicom_files = []
668+
for pattern in ["*.dcm", "*"]:
669+
dicom_files.extend(glob.glob(os.path.join(input_dir, pattern)))
670+
671+
# Filter to actual DICOM files
672+
valid_dicom_files = []
673+
for file_path in dicom_files:
674+
if os.path.isfile(file_path):
675+
try:
676+
# Quick check if it's a DICOM file
677+
with open(file_path, 'rb') as f:
678+
f.seek(128)
679+
magic = f.read(4)
680+
if magic == b'DICM':
681+
valid_dicom_files.append(file_path)
682+
except Exception:
683+
continue
684+
685+
if not valid_dicom_files:
686+
raise ValueError(f"No valid DICOM files found in {input_dir}")
687+
688+
logger.info(f"Found {len(valid_dicom_files)} DICOM files to transcode")
689+
690+
# Create output directory
691+
if output_dir is None:
692+
output_dir = tempfile.mkdtemp(prefix="htj2k_")
693+
else:
694+
os.makedirs(output_dir, exist_ok=True)
695+
696+
# Create encoder and decoder instances (reused for all files)
697+
encoder = nvimgcodec.Encoder()
698+
decoder = nvimgcodec.Decoder() if verify else None
699+
700+
# HTJ2K Transfer Syntax UID - Lossless Only
701+
# 1.2.840.10008.1.2.4.201 = HTJ2K Lossless Only
702+
target_transfer_syntax = "1.2.840.10008.1.2.4.201"
703+
quality_type = nvimgcodec.QualityType.LOSSLESS
704+
logger.info("Using lossless HTJ2K compression")
705+
706+
# Configure JPEG2K encoding parameters
707+
jpeg2k_encode_params = nvimgcodec.Jpeg2kEncodeParams()
708+
jpeg2k_encode_params.num_resolutions = num_resolutions
709+
jpeg2k_encode_params.code_block_size = code_block_size
710+
jpeg2k_encode_params.bitstream_type = nvimgcodec.Jpeg2kBitstreamType.JP2
711+
jpeg2k_encode_params.prog_order = nvimgcodec.Jpeg2kProgOrder.LRCP
712+
jpeg2k_encode_params.ht = True # Enable High Throughput mode
713+
714+
encode_params = nvimgcodec.EncodeParams(
715+
quality_type=quality_type,
716+
jpeg2k_encode_params=jpeg2k_encode_params,
717+
)
718+
719+
start_time = time.time()
720+
transcoded_count = 0
721+
skipped_count = 0
722+
failed_count = 0
723+
724+
# Phase 1: Load all DICOM files and prepare pixel arrays for batch encoding
725+
logger.info("Phase 1: Loading DICOM files and preparing pixel arrays...")
726+
dicom_datasets = []
727+
pixel_arrays = []
728+
files_to_encode = []
729+
730+
for i, input_file in enumerate(valid_dicom_files, 1):
731+
try:
732+
# Read DICOM
733+
ds = pydicom.dcmread(input_file)
734+
735+
# Check if already HTJ2K
736+
current_ts = getattr(ds, 'file_meta', {}).get('TransferSyntaxUID', None)
737+
if current_ts and str(current_ts).startswith('1.2.840.10008.1.2.4.20'):
738+
logger.debug(f"[{i}/{len(valid_dicom_files)}] Already HTJ2K: {os.path.basename(input_file)}")
739+
# Just copy the file
740+
output_file = os.path.join(output_dir, os.path.basename(input_file))
741+
shutil.copy2(input_file, output_file)
742+
skipped_count += 1
743+
continue
744+
745+
# Use pydicom's pixel_array to decode the source image
746+
# This handles all transfer syntaxes automatically
747+
source_pixel_array = ds.pixel_array
748+
749+
# Ensure it's a numpy array
750+
if not isinstance(source_pixel_array, np.ndarray):
751+
source_pixel_array = np.array(source_pixel_array)
752+
753+
# Add channel dimension if needed (nvimgcodec expects shape like (H, W, C))
754+
if source_pixel_array.ndim == 2:
755+
source_pixel_array = source_pixel_array[:, :, np.newaxis]
756+
757+
# Store for batch encoding
758+
dicom_datasets.append(ds)
759+
pixel_arrays.append(source_pixel_array)
760+
files_to_encode.append(input_file)
761+
762+
if i % 50 == 0 or i == len(valid_dicom_files):
763+
logger.info(f"Loading progress: {i}/{len(valid_dicom_files)} files loaded")
764+
765+
except Exception as e:
766+
logger.error(f"[{i}/{len(valid_dicom_files)}] Error loading {os.path.basename(input_file)}: {e}")
767+
failed_count += 1
768+
continue
769+
770+
if not pixel_arrays:
771+
logger.warning("No images to encode")
772+
return output_dir
773+
774+
# Phase 2: Batch encode all images to HTJ2K
775+
logger.info(f"Phase 2: Batch encoding {len(pixel_arrays)} images to HTJ2K...")
776+
encode_start = time.time()
777+
778+
try:
779+
encoded_htj2k_images = encoder.encode(
780+
pixel_arrays,
781+
codec="jpeg2k",
782+
params=encode_params,
783+
)
784+
encode_time = time.time() - encode_start
785+
logger.info(f"Batch encoding completed in {encode_time:.2f} seconds ({len(pixel_arrays)/encode_time:.1f} images/sec)")
786+
except Exception as e:
787+
logger.error(f"Batch encoding failed: {e}")
788+
# Fall back to individual encoding
789+
logger.warning("Falling back to individual encoding...")
790+
encoded_htj2k_images = []
791+
for idx, pixel_array in enumerate(pixel_arrays):
792+
try:
793+
encoded_image = encoder.encode(
794+
[pixel_array],
795+
codec="jpeg2k",
796+
params=encode_params,
797+
)
798+
encoded_htj2k_images.extend(encoded_image)
799+
except Exception as e2:
800+
logger.error(f"Failed to encode image {idx}: {e2}")
801+
encoded_htj2k_images.append(None)
802+
803+
# Phase 3: Save encoded data back to DICOM files
804+
logger.info("Phase 3: Saving encoded DICOM files...")
805+
save_start = time.time()
806+
807+
for idx, (ds, encoded_data, input_file) in enumerate(zip(dicom_datasets, encoded_htj2k_images, files_to_encode)):
808+
try:
809+
if encoded_data is None:
810+
logger.error(f"Skipping {os.path.basename(input_file)} - encoding failed")
811+
failed_count += 1
812+
continue
813+
814+
# Encapsulate encoded frames for DICOM
815+
new_encoded_frames = [bytes(encoded_data)]
816+
encapsulated_pixel_data = pydicom.encaps.encapsulate(new_encoded_frames)
817+
ds.PixelData = encapsulated_pixel_data
818+
819+
# Update transfer syntax UID
820+
ds.file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
821+
822+
# Save to output directory
823+
output_file = os.path.join(output_dir, os.path.basename(input_file))
824+
ds.save_as(output_file)
825+
826+
# Verify if requested
827+
if verify:
828+
ds_verify = pydicom.dcmread(output_file)
829+
pixel_data = ds_verify.PixelData
830+
data_sequence = pydicom.encaps.decode_data_sequence(pixel_data)
831+
images_verify = decoder.decode(
832+
data_sequence,
833+
params=nvimgcodec.DecodeParams(
834+
allow_any_depth=True,
835+
color_spec=nvimgcodec.ColorSpec.UNCHANGED
836+
),
837+
)
838+
image_verify = np.array(images_verify[0].cpu()).squeeze()
839+
840+
if not np.allclose(image_verify, ds_verify.pixel_array):
841+
logger.warning(f"Verification failed for {os.path.basename(input_file)}")
842+
failed_count += 1
843+
continue
844+
845+
transcoded_count += 1
846+
847+
if (idx + 1) % 50 == 0 or (idx + 1) == len(dicom_datasets):
848+
logger.info(f"Saving progress: {idx + 1}/{len(dicom_datasets)} files saved")
849+
850+
except Exception as e:
851+
logger.error(f"Error saving {os.path.basename(input_file)}: {e}")
852+
failed_count += 1
853+
continue
854+
855+
save_time = time.time() - save_start
856+
logger.info(f"Saving completed in {save_time:.2f} seconds")
857+
858+
elapsed_time = time.time() - start_time
859+
860+
logger.info(f"Transcoding complete:")
861+
logger.info(f" Total files: {len(valid_dicom_files)}")
862+
logger.info(f" Successfully transcoded: {transcoded_count}")
863+
logger.info(f" Already HTJ2K (copied): {skipped_count}")
864+
logger.info(f" Failed: {failed_count}")
865+
logger.info(f" Time elapsed: {elapsed_time:.2f} seconds")
866+
logger.info(f" Output directory: {output_dir}")
867+
868+
return output_dir

0 commit comments

Comments
 (0)