@@ -120,9 +120,9 @@ def test_fit_predict(self, coords, rng) -> None:
120120 ).shape == (20 , 2 * 2 )
121121 assert isinstance (score , pd .Series )
122122 assert score .shape == (2 ,)
123- # Test that the score follows the new standardized format
124- assert "unit_r2 " in score .index
125- assert "unit_r2_std " in score .index
123+ # Test that the score follows the new unified format
124+ assert "unit_0_r2 " in score .index
125+ assert "unit_0_r2_std " in score .index
126126 assert isinstance (predictions , az .InferenceData )
127127
128128
@@ -423,15 +423,15 @@ def test_scoring_multi_unit(self, synthetic_control_data):
423423 # Score should be a pandas Series with separate r2 and r2_std for each treated unit
424424 assert isinstance (score , pd .Series )
425425
426- # Check that we have r2 and r2_std for each treated unit
427- for unit in treated_units :
428- assert f"{ unit } _r2" in score .index
429- assert f"{ unit } _r2_std" in score .index
426+ # Check that we have r2 and r2_std for each treated unit using unified format
427+ for i , unit in enumerate ( treated_units ) :
428+ assert f"unit_ { i } _r2" in score .index
429+ assert f"unit_ { i } _r2_std" in score .index
430430
431431 # R2 should be reasonable (between 0 and 1 typically, though can be negative)
432- assert score [f"{ unit } _r2" ] >= - 1 # R2 can be negative for very bad fits
432+ assert score [f"unit_ { i } _r2" ] >= - 1 # R2 can be negative for very bad fits
433433 assert (
434- score [f"{ unit } _r2_std" ] >= 0
434+ score [f"unit_ { i } _r2_std" ] >= 0
435435 ) # Standard deviation should be non-negative
436436
437437 def test_scoring_single_unit (self , single_treated_data ):
@@ -444,16 +444,14 @@ def test_scoring_single_unit(self, single_treated_data):
444444 # Test scoring
445445 score = wsf .score (X , y )
446446
447- # Now consistently uses treated unit name prefix even for single unit
447+ # Now consistently uses unified unit indexing even for single unit
448448 assert isinstance (score , pd .Series )
449- assert "treated_0_r2 " in score .index
450- assert "treated_0_r2_std " in score .index
449+ assert "unit_0_r2 " in score .index
450+ assert "unit_0_r2_std " in score .index
451451
452452 # R2 should be reasonable
453- assert score ["treated_0_r2" ] >= - 1 # R2 can be negative for very bad fits
454- assert (
455- score ["treated_0_r2_std" ] >= 0
456- ) # Standard deviation should be non-negative
453+ assert score ["unit_0_r2" ] >= - 1 # R2 can be negative for very bad fits
454+ assert score ["unit_0_r2_std" ] >= 0 # Standard deviation should be non-negative
457455
458456 def test_r2_scores_differ_across_units (self , rng ):
459457 """Test that R² scores are different for different treated units.
@@ -523,8 +521,8 @@ def test_r2_scores_differ_across_units(self, rng):
523521 wsf .fit (X , y , coords = coords )
524522 scores = wsf .score (X , y )
525523
526- # Extract R² values for each treated unit
527- r2_values = [scores [f"{ unit } _r2" ] for unit in treated_units ]
524+ # Extract R² values for each treated unit using unified format
525+ r2_values = [scores [f"unit_ { i } _r2" ] for i in range ( len ( treated_units )) ]
528526
529527 # Test that not all R² values are the same
530528 # Use a tolerance to avoid issues with floating point precision
0 commit comments