Skip to content

Commit b652ca7

Browse files
committed
transcode to htj2k function to use nvimgcodec for decoding + mini-batch processing for large directories
Signed-off-by: Joaquin Anton Guirao <janton@nvidia.com>
1 parent 67da848 commit b652ca7

File tree

4 files changed

+194
-191
lines changed

4 files changed

+194
-191
lines changed

monailabel/datastore/utils/convert.py

Lines changed: 152 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def dicom_to_nifti(series_dir, is_seg=False):
213213
from monailabel.transform.reader import NvDicomReader
214214

215215
# Use NvDicomReader with LoadImage
216-
reader = NvDicomReader(reverse_indexing=True)
216+
reader = NvDicomReader()
217217
loader = LoadImage(reader=reader, image_only=False)
218218

219219
# Load the DICOM (supports both directories and single files)
@@ -644,43 +644,78 @@ def transcode_dicom_to_htj2k(
644644
output_dir: str = None,
645645
num_resolutions: int = 6,
646646
code_block_size: tuple = (64, 64),
647-
verify: bool = False,
647+
max_batch_size: int = 256,
648648
) -> str:
649649
"""
650650
Transcode DICOM files to HTJ2K (High Throughput JPEG 2000) lossless compression.
651651
652652
HTJ2K is a faster variant of JPEG 2000 that provides better compression performance
653-
for medical imaging applications. This function uses nvidia-nvimgcodec for encoding
654-
with batch processing for improved performance. All transcoding is performed using
655-
lossless compression to preserve image quality.
653+
for medical imaging applications. This function uses nvidia-nvimgcodec for hardware-
654+
accelerated decoding and encoding with batch processing for optimal performance.
655+
All transcoding is performed using lossless compression to preserve image quality.
656656
657-
The function operates in three phases:
658-
1. Load all DICOM files and prepare pixel arrays
659-
2. Batch encode all images to HTJ2K in parallel
660-
3. Save encoded data back to DICOM files
657+
The function processes files in configurable batches:
658+
1. Categorizes files by transfer syntax (HTJ2K/JPEG2000/JPEG/uncompressed)
659+
2. Uses nvimgcodec decoder for compressed files (JPEG2000, JPEG)
660+
3. Falls back to pydicom pixel_array for uncompressed files
661+
4. Batch encodes all images to HTJ2K using nvimgcodec
662+
5. Saves transcoded files with updated transfer syntax
663+
6. Copies already-HTJ2K files directly (no re-encoding)
664+
665+
Supported source transfer syntaxes:
666+
- JPEG 2000 (lossless and lossy)
667+
- JPEG (baseline, extended, lossless)
668+
- Uncompressed (Explicit/Implicit VR Little/Big Endian)
669+
- Already HTJ2K files are copied without re-encoding
670+
671+
Typical compression ratios of 60-70% with lossless quality.
672+
Processing speed depends on batch size and GPU capabilities.
661673
662674
Args:
663675
input_dir: Path to directory containing DICOM files to transcode
664676
output_dir: Path to output directory for transcoded files. If None, creates temp directory
665-
num_resolutions: Number of resolution levels (default: 6)
677+
num_resolutions: Number of wavelet decomposition levels (default: 6)
678+
Higher values = better compression but slower encoding
666679
code_block_size: Code block size as (height, width) tuple (default: (64, 64))
667-
verify: If True, decode output to verify correctness (default: False)
680+
Must be powers of 2. Common values: (32,32), (64,64), (128,128)
681+
max_batch_size: Maximum number of DICOM files to process in each batch (default: 256)
682+
Lower values reduce memory usage, higher values may improve speed
668683
669684
Returns:
670-
Path to output directory containing transcoded DICOM files
685+
str: Path to output directory containing transcoded DICOM files
671686
672687
Raises:
673-
ImportError: If nvidia-nvimgcodec or pydicom are not available
674-
ValueError: If input directory doesn't exist or contains no DICOM files
688+
ImportError: If nvidia-nvimgcodec is not available
689+
ValueError: If input directory doesn't exist or contains no valid DICOM files
690+
ValueError: If DICOM files are missing required attributes (TransferSyntaxUID, PixelData)
675691
676692
Example:
693+
>>> # Basic usage with default settings
677694
>>> output_dir = transcode_dicom_to_htj2k("/path/to/dicoms")
678-
>>> # Transcoded files are now in output_dir with lossless HTJ2K compression
695+
>>> print(f"Transcoded files saved to: {output_dir}")
696+
697+
>>> # Custom output directory and batch size
698+
>>> output_dir = transcode_dicom_to_htj2k(
699+
... input_dir="/path/to/dicoms",
700+
... output_dir="/path/to/output",
701+
... max_batch_size=50,
702+
... num_resolutions=5
703+
... )
704+
705+
>>> # Process with smaller code blocks for memory efficiency
706+
>>> output_dir = transcode_dicom_to_htj2k(
707+
... input_dir="/path/to/dicoms",
708+
... code_block_size=(32, 32),
709+
... max_batch_size=5
710+
... )
679711
680712
Note:
681713
Requires nvidia-nvimgcodec to be installed:
682714
pip install nvidia-nvimgcodec-cu{XX}[all]
683715
Replace {XX} with your CUDA version (e.g., cu13 for CUDA 13.x)
716+
717+
The function preserves all DICOM metadata including Patient, Study, and Series
718+
information. Only the transfer syntax and pixel data encoding are modified.
684719
"""
685720
import glob
686721
import shutil
@@ -735,7 +770,7 @@ def transcode_dicom_to_htj2k(
735770

736771
# Create encoder and decoder instances (reused for all files)
737772
encoder = _get_nvimgcodec_encoder()
738-
decoder = _get_nvimgcodec_decoder() if verify else None
773+
decoder = _get_nvimgcodec_decoder() # Always needed for decoding input DICOM images
739774

740775
# HTJ2K Transfer Syntax UID - Lossless Only
741776
# 1.2.840.10008.1.2.4.201 = HTJ2K Lossless Only
@@ -755,153 +790,124 @@ def transcode_dicom_to_htj2k(
755790
quality_type=quality_type,
756791
jpeg2k_encode_params=jpeg2k_encode_params,
757792
)
793+
794+
decode_params = nvimgcodec.DecodeParams(
795+
allow_any_depth=True,
796+
color_spec=nvimgcodec.ColorSpec.UNCHANGED,
797+
)
758798

759-
start_time = time.time()
760-
transcoded_count = 0
761-
skipped_count = 0
762-
failed_count = 0
763-
764-
# Phase 1: Load all DICOM files and prepare pixel arrays for batch encoding
765-
logger.info("Phase 1: Loading DICOM files and preparing pixel arrays...")
766-
dicom_datasets = []
767-
pixel_arrays = []
768-
files_to_encode = []
799+
# Define transfer syntax constants (use frozenset for O(1) membership testing)
800+
JPEG2000_SYNTAXES = frozenset([
801+
"1.2.840.10008.1.2.4.90", # JPEG 2000 Image Compression (Lossless Only)
802+
"1.2.840.10008.1.2.4.91", # JPEG 2000 Image Compression
803+
])
769804

770-
for i, input_file in enumerate(valid_dicom_files, 1):
771-
try:
772-
# Read DICOM
773-
ds = pydicom.dcmread(input_file)
774-
775-
# Check if already HTJ2K
776-
current_ts = getattr(ds, 'file_meta', {}).get('TransferSyntaxUID', None)
777-
if current_ts and str(current_ts).startswith('1.2.840.10008.1.2.4.20'):
778-
logger.debug(f"[{i}/{len(valid_dicom_files)}] Already HTJ2K: {os.path.basename(input_file)}")
779-
# Just copy the file
780-
output_file = os.path.join(output_dir, os.path.basename(input_file))
781-
shutil.copy2(input_file, output_file)
782-
skipped_count += 1
783-
continue
784-
785-
# Use pydicom's pixel_array to decode the source image
786-
# This handles all transfer syntaxes automatically
787-
source_pixel_array = ds.pixel_array
788-
789-
# Ensure it's a numpy array
790-
if not isinstance(source_pixel_array, np.ndarray):
791-
source_pixel_array = np.array(source_pixel_array)
792-
793-
# Add channel dimension if needed (nvimgcodec expects shape like (H, W, C))
794-
if source_pixel_array.ndim == 2:
795-
source_pixel_array = source_pixel_array[:, :, np.newaxis]
796-
797-
# Store for batch encoding
798-
dicom_datasets.append(ds)
799-
pixel_arrays.append(source_pixel_array)
800-
files_to_encode.append(input_file)
801-
802-
if i % 50 == 0 or i == len(valid_dicom_files):
803-
logger.info(f"Loading progress: {i}/{len(valid_dicom_files)} files loaded")
804-
805-
except Exception as e:
806-
logger.error(f"[{i}/{len(valid_dicom_files)}] Error loading {os.path.basename(input_file)}: {e}")
807-
failed_count += 1
808-
continue
805+
HTJ2K_SYNTAXES = frozenset([
806+
"1.2.840.10008.1.2.4.201", # High-Throughput JPEG 2000 Image Compression (Lossless Only)
807+
"1.2.840.10008.1.2.4.202", # High-Throughput JPEG 2000 with RPCL Options Image Compression (Lossless Only)
808+
"1.2.840.10008.1.2.4.203", # High-Throughput JPEG 2000 Image Compression
809+
])
809810

810-
if not pixel_arrays:
811-
logger.warning("No images to encode")
812-
return output_dir
811+
JPEG_SYNTAXES = frozenset([
812+
"1.2.840.10008.1.2.4.50", # JPEG Baseline (Process 1)
813+
"1.2.840.10008.1.2.4.51", # JPEG Extended (Process 2 & 4)
814+
"1.2.840.10008.1.2.4.57", # JPEG Lossless, Non-Hierarchical (Process 14)
815+
"1.2.840.10008.1.2.4.70", # JPEG Lossless, Non-Hierarchical, First-Order Prediction
816+
])
813817

814-
# Phase 2: Batch encode all images to HTJ2K
815-
logger.info(f"Phase 2: Batch encoding {len(pixel_arrays)} images to HTJ2K...")
816-
encode_start = time.time()
818+
# Pre-compute combined set for nvimgcodec-compatible formats
819+
NVIMGCODEC_SYNTAXES = JPEG2000_SYNTAXES | JPEG_SYNTAXES
817820

818-
try:
819-
encoded_htj2k_images = encoder.encode(
820-
pixel_arrays,
821-
codec="jpeg2k",
822-
params=encode_params,
823-
)
824-
encode_time = time.time() - encode_start
825-
logger.info(f"Batch encoding completed in {encode_time:.2f} seconds ({len(pixel_arrays)/encode_time:.1f} images/sec)")
826-
except Exception as e:
827-
logger.error(f"Batch encoding failed: {e}")
828-
# Fall back to individual encoding
829-
logger.warning("Falling back to individual encoding...")
830-
encoded_htj2k_images = []
831-
for idx, pixel_array in enumerate(pixel_arrays):
832-
try:
833-
encoded_image = encoder.encode(
834-
[pixel_array],
835-
codec="jpeg2k",
836-
params=encode_params,
837-
)
838-
encoded_htj2k_images.extend(encoded_image)
839-
except Exception as e2:
840-
logger.error(f"Failed to encode image {idx}: {e2}")
841-
encoded_htj2k_images.append(None)
821+
start_time = time.time()
822+
transcoded_count = 0
823+
skipped_count = 0
842824

843-
# Phase 3: Save encoded data back to DICOM files
844-
logger.info("Phase 3: Saving encoded DICOM files...")
845-
save_start = time.time()
825+
# Calculate batch info for logging
826+
total_files = len(valid_dicom_files)
827+
total_batches = (total_files + max_batch_size - 1) // max_batch_size
846828

847-
for idx, (ds, encoded_data, input_file) in enumerate(zip(dicom_datasets, encoded_htj2k_images, files_to_encode)):
848-
try:
849-
if encoded_data is None:
850-
logger.error(f"Skipping {os.path.basename(input_file)} - encoding failed")
851-
failed_count += 1
852-
continue
853-
854-
# Encapsulate encoded frames for DICOM
855-
new_encoded_frames = [bytes(encoded_data)]
856-
encapsulated_pixel_data = pydicom.encaps.encapsulate(new_encoded_frames)
857-
ds.PixelData = encapsulated_pixel_data
858-
859-
# Update transfer syntax UID
860-
ds.file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
829+
for batch_start in range(0, total_files, max_batch_size):
830+
batch_end = min(batch_start + max_batch_size, total_files)
831+
current_batch = batch_start // max_batch_size + 1
832+
logger.info(f"[{batch_start}..{batch_end}] Processing batch {current_batch}/{total_batches}")
833+
batch_files = valid_dicom_files[batch_start:batch_end]
834+
batch_datasets = [pydicom.dcmread(file) for file in batch_files]
835+
nvimgcodec_batch = []
836+
pydicom_batch = []
837+
copy_batch = []
838+
for idx, ds in enumerate(batch_datasets):
839+
current_ts = getattr(ds, 'file_meta', {}).get('TransferSyntaxUID', None)
840+
if current_ts is None:
841+
raise ValueError(f"DICOM file {os.path.basename(batch_files[idx])} does not have a Transfer Syntax UID")
861842

862-
# Save to output directory
863-
output_file = os.path.join(output_dir, os.path.basename(input_file))
864-
ds.save_as(output_file)
843+
ts_str = str(current_ts)
844+
if ts_str in NVIMGCODEC_SYNTAXES:
845+
if not hasattr(ds, "PixelData") or ds.PixelData is None:
846+
raise ValueError(f"DICOM file {os.path.basename(batch_files[idx])} does not have a PixelData member")
847+
nvimgcodec_batch.append(idx)
848+
elif ts_str in HTJ2K_SYNTAXES:
849+
copy_batch.append(idx)
850+
else:
851+
pydicom_batch.append(idx)
852+
853+
if copy_batch:
854+
for idx in copy_batch:
855+
output_file = os.path.join(output_dir, os.path.basename(batch_files[idx]))
856+
shutil.copy2(batch_files[idx], output_file)
857+
skipped_count += len(copy_batch)
858+
859+
data_sequence = []
860+
decoded_data = []
861+
num_frames = []
862+
863+
# Decode using nvimgcodec for compressed formats
864+
if nvimgcodec_batch:
865+
for idx in nvimgcodec_batch:
866+
frames = [fragment for fragment in pydicom.encaps.generate_frames(batch_datasets[idx].PixelData)]
867+
num_frames.append(len(frames))
868+
data_sequence.extend(frames)
869+
decoder_output = decoder.decode(data_sequence, params=decode_params)
870+
decoded_data.extend(decoder_output)
871+
872+
# Decode using pydicom for uncompressed formats
873+
if pydicom_batch:
874+
for idx in pydicom_batch:
875+
source_pixel_array = batch_datasets[idx].pixel_array
876+
if not isinstance(source_pixel_array, np.ndarray):
877+
source_pixel_array = np.array(source_pixel_array)
878+
if source_pixel_array.ndim == 2:
879+
source_pixel_array = source_pixel_array[:, :, np.newaxis]
880+
for frame_idx in range(source_pixel_array.shape[-1]):
881+
decoded_data.append(source_pixel_array[:, :, frame_idx])
882+
num_frames.append(source_pixel_array.shape[-1])
883+
884+
# Encode all frames to HTJ2K
885+
encoded_data = encoder.encode(decoded_data, codec="jpeg2k", params=encode_params)
886+
887+
# Reassemble and save transcoded files
888+
frame_offset = 0
889+
files_to_process = nvimgcodec_batch + pydicom_batch
890+
891+
for list_idx, dataset_idx in enumerate(files_to_process):
892+
nframes = num_frames[list_idx]
893+
encoded_frames = [bytes(enc) for enc in encoded_data[frame_offset:frame_offset + nframes]]
894+
frame_offset += nframes
865895

866-
# Verify if requested
867-
if verify:
868-
ds_verify = pydicom.dcmread(output_file)
869-
pixel_data = ds_verify.PixelData
870-
data_sequence = [fragment for fragment in pydicom.encaps.generate_frames(pixel_data)]
871-
images_verify = decoder.decode(
872-
data_sequence,
873-
params=nvimgcodec.DecodeParams(
874-
allow_any_depth=True,
875-
color_spec=nvimgcodec.ColorSpec.UNCHANGED
876-
),
877-
)
878-
image_verify = np.array(images_verify[0].cpu()).squeeze()
879-
880-
if not np.allclose(image_verify, ds_verify.pixel_array):
881-
logger.warning(f"Verification failed for {os.path.basename(input_file)}")
882-
failed_count += 1
883-
continue
896+
# Update dataset with HTJ2K encoded data
897+
batch_datasets[dataset_idx].PixelData = pydicom.encaps.encapsulate(encoded_frames)
898+
batch_datasets[dataset_idx].file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
884899

900+
# Save transcoded file
901+
output_file = os.path.join(output_dir, os.path.basename(batch_files[dataset_idx]))
902+
batch_datasets[dataset_idx].save_as(output_file)
885903
transcoded_count += 1
886-
887-
if (idx + 1) % 50 == 0 or (idx + 1) == len(dicom_datasets):
888-
logger.info(f"Saving progress: {idx + 1}/{len(dicom_datasets)} files saved")
889-
890-
except Exception as e:
891-
logger.error(f"Error saving {os.path.basename(input_file)}: {e}")
892-
failed_count += 1
893-
continue
894-
895-
save_time = time.time() - save_start
896-
logger.info(f"Saving completed in {save_time:.2f} seconds")
897904

898905
elapsed_time = time.time() - start_time
899-
906+
900907
logger.info(f"Transcoding complete:")
901908
logger.info(f" Total files: {len(valid_dicom_files)}")
902909
logger.info(f" Successfully transcoded: {transcoded_count}")
903910
logger.info(f" Already HTJ2K (copied): {skipped_count}")
904-
logger.info(f" Failed: {failed_count}")
905911
logger.info(f" Time elapsed: {elapsed_time:.2f} seconds")
906912
logger.info(f" Output directory: {output_dir}")
907913

0 commit comments

Comments
 (0)