@@ -274,23 +274,39 @@ def assert_dask_array(da, dask):
274274
275275
276276@arm_xfail
277- @pytest .mark .parametrize ("dask" , [False , True ])
278- def test_datetime_reduce (dask ):
279- time = np .array (pd .date_range ("15/12/1999" , periods = 11 ))
280- time [8 :11 ] = np .nan
281- da = DataArray (np .linspace (0 , 365 , num = 11 ), dims = "time" , coords = {"time" : time })
282-
283- if dask and has_dask :
284- chunks = {"time" : 5 }
285- da = da .chunk (chunks )
286-
287- actual = da ["time" ].mean ()
288- assert not pd .isnull (actual )
289- actual = da ["time" ].mean (skipna = False )
290- assert pd .isnull (actual )
291-
292- # test for a 0d array
293- assert da ["time" ][0 ].mean () == da ["time" ][:1 ].mean ()
277+ @pytest .mark .parametrize ("dask" , [False , True ] if has_dask else [False ])
278+ def test_datetime_mean (dask ):
279+ # Note: only testing numpy, as dask is broken upstream
280+ da = DataArray (
281+ np .array (["2010-01-01" , "NaT" , "2010-01-03" , "NaT" , "NaT" ], dtype = "M8" ),
282+ dims = ["time" ],
283+ )
284+ if dask :
285+ # Trigger use case where a chunk is full of NaT
286+ da = da .chunk ({"time" : 3 })
287+
288+ expect = DataArray (np .array ("2010-01-02" , dtype = "M8" ))
289+ expect_nat = DataArray (np .array ("NaT" , dtype = "M8" ))
290+
291+ actual = da .mean ()
292+ if dask :
293+ assert actual .chunks is not None
294+ assert_equal (actual , expect )
295+
296+ actual = da .mean (skipna = False )
297+ if dask :
298+ assert actual .chunks is not None
299+ assert_equal (actual , expect_nat )
300+
301+ # tests for 1d array full of NaT
302+ assert_equal (da [[1 ]].mean (), expect_nat )
303+ assert_equal (da [[1 ]].mean (skipna = False ), expect_nat )
304+
305+ # tests for a 0d array
306+ assert_equal (da [0 ].mean (), da [0 ])
307+ assert_equal (da [0 ].mean (skipna = False ), da [0 ])
308+ assert_equal (da [1 ].mean (), expect_nat )
309+ assert_equal (da [1 ].mean (skipna = False ), expect_nat )
294310
295311
296312@requires_cftime
0 commit comments