Skip to content

Commit 6d96f33

Browse files
authored
Pickle xarray.ufunc functions (#928)
Fixes GH901
1 parent d827dd0 commit 6d96f33

File tree

3 files changed

+24
-8
lines changed

3 files changed

+24
-8
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ Bug fixes
106106
use to numpy functions instead of dask.array functions (:issue:`876`). By
107107
`Stephan Hoyer <https://github.com/shoyer>`_.
108108

109+
- Support for pickling functions from ``xarray.ufuncs`` (:issue:`901`). By
110+
`Stephan Hoyer <https://github.com/shoyer>`_.
111+
109112
- ``Variable.copy(deep=True)`` no longer converts MultiIndex into a base Index
110113
(:issue:`769`). By `Benoit Bovy <https://github.com/benbovy>`_.
111114

xarray/test/test_ufuncs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import pickle
2+
13
import numpy as np
24

35
import xarray.ufuncs as xu
@@ -56,3 +58,8 @@ def test_groupby(self):
5658

5759
with self.assertRaisesRegexp(TypeError, 'only support binary ops'):
5860
xu.maximum(ds.a.variable, ds_grouped)
61+
62+
def test_pickle(self):
63+
a = 1.0
64+
cos_pickled = pickle.loads(pickle.dumps(xu.cos))
65+
self.assertIdentical(cos_pickled(a), xu.cos(a))

xarray/ufuncs.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,38 @@ def _dispatch_priority(obj):
3535
return -1
3636

3737

38-
def _create_op(name):
38+
class _UFuncDispatcher(object):
39+
"""Wrapper for dispatching ufuncs."""
40+
def __init__(self, name):
41+
self._name = name
3942

40-
def func(*args, **kwargs):
43+
def __call__(self, *args, **kwargs):
4144
new_args = args
42-
f = _dask_or_eager_func(name, n_array_args=len(args))
45+
f = _dask_or_eager_func(self._name, n_array_args=len(args))
4346
if len(args) > 2 or len(args) == 0:
4447
raise TypeError('cannot handle %s arguments for %r' %
45-
(len(args), name))
48+
(len(args), self._name))
4649
elif len(args) == 1:
4750
if isinstance(args[0], _xarray_types):
48-
f = args[0]._unary_op(func)
51+
f = args[0]._unary_op(self)
4952
else: # len(args) = 2
5053
p1, p2 = map(_dispatch_priority, args)
5154
if p1 >= p2:
5255
if isinstance(args[0], _xarray_types):
53-
f = args[0]._binary_op(func)
56+
f = args[0]._binary_op(self)
5457
else:
5558
if isinstance(args[1], _xarray_types):
56-
f = args[1]._binary_op(func, reflexive=True)
59+
f = args[1]._binary_op(self, reflexive=True)
5760
new_args = tuple(reversed(args))
5861
res = f(*new_args, **kwargs)
5962
if res is NotImplemented:
6063
raise TypeError('%r not implemented for types (%r, %r)'
61-
% (name, type(args[0]), type(args[1])))
64+
% (self._name, type(args[0]), type(args[1])))
6265
return res
6366

67+
68+
def _create_op(name):
69+
func = _UFuncDispatcher(name)
6470
func.__name__ = name
6571
doc = getattr(_np, name).__doc__
6672
func.__doc__ = ('xarray specific variant of numpy.%s. Handles '

0 commit comments

Comments
 (0)