@@ -46,6 +46,24 @@ class ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixi
4646 a single NumPy array, which may exhaust the memory of your worker.
4747 You probably want to always specify `scoring`.
4848
49+ predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
50+ An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
51+ type of the estimators ``predict`` call.
52+ This meta is necessary for for some estimators to work with
53+ ``dask.dataframe`` and ``dask.array``
54+
55+ predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
56+ An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
57+ type of the estimators ``predict_proba`` call.
58+ This meta is necessary for for some estimators to work with
59+ ``dask.dataframe`` and ``dask.array``
60+
61+ transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
62+ An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
63+ type of the estimators ``transform`` call.
64+ This meta is necessary for for some estimators to work with
65+ ``dask.dataframe`` and ``dask.array``
66+
4967 Notes
5068 -----
5169
@@ -115,9 +133,19 @@ class ParallelPostFit(sklearn.base.BaseEstimator, sklearn.base.MetaEstimatorMixi
115133 [0.99407016, 0.00592984]])
116134 """
117135
118- def __init__ (self , estimator = None , scoring = None ):
136+ def __init__ (
137+ self ,
138+ estimator = None ,
139+ scoring = None ,
140+ predict_meta = None ,
141+ predict_proba_meta = None ,
142+ transform_meta = None ,
143+ ):
119144 self .estimator = estimator
120145 self .scoring = scoring
146+ self .predict_meta = predict_meta
147+ self .predict_proba_meta = predict_proba_meta
148+ self .transform_meta = transform_meta
121149
122150 def _check_array (self , X ):
123151 """Validate an array for post-fit tasks.
@@ -202,13 +230,24 @@ def transform(self, X):
202230 """
203231 self ._check_method ("transform" )
204232 X = self ._check_array (X )
233+ meta = self .transform_meta
205234
206235 if isinstance (X , da .Array ):
207- xx = np .zeros ((1 , X .shape [1 ]), dtype = X .dtype )
208- dt = _transform (xx , self ._postfit_estimator ).dtype
209- return X .map_blocks (_transform , estimator = self ._postfit_estimator , dtype = dt )
236+ if meta is None :
237+ meta = _get_output_dask_ar_meta_for_estimator (
238+ _transform , self ._postfit_estimator , X
239+ )
240+ return X .map_blocks (
241+ _transform , estimator = self ._postfit_estimator , meta = meta
242+ )
210243 elif isinstance (X , dd ._Frame ):
211- return X .map_partitions (_transform , estimator = self ._postfit_estimator )
244+ if meta is None :
245+ # dask-dataframe relies on dd.core.no_default
246+ # for infering meta
247+ meta = dd .core .no_default
248+ return X .map_partitions (
249+ _transform , estimator = self ._postfit_estimator , meta = meta
250+ )
212251 else :
213252 return _transform (X , estimator = self ._postfit_estimator )
214253
@@ -271,18 +310,25 @@ def predict(self, X):
271310 """
272311 self ._check_method ("predict" )
273312 X = self ._check_array (X )
313+ meta = self .predict_meta
274314
275315 if isinstance (X , da .Array ):
316+ if meta is None :
317+ meta = _get_output_dask_ar_meta_for_estimator (
318+ _predict , self ._postfit_estimator , X
319+ )
320+
276321 result = X .map_blocks (
277- _predict , dtype = "int" , estimator = self ._postfit_estimator , drop_axis = 1
322+ _predict , estimator = self ._postfit_estimator , drop_axis = 1 , meta = meta
278323 )
279324 return result
280325
281326 elif isinstance (X , dd ._Frame ):
327+ if meta is None :
328+ meta = dd .core .no_default
282329 return X .map_partitions (
283- _predict , estimator = self ._postfit_estimator , meta = np . array ([ 1 ])
330+ _predict , estimator = self ._postfit_estimator , meta = meta
284331 )
285-
286332 else :
287333 return _predict (X , estimator = self ._postfit_estimator )
288334
@@ -308,16 +354,26 @@ def predict_proba(self, X):
308354
309355 self ._check_method ("predict_proba" )
310356
357+ meta = self .predict_proba_meta
358+
311359 if isinstance (X , da .Array ):
360+ if meta is None :
361+ meta = _get_output_dask_ar_meta_for_estimator (
362+ _predict_proba , self ._postfit_estimator , X
363+ )
312364 # XXX: multiclass
313365 return X .map_blocks (
314366 _predict_proba ,
315367 estimator = self ._postfit_estimator ,
316- dtype = "float" ,
368+ meta = meta ,
317369 chunks = (X .chunks [0 ], len (self ._postfit_estimator .classes_ )),
318370 )
319371 elif isinstance (X , dd ._Frame ):
320- return X .map_partitions (_predict_proba , estimator = self ._postfit_estimator )
372+ if meta is None :
373+ meta = dd .core .no_default
374+ return X .map_partitions (
375+ _predict_proba , estimator = self ._postfit_estimator , meta = meta
376+ )
321377 else :
322378 return _predict_proba (X , estimator = self ._postfit_estimator )
323379
@@ -424,6 +480,24 @@ class Incremental(ParallelPostFit):
424480 of the Dask arrays (default), or to fit in sequential order. This does
425481 not control shuffle between blocks or shuffling each block.
426482
483+ predict_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
484+ An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
485+ type of the estimators ``predict`` call.
486+ This meta is necessary for for some estimators to work with
487+ ``dask.dataframe`` and ``dask.array``
488+
489+ predict_proba_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
490+ An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
491+ type of the estimators ``predict_proba`` call.
492+ This meta is necessary for for some estimators to work with
493+ ``dask.dataframe`` and ``dask.array``
494+
495+ transform_meta: pd.Series, pd.DataFrame, np.array deafult: None(infer)
496+ An empty ``pd.Series``, ``pd.DataFrame``, ``np.array`` that matches the output
497+ type of the estimators ``transform`` call.
498+ This meta is necessary for for some estimators to work with
499+ ``dask.dataframe`` and ``dask.array``
500+
427501 Attributes
428502 ----------
429503 estimator_ : Estimator
@@ -460,11 +534,20 @@ def __init__(
460534 shuffle_blocks = True ,
461535 random_state = None ,
462536 assume_equal_chunks = True ,
537+ predict_meta = None ,
538+ predict_proba_meta = None ,
539+ transform_meta = None ,
463540 ):
464541 self .shuffle_blocks = shuffle_blocks
465542 self .random_state = random_state
466543 self .assume_equal_chunks = assume_equal_chunks
467- super (Incremental , self ).__init__ (estimator = estimator , scoring = scoring )
544+ super (Incremental , self ).__init__ (
545+ estimator = estimator ,
546+ scoring = scoring ,
547+ predict_meta = predict_meta ,
548+ predict_proba_meta = predict_proba_meta ,
549+ transform_meta = transform_meta ,
550+ )
468551
469552 @property
470553 def _postfit_estimator (self ):
@@ -552,3 +635,36 @@ def _predict_proba(part, estimator):
552635
553636def _transform (part , estimator ):
554637 return estimator .transform (part )
638+
639+
640+ def _get_output_dask_ar_meta_for_estimator (model_fn , estimator , input_dask_ar ):
641+ """
642+ Returns the output metadata array
643+ for the model function (predict, transform etc)
644+ by running the appropriate function on dummy data
645+ of shape (1, n_features)
646+
647+ Parameters
648+ ----------
649+
650+ model_fun: Model function
651+ _predict, _transform etc
652+
653+ estimator : Estimator
654+ The underlying estimator that is fit.
655+
656+ input_dask_ar: The input dask_array
657+
658+ Returns
659+ -------
660+ metadata: metadata of output dask array
661+
662+ """
663+ # sklearn fails if input array has size size
664+ # It requires at least 1 sample to run successfully
665+ ar = np .zeros (
666+ shape = (1 , input_dask_ar .shape [1 ]),
667+ dtype = input_dask_ar .dtype ,
668+ like = input_dask_ar ._meta ,
669+ )
670+ return model_fn (ar , estimator )
0 commit comments