Skip to content

Commit c18a166

Browse files
committed
Refactor HTJ2K transcoding and add multi-frame YBR JPEG test
- Extract helper functions for frame extraction and validation - _extract_frames_from_compressed: Extract frames from encapsulated DICOM - _extract_frames_from_uncompressed: Extract frames from pixel arrays - _validate_frames: Check for None values in decoded/encoded frames - _find_dicom_files: Recursively find DICOM files with proper sorting - Add PhotometricInterpretation update from YBR to RGB - Prevents double color space conversion by DICOM readers - Updates metadata to match actual RGB pixel data after nvimgcodec decoding - Add fancy_upsampling=1 option to nvimgcodec decoder - Add test_transcode_multiframe_jpeg_ybr_to_htj2k test - Tests transcoding of 30-frame JPEG file with YBR_FULL_422 color space - Verifies PhotometricInterpretation update and color space conversion - Uses pydicom's built-in examples_ybr_color.dcm test file - Validates pixel values match within tolerance (atol=5)
1 parent af3e93c commit c18a166

File tree

2 files changed

+310
-43
lines changed

2 files changed

+310
-43
lines changed

monailabel/datastore/utils/convert_htj2k.py

Lines changed: 181 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _get_nvimgcodec_decoder():
2929
global _NVIMGCODEC_DECODER
3030
if _NVIMGCODEC_DECODER is None:
3131
from nvidia import nvimgcodec
32-
_NVIMGCODEC_DECODER = nvimgcodec.Decoder()
32+
_NVIMGCODEC_DECODER = nvimgcodec.Decoder(options=':fancy_upsampling=1')
3333
return _NVIMGCODEC_DECODER
3434

3535

@@ -41,12 +41,10 @@ def _setup_htj2k_decode_params():
4141
nvimgcodec.DecodeParams: Decode parameters configured for DICOM
4242
"""
4343
from nvidia import nvimgcodec
44-
4544
decode_params = nvimgcodec.DecodeParams(
4645
allow_any_depth=True,
4746
color_spec=nvimgcodec.ColorSpec.UNCHANGED,
4847
)
49-
5048
return decode_params
5149

5250

@@ -82,6 +80,102 @@ def _setup_htj2k_encode_params(num_resolutions: int = 6, code_block_size: tuple
8280
return encode_params, target_transfer_syntax
8381

8482

83+
def _extract_frames_from_compressed(ds, number_of_frames=None):
84+
"""
85+
Extract frames from encapsulated (compressed) DICOM pixel data.
86+
87+
Args:
88+
ds: pydicom Dataset with encapsulated PixelData
89+
number_of_frames: Expected number of frames (from NumberOfFrames tag)
90+
91+
Returns:
92+
list: List of compressed frame data (bytes)
93+
"""
94+
frames = list(pydicom.encaps.generate_frames(ds.PixelData, number_of_frames=number_of_frames))
95+
return frames
96+
97+
98+
def _extract_frames_from_uncompressed(pixel_array, num_frames_tag):
99+
"""
100+
Extract individual frames from uncompressed pixel array.
101+
102+
Handles different array shapes:
103+
- 2D (H, W): single frame grayscale
104+
- 3D (N, H, W): multi-frame grayscale OR (H, W, C): single frame color
105+
- 4D (N, H, W, C): multi-frame color
106+
107+
Args:
108+
pixel_array: Numpy array of pixel data
109+
num_frames_tag: NumberOfFrames value from DICOM tag
110+
111+
Returns:
112+
list: List of frame arrays
113+
"""
114+
if not isinstance(pixel_array, np.ndarray):
115+
pixel_array = np.array(pixel_array)
116+
117+
# 2D: single frame grayscale
118+
if pixel_array.ndim == 2:
119+
return [pixel_array]
120+
121+
# 3D: multi-frame grayscale OR single-frame color
122+
if pixel_array.ndim == 3:
123+
if num_frames_tag > 1 or pixel_array.shape[0] == num_frames_tag:
124+
# Multi-frame grayscale: (N, H, W)
125+
return [pixel_array[i] for i in range(pixel_array.shape[0])]
126+
# Single-frame color: (H, W, C)
127+
return [pixel_array]
128+
129+
# 4D: multi-frame color
130+
if pixel_array.ndim == 4:
131+
return [pixel_array[i] for i in range(pixel_array.shape[0])]
132+
133+
raise ValueError(f"Unexpected pixel array dimensions: {pixel_array.ndim}")
134+
135+
136+
def _validate_frames(frames, context_msg="Frame"):
137+
"""
138+
Check for None values in decoded/encoded frames.
139+
140+
Args:
141+
frames: List of frames to validate
142+
context_msg: Context message for error reporting
143+
144+
Raises:
145+
ValueError: If any frame is None
146+
"""
147+
for idx, frame in enumerate(frames):
148+
if frame is None:
149+
raise ValueError(f"{context_msg} {idx} failed (returned None)")
150+
151+
152+
def _find_dicom_files(input_dir):
153+
"""
154+
Recursively find all valid DICOM files in a directory.
155+
156+
Args:
157+
input_dir: Directory to search
158+
159+
Returns:
160+
list: Sorted list of DICOM file paths
161+
"""
162+
valid_dicom_files = []
163+
for root, dirs, files in os.walk(input_dir):
164+
for f in files:
165+
file_path = os.path.join(root, f)
166+
if os.path.isfile(file_path):
167+
try:
168+
with open(file_path, "rb") as fp:
169+
fp.seek(128)
170+
if fp.read(4) == b"DICM":
171+
valid_dicom_files.append(file_path)
172+
except Exception:
173+
continue
174+
175+
valid_dicom_files.sort() # For reproducible processing order
176+
return valid_dicom_files
177+
178+
85179
def _get_transfer_syntax_constants():
86180
"""
87181
Get transfer syntax UID constants for categorizing DICOM files.
@@ -131,12 +225,17 @@ def transcode_dicom_to_htj2k(
131225
accelerated decoding and encoding with batch processing for optimal performance.
132226
All transcoding is performed using lossless compression to preserve image quality.
133227
134-
The function processes files in configurable batches:
228+
The function processes files with streaming decode-encode batches:
135229
1. Categorizes files by transfer syntax (HTJ2K/JPEG2000/JPEG/uncompressed)
136-
2. Uses nvimgcodec decoder for compressed files (HTJ2K, JPEG2000, JPEG)
137-
3. Falls back to pydicom pixel_array for uncompressed files
138-
4. Batch encodes all images to HTJ2K using nvimgcodec
139-
5. Saves transcoded files with updated transfer syntax and optional Basic Offset Table
230+
2. Extracts all frames from source files
231+
3. Processes frames in batches of max_batch_size:
232+
- Decodes batch using nvimgcodec (compressed) or pydicom (uncompressed)
233+
- Immediately encodes batch to HTJ2K
234+
- Discards decoded frames to save memory (streaming)
235+
4. Saves transcoded files with updated transfer syntax and optional Basic Offset Table
236+
237+
This streaming approach minimizes memory usage by never holding all decoded frames
238+
in memory simultaneously.
140239
141240
Supported source transfer syntaxes:
142241
- HTJ2K (High-Throughput JPEG 2000) - decoded and re-encoded to add BOT if needed
@@ -217,21 +316,8 @@ def transcode_dicom_to_htj2k(
217316
if not os.path.isdir(input_dir):
218317
raise ValueError(f"Input path is not a directory: {input_dir}")
219318

220-
# Recursively find all files under input_dir that have the DICOM magic bytes at offset 128
221-
valid_dicom_files = []
222-
for root, dirs, files in os.walk(input_dir):
223-
for f in files:
224-
file_path = os.path.join(root, f)
225-
if os.path.isfile(file_path):
226-
try:
227-
with open(file_path, "rb") as fp:
228-
fp.seek(128)
229-
magic = fp.read(4)
230-
if magic == b"DICM":
231-
valid_dicom_files.append(file_path)
232-
except Exception:
233-
continue
234-
319+
# Find all valid DICOM files
320+
valid_dicom_files = _find_dicom_files(input_dir)
235321
if not valid_dicom_files:
236322
raise ValueError(f"No valid DICOM files found in {input_dir}")
237323

@@ -288,33 +374,76 @@ def transcode_dicom_to_htj2k(
288374
else:
289375
pydicom_batch.append(idx)
290376

291-
data_sequence = []
292-
decoded_data = []
293377
num_frames = []
378+
encoded_data = []
294379

295-
# Decode using nvimgcodec for compressed formats
380+
# Process nvimgcodec_batch: extract frames, decode, encode in streaming batches
296381
if nvimgcodec_batch:
382+
# First, extract all compressed frames from all files
383+
all_compressed_frames = []
384+
385+
logger.info(f" Extracting frames from {len(nvimgcodec_batch)} nvimgcodec files:")
297386
for idx in nvimgcodec_batch:
298-
frames = [fragment for fragment in pydicom.encaps.generate_frames(batch_datasets[idx].PixelData)]
387+
ds = batch_datasets[idx]
388+
number_of_frames = int(ds.NumberOfFrames) if hasattr(ds, 'NumberOfFrames') else None
389+
frames = _extract_frames_from_compressed(ds, number_of_frames)
390+
logger.info(f" File idx={idx} ({os.path.basename(batch_files[idx])}): extracted {len(frames)} frames (expected: {number_of_frames})")
299391
num_frames.append(len(frames))
300-
data_sequence.extend(frames)
301-
decoder_output = decoder.decode(data_sequence, params=decode_params)
302-
decoded_data.extend(decoder_output)
392+
all_compressed_frames.extend(frames)
393+
394+
# Now decode and encode in batches (streaming to reduce memory)
395+
total_frames = len(all_compressed_frames)
396+
logger.info(f" Processing {total_frames} frames from {len(nvimgcodec_batch)} files in batches of {max_batch_size}")
397+
398+
for frame_batch_start in range(0, total_frames, max_batch_size):
399+
frame_batch_end = min(frame_batch_start + max_batch_size, total_frames)
400+
compressed_batch = all_compressed_frames[frame_batch_start:frame_batch_end]
401+
402+
if total_frames > max_batch_size:
403+
logger.info(f" Processing frames [{frame_batch_start}..{frame_batch_end}) of {total_frames}")
404+
405+
# Decode batch
406+
decoded_batch = decoder.decode(compressed_batch, params=decode_params)
407+
_validate_frames(decoded_batch, f"Decoded frame [{frame_batch_start}+")
408+
409+
# Encode batch immediately (streaming - no need to keep decoded data)
410+
encoded_batch = encoder.encode(decoded_batch, codec="jpeg2k", params=encode_params)
411+
_validate_frames(encoded_batch, f"Encoded frame [{frame_batch_start}+")
412+
413+
# Store encoded frames and discard decoded frames to save memory
414+
encoded_data.extend(encoded_batch)
415+
# decoded_batch is automatically freed here
303416

304-
# Decode using pydicom for uncompressed formats
417+
# Process pydicom_batch: extract frames and encode in streaming batches
305418
if pydicom_batch:
419+
# Extract all frames from uncompressed files
420+
all_decoded_frames = []
421+
306422
for idx in pydicom_batch:
307-
source_pixel_array = batch_datasets[idx].pixel_array
308-
if not isinstance(source_pixel_array, np.ndarray):
309-
source_pixel_array = np.array(source_pixel_array)
310-
if source_pixel_array.ndim == 2:
311-
source_pixel_array = source_pixel_array[:, :, np.newaxis]
312-
for frame_idx in range(source_pixel_array.shape[-1]):
313-
decoded_data.append(source_pixel_array[:, :, frame_idx])
314-
num_frames.append(source_pixel_array.shape[-1])
315-
316-
# Encode all frames to HTJ2K
317-
encoded_data = encoder.encode(decoded_data, codec="jpeg2k", params=encode_params)
423+
ds = batch_datasets[idx]
424+
num_frames_tag = int(ds.NumberOfFrames) if hasattr(ds, 'NumberOfFrames') else 1
425+
frames = _extract_frames_from_uncompressed(ds.pixel_array, num_frames_tag)
426+
all_decoded_frames.extend(frames)
427+
num_frames.append(len(frames))
428+
429+
# Encode in batches (streaming)
430+
total_frames = len(all_decoded_frames)
431+
if total_frames > 0:
432+
logger.info(f" Encoding {total_frames} uncompressed frames in batches of {max_batch_size}")
433+
434+
for frame_batch_start in range(0, total_frames, max_batch_size):
435+
frame_batch_end = min(frame_batch_start + max_batch_size, total_frames)
436+
decoded_batch = all_decoded_frames[frame_batch_start:frame_batch_end]
437+
438+
if total_frames > max_batch_size:
439+
logger.info(f" Encoding frames [{frame_batch_start}..{frame_batch_end}) of {total_frames}")
440+
441+
# Encode batch
442+
encoded_batch = encoder.encode(decoded_batch, codec="jpeg2k", params=encode_params)
443+
_validate_frames(encoded_batch, f"Encoded frame [{frame_batch_start}+")
444+
445+
# Store encoded frames
446+
encoded_data.extend(encoded_batch)
318447

319448
# Reassemble and save transcoded files
320449
frame_offset = 0
@@ -334,7 +463,16 @@ def transcode_dicom_to_htj2k(
334463
batch_datasets[dataset_idx].PixelData = pydicom.encaps.encapsulate(encoded_frames)
335464

336465
batch_datasets[dataset_idx].file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
337-
466+
467+
# Update PhotometricInterpretation to RGB since we decoded with SRGB color_spec
468+
# The pixel data is now in RGB color space, so the metadata must reflect this
469+
# to prevent double conversion by DICOM readers
470+
if hasattr(batch_datasets[dataset_idx], 'PhotometricInterpretation'):
471+
original_pi = batch_datasets[dataset_idx].PhotometricInterpretation
472+
if original_pi.startswith('YBR'):
473+
batch_datasets[dataset_idx].PhotometricInterpretation = 'RGB'
474+
logger.info(f" Updated PhotometricInterpretation: {original_pi} -> RGB")
475+
338476
# Save transcoded file
339477
output_file = os.path.join(output_dir, os.path.basename(batch_files[dataset_idx]))
340478
batch_datasets[dataset_idx].save_as(output_file)

0 commit comments

Comments
 (0)