Skip to content

Commit 13f3377

Browse files
committed
Set color_spec explicitly to RGB when decoding YBR Photometric interpretations.
Group frames per PhotometricInterpretation before sending them to decode. Signed-off-by: Joaquin Anton Guirao <janton@nvidia.com>
1 parent 546e4dc commit 13f3377

File tree

1 file changed

+89
-29
lines changed

1 file changed

+89
-29
lines changed

monailabel/datastore/utils/convert_htj2k.py

Lines changed: 89 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,22 @@ def _get_nvimgcodec_decoder():
3333
return _NVIMGCODEC_DECODER
3434

3535

36-
def _setup_htj2k_decode_params():
36+
def _setup_htj2k_decode_params(color_spec=None):
3737
"""
3838
Create nvimgcodec decoding parameters for DICOM images.
3939
40+
Args:
41+
color_spec: Color specification to use. If None, defaults to UNCHANGED.
42+
4043
Returns:
4144
nvimgcodec.DecodeParams: Decode parameters configured for DICOM
4245
"""
4346
from nvidia import nvimgcodec
47+
if color_spec is None:
48+
color_spec = nvimgcodec.ColorSpec.UNCHANGED
4449
decode_params = nvimgcodec.DecodeParams(
4550
allow_any_depth=True,
46-
color_spec=nvimgcodec.ColorSpec.UNCHANGED,
51+
color_spec=color_spec,
4752
)
4853
return decode_params
4954

@@ -337,12 +342,12 @@ def transcode_dicom_to_htj2k(
337342
encoder = _get_nvimgcodec_encoder()
338343
decoder = _get_nvimgcodec_decoder() # Always needed for decoding input DICOM images
339344

340-
# Setup HTJ2K encoding and decoding parameters
345+
# Setup HTJ2K encoding parameters
341346
encode_params, target_transfer_syntax = _setup_htj2k_encode_params(
342347
num_resolutions=num_resolutions,
343348
code_block_size=code_block_size
344349
)
345-
decode_params = _setup_htj2k_decode_params()
350+
# Note: decode_params is created per-PhotometricInterpretation group in the batch processing
346351
logger.info("Using lossless HTJ2K compression")
347352

348353
# Get transfer syntax constants
@@ -383,40 +388,78 @@ def transcode_dicom_to_htj2k(
383388

384389
# Process nvimgcodec_batch: extract frames, decode, encode in streaming batches
385390
if nvimgcodec_batch:
386-
# First, extract all compressed frames from all files
387-
all_compressed_frames = []
391+
from collections import defaultdict
392+
393+
# First, extract all compressed frames and group by PhotometricInterpretation
394+
grouped_frames = defaultdict(list) # Key: PhotometricInterpretation, Value: list of (file_idx, frame_data)
395+
frame_counts = {} # Track number of frames per file
388396

389397
logger.info(f" Extracting frames from {len(nvimgcodec_batch)} nvimgcodec files:")
390398
for idx in nvimgcodec_batch:
391399
ds = batch_datasets[idx]
392400
number_of_frames = int(ds.NumberOfFrames) if hasattr(ds, 'NumberOfFrames') else None
393401
frames = _extract_frames_from_compressed(ds, number_of_frames)
394402
logger.info(f" File idx={idx} ({os.path.basename(batch_files[idx])}): extracted {len(frames)} frames (expected: {number_of_frames})")
403+
404+
# Get PhotometricInterpretation for this file
405+
photometric = getattr(ds, 'PhotometricInterpretation', 'UNKNOWN')
406+
407+
# Store frames grouped by PhotometricInterpretation
408+
for frame in frames:
409+
grouped_frames[photometric].append((idx, frame))
410+
411+
frame_counts[idx] = len(frames)
395412
num_frames.append(len(frames))
396-
all_compressed_frames.extend(frames)
397413

398-
# Now decode and encode in batches (streaming to reduce memory)
399-
total_frames = len(all_compressed_frames)
400-
logger.info(f" Processing {total_frames} frames from {len(nvimgcodec_batch)} files in batches of {max_batch_size}")
414+
# Process each PhotometricInterpretation group separately
415+
logger.info(f" Found {len(grouped_frames)} unique PhotometricInterpretation groups")
416+
417+
# Track encoded frames per file to maintain order
418+
encoded_frames_by_file = {idx: [] for idx in nvimgcodec_batch}
401419

402-
for frame_batch_start in range(0, total_frames, max_batch_size):
403-
frame_batch_end = min(frame_batch_start + max_batch_size, total_frames)
404-
compressed_batch = all_compressed_frames[frame_batch_start:frame_batch_end]
420+
for photometric, frame_list in grouped_frames.items():
421+
# Determine color_spec based on PhotometricInterpretation
422+
if photometric.startswith('YBR'):
423+
color_spec = nvimgcodec.ColorSpec.RGB
424+
logger.info(f" Processing {len(frame_list)} frames with PhotometricInterpretation={photometric} using color_spec=RGB")
425+
else:
426+
color_spec = nvimgcodec.ColorSpec.UNCHANGED
427+
logger.info(f" Processing {len(frame_list)} frames with PhotometricInterpretation={photometric} using color_spec=UNCHANGED")
405428

406-
if total_frames > max_batch_size:
407-
logger.info(f" Processing frames [{frame_batch_start}..{frame_batch_end}) of {total_frames}")
429+
# Create decode params for this group
430+
group_decode_params = _setup_htj2k_decode_params(color_spec=color_spec)
408431

409-
# Decode batch
410-
decoded_batch = decoder.decode(compressed_batch, params=decode_params)
411-
_validate_frames(decoded_batch, f"Decoded frame [{frame_batch_start}+")
432+
# Extract just the frame data (without file index)
433+
compressed_frames = [frame_data for _, frame_data in frame_list]
412434

413-
# Encode batch immediately (streaming - no need to keep decoded data)
414-
encoded_batch = encoder.encode(decoded_batch, codec="jpeg2k", params=encode_params)
415-
_validate_frames(encoded_batch, f"Encoded frame [{frame_batch_start}+")
435+
# Decode and encode in batches (streaming to reduce memory)
436+
total_frames = len(compressed_frames)
416437

417-
# Store encoded frames and discard decoded frames to save memory
418-
encoded_data.extend(encoded_batch)
419-
# decoded_batch is automatically freed here
438+
for frame_batch_start in range(0, total_frames, max_batch_size):
439+
frame_batch_end = min(frame_batch_start + max_batch_size, total_frames)
440+
compressed_batch = compressed_frames[frame_batch_start:frame_batch_end]
441+
file_indices_batch = [file_idx for file_idx, _ in frame_list[frame_batch_start:frame_batch_end]]
442+
443+
if total_frames > max_batch_size:
444+
logger.info(f" Processing frames [{frame_batch_start}..{frame_batch_end}) of {total_frames} for {photometric}")
445+
446+
# Decode batch with appropriate color_spec
447+
decoded_batch = decoder.decode(compressed_batch, params=group_decode_params)
448+
_validate_frames(decoded_batch, f"Decoded frame [{frame_batch_start}+")
449+
450+
# Encode batch immediately (streaming - no need to keep decoded data)
451+
encoded_batch = encoder.encode(decoded_batch, codec="jpeg2k", params=encode_params)
452+
_validate_frames(encoded_batch, f"Encoded frame [{frame_batch_start}+")
453+
454+
# Store encoded frames by file index to maintain order
455+
for file_idx, encoded_frame in zip(file_indices_batch, encoded_batch):
456+
encoded_frames_by_file[file_idx].append(encoded_frame)
457+
458+
# decoded_batch is automatically freed here
459+
460+
# Reconstruct encoded_data in original file order
461+
for idx in nvimgcodec_batch:
462+
encoded_data.extend(encoded_frames_by_file[idx])
420463

421464
# Process pydicom_batch: extract frames and encode in streaming batches
422465
if pydicom_batch:
@@ -468,7 +511,7 @@ def transcode_dicom_to_htj2k(
468511

469512
batch_datasets[dataset_idx].file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
470513

471-
# Update PhotometricInterpretation to RGB since we decoded with SRGB color_spec
514+
# Update PhotometricInterpretation to RGB for YBR images since we decoded with RGB color_spec
472515
# The pixel data is now in RGB color space, so the metadata must reflect this
473516
# to prevent double conversion by DICOM readers
474517
if hasattr(batch_datasets[dataset_idx], 'PhotometricInterpretation'):
@@ -647,19 +690,18 @@ def convert_single_frame_dicom_series_to_multiframe(
647690
encoder = _get_nvimgcodec_encoder()
648691
decoder = _get_nvimgcodec_decoder()
649692

650-
# Setup HTJ2K encoding and decoding parameters
693+
# Setup HTJ2K encoding parameters
651694
encode_params, target_transfer_syntax = _setup_htj2k_encode_params(
652695
num_resolutions=num_resolutions,
653696
code_block_size=code_block_size
654697
)
655-
decode_params = _setup_htj2k_decode_params()
698+
# Note: decode_params is created per-series based on PhotometricInterpretation
656699
logger.info("HTJ2K conversion enabled")
657700
else:
658701
# No conversion - preserve original transfer syntax
659702
encoder = None
660703
decoder = None
661704
encode_params = None
662-
decode_params = None
663705
target_transfer_syntax = None # Will be determined from first dataset
664706
logger.info("Preserving original transfer syntax (no HTJ2K conversion)")
665707

@@ -708,6 +750,17 @@ def convert_single_frame_dicom_series_to_multiframe(
708750
# Check if we're dealing with encapsulated (compressed) data
709751
is_encapsulated = hasattr(template_ds, 'PixelData') and template_ds.file_meta.TransferSyntaxUID != pydicom.uid.ExplicitVRLittleEndian
710752

753+
# Determine color_spec for this series based on PhotometricInterpretation
754+
if convert_to_htj2k:
755+
photometric = getattr(template_ds, 'PhotometricInterpretation', 'UNKNOWN')
756+
if photometric.startswith('YBR'):
757+
series_color_spec = nvimgcodec.ColorSpec.RGB
758+
logger.info(f" Series PhotometricInterpretation={photometric}, using color_spec=RGB")
759+
else:
760+
series_color_spec = nvimgcodec.ColorSpec.UNCHANGED
761+
logger.info(f" Series PhotometricInterpretation={photometric}, using color_spec=UNCHANGED")
762+
series_decode_params = _setup_htj2k_decode_params(color_spec=series_color_spec)
763+
711764
# Collect all frames from all instances
712765
all_frames = [] # Will contain either numpy arrays (for HTJ2K) or bytes (for preserving)
713766

@@ -719,7 +772,7 @@ def convert_single_frame_dicom_series_to_multiframe(
719772
if current_ts in NVIMGCODEC_SYNTAXES:
720773
# Compressed format - use nvimgcodec decoder
721774
frames = [fragment for fragment in pydicom.encaps.generate_frames(ds.PixelData)]
722-
decoded = decoder.decode(frames, params=decode_params)
775+
decoded = decoder.decode(frames, params=series_decode_params)
723776
all_frames.extend(decoded)
724777
else:
725778
# Uncompressed format - use pydicom
@@ -818,6 +871,13 @@ def convert_single_frame_dicom_series_to_multiframe(
818871

819872
output_ds.file_meta.TransferSyntaxUID = pydicom.uid.UID(target_transfer_syntax)
820873

874+
# Update PhotometricInterpretation if we converted from YBR to RGB
875+
if convert_to_htj2k and hasattr(output_ds, 'PhotometricInterpretation'):
876+
original_pi = output_ds.PhotometricInterpretation
877+
if original_pi.startswith('YBR'):
878+
output_ds.PhotometricInterpretation = 'RGB'
879+
logger.info(f" Updated PhotometricInterpretation: {original_pi} -> RGB")
880+
821881
# Set NumberOfFrames (critical!)
822882
output_ds.NumberOfFrames = total_frame_count
823883

0 commit comments

Comments
 (0)