Skip to content

Commit f752e29

Browse files
authored
Change predict, transform, predict_proba to infer metadata by default for ParallelPostFit (#862)
1 parent 11f9703 commit f752e29

File tree

6 files changed

+284
-17
lines changed

6 files changed

+284
-17
lines changed

dask_ml/model_selection/_hyperband.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,24 @@ class HyperbandSearchCV(BaseIncrementalSearchCV):
154154
prefix : str, optional, default=""
155155
While logging, add ``prefix`` to each message.
156156
157+
predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
158+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
159+
type of the estimators ``predict`` call.
160+
This meta is necessary for for some estimators to work with
161+
``dask.dataframe`` and ``dask.array``
162+
163+
predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
164+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
165+
type of the estimators ``predict_proba`` call.
166+
This meta is necessary for for some estimators to work with
167+
``dask.dataframe`` and ``dask.array``
168+
169+
transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
170+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
171+
type of the estimators ``transform`` call.
172+
This meta is necessary for for some estimators to work with
173+
``dask.dataframe`` and ``dask.array``
174+
157175
Examples
158176
--------
159177
>>> import numpy as np
@@ -340,6 +358,9 @@ def __init__(
340358
scoring=None,
341359
verbose=False,
342360
prefix="",
361+
predict_meta=None,
362+
predict_proba_meta=None,
363+
transform_meta=None,
343364
):
344365
self.aggressiveness = aggressiveness
345366

@@ -354,6 +375,9 @@ def __init__(
354375
scoring=scoring,
355376
verbose=verbose,
356377
prefix=prefix,
378+
predict_meta=predict_meta,
379+
predict_proba_meta=predict_proba_meta,
380+
transform_meta=transform_meta,
357381
)
358382

359383
def _get_SHAs(self, brackets):

dask_ml/model_selection/_incremental.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _score(
128128

129129

130130
def _create_model(model: Model, ident: Int, **params: Params) -> Tuple[Model, Meta]:
131-
""" Create a model by cloning and then setting params """
131+
"""Create a model by cloning and then setting params"""
132132
with log_errors():
133133
model = clone(model).set_params(**params)
134134
return model, {"model_id": ident, "params": params, "partial_fit_calls": 0}
@@ -515,6 +515,9 @@ def __init__(
515515
tol=1e-3,
516516
verbose=False,
517517
prefix="",
518+
predict_meta=None,
519+
predict_proba_meta=None,
520+
transform_meta=None,
518521
):
519522
self.parameters = parameters
520523
self.test_size = test_size
@@ -524,7 +527,13 @@ def __init__(
524527
self.tol = tol
525528
self.verbose = verbose
526529
self.prefix = prefix
527-
super(BaseIncrementalSearchCV, self).__init__(estimator, scoring=scoring)
530+
super(BaseIncrementalSearchCV, self).__init__(
531+
estimator,
532+
scoring=scoring,
533+
predict_meta=predict_meta,
534+
predict_proba_meta=predict_proba_meta,
535+
transform_meta=transform_meta,
536+
)
528537

529538
async def _validate_parameters(self, X, y):
530539
if (self.max_iter is not None) and self.max_iter < 1:
@@ -846,6 +855,24 @@ class IncrementalSearchCV(BaseIncrementalSearchCV):
846855
prefix : str, optional, default=""
847856
While logging, add ``prefix`` to each message.
848857
858+
predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
859+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
860+
type of the estimators ``predict`` call.
861+
This meta is necessary for for some estimators to work with
862+
``dask.dataframe`` and ``dask.array``
863+
864+
predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
865+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
866+
type of the estimators ``predict_proba`` call.
867+
This meta is necessary for for some estimators to work with
868+
``dask.dataframe`` and ``dask.array``
869+
870+
transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
871+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
872+
type of the estimators ``transform`` call.
873+
This meta is necessary for for some estimators to work with
874+
``dask.dataframe`` and ``dask.array``
875+
849876
Attributes
850877
----------
851878
cv_results_ : dict of np.ndarrays
@@ -977,6 +1004,9 @@ def __init__(
9771004
verbose=False,
9781005
prefix="",
9791006
scores_per_fit=None,
1007+
predict_meta=None,
1008+
predict_proba_meta=None,
1009+
transform_meta=None,
9801010
):
9811011

9821012
self.n_initial_parameters = n_initial_parameters
@@ -995,6 +1025,9 @@ def __init__(
9951025
tol=tol,
9961026
verbose=verbose,
9971027
prefix=prefix,
1028+
predict_meta=predict_meta,
1029+
predict_proba_meta=predict_proba_meta,
1030+
transform_meta=transform_meta,
9981031
)
9991032

10001033
def _decay_deprecated(self):
@@ -1338,7 +1371,7 @@ def _adapt(self, info):
13381371
start = self.n_initial_parameters
13391372

13401373
def inverse(time):
1341-
""" Decrease target number of models inversely with time """
1374+
"""Decrease target number of models inversely with time"""
13421375
return int(start / (1 + time) ** self.decay_rate)
13431376

13441377
example = toolz.first(info.values())

dask_ml/wrappers.py

Lines changed: 127 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,24 @@ class ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixi
4646
a single NumPy array, which may exhaust the memory of your worker.
4747
You probably want to always specify `scoring`.
4848
49+
predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
50+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
51+
type of the estimators ``predict`` call.
52+
This meta is necessary for for some estimators to work with
53+
``dask.dataframe`` and ``dask.array``
54+
55+
predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
56+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
57+
type of the estimators ``predict_proba`` call.
58+
This meta is necessary for for some estimators to work with
59+
``dask.dataframe`` and ``dask.array``
60+
61+
transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
62+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
63+
type of the estimators ``transform`` call.
64+
This meta is necessary for for some estimators to work with
65+
``dask.dataframe`` and ``dask.array``
66+
4967
Notes
5068
-----
5169
@@ -115,9 +133,19 @@ class ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixi
115133
[0.99407016, 0.00592984]])
116134
"""
117135

118-
def __init__(self, estimator=None, scoring=None):
136+
def __init__(
137+
self,
138+
estimator=None,
139+
scoring=None,
140+
predict_meta=None,
141+
predict_proba_meta=None,
142+
transform_meta=None,
143+
):
119144
self.estimator = estimator
120145
self.scoring = scoring
146+
self.predict_meta = predict_meta
147+
self.predict_proba_meta = predict_proba_meta
148+
self.transform_meta = transform_meta
121149

122150
def _check_array(self, X):
123151
"""Validate an array for post-fit tasks.
@@ -202,13 +230,24 @@ def transform(self, X):
202230
"""
203231
self._check_method("transform")
204232
X = self._check_array(X)
233+
meta = self.transform_meta
205234

206235
if isinstance(X, da.Array):
207-
xx = np.zeros((1, X.shape[1]), dtype=X.dtype)
208-
dt = _transform(xx, self._postfit_estimator).dtype
209-
return X.map_blocks(_transform, estimator=self._postfit_estimator, dtype=dt)
236+
if meta is None:
237+
meta = _get_output_dask_ar_meta_for_estimator(
238+
_transform, self._postfit_estimator, X
239+
)
240+
return X.map_blocks(
241+
_transform, estimator=self._postfit_estimator, meta=meta
242+
)
210243
elif isinstance(X, dd._Frame):
211-
return X.map_partitions(_transform, estimator=self._postfit_estimator)
244+
if meta is None:
245+
# dask-dataframe relies on dd.core.no_default
246+
# for infering meta
247+
meta = dd.core.no_default
248+
return X.map_partitions(
249+
_transform, estimator=self._postfit_estimator, meta=meta
250+
)
212251
else:
213252
return _transform(X, estimator=self._postfit_estimator)
214253

@@ -271,18 +310,25 @@ def predict(self, X):
271310
"""
272311
self._check_method("predict")
273312
X = self._check_array(X)
313+
meta = self.predict_meta
274314

275315
if isinstance(X, da.Array):
316+
if meta is None:
317+
meta = _get_output_dask_ar_meta_for_estimator(
318+
_predict, self._postfit_estimator, X
319+
)
320+
276321
result = X.map_blocks(
277-
_predict, dtype="int", estimator=self._postfit_estimator, drop_axis=1
322+
_predict, estimator=self._postfit_estimator, drop_axis=1, meta=meta
278323
)
279324
return result
280325

281326
elif isinstance(X, dd._Frame):
327+
if meta is None:
328+
meta = dd.core.no_default
282329
return X.map_partitions(
283-
_predict, estimator=self._postfit_estimator, meta=np.array([1])
330+
_predict, estimator=self._postfit_estimator, meta=meta
284331
)
285-
286332
else:
287333
return _predict(X, estimator=self._postfit_estimator)
288334

@@ -308,16 +354,26 @@ def predict_proba(self, X):
308354

309355
self._check_method("predict_proba")
310356

357+
meta = self.predict_proba_meta
358+
311359
if isinstance(X, da.Array):
360+
if meta is None:
361+
meta = _get_output_dask_ar_meta_for_estimator(
362+
_predict_proba, self._postfit_estimator, X
363+
)
312364
# XXX: multiclass
313365
return X.map_blocks(
314366
_predict_proba,
315367
estimator=self._postfit_estimator,
316-
dtype="float",
368+
meta=meta,
317369
chunks=(X.chunks[0], len(self._postfit_estimator.classes_)),
318370
)
319371
elif isinstance(X, dd._Frame):
320-
return X.map_partitions(_predict_proba, estimator=self._postfit_estimator)
372+
if meta is None:
373+
meta = dd.core.no_default
374+
return X.map_partitions(
375+
_predict_proba, estimator=self._postfit_estimator, meta=meta
376+
)
321377
else:
322378
return _predict_proba(X, estimator=self._postfit_estimator)
323379

@@ -424,6 +480,24 @@ class Incremental(ParallelPostFit):
424480
of the Dask arrays (default), or to fit in sequential order. This does
425481
not control shuffle between blocks or shuffling each block.
426482
483+
predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
484+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
485+
type of the estimators ``predict`` call.
486+
This meta is necessary for for some estimators to work with
487+
``dask.dataframe`` and ``dask.array``
488+
489+
predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
490+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
491+
type of the estimators ``predict_proba`` call.
492+
This meta is necessary for for some estimators to work with
493+
``dask.dataframe`` and ``dask.array``
494+
495+
transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
496+
An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
497+
type of the estimators ``transform`` call.
498+
This meta is necessary for for some estimators to work with
499+
``dask.dataframe`` and ``dask.array``
500+
427501
Attributes
428502
----------
429503
estimator_ : Estimator
@@ -460,11 +534,20 @@ def __init__(
460534
shuffle_blocks=True,
461535
random_state=None,
462536
assume_equal_chunks=True,
537+
predict_meta=None,
538+
predict_proba_meta=None,
539+
transform_meta=None,
463540
):
464541
self.shuffle_blocks = shuffle_blocks
465542
self.random_state = random_state
466543
self.assume_equal_chunks = assume_equal_chunks
467-
super(Incremental, self).__init__(estimator=estimator, scoring=scoring)
544+
super(Incremental, self).__init__(
545+
estimator=estimator,
546+
scoring=scoring,
547+
predict_meta=predict_meta,
548+
predict_proba_meta=predict_proba_meta,
549+
transform_meta=transform_meta,
550+
)
468551

469552
@property
470553
def _postfit_estimator(self):
@@ -552,3 +635,36 @@ def _predict_proba(part, estimator):
552635

553636
def _transform(part, estimator):
554637
return estimator.transform(part)
638+
639+
640+
def _get_output_dask_ar_meta_for_estimator(model_fn, estimator, input_dask_ar):
641+
"""
642+
Returns the output metadata array
643+
for the model function (predict, transform etc)
644+
by running the appropriate function on dummy data
645+
of shape (1, n_features)
646+
647+
Parameters
648+
----------
649+
650+
model_fun: Model function
651+
_predict, _transform etc
652+
653+
estimator : Estimator
654+
The underlying estimator that is fit.
655+
656+
input_dask_ar: The input dask_array
657+
658+
Returns
659+
-------
660+
metadata: metadata of output dask array
661+
662+
"""
663+
# sklearn fails if input array has size size
664+
# It requires at least 1 sample to run successfully
665+
ar = np.zeros(
666+
shape=(1, input_dask_ar.shape[1]),
667+
dtype=input_dask_ar.dtype,
668+
like=input_dask_ar._meta,
669+
)
670+
return model_fn(ar, estimator)

tests/model_selection/test_hyperband.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,14 @@ async def test_correct_params(c, s, a, b):
275275
"verbose",
276276
"prefix",
277277
}
278-
assert set(search.get_params().keys()) == base.union({"aggressiveness"})
278+
279+
search_keys = set(search.get_params().keys())
280+
# we remove meta because thats dask specific attribute
281+
search_keys.remove("predict_meta")
282+
search_keys.remove("predict_proba_meta")
283+
search_keys.remove("transform_meta")
284+
285+
assert search_keys == base.union({"aggressiveness"})
279286
meta = search.metadata
280287
SHAs_params = [
281288
bracket["SuccessiveHalvingSearchCV params"] for bracket in meta["brackets"]

tests/model_selection/test_incremental.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ async def test_explicit(c, s, a, b):
169169
params = [{"alpha": 0.1}, {"alpha": 0.2}]
170170

171171
def additional_calls(scores):
172-
""" Progress through predefined updates, checking along the way """
172+
"""Progress through predefined updates, checking along the way"""
173173
ts = scores[0][-1]["partial_fit_calls"]
174174
ts -= 1 # partial_fit_calls = time step + 1
175175
if ts == 0:

0 commit comments

Comments
 (0)