Skip to content

Commit 9313c90

Browse files
committed
Modify conversion to multiframe utility to allow for either original or htj2k encoding
Signed-off-by: Joaquin Anton Guirao <janton@nvidia.com>
1 parent c768909 commit 9313c90

File tree

2 files changed

+175
-104
lines changed

2 files changed

+175
-104
lines changed

monailabel/datastore/utils/convert.py

Lines changed: 172 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -639,22 +639,6 @@ def dicom_seg_to_itk_image(label, output_ext=".seg.nrrd"):
639639
return output_file
640640

641641

642-
def _create_basic_offset_table_pixel_data(encoded_frames: list) -> bytes:
643-
"""
644-
Create encapsulated pixel data with Basic Offset Table for multi-frame DICOM.
645-
646-
Uses pydicom's encapsulate() function to ensure 100% standard compliance.
647-
648-
Args:
649-
encoded_frames: List of encoded frame byte strings
650-
651-
Returns:
652-
bytes: Encapsulated pixel data with Basic Offset Table per DICOM Part 5 Section A.4
653-
"""
654-
return pydicom.encaps.encapsulate(encoded_frames, has_bot=True)
655-
656-
657-
658642
def _setup_htj2k_decode_params():
659643
"""
660644
Create nvimgcodec decoding parameters for DICOM images.
@@ -737,21 +721,6 @@ def _get_transfer_syntax_constants():
737721
}
738722

739723

740-
def _create_basic_offset_table_pixel_data(encoded_frames: list) -> bytes:
741-
"""
742-
Create encapsulated pixel data with Basic Offset Table for multi-frame DICOM.
743-
744-
Uses pydicom's encapsulate() function to ensure 100% standard compliance.
745-
746-
Args:
747-
encoded_frames: List of encoded frame byte strings
748-
749-
Returns:
750-
bytes: Encapsulated pixel data with Basic Offset Table per DICOM Part 5 Section A.4
751-
"""
752-
return pydicom.encaps.encapsulate(encoded_frames, has_bot=True)
753-
754-
755724
def transcode_dicom_to_htj2k(
756725
input_dir: str,
757726
output_dir: str = None,
@@ -926,10 +895,9 @@ def transcode_dicom_to_htj2k(
926895
if not hasattr(ds, "PixelData") or ds.PixelData is None:
927896
raise ValueError(f"DICOM file {os.path.basename(batch_files[idx])} does not have a PixelData member")
928897
nvimgcodec_batch.append(idx)
929-
930898
else:
931899
pydicom_batch.append(idx)
932-
900+
933901
data_sequence = []
934902
decoded_data = []
935903
num_frames = []
@@ -970,8 +938,8 @@ def transcode_dicom_to_htj2k(
970938
# Update dataset with HTJ2K encoded data
971939
# Create Basic Offset Table for multi-frame files if requested
972940
if add_basic_offset_table and nframes > 1:
973-
batch_datasets[dataset_idx].PixelData = _create_basic_offset_table_pixel_data(encoded_frames)
974-
logger.debug(f"Created Basic Offset Table for {os.path.basename(batch_files[dataset_idx])} ({nframes} frames)")
941+
batch_datasets[dataset_idx].PixelData = pydicom.encaps.encapsulate(encoded_frames, has_bot=True)
942+
logger.info(f"Basic Offset Table included for efficient frame access")
975943
else:
976944
batch_datasets[dataset_idx].PixelData = pydicom.encaps.encapsulate(encoded_frames)
977945

@@ -993,17 +961,19 @@ def transcode_dicom_to_htj2k(
993961
return output_dir
994962

995963

996-
def transcode_dicom_to_htj2k_multiframe(
964+
def convert_single_frame_dicom_series_to_multiframe(
997965
input_dir: str,
998966
output_dir: str = None,
967+
convert_to_htj2k: bool = False,
999968
num_resolutions: int = 6,
1000969
code_block_size: tuple = (64, 64),
970+
add_basic_offset_table: bool = True,
1001971
) -> str:
1002972
"""
1003-
Transcode DICOM files to HTJ2K and combine all frames from the same series into single multi-frame files.
973+
Convert single-frame DICOM series to multi-frame DICOM files, optionally with HTJ2K compression.
1004974
1005975
This function groups DICOM files by SeriesInstanceUID and combines all frames from each series
1006-
into a single multi-frame DICOM file with HTJ2K compression. This is useful for:
976+
into a single multi-frame DICOM file. This is useful for:
1007977
- Reducing file count (one file per series instead of many)
1008978
- Improving storage efficiency
1009979
- Enabling more efficient frame-level access patterns
@@ -1012,28 +982,38 @@ def transcode_dicom_to_htj2k_multiframe(
1012982
1. Scans input directory recursively for DICOM files
1013983
2. Groups files by StudyInstanceUID and SeriesInstanceUID
1014984
3. For each series, decodes all frames and combines them
1015-
4. Encodes combined frames to HTJ2K
985+
4. Optionally encodes combined frames to HTJ2K (if convert_to_htj2k=True)
1016986
5. Creates a Basic Offset Table for efficient frame access (per DICOM Part 5 Section A.4)
1017987
6. Saves as a single multi-frame DICOM file per series
1018988
1019989
Args:
1020990
input_dir: Path to directory containing DICOM files (will scan recursively)
1021991
output_dir: Path to output directory for transcoded files. If None, creates temp directory
1022-
num_resolutions: Number of wavelet decomposition levels (default: 6)
1023-
code_block_size: Code block size as (height, width) tuple (default: (64, 64))
992+
convert_to_htj2k: If True, convert frames to HTJ2K compression; if False, use uncompressed format (default: False)
993+
num_resolutions: Number of wavelet decomposition levels (default: 6, only used if convert_to_htj2k=True)
994+
code_block_size: Code block size as (height, width) tuple (default: (64, 64), only used if convert_to_htj2k=True)
995+
add_basic_offset_table: If True, creates Basic Offset Table for multi-frame DICOMs (default: True)
996+
BOT enables O(1) frame access without parsing entire pixel data stream
997+
Per DICOM Part 5 Section A.4. Only affects multi-frame files.
1024998
1025999
Returns:
1026-
str: Path to output directory containing transcoded multi-frame DICOM files
1000+
str: Path to output directory containing multi-frame DICOM files
10271001
10281002
Raises:
1029-
ImportError: If nvidia-nvimgcodec is not available
1003+
ImportError: If nvidia-nvimgcodec is not available and convert_to_htj2k=True
10301004
ValueError: If input directory doesn't exist or contains no valid DICOM files
10311005
10321006
Example:
1033-
>>> # Combine series and transcode to HTJ2K
1034-
>>> output_dir = transcode_dicom_to_htj2k_multiframe("/path/to/dicoms")
1007+
>>> # Combine series without HTJ2K conversion (uncompressed)
1008+
>>> output_dir = convert_single_frame_dicom_series_to_multiframe("/path/to/dicoms")
10351009
>>> print(f"Multi-frame files saved to: {output_dir}")
10361010
1011+
>>> # Combine series with HTJ2K conversion
1012+
>>> output_dir = convert_single_frame_dicom_series_to_multiframe(
1013+
... "/path/to/dicoms",
1014+
... convert_to_htj2k=True
1015+
... )
1016+
10371017
Note:
10381018
Each output file is named using the SeriesInstanceUID:
10391019
<StudyUID>/<SeriesUID>.dcm
@@ -1053,15 +1033,16 @@ def transcode_dicom_to_htj2k_multiframe(
10531033
from collections import defaultdict
10541034
from pathlib import Path
10551035

1056-
# Check for nvidia-nvimgcodec
1057-
try:
1058-
from nvidia import nvimgcodec
1059-
except ImportError:
1060-
raise ImportError(
1061-
"nvidia-nvimgcodec is required for HTJ2K transcoding. "
1062-
"Install it with: pip install nvidia-nvimgcodec-cu{XX}[all] "
1063-
"(replace {XX} with your CUDA version, e.g., cu13)"
1064-
)
1036+
# Check for nvidia-nvimgcodec only if HTJ2K conversion is requested
1037+
if convert_to_htj2k:
1038+
try:
1039+
from nvidia import nvimgcodec
1040+
except ImportError:
1041+
raise ImportError(
1042+
"nvidia-nvimgcodec is required for HTJ2K transcoding. "
1043+
"Install it with: pip install nvidia-nvimgcodec-cu{XX}[all] "
1044+
"(replace {XX} with your CUDA version, e.g., cu13)"
1045+
)
10651046

10661047
import pydicom
10671048
import numpy as np
@@ -1123,20 +1104,32 @@ def transcode_dicom_to_htj2k_multiframe(
11231104

11241105
# Create output directory
11251106
if output_dir is None:
1126-
output_dir = tempfile.mkdtemp(prefix="htj2k_multiframe_")
1107+
prefix = "htj2k_multiframe_" if convert_to_htj2k else "multiframe_"
1108+
output_dir = tempfile.mkdtemp(prefix=prefix)
11271109
else:
11281110
os.makedirs(output_dir, exist_ok=True)
11291111

1130-
# Create encoder and decoder instances
1131-
encoder = _get_nvimgcodec_encoder()
1132-
decoder = _get_nvimgcodec_decoder()
1133-
1134-
# Setup HTJ2K encoding and decoding parameters
1135-
encode_params, target_transfer_syntax = _setup_htj2k_encode_params(
1136-
num_resolutions=num_resolutions,
1137-
code_block_size=code_block_size
1138-
)
1139-
decode_params = _setup_htj2k_decode_params()
1112+
# Setup encoder/decoder and parameters based on conversion mode
1113+
if convert_to_htj2k:
1114+
# Create encoder and decoder instances for HTJ2K
1115+
encoder = _get_nvimgcodec_encoder()
1116+
decoder = _get_nvimgcodec_decoder()
1117+
1118+
# Setup HTJ2K encoding and decoding parameters
1119+
encode_params, target_transfer_syntax = _setup_htj2k_encode_params(
1120+
num_resolutions=num_resolutions,
1121+
code_block_size=code_block_size
1122+
)
1123+
decode_params = _setup_htj2k_decode_params()
1124+
logger.info("HTJ2K conversion enabled")
1125+
else:
1126+
# No conversion - preserve original transfer syntax
1127+
encoder = None
1128+
decoder = None
1129+
encode_params = None
1130+
decode_params = None
1131+
target_transfer_syntax = None # Will be determined from first dataset
1132+
logger.info("Preserving original transfer syntax (no HTJ2K conversion)")
11401133

11411134
# Get transfer syntax constants
11421135
ts_constants = _get_transfer_syntax_constants()
@@ -1175,53 +1168,122 @@ def transcode_dicom_to_htj2k_multiframe(
11751168
# Use first dataset as template
11761169
template_ds = datasets[0]
11771170

1171+
# Determine transfer syntax from first dataset
1172+
if target_transfer_syntax is None:
1173+
target_transfer_syntax = str(getattr(template_ds.file_meta, 'TransferSyntaxUID', '1.2.840.10008.1.2.1'))
1174+
logger.info(f" Using original transfer syntax: {target_transfer_syntax}")
1175+
1176+
# Check if we're dealing with encapsulated (compressed) data
1177+
is_encapsulated = hasattr(template_ds, 'PixelData') and template_ds.file_meta.TransferSyntaxUID != pydicom.uid.ExplicitVRLittleEndian
1178+
11781179
# Collect all frames from all instances
1179-
all_decoded_frames = []
1180+
all_frames = [] # Will contain either numpy arrays (for HTJ2K) or bytes (for preserving)
11801181

1181-
for ds in datasets:
1182-
current_ts = str(getattr(ds.file_meta, 'TransferSyntaxUID', None))
1182+
if convert_to_htj2k:
1183+
# HTJ2K mode: decode all frames
1184+
for ds in datasets:
1185+
current_ts = str(getattr(ds.file_meta, 'TransferSyntaxUID', None))
1186+
1187+
if current_ts in NVIMGCODEC_SYNTAXES:
1188+
# Compressed format - use nvimgcodec decoder
1189+
frames = [fragment for fragment in pydicom.encaps.generate_frames(ds.PixelData)]
1190+
decoded = decoder.decode(frames, params=decode_params)
1191+
all_frames.extend(decoded)
1192+
else:
1193+
# Uncompressed format - use pydicom
1194+
pixel_array = ds.pixel_array
1195+
if not isinstance(pixel_array, np.ndarray):
1196+
pixel_array = np.array(pixel_array)
1197+
1198+
# Handle single frame vs multi-frame
1199+
if pixel_array.ndim == 2:
1200+
all_frames.append(pixel_array)
1201+
elif pixel_array.ndim == 3:
1202+
for frame_idx in range(pixel_array.shape[0]):
1203+
all_frames.append(pixel_array[frame_idx, :, :])
1204+
else:
1205+
# Preserve original encoding: extract frames without decoding
1206+
first_ts = str(getattr(datasets[0].file_meta, 'TransferSyntaxUID', None))
11831207

1184-
if current_ts in NVIMGCODEC_SYNTAXES:
1185-
# Compressed format - use nvimgcodec decoder
1186-
frames = [fragment for fragment in pydicom.encaps.generate_frames(ds.PixelData)]
1187-
decoded = decoder.decode(frames, params=decode_params)
1188-
all_decoded_frames.extend(decoded)
1208+
if first_ts in NVIMGCODEC_SYNTAXES or pydicom.encaps.encapsulate_extended:
1209+
# Encapsulated data - extract compressed frames
1210+
for ds in datasets:
1211+
if hasattr(ds, 'PixelData'):
1212+
try:
1213+
# Extract compressed frames
1214+
frames = [fragment for fragment in pydicom.encaps.generate_frames(ds.PixelData)]
1215+
all_frames.extend(frames)
1216+
except:
1217+
# Fall back to pixel_array for uncompressed
1218+
pixel_array = ds.pixel_array
1219+
if not isinstance(pixel_array, np.ndarray):
1220+
pixel_array = np.array(pixel_array)
1221+
if pixel_array.ndim == 2:
1222+
all_frames.append(pixel_array)
1223+
elif pixel_array.ndim == 3:
1224+
for frame_idx in range(pixel_array.shape[0]):
1225+
all_frames.append(pixel_array[frame_idx, :, :])
11891226
else:
1190-
# Uncompressed format - use pydicom
1191-
pixel_array = ds.pixel_array
1192-
if not isinstance(pixel_array, np.ndarray):
1193-
pixel_array = np.array(pixel_array)
1194-
1195-
# Handle single frame vs multi-frame
1196-
if pixel_array.ndim == 2:
1197-
# Single frame
1198-
pixel_array = pixel_array[:, :, np.newaxis]
1199-
all_decoded_frames.append(pixel_array)
1200-
elif pixel_array.ndim == 3:
1201-
# Multi-frame (frames are first dimension)
1202-
for frame_idx in range(pixel_array.shape[0]):
1203-
frame_2d = pixel_array[frame_idx, :, :]
1204-
if frame_2d.ndim == 2:
1205-
frame_2d = frame_2d[:, :, np.newaxis]
1206-
all_decoded_frames.append(frame_2d)
1227+
# Uncompressed data - use pixel arrays
1228+
for ds in datasets:
1229+
pixel_array = ds.pixel_array
1230+
if not isinstance(pixel_array, np.ndarray):
1231+
pixel_array = np.array(pixel_array)
1232+
if pixel_array.ndim == 2:
1233+
all_frames.append(pixel_array)
1234+
elif pixel_array.ndim == 3:
1235+
for frame_idx in range(pixel_array.shape[0]):
1236+
all_frames.append(pixel_array[frame_idx, :, :])
12071237

1208-
total_frame_count = len(all_decoded_frames)
1238+
total_frame_count = len(all_frames)
12091239
logger.info(f" Total frames in series: {total_frame_count}")
12101240

1211-
# Encode all frames to HTJ2K
1212-
logger.info(f" Encoding {total_frame_count} frames to HTJ2K...")
1213-
encoded_frames = encoder.encode(all_decoded_frames, codec="jpeg2k", params=encode_params)
1214-
1215-
# Convert to bytes
1216-
encoded_frames_bytes = [bytes(enc) for enc in encoded_frames]
1241+
# Encode frames based on conversion mode
1242+
if convert_to_htj2k:
1243+
logger.info(f" Encoding {total_frame_count} frames to HTJ2K...")
1244+
# Ensure frames have channel dimension for encoder
1245+
frames_for_encoding = []
1246+
for frame in all_frames:
1247+
if frame.ndim == 2:
1248+
frame = frame[:, :, np.newaxis]
1249+
frames_for_encoding.append(frame)
1250+
encoded_frames = encoder.encode(frames_for_encoding, codec="jpeg2k", params=encode_params)
1251+
# Convert to bytes
1252+
encoded_frames_bytes = [bytes(enc) for enc in encoded_frames]
1253+
else:
1254+
logger.info(f" Preserving original encoding for {total_frame_count} frames...")
1255+
# Check if frames are already bytes (encapsulated) or numpy arrays (uncompressed)
1256+
if len(all_frames) > 0 and isinstance(all_frames[0], bytes):
1257+
# Already encapsulated - use as-is
1258+
encoded_frames_bytes = all_frames
1259+
else:
1260+
# Uncompressed numpy arrays
1261+
encoded_frames_bytes = None
12171262

12181263
# Create SIMPLE multi-frame DICOM file (like the user's example)
12191264
# Use first dataset as template, keeping its metadata
12201265
logger.info(f" Creating simple multi-frame DICOM from {total_frame_count} frames...")
12211266
output_ds = datasets[0].copy() # Start from first dataset
12221267

1223-
# Update pixel data with all HTJ2K encoded frames + Basic Offset Table
1224-
output_ds.PixelData = _create_basic_offset_table_pixel_data(encoded_frames_bytes)
1268+
# CRITICAL: Set SOP Instance UID to match the SeriesInstanceUID (which will be the filename)
1269+
# This ensures the file's internal SOP Instance UID matches its filename
1270+
output_ds.SOPInstanceUID = series_uid
1271+
1272+
# Update pixel data based on conversion mode
1273+
if encoded_frames_bytes is not None:
1274+
# Encapsulated data (HTJ2K or preserved compressed format)
1275+
# Use Basic Offset Table for multi-frame efficiency
1276+
if add_basic_offset_table:
1277+
output_ds.PixelData = pydicom.encaps.encapsulate(encoded_frames_bytes, has_bot=True)
1278+
logger.info(f" ✓ Basic Offset Table included for efficient frame access")
1279+
else:
1280+
output_ds.PixelData = pydicom.encaps.encapsulate(encoded_frames_bytes)
1281+
else:
1282+
# Uncompressed mode: combine all frames into a 3D array
1283+
# Stack frames: (frames, rows, cols)
1284+
combined_pixel_array = np.stack(all_frames, axis=0)
1285+
output_ds.PixelData = combined_pixel_array.tobytes()
1286+
12251287
output_ds.file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
12261288

12271289
# Set NumberOfFrames (critical!)
@@ -1371,7 +1433,8 @@ def transcode_dicom_to_htj2k_multiframe(
13711433
logger.error(f" ❌ MISMATCH: Top-level Z={output_ds.ImagePositionPatient[2]} != Frame[0] Z={first_frame_pos[2]}")
13721434

13731435
logger.info(f" ✓ Created multi-frame with {total_frame_count} frames (OHIF-compatible)")
1374-
logger.info(f" ✓ Basic Offset Table included for efficient frame access")
1436+
if encoded_frames_bytes is not None:
1437+
logger.info(f" ✓ Basic Offset Table included for efficient frame access")
13751438

13761439
# Create output directory structure
13771440
study_output_dir = os.path.join(output_dir, study_uid)
@@ -1393,9 +1456,16 @@ def transcode_dicom_to_htj2k_multiframe(
13931456

13941457
elapsed_time = time.time() - start_time
13951458

1396-
logger.info(f"\nMulti-frame HTJ2K transcoding complete:")
1459+
if convert_to_htj2k:
1460+
logger.info(f"\nMulti-frame HTJ2K conversion complete:")
1461+
else:
1462+
logger.info(f"\nMulti-frame DICOM conversion complete:")
13971463
logger.info(f" Total series processed: {processed_series}")
1398-
logger.info(f" Total frames encoded: {total_frames}")
1464+
logger.info(f" Total frames combined: {total_frames}")
1465+
if convert_to_htj2k:
1466+
logger.info(f" Format: HTJ2K compressed")
1467+
else:
1468+
logger.info(f" Format: Original transfer syntax preserved")
13991469
logger.info(f" Time elapsed: {elapsed_time:.2f} seconds")
14001470
logger.info(f" Output directory: {output_dir}")
14011471

0 commit comments

Comments
 (0)