@@ -65,7 +65,14 @@ class TestDicomSegmentation(unittest.TestCase):
6565 "e7567e0a064f0c334226a0658de23afd" ,
6666 "1.2.826.0.1.3680043.8.274.1.1.8323329.686521.1629744176.620266"
6767 )
68-
68+
69+ dicomweb_htj2k_multiframe_series = os .path .join (
70+ data_dir ,
71+ "dataset" ,
72+ "dicomweb_htj2k_multiframe" ,
73+ "1.2.826.0.1.3680043.8.274.1.1.8323329.686521.1629744176.620251"
74+ )
75+
6976 @classmethod
7077 def setUpClass (cls ) -> None :
7178 """Initialize MONAI Label app for direct usage without server."""
@@ -128,6 +135,25 @@ def _run_inference(self, image_path: str, model_name: str = "segmentation_spleen
128135
129136 return label_data , label_json , inference_time
130137
138+ def _load_segmentation_array (self , label_data ):
139+ """
140+ Load segmentation data as numpy array.
141+
142+ Args:
143+ label_data: File path (str) or numpy array
144+
145+ Returns:
146+ numpy array of segmentation
147+ """
148+ if isinstance (label_data , str ):
149+ import nibabel as nib
150+ nii = nib .load (label_data )
151+ return nii .get_fdata ()
152+ elif isinstance (label_data , np .ndarray ):
153+ return label_data
154+ else :
155+ raise ValueError (f"Unexpected label data type: { type (label_data )} " )
156+
131157 def _validate_segmentation_output (self , label_data , label_json ):
132158 """
133159 Validate that the segmentation output is correct.
@@ -146,9 +172,7 @@ def _validate_segmentation_output(self, label_data, label_json):
146172
147173 # Try to load and verify the file
148174 try :
149- import nibabel as nib
150- nii = nib .load (label_data )
151- array = nii .get_fdata ()
175+ array = self ._load_segmentation_array (label_data )
152176 self .assertGreater (array .size , 0 , "Segmentation array should not be empty" )
153177 logger .info (f"Segmentation shape: { array .shape } , dtype: { array .dtype } " )
154178 logger .info (f"Unique labels: { np .unique (array )} " )
@@ -166,6 +190,71 @@ def _validate_segmentation_output(self, label_data, label_json):
166190 self .assertIsInstance (label_json , dict , "Label JSON should be a dictionary" )
167191 logger .info (f"Label metadata keys: { list (label_json .keys ())} " )
168192
193+ def _compare_segmentations (self , label_data_1 , label_data_2 , name_1 = "Reference" , name_2 = "Comparison" , tolerance = 0.05 ):
194+ """
195+ Compare two segmentation outputs to verify they are similar.
196+
197+ Args:
198+ label_data_1: First segmentation (file path or array)
199+ label_data_2: Second segmentation (file path or array)
200+ name_1: Name for first segmentation (for logging)
201+ name_2: Name for second segmentation (for logging)
202+ tolerance: Maximum allowed dice coefficient difference (0.0-1.0)
203+
204+ Returns:
205+ dict with comparison metrics
206+ """
207+ # Load arrays
208+ array_1 = self ._load_segmentation_array (label_data_1 )
209+ array_2 = self ._load_segmentation_array (label_data_2 )
210+
211+ # Check shapes match
212+ self .assertEqual (array_1 .shape , array_2 .shape ,
213+ f"Segmentation shapes should match: { array_1 .shape } vs { array_2 .shape } " )
214+
215+ # Calculate dice coefficient for each label
216+ unique_labels = np .union1d (np .unique (array_1 ), np .unique (array_2 ))
217+ unique_labels = unique_labels [unique_labels != 0 ] # Exclude background
218+
219+ dice_scores = {}
220+ for label in unique_labels :
221+ mask_1 = (array_1 == label ).astype (np .float32 )
222+ mask_2 = (array_2 == label ).astype (np .float32 )
223+
224+ intersection = np .sum (mask_1 * mask_2 )
225+ sum_masks = np .sum (mask_1 ) + np .sum (mask_2 )
226+
227+ if sum_masks > 0 :
228+ dice = (2.0 * intersection ) / sum_masks
229+ dice_scores [int (label )] = dice
230+ else :
231+ dice_scores [int (label )] = 0.0
232+
233+ # Calculate overall metrics
234+ exact_match = np .array_equal (array_1 , array_2 )
235+ pixel_accuracy = np .mean (array_1 == array_2 )
236+
237+ comparison_result = {
238+ 'exact_match' : exact_match ,
239+ 'pixel_accuracy' : pixel_accuracy ,
240+ 'dice_scores' : dice_scores ,
241+ 'avg_dice' : np .mean (list (dice_scores .values ())) if dice_scores else 0.0
242+ }
243+
244+ # Log results
245+ logger .info (f"\n Comparing { name_1 } vs { name_2 } :" )
246+ logger .info (f" Exact match: { exact_match } " )
247+ logger .info (f" Pixel accuracy: { pixel_accuracy :.4f} " )
248+ logger .info (f" Dice scores by label: { dice_scores } " )
249+ logger .info (f" Average Dice: { comparison_result ['avg_dice' ]:.4f} " )
250+
251+ # Assert high similarity
252+ self .assertGreater (comparison_result ['avg_dice' ], 1.0 - tolerance ,
253+ f"Segmentations should be similar (Dice > { 1.0 - tolerance :.2f} ). "
254+ f"Got { comparison_result ['avg_dice' ]:.4f} " )
255+
256+ return comparison_result
257+
169258 def test_01_app_initialized (self ):
170259 """Test that the app is properly initialized."""
171260 if not torch .cuda .is_available ():
@@ -223,53 +312,110 @@ def test_03_dicom_inference_dicomweb_htj2k(self):
223312 self .assertLess (inference_time , 60.0 , "Inference should complete within 60 seconds" )
224313 logger .info (f"✓ DICOM inference test passed (HTJ2K) in { inference_time :.3f} s" )
225314
226- def test_04_dicom_inference_both_formats (self ):
227- """Test inference on both standard and HTJ2K compressed DICOM series."""
315+ def test_04_compare_all_formats (self ):
316+ """
317+ Compare segmentation outputs across all DICOM format variations.
318+
319+ This is the KEY test that validates:
320+ - Standard DICOM (uncompressed, single-frame)
321+ - HTJ2K compressed DICOM (single-frame)
322+ - Multi-frame HTJ2K DICOM
323+
324+ All produce IDENTICAL or highly similar segmentation results.
325+ """
228326 if not torch .cuda .is_available ():
229327 self .skipTest ("CUDA not available" )
230328
231329 if not self .app :
232330 self .skipTest ("App not initialized" )
233331
234- # Test both series types
332+ logger .info (f"\n { '=' * 60 } " )
333+ logger .info ("Comparing Segmentation Outputs Across All Formats" )
334+ logger .info (f"{ '=' * 60 } " )
335+
336+ # Test all series types
235337 test_series = [
236338 ("Standard DICOM" , self .dicomweb_series ),
237339 ("HTJ2K DICOM" , self .dicomweb_htj2k_series ),
340+ ("Multi-frame HTJ2K" , self .dicomweb_htj2k_multiframe_series ),
238341 ]
239342
240- total_time = 0
241- successful = 0
242-
243- for series_type , dicom_dir in test_series :
244- if not os .path .exists (dicom_dir ):
245- logger .warning (f"Skipping { series_type } : { dicom_dir } not found" )
343+ # Run inference on all available formats
344+ results = {}
345+ for series_name , series_path in test_series :
346+ if not os .path .exists (series_path ):
347+ logger .warning (f"Skipping { series_name } : not found" )
246348 continue
247349
248- logger .info (f"\n Processing { series_type } : { dicom_dir } " )
249-
350+ logger .info (f"\n Running { series_name } ..." )
250351 try :
251- label_data , label_json , inference_time = self ._run_inference (dicom_dir )
352+ label_data , label_json , inference_time = self ._run_inference (series_path )
252353 self ._validate_segmentation_output (label_data , label_json )
253354
254- total_time += inference_time
255- successful += 1
256- logger .info (f"✓ { series_type } success in { inference_time :.3f} s" )
257-
355+ results [series_name ] = {
356+ 'label_data' : label_data ,
357+ 'label_json' : label_json ,
358+ 'time' : inference_time
359+ }
360+ logger .info (f" ✓ { series_name } completed in { inference_time :.3f} s" )
258361 except Exception as e :
259- logger .error (f"✗ { series_type } failed: { e } " , exc_info = True )
362+ logger .error (f" ✗ { series_name } failed: { e } " , exc_info = True )
260363
364+ # Require at least 2 formats to compare
365+ self .assertGreaterEqual (len (results ), 2 ,
366+ "Need at least 2 formats to compare. Check test data availability." )
367+
368+ # Compare all pairs
369+ logger .info (f"\n { '=' * 60 } " )
370+ logger .info ("Cross-Format Comparison:" )
371+ logger .info (f"{ '=' * 60 } " )
372+
373+ format_names = list (results .keys ())
374+ comparison_results = []
375+
376+ for i in range (len (format_names )):
377+ for j in range (i + 1 , len (format_names )):
378+ name1 = format_names [i ]
379+ name2 = format_names [j ]
380+
381+ logger .info (f"\n Comparing: { name1 } vs { name2 } " )
382+ try :
383+ comparison = self ._compare_segmentations (
384+ results [name1 ]['label_data' ],
385+ results [name2 ]['label_data' ],
386+ name_1 = name1 ,
387+ name_2 = name2 ,
388+ tolerance = 0.05 # Allow 5% dice variation
389+ )
390+ comparison_results .append ({
391+ 'pair' : f"{ name1 } vs { name2 } " ,
392+ 'dice' : comparison ['avg_dice' ],
393+ 'pixel_accuracy' : comparison ['pixel_accuracy' ]
394+ })
395+ except Exception as e :
396+ logger .error (f"Comparison failed: { e } " , exc_info = True )
397+ raise
398+
399+ # Summary
261400 logger .info (f"\n { '=' * 60 } " )
262- logger .info (f"Summary: { successful } /{ len (test_series )} series processed successfully" )
263- if successful > 0 :
264- logger .info (f"Total inference time: { total_time :.3f} s" )
265- logger .info (f"Average time per series: { total_time / successful :.3f} s" )
401+ logger .info ("Comparison Summary:" )
402+ for comp in comparison_results :
403+ logger .info (f" { comp ['pair' ]} : Dice={ comp ['dice' ]:.4f} , Accuracy={ comp ['pixel_accuracy' ]:.4f} " )
266404 logger .info (f"{ '=' * 60 } " )
267405
268- # At least one should succeed
269- self .assertGreater (successful , 0 , "At least one DICOM series should be processed successfully" )
406+ # All comparisons should show high similarity
407+ self .assertTrue (len (comparison_results ) > 0 , "Should have at least one comparison" )
408+ avg_dice = np .mean ([c ['dice' ] for c in comparison_results ])
409+ logger .info (f"\n Overall average Dice across all comparisons: { avg_dice :.4f} " )
410+ self .assertGreater (avg_dice , 0.95 ,
411+ "All formats should produce highly similar segmentations (avg Dice > 0.95)" )
270412
271413 def test_05_compare_dicom_vs_nifti (self ):
272- """Compare inference results between DICOM series and pre-converted NIfTI files."""
414+ """
415+ Compare inference results between DICOM series and pre-converted NIfTI files.
416+
417+ Validates that the DICOM reader produces identical results to pre-converted NIfTI.
418+ """
273419 if not torch .cuda .is_available ():
274420 self .skipTest ("CUDA not available" )
275421
@@ -286,29 +432,75 @@ def test_05_compare_dicom_vs_nifti(self):
286432 if not os .path .exists (nifti_file ):
287433 self .skipTest (f"Corresponding NIfTI file not found: { nifti_file } " )
288434
289- logger .info (f"Comparing DICOM vs NIfTI inference:" )
435+ logger .info (f"\n { '=' * 60 } " )
436+ logger .info ("Comparing DICOM vs NIfTI Segmentation" )
437+ logger .info (f"{ '=' * 60 } " )
290438 logger .info (f" DICOM: { dicom_dir } " )
291439 logger .info (f" NIfTI: { nifti_file } " )
292440
293441 # Run inference on DICOM
294442 logger .info ("\n --- Running inference on DICOM series ---" )
295443 dicom_label , dicom_json , dicom_time = self ._run_inference (dicom_dir )
444+ self ._validate_segmentation_output (dicom_label , dicom_json )
296445
297446 # Run inference on NIfTI
298447 logger .info ("\n --- Running inference on NIfTI file ---" )
299448 nifti_label , nifti_json , nifti_time = self ._run_inference (nifti_file )
300-
301- # Validate both
302- self ._validate_segmentation_output (dicom_label , dicom_json )
303449 self ._validate_segmentation_output (nifti_label , nifti_json )
304450
305- logger .info (f"\n Performance comparison:" )
451+ # Compare the segmentation outputs
452+ comparison = self ._compare_segmentations (
453+ dicom_label ,
454+ nifti_label ,
455+ name_1 = "DICOM" ,
456+ name_2 = "NIfTI" ,
457+ tolerance = 0.01 # Stricter tolerance - should be nearly identical
458+ )
459+
460+ logger .info (f"\n { '=' * 60 } " )
461+ logger .info ("Comparison Summary:" )
306462 logger .info (f" DICOM inference time: { dicom_time :.3f} s" )
307463 logger .info (f" NIfTI inference time: { nifti_time :.3f} s" )
464+ logger .info (f" Dice coefficient: { comparison ['avg_dice' ]:.4f} " )
465+ logger .info (f" Pixel accuracy: { comparison ['pixel_accuracy' ]:.4f} " )
466+ logger .info (f" Exact match: { comparison ['exact_match' ]} " )
467+ logger .info (f"{ '=' * 60 } " )
468+
469+ # Should be nearly identical (Dice > 0.99)
470+ self .assertGreater (comparison ['avg_dice' ], 0.99 ,
471+ "DICOM and NIfTI segmentations should be nearly identical" )
472+
473+ def test_06_multiframe_htj2k_inference (self ):
474+ """
475+ Test basic inference on multi-frame HTJ2K compressed DICOM series.
476+
477+ Note: Comprehensive cross-format comparison is done in test_04.
478+ This test ensures multi-frame HTJ2K inference works standalone.
479+ """
480+ if not torch .cuda .is_available ():
481+ self .skipTest ("CUDA not available" )
482+
483+ if not self .app :
484+ self .skipTest ("App not initialized" )
485+
486+ if not os .path .exists (self .dicomweb_htj2k_multiframe_series ):
487+ self .skipTest (f"Multi-frame HTJ2K series not found: { self .dicomweb_htj2k_multiframe_series } " )
488+
489+ logger .info (f"\n { '=' * 60 } " )
490+ logger .info ("Testing Multi-Frame HTJ2K DICOM Inference" )
491+ logger .info (f"{ '=' * 60 } " )
492+ logger .info (f"Series path: { self .dicomweb_htj2k_multiframe_series } " )
493+
494+ # Run inference
495+ label_data , label_json , inference_time = self ._run_inference (self .dicomweb_htj2k_multiframe_series )
496+
497+ # Validate output
498+ self ._validate_segmentation_output (label_data , label_json )
499+
500+ # Performance check
501+ self .assertLess (inference_time , 60.0 , "Inference should complete within 60 seconds" )
308502
309- # Both should complete successfully
310- self .assertIsNotNone (dicom_label , "DICOM inference should succeed" )
311- self .assertIsNotNone (nifti_label , "NIfTI inference should succeed" )
503+ logger .info (f"✓ Multi-frame HTJ2K inference test passed in { inference_time :.3f} s" )
312504
313505
314506if __name__ == "__main__" :
0 commit comments