Skip to content

Commit 0a3fd79

Browse files
committed
Add comprehensive multi-frame HTJ2K DICOM testing and improve segmentation validation
This commit adds extensive test coverage for multi-frame HTJ2K DICOM handling and improves segmentation output validation across different DICOM formats. Test Improvements - test_dicom_segmentation.py: - Add _load_segmentation_array() helper for consistent segmentation loading - Add _compare_segmentations() helper using Dice coefficient and pixel accuracy - Refactor test_04 to test_04_compare_all_formats for comprehensive cross-format comparison * Compares Standard DICOM, HTJ2K, and Multi-frame HTJ2K outputs * Validates all formats produce highly similar segmentations (Dice > 0.95) - Improve test_05_compare_dicom_vs_nifti with actual segmentation comparison logic - Update test_06_multiframe_htj2k_inference with corrected test data path - Remove redundant tests (test_07, test_08, test_09) - functionality consolidated in test_04 Multi-frame HTJ2K Tests - test_convert.py: - Add HTJ2K_TRANSFER_SYNTAXES constant for explicit transfer syntax validation - Add test_transcode_dicom_to_htj2k_multiframe_metadata() * Validates all DICOM metadata preservation (ImagePositionPatient, ImageOrientationPatient, etc.) * Verifies per-frame functional groups match original files * Checks frame ordering and spatial attributes - Add test_transcode_dicom_to_htj2k_multiframe_lossless() * Validates pixel-perfect lossless compression * Verifies all frames match original pixel data - Add test_transcode_dicom_to_htj2k_multiframe_nifti_consistency() * Ensures multi-frame HTJ2K produces identical NIfTI output as original series - Update all transfer syntax checks to use HTJ2K_TRANSFER_SYNTAXES constant * Replaces .startswith("1.2.840.10008.1.2.4.20") with explicit UID list * Covers all three HTJ2K variants (lossless, RPCL, lossy) Code Cleanup: - Revert debug logging in monailabel/endpoints/infer.py - Add HTJ2K transfer syntax documentation in convert.py All tests pass successfully, validating that: 1. Segmentation outputs are consistent across all DICOM formats 2. Multi-frame HTJ2K transcoding preserves all metadata correctly 3. Multi-frame HTJ2K compression is lossless 4. Multi-frame HTJ2K produces identical results to single-frame series Signed-off-by: Joaquin Anton Guirao <janton@nvidia.com>
1 parent fe3ec21 commit 0a3fd79

File tree

4 files changed

+667
-60
lines changed

4 files changed

+667
-60
lines changed

monailabel/datastore/utils/convert.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,22 @@ def dicom_seg_to_itk_image(label, output_ext=".seg.nrrd"):
639639
return output_file
640640

641641

642+
def _create_basic_offset_table_pixel_data(encoded_frames: list) -> bytes:
643+
"""
644+
Create encapsulated pixel data with Basic Offset Table for multi-frame DICOM.
645+
646+
Uses pydicom's encapsulate() function to ensure 100% standard compliance.
647+
648+
Args:
649+
encoded_frames: List of encoded frame byte strings
650+
651+
Returns:
652+
bytes: Encapsulated pixel data with Basic Offset Table per DICOM Part 5 Section A.4
653+
"""
654+
return pydicom.encaps.encapsulate(encoded_frames, has_bot=True)
655+
656+
657+
642658
def _setup_htj2k_decode_params():
643659
"""
644660
Create nvimgcodec decoding parameters for DICOM images.

monailabel/endpoints/infer.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -92,20 +92,6 @@ def send_response(datastore, result, output, background_tasks):
9292
return res_json
9393

9494
if output == "image":
95-
# Log NRRD metadata before sending response
96-
try:
97-
import nrrd
98-
if res_img and os.path.exists(res_img) and (res_img.endswith('.nrrd') or res_img.endswith('.nrrd.gz')):
99-
_, header = nrrd.read(res_img, index_order='C')
100-
logger.info(f"[NRRD Geometry] File: {os.path.basename(res_img)}")
101-
logger.info(f"[NRRD Geometry] Dimensions: {header.get('sizes')}")
102-
logger.info(f"[NRRD Geometry] Space Origin: {header.get('space origin')}")
103-
logger.info(f"[NRRD Geometry] Space Directions: {header.get('space directions')}")
104-
logger.info(f"[NRRD Geometry] Space: {header.get('space')}")
105-
logger.info(f"[NRRD Geometry] Type: {header.get('type')}")
106-
logger.info(f"[NRRD Geometry] Encoding: {header.get('encoding')}")
107-
except Exception as e:
108-
logger.warning(f"Failed to read NRRD metadata: {e}")
10995
return FileResponse(res_img, media_type=get_mime_type(res_img), filename=os.path.basename(res_img))
11096

11197
if output == "dicom_seg":

tests/integration/radiology_serverless/test_dicom_segmentation.py

Lines changed: 228 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,14 @@ class TestDicomSegmentation(unittest.TestCase):
6565
"e7567e0a064f0c334226a0658de23afd",
6666
"1.2.826.0.1.3680043.8.274.1.1.8323329.686521.1629744176.620266"
6767
)
68-
68+
69+
dicomweb_htj2k_multiframe_series = os.path.join(
70+
data_dir,
71+
"dataset",
72+
"dicomweb_htj2k_multiframe",
73+
"1.2.826.0.1.3680043.8.274.1.1.8323329.686521.1629744176.620251"
74+
)
75+
6976
@classmethod
7077
def setUpClass(cls) -> None:
7178
"""Initialize MONAI Label app for direct usage without server."""
@@ -128,6 +135,25 @@ def _run_inference(self, image_path: str, model_name: str = "segmentation_spleen
128135

129136
return label_data, label_json, inference_time
130137

138+
def _load_segmentation_array(self, label_data):
139+
"""
140+
Load segmentation data as numpy array.
141+
142+
Args:
143+
label_data: File path (str) or numpy array
144+
145+
Returns:
146+
numpy array of segmentation
147+
"""
148+
if isinstance(label_data, str):
149+
import nibabel as nib
150+
nii = nib.load(label_data)
151+
return nii.get_fdata()
152+
elif isinstance(label_data, np.ndarray):
153+
return label_data
154+
else:
155+
raise ValueError(f"Unexpected label data type: {type(label_data)}")
156+
131157
def _validate_segmentation_output(self, label_data, label_json):
132158
"""
133159
Validate that the segmentation output is correct.
@@ -146,9 +172,7 @@ def _validate_segmentation_output(self, label_data, label_json):
146172

147173
# Try to load and verify the file
148174
try:
149-
import nibabel as nib
150-
nii = nib.load(label_data)
151-
array = nii.get_fdata()
175+
array = self._load_segmentation_array(label_data)
152176
self.assertGreater(array.size, 0, "Segmentation array should not be empty")
153177
logger.info(f"Segmentation shape: {array.shape}, dtype: {array.dtype}")
154178
logger.info(f"Unique labels: {np.unique(array)}")
@@ -166,6 +190,71 @@ def _validate_segmentation_output(self, label_data, label_json):
166190
self.assertIsInstance(label_json, dict, "Label JSON should be a dictionary")
167191
logger.info(f"Label metadata keys: {list(label_json.keys())}")
168192

193+
def _compare_segmentations(self, label_data_1, label_data_2, name_1="Reference", name_2="Comparison", tolerance=0.05):
194+
"""
195+
Compare two segmentation outputs to verify they are similar.
196+
197+
Args:
198+
label_data_1: First segmentation (file path or array)
199+
label_data_2: Second segmentation (file path or array)
200+
name_1: Name for first segmentation (for logging)
201+
name_2: Name for second segmentation (for logging)
202+
tolerance: Maximum allowed dice coefficient difference (0.0-1.0)
203+
204+
Returns:
205+
dict with comparison metrics
206+
"""
207+
# Load arrays
208+
array_1 = self._load_segmentation_array(label_data_1)
209+
array_2 = self._load_segmentation_array(label_data_2)
210+
211+
# Check shapes match
212+
self.assertEqual(array_1.shape, array_2.shape,
213+
f"Segmentation shapes should match: {array_1.shape} vs {array_2.shape}")
214+
215+
# Calculate dice coefficient for each label
216+
unique_labels = np.union1d(np.unique(array_1), np.unique(array_2))
217+
unique_labels = unique_labels[unique_labels != 0] # Exclude background
218+
219+
dice_scores = {}
220+
for label in unique_labels:
221+
mask_1 = (array_1 == label).astype(np.float32)
222+
mask_2 = (array_2 == label).astype(np.float32)
223+
224+
intersection = np.sum(mask_1 * mask_2)
225+
sum_masks = np.sum(mask_1) + np.sum(mask_2)
226+
227+
if sum_masks > 0:
228+
dice = (2.0 * intersection) / sum_masks
229+
dice_scores[int(label)] = dice
230+
else:
231+
dice_scores[int(label)] = 0.0
232+
233+
# Calculate overall metrics
234+
exact_match = np.array_equal(array_1, array_2)
235+
pixel_accuracy = np.mean(array_1 == array_2)
236+
237+
comparison_result = {
238+
'exact_match': exact_match,
239+
'pixel_accuracy': pixel_accuracy,
240+
'dice_scores': dice_scores,
241+
'avg_dice': np.mean(list(dice_scores.values())) if dice_scores else 0.0
242+
}
243+
244+
# Log results
245+
logger.info(f"\nComparing {name_1} vs {name_2}:")
246+
logger.info(f" Exact match: {exact_match}")
247+
logger.info(f" Pixel accuracy: {pixel_accuracy:.4f}")
248+
logger.info(f" Dice scores by label: {dice_scores}")
249+
logger.info(f" Average Dice: {comparison_result['avg_dice']:.4f}")
250+
251+
# Assert high similarity
252+
self.assertGreater(comparison_result['avg_dice'], 1.0 - tolerance,
253+
f"Segmentations should be similar (Dice > {1.0 - tolerance:.2f}). "
254+
f"Got {comparison_result['avg_dice']:.4f}")
255+
256+
return comparison_result
257+
169258
def test_01_app_initialized(self):
170259
"""Test that the app is properly initialized."""
171260
if not torch.cuda.is_available():
@@ -223,53 +312,110 @@ def test_03_dicom_inference_dicomweb_htj2k(self):
223312
self.assertLess(inference_time, 60.0, "Inference should complete within 60 seconds")
224313
logger.info(f"✓ DICOM inference test passed (HTJ2K) in {inference_time:.3f}s")
225314

226-
def test_04_dicom_inference_both_formats(self):
227-
"""Test inference on both standard and HTJ2K compressed DICOM series."""
315+
def test_04_compare_all_formats(self):
316+
"""
317+
Compare segmentation outputs across all DICOM format variations.
318+
319+
This is the KEY test that validates:
320+
- Standard DICOM (uncompressed, single-frame)
321+
- HTJ2K compressed DICOM (single-frame)
322+
- Multi-frame HTJ2K DICOM
323+
324+
All produce IDENTICAL or highly similar segmentation results.
325+
"""
228326
if not torch.cuda.is_available():
229327
self.skipTest("CUDA not available")
230328

231329
if not self.app:
232330
self.skipTest("App not initialized")
233331

234-
# Test both series types
332+
logger.info(f"\n{'='*60}")
333+
logger.info("Comparing Segmentation Outputs Across All Formats")
334+
logger.info(f"{'='*60}")
335+
336+
# Test all series types
235337
test_series = [
236338
("Standard DICOM", self.dicomweb_series),
237339
("HTJ2K DICOM", self.dicomweb_htj2k_series),
340+
("Multi-frame HTJ2K", self.dicomweb_htj2k_multiframe_series),
238341
]
239342

240-
total_time = 0
241-
successful = 0
242-
243-
for series_type, dicom_dir in test_series:
244-
if not os.path.exists(dicom_dir):
245-
logger.warning(f"Skipping {series_type}: {dicom_dir} not found")
343+
# Run inference on all available formats
344+
results = {}
345+
for series_name, series_path in test_series:
346+
if not os.path.exists(series_path):
347+
logger.warning(f"Skipping {series_name}: not found")
246348
continue
247349

248-
logger.info(f"\nProcessing {series_type}: {dicom_dir}")
249-
350+
logger.info(f"\nRunning {series_name}...")
250351
try:
251-
label_data, label_json, inference_time = self._run_inference(dicom_dir)
352+
label_data, label_json, inference_time = self._run_inference(series_path)
252353
self._validate_segmentation_output(label_data, label_json)
253354

254-
total_time += inference_time
255-
successful += 1
256-
logger.info(f"✓ {series_type} success in {inference_time:.3f}s")
257-
355+
results[series_name] = {
356+
'label_data': label_data,
357+
'label_json': label_json,
358+
'time': inference_time
359+
}
360+
logger.info(f" ✓ {series_name} completed in {inference_time:.3f}s")
258361
except Exception as e:
259-
logger.error(f"✗ {series_type} failed: {e}", exc_info=True)
362+
logger.error(f" {series_name} failed: {e}", exc_info=True)
260363

364+
# Require at least 2 formats to compare
365+
self.assertGreaterEqual(len(results), 2,
366+
"Need at least 2 formats to compare. Check test data availability.")
367+
368+
# Compare all pairs
369+
logger.info(f"\n{'='*60}")
370+
logger.info("Cross-Format Comparison:")
371+
logger.info(f"{'='*60}")
372+
373+
format_names = list(results.keys())
374+
comparison_results = []
375+
376+
for i in range(len(format_names)):
377+
for j in range(i + 1, len(format_names)):
378+
name1 = format_names[i]
379+
name2 = format_names[j]
380+
381+
logger.info(f"\nComparing: {name1} vs {name2}")
382+
try:
383+
comparison = self._compare_segmentations(
384+
results[name1]['label_data'],
385+
results[name2]['label_data'],
386+
name_1=name1,
387+
name_2=name2,
388+
tolerance=0.05 # Allow 5% dice variation
389+
)
390+
comparison_results.append({
391+
'pair': f"{name1} vs {name2}",
392+
'dice': comparison['avg_dice'],
393+
'pixel_accuracy': comparison['pixel_accuracy']
394+
})
395+
except Exception as e:
396+
logger.error(f"Comparison failed: {e}", exc_info=True)
397+
raise
398+
399+
# Summary
261400
logger.info(f"\n{'='*60}")
262-
logger.info(f"Summary: {successful}/{len(test_series)} series processed successfully")
263-
if successful > 0:
264-
logger.info(f"Total inference time: {total_time:.3f}s")
265-
logger.info(f"Average time per series: {total_time/successful:.3f}s")
401+
logger.info("Comparison Summary:")
402+
for comp in comparison_results:
403+
logger.info(f" {comp['pair']}: Dice={comp['dice']:.4f}, Accuracy={comp['pixel_accuracy']:.4f}")
266404
logger.info(f"{'='*60}")
267405

268-
# At least one should succeed
269-
self.assertGreater(successful, 0, "At least one DICOM series should be processed successfully")
406+
# All comparisons should show high similarity
407+
self.assertTrue(len(comparison_results) > 0, "Should have at least one comparison")
408+
avg_dice = np.mean([c['dice'] for c in comparison_results])
409+
logger.info(f"\nOverall average Dice across all comparisons: {avg_dice:.4f}")
410+
self.assertGreater(avg_dice, 0.95,
411+
"All formats should produce highly similar segmentations (avg Dice > 0.95)")
270412

271413
def test_05_compare_dicom_vs_nifti(self):
272-
"""Compare inference results between DICOM series and pre-converted NIfTI files."""
414+
"""
415+
Compare inference results between DICOM series and pre-converted NIfTI files.
416+
417+
Validates that the DICOM reader produces identical results to pre-converted NIfTI.
418+
"""
273419
if not torch.cuda.is_available():
274420
self.skipTest("CUDA not available")
275421

@@ -286,29 +432,75 @@ def test_05_compare_dicom_vs_nifti(self):
286432
if not os.path.exists(nifti_file):
287433
self.skipTest(f"Corresponding NIfTI file not found: {nifti_file}")
288434

289-
logger.info(f"Comparing DICOM vs NIfTI inference:")
435+
logger.info(f"\n{'='*60}")
436+
logger.info("Comparing DICOM vs NIfTI Segmentation")
437+
logger.info(f"{'='*60}")
290438
logger.info(f" DICOM: {dicom_dir}")
291439
logger.info(f" NIfTI: {nifti_file}")
292440

293441
# Run inference on DICOM
294442
logger.info("\n--- Running inference on DICOM series ---")
295443
dicom_label, dicom_json, dicom_time = self._run_inference(dicom_dir)
444+
self._validate_segmentation_output(dicom_label, dicom_json)
296445

297446
# Run inference on NIfTI
298447
logger.info("\n--- Running inference on NIfTI file ---")
299448
nifti_label, nifti_json, nifti_time = self._run_inference(nifti_file)
300-
301-
# Validate both
302-
self._validate_segmentation_output(dicom_label, dicom_json)
303449
self._validate_segmentation_output(nifti_label, nifti_json)
304450

305-
logger.info(f"\nPerformance comparison:")
451+
# Compare the segmentation outputs
452+
comparison = self._compare_segmentations(
453+
dicom_label,
454+
nifti_label,
455+
name_1="DICOM",
456+
name_2="NIfTI",
457+
tolerance=0.01 # Stricter tolerance - should be nearly identical
458+
)
459+
460+
logger.info(f"\n{'='*60}")
461+
logger.info("Comparison Summary:")
306462
logger.info(f" DICOM inference time: {dicom_time:.3f}s")
307463
logger.info(f" NIfTI inference time: {nifti_time:.3f}s")
464+
logger.info(f" Dice coefficient: {comparison['avg_dice']:.4f}")
465+
logger.info(f" Pixel accuracy: {comparison['pixel_accuracy']:.4f}")
466+
logger.info(f" Exact match: {comparison['exact_match']}")
467+
logger.info(f"{'='*60}")
468+
469+
# Should be nearly identical (Dice > 0.99)
470+
self.assertGreater(comparison['avg_dice'], 0.99,
471+
"DICOM and NIfTI segmentations should be nearly identical")
472+
473+
def test_06_multiframe_htj2k_inference(self):
474+
"""
475+
Test basic inference on multi-frame HTJ2K compressed DICOM series.
476+
477+
Note: Comprehensive cross-format comparison is done in test_04.
478+
This test ensures multi-frame HTJ2K inference works standalone.
479+
"""
480+
if not torch.cuda.is_available():
481+
self.skipTest("CUDA not available")
482+
483+
if not self.app:
484+
self.skipTest("App not initialized")
485+
486+
if not os.path.exists(self.dicomweb_htj2k_multiframe_series):
487+
self.skipTest(f"Multi-frame HTJ2K series not found: {self.dicomweb_htj2k_multiframe_series}")
488+
489+
logger.info(f"\n{'='*60}")
490+
logger.info("Testing Multi-Frame HTJ2K DICOM Inference")
491+
logger.info(f"{'='*60}")
492+
logger.info(f"Series path: {self.dicomweb_htj2k_multiframe_series}")
493+
494+
# Run inference
495+
label_data, label_json, inference_time = self._run_inference(self.dicomweb_htj2k_multiframe_series)
496+
497+
# Validate output
498+
self._validate_segmentation_output(label_data, label_json)
499+
500+
# Performance check
501+
self.assertLess(inference_time, 60.0, "Inference should complete within 60 seconds")
308502

309-
# Both should complete successfully
310-
self.assertIsNotNone(dicom_label, "DICOM inference should succeed")
311-
self.assertIsNotNone(nifti_label, "NIfTI inference should succeed")
503+
logger.info(f"✓ Multi-frame HTJ2K inference test passed in {inference_time:.3f}s")
312504

313505

314506
if __name__ == "__main__":

0 commit comments

Comments
 (0)