@@ -1653,6 +1653,74 @@ def test_set_dims_object_dtype(self):
16531653 expected = Variable (["x" ], exp_values )
16541654 assert_identical (actual , expected )
16551655
1656+ def test_set_dims_without_broadcast (self ):
1657+ class ArrayWithoutBroadcastTo (NDArrayMixin , indexing .ExplicitlyIndexed ):
1658+ def __init__ (self , array ):
1659+ self .array = array
1660+
1661+ # Broadcasting with __getitem__ is "easier" to implement
1662+ # especially for dims of 1
1663+ def __getitem__ (self , key ):
1664+ return self .array [key ]
1665+
1666+ def __array_function__ (self , * args , ** kwargs ):
1667+ raise NotImplementedError (
1668+ "Not we don't want to use broadcast_to here "
1669+ "https://github.com/pydata/xarray/issues/9462"
1670+ )
1671+
1672+ arr = ArrayWithoutBroadcastTo (np .zeros ((3 , 4 )))
1673+ # We should be able to add a new axis without broadcasting
1674+ assert arr [np .newaxis , :, :].shape == (1 , 3 , 4 )
1675+ with pytest .raises (NotImplementedError ):
1676+ np .broadcast_to (arr , (1 , 3 , 4 ))
1677+
1678+ v = Variable (["x" , "y" ], arr )
1679+ v_expanded = v .set_dims (["z" , "x" , "y" ])
1680+ assert v_expanded .dims == ("z" , "x" , "y" )
1681+ assert v_expanded .shape == (1 , 3 , 4 )
1682+
1683+ v_expanded = v .set_dims (["x" , "z" , "y" ])
1684+ assert v_expanded .dims == ("x" , "z" , "y" )
1685+ assert v_expanded .shape == (3 , 1 , 4 )
1686+
1687+ v_expanded = v .set_dims (["x" , "y" , "z" ])
1688+ assert v_expanded .dims == ("x" , "y" , "z" )
1689+ assert v_expanded .shape == (3 , 4 , 1 )
1690+
1691+ # Explicitly asking for a shape of 1 triggers a different
1692+ # codepath in set_dims
1693+ # https://github.com/pydata/xarray/issues/9462
1694+ v_expanded = v .set_dims (["z" , "x" , "y" ], shape = (1 , 3 , 4 ))
1695+ assert v_expanded .dims == ("z" , "x" , "y" )
1696+ assert v_expanded .shape == (1 , 3 , 4 )
1697+
1698+ v_expanded = v .set_dims (["x" , "z" , "y" ], shape = (3 , 1 , 4 ))
1699+ assert v_expanded .dims == ("x" , "z" , "y" )
1700+ assert v_expanded .shape == (3 , 1 , 4 )
1701+
1702+ v_expanded = v .set_dims (["x" , "y" , "z" ], shape = (3 , 4 , 1 ))
1703+ assert v_expanded .dims == ("x" , "y" , "z" )
1704+ assert v_expanded .shape == (3 , 4 , 1 )
1705+
1706+ v_expanded = v .set_dims ({"z" : 1 , "x" : 3 , "y" : 4 })
1707+ assert v_expanded .dims == ("z" , "x" , "y" )
1708+ assert v_expanded .shape == (1 , 3 , 4 )
1709+
1710+ v_expanded = v .set_dims ({"x" : 3 , "z" : 1 , "y" : 4 })
1711+ assert v_expanded .dims == ("x" , "z" , "y" )
1712+ assert v_expanded .shape == (3 , 1 , 4 )
1713+
1714+ v_expanded = v .set_dims ({"x" : 3 , "y" : 4 , "z" : 1 })
1715+ assert v_expanded .dims == ("x" , "y" , "z" )
1716+ assert v_expanded .shape == (3 , 4 , 1 )
1717+
1718+ with pytest .raises (NotImplementedError ):
1719+ v .set_dims ({"z" : 2 , "x" : 3 , "y" : 4 })
1720+
1721+ with pytest .raises (NotImplementedError ):
1722+ v .set_dims (["z" , "x" , "y" ], shape = (2 , 3 , 4 ))
1723+
16561724 def test_stack (self ):
16571725 v = Variable (["x" , "y" ], [[0 , 1 ], [2 , 3 ]], {"foo" : "bar" })
16581726 actual = v .stack (z = ("x" , "y" ))
0 commit comments