Skip to content

Commit 546e4dc

Browse files
committed
Refactor HTJ2K transcoding and add comprehensive test coverage
- Extract helper functions for frame extraction and validation - _extract_frames_from_compressed: Extract frames from encapsulated DICOM (now defaults to 1 frame for single-frame images without NumberOfFrames tag) - _extract_frames_from_uncompressed: Extract frames from pixel arrays - _validate_frames: Check for None values in decoded/encoded frames - _find_dicom_files: Recursively find DICOM files with proper sorting - Add PhotometricInterpretation update from YBR to RGB - Prevents double color space conversion by DICOM readers - Updates metadata to match actual RGB pixel data after nvimgcodec decoding - Add fancy_upsampling=1 option to nvimgcodec decoder - Add comprehensive test coverage using pydicom built-in examples: - test_transcode_multiframe_jpeg_ybr_to_htj2k: 30-frame JPEG with YBR_FULL_422 color space, verifies color space conversion and PhotometricInterpretation update (max_diff: 4.0, atol=5) - test_transcode_ct_example_to_htj2k: Uncompressed CT grayscale (MONOCHROME2), verifies lossless transcoding - test_transcode_mr_example_to_htj2k: Uncompressed MR grayscale (MONOCHROME2), verifies lossless transcoding - test_transcode_rgb_color_example_to_htj2k: Uncompressed RGB color image, verifies PhotometricInterpretation preservation and lossless transcoding - test_transcode_jpeg2k_example_to_htj2k: JPEG 2000 with YBR_RCT (reversible color transform), verifies PhotometricInterpretation update and perfect lossless conversion (max_diff: 0.0)
1 parent af3e93c commit 546e4dc

File tree

2 files changed

+534
-43
lines changed

2 files changed

+534
-43
lines changed

monailabel/datastore/utils/convert_htj2k.py

Lines changed: 185 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _get_nvimgcodec_decoder():
2929
global _NVIMGCODEC_DECODER
3030
if _NVIMGCODEC_DECODER is None:
3131
from nvidia import nvimgcodec
32-
_NVIMGCODEC_DECODER = nvimgcodec.Decoder()
32+
_NVIMGCODEC_DECODER = nvimgcodec.Decoder(options=':fancy_upsampling=1')
3333
return _NVIMGCODEC_DECODER
3434

3535

@@ -41,12 +41,10 @@ def _setup_htj2k_decode_params():
4141
nvimgcodec.DecodeParams: Decode parameters configured for DICOM
4242
"""
4343
from nvidia import nvimgcodec
44-
4544
decode_params = nvimgcodec.DecodeParams(
4645
allow_any_depth=True,
4746
color_spec=nvimgcodec.ColorSpec.UNCHANGED,
4847
)
49-
5048
return decode_params
5149

5250

@@ -82,6 +80,106 @@ def _setup_htj2k_encode_params(num_resolutions: int = 6, code_block_size: tuple
8280
return encode_params, target_transfer_syntax
8381

8482

83+
def _extract_frames_from_compressed(ds, number_of_frames=None):
84+
"""
85+
Extract frames from encapsulated (compressed) DICOM pixel data.
86+
87+
Args:
88+
ds: pydicom Dataset with encapsulated PixelData
89+
number_of_frames: Expected number of frames (from NumberOfFrames tag)
90+
91+
Returns:
92+
list: List of compressed frame data (bytes)
93+
"""
94+
# Default to 1 frame if not specified (for single-frame images without NumberOfFrames tag)
95+
if number_of_frames is None:
96+
number_of_frames = 1
97+
98+
frames = list(pydicom.encaps.generate_frames(ds.PixelData, number_of_frames=number_of_frames))
99+
return frames
100+
101+
102+
def _extract_frames_from_uncompressed(pixel_array, num_frames_tag):
103+
"""
104+
Extract individual frames from uncompressed pixel array.
105+
106+
Handles different array shapes:
107+
- 2D (H, W): single frame grayscale
108+
- 3D (N, H, W): multi-frame grayscale OR (H, W, C): single frame color
109+
- 4D (N, H, W, C): multi-frame color
110+
111+
Args:
112+
pixel_array: Numpy array of pixel data
113+
num_frames_tag: NumberOfFrames value from DICOM tag
114+
115+
Returns:
116+
list: List of frame arrays
117+
"""
118+
if not isinstance(pixel_array, np.ndarray):
119+
pixel_array = np.array(pixel_array)
120+
121+
# 2D: single frame grayscale
122+
if pixel_array.ndim == 2:
123+
return [pixel_array]
124+
125+
# 3D: multi-frame grayscale OR single-frame color
126+
if pixel_array.ndim == 3:
127+
if num_frames_tag > 1 or pixel_array.shape[0] == num_frames_tag:
128+
# Multi-frame grayscale: (N, H, W)
129+
return [pixel_array[i] for i in range(pixel_array.shape[0])]
130+
# Single-frame color: (H, W, C)
131+
return [pixel_array]
132+
133+
# 4D: multi-frame color
134+
if pixel_array.ndim == 4:
135+
return [pixel_array[i] for i in range(pixel_array.shape[0])]
136+
137+
raise ValueError(f"Unexpected pixel array dimensions: {pixel_array.ndim}")
138+
139+
140+
def _validate_frames(frames, context_msg="Frame"):
141+
"""
142+
Check for None values in decoded/encoded frames.
143+
144+
Args:
145+
frames: List of frames to validate
146+
context_msg: Context message for error reporting
147+
148+
Raises:
149+
ValueError: If any frame is None
150+
"""
151+
for idx, frame in enumerate(frames):
152+
if frame is None:
153+
raise ValueError(f"{context_msg} {idx} failed (returned None)")
154+
155+
156+
def _find_dicom_files(input_dir):
157+
"""
158+
Recursively find all valid DICOM files in a directory.
159+
160+
Args:
161+
input_dir: Directory to search
162+
163+
Returns:
164+
list: Sorted list of DICOM file paths
165+
"""
166+
valid_dicom_files = []
167+
for root, dirs, files in os.walk(input_dir):
168+
for f in files:
169+
file_path = os.path.join(root, f)
170+
if os.path.isfile(file_path):
171+
try:
172+
with open(file_path, "rb") as fp:
173+
fp.seek(128)
174+
if fp.read(4) == b"DICM":
175+
valid_dicom_files.append(file_path)
176+
except Exception:
177+
continue
178+
179+
valid_dicom_files.sort() # For reproducible processing order
180+
return valid_dicom_files
181+
182+
85183
def _get_transfer_syntax_constants():
86184
"""
87185
Get transfer syntax UID constants for categorizing DICOM files.
@@ -131,12 +229,17 @@ def transcode_dicom_to_htj2k(
131229
accelerated decoding and encoding with batch processing for optimal performance.
132230
All transcoding is performed using lossless compression to preserve image quality.
133231
134-
The function processes files in configurable batches:
232+
The function processes files with streaming decode-encode batches:
135233
1. Categorizes files by transfer syntax (HTJ2K/JPEG2000/JPEG/uncompressed)
136-
2. Uses nvimgcodec decoder for compressed files (HTJ2K, JPEG2000, JPEG)
137-
3. Falls back to pydicom pixel_array for uncompressed files
138-
4. Batch encodes all images to HTJ2K using nvimgcodec
139-
5. Saves transcoded files with updated transfer syntax and optional Basic Offset Table
234+
2. Extracts all frames from source files
235+
3. Processes frames in batches of max_batch_size:
236+
- Decodes batch using nvimgcodec (compressed) or pydicom (uncompressed)
237+
- Immediately encodes batch to HTJ2K
238+
- Discards decoded frames to save memory (streaming)
239+
4. Saves transcoded files with updated transfer syntax and optional Basic Offset Table
240+
241+
This streaming approach minimizes memory usage by never holding all decoded frames
242+
in memory simultaneously.
140243
141244
Supported source transfer syntaxes:
142245
- HTJ2K (High-Throughput JPEG 2000) - decoded and re-encoded to add BOT if needed
@@ -217,21 +320,8 @@ def transcode_dicom_to_htj2k(
217320
if not os.path.isdir(input_dir):
218321
raise ValueError(f"Input path is not a directory: {input_dir}")
219322

220-
# Recursively find all files under input_dir that have the DICOM magic bytes at offset 128
221-
valid_dicom_files = []
222-
for root, dirs, files in os.walk(input_dir):
223-
for f in files:
224-
file_path = os.path.join(root, f)
225-
if os.path.isfile(file_path):
226-
try:
227-
with open(file_path, "rb") as fp:
228-
fp.seek(128)
229-
magic = fp.read(4)
230-
if magic == b"DICM":
231-
valid_dicom_files.append(file_path)
232-
except Exception:
233-
continue
234-
323+
# Find all valid DICOM files
324+
valid_dicom_files = _find_dicom_files(input_dir)
235325
if not valid_dicom_files:
236326
raise ValueError(f"No valid DICOM files found in {input_dir}")
237327

@@ -288,33 +378,76 @@ def transcode_dicom_to_htj2k(
288378
else:
289379
pydicom_batch.append(idx)
290380

291-
data_sequence = []
292-
decoded_data = []
293381
num_frames = []
382+
encoded_data = []
294383

295-
# Decode using nvimgcodec for compressed formats
384+
# Process nvimgcodec_batch: extract frames, decode, encode in streaming batches
296385
if nvimgcodec_batch:
386+
# First, extract all compressed frames from all files
387+
all_compressed_frames = []
388+
389+
logger.info(f" Extracting frames from {len(nvimgcodec_batch)} nvimgcodec files:")
297390
for idx in nvimgcodec_batch:
298-
frames = [fragment for fragment in pydicom.encaps.generate_frames(batch_datasets[idx].PixelData)]
391+
ds = batch_datasets[idx]
392+
number_of_frames = int(ds.NumberOfFrames) if hasattr(ds, 'NumberOfFrames') else None
393+
frames = _extract_frames_from_compressed(ds, number_of_frames)
394+
logger.info(f" File idx={idx} ({os.path.basename(batch_files[idx])}): extracted {len(frames)} frames (expected: {number_of_frames})")
299395
num_frames.append(len(frames))
300-
data_sequence.extend(frames)
301-
decoder_output = decoder.decode(data_sequence, params=decode_params)
302-
decoded_data.extend(decoder_output)
396+
all_compressed_frames.extend(frames)
397+
398+
# Now decode and encode in batches (streaming to reduce memory)
399+
total_frames = len(all_compressed_frames)
400+
logger.info(f" Processing {total_frames} frames from {len(nvimgcodec_batch)} files in batches of {max_batch_size}")
401+
402+
for frame_batch_start in range(0, total_frames, max_batch_size):
403+
frame_batch_end = min(frame_batch_start + max_batch_size, total_frames)
404+
compressed_batch = all_compressed_frames[frame_batch_start:frame_batch_end]
405+
406+
if total_frames > max_batch_size:
407+
logger.info(f" Processing frames [{frame_batch_start}..{frame_batch_end}) of {total_frames}")
408+
409+
# Decode batch
410+
decoded_batch = decoder.decode(compressed_batch, params=decode_params)
411+
_validate_frames(decoded_batch, f"Decoded frame [{frame_batch_start}+")
412+
413+
# Encode batch immediately (streaming - no need to keep decoded data)
414+
encoded_batch = encoder.encode(decoded_batch, codec="jpeg2k", params=encode_params)
415+
_validate_frames(encoded_batch, f"Encoded frame [{frame_batch_start}+")
416+
417+
# Store encoded frames and discard decoded frames to save memory
418+
encoded_data.extend(encoded_batch)
419+
# decoded_batch is automatically freed here
303420

304-
# Decode using pydicom for uncompressed formats
421+
# Process pydicom_batch: extract frames and encode in streaming batches
305422
if pydicom_batch:
423+
# Extract all frames from uncompressed files
424+
all_decoded_frames = []
425+
306426
for idx in pydicom_batch:
307-
source_pixel_array = batch_datasets[idx].pixel_array
308-
if not isinstance(source_pixel_array, np.ndarray):
309-
source_pixel_array = np.array(source_pixel_array)
310-
if source_pixel_array.ndim == 2:
311-
source_pixel_array = source_pixel_array[:, :, np.newaxis]
312-
for frame_idx in range(source_pixel_array.shape[-1]):
313-
decoded_data.append(source_pixel_array[:, :, frame_idx])
314-
num_frames.append(source_pixel_array.shape[-1])
315-
316-
# Encode all frames to HTJ2K
317-
encoded_data = encoder.encode(decoded_data, codec="jpeg2k", params=encode_params)
427+
ds = batch_datasets[idx]
428+
num_frames_tag = int(ds.NumberOfFrames) if hasattr(ds, 'NumberOfFrames') else 1
429+
frames = _extract_frames_from_uncompressed(ds.pixel_array, num_frames_tag)
430+
all_decoded_frames.extend(frames)
431+
num_frames.append(len(frames))
432+
433+
# Encode in batches (streaming)
434+
total_frames = len(all_decoded_frames)
435+
if total_frames > 0:
436+
logger.info(f" Encoding {total_frames} uncompressed frames in batches of {max_batch_size}")
437+
438+
for frame_batch_start in range(0, total_frames, max_batch_size):
439+
frame_batch_end = min(frame_batch_start + max_batch_size, total_frames)
440+
decoded_batch = all_decoded_frames[frame_batch_start:frame_batch_end]
441+
442+
if total_frames > max_batch_size:
443+
logger.info(f" Encoding frames [{frame_batch_start}..{frame_batch_end}) of {total_frames}")
444+
445+
# Encode batch
446+
encoded_batch = encoder.encode(decoded_batch, codec="jpeg2k", params=encode_params)
447+
_validate_frames(encoded_batch, f"Encoded frame [{frame_batch_start}+")
448+
449+
# Store encoded frames
450+
encoded_data.extend(encoded_batch)
318451

319452
# Reassemble and save transcoded files
320453
frame_offset = 0
@@ -334,7 +467,16 @@ def transcode_dicom_to_htj2k(
334467
batch_datasets[dataset_idx].PixelData = pydicom.encaps.encapsulate(encoded_frames)
335468

336469
batch_datasets[dataset_idx].file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
337-
470+
471+
# Update PhotometricInterpretation to RGB since we decoded with SRGB color_spec
472+
# The pixel data is now in RGB color space, so the metadata must reflect this
473+
# to prevent double conversion by DICOM readers
474+
if hasattr(batch_datasets[dataset_idx], 'PhotometricInterpretation'):
475+
original_pi = batch_datasets[dataset_idx].PhotometricInterpretation
476+
if original_pi.startswith('YBR'):
477+
batch_datasets[dataset_idx].PhotometricInterpretation = 'RGB'
478+
logger.info(f" Updated PhotometricInterpretation: {original_pi} -> RGB")
479+
338480
# Save transcoded file
339481
output_file = os.path.join(output_dir, os.path.basename(batch_files[dataset_idx]))
340482
batch_datasets[dataset_idx].save_as(output_file)

0 commit comments

Comments
 (0)