77import numpy as np
88import pandas as pd
99from tqdm import tqdm
10+ from scipy import stats
1011
1112from typing import Any , Dict , Optional , List , Tuple
1213import matplotlib .pyplot as plt
@@ -334,12 +335,89 @@ def draw_samples(self,
334335 dist_samples = dist_samples .astype (int )
335336
336337 return dist_samples
338+
339+ def get_moments (self ,
340+ predt_params : pd .DataFrame ,
341+ inference : str = "none" ,
342+ n_samples : int = 1000 ,
343+ seed : int = 123
344+ ) -> pd .DataFrame :
345+ """
346+ Function that returns moments (mean, variance, mode) of a predicted distribution.
347+
348+ Arguments
349+ ---------
350+ predt_params: pd.DataFrame
351+ pd.DataFrame with predicted distributional parameters.
352+ inference: str
353+ Type of inference from drawn samples:
354+ - "none" (default) Will return only the exact, implemented moments.
355+ - "missing" Will infer moments for missing implementations by drawing samples.
356+ - "all" Will infer all moments by drawing samples.
357+ n_samples: int
358+ Number of sample to draw from predicted response distribution.
359+ seed: int
360+ Manual seed.
361+
362+ Returns
363+ -------
364+ pred_dist: pd.DataFrame
365+ DataFrame with mean, variance, and mode of predicted response distribution.
366+
367+ """
368+ if self .tau is None :
369+ pred_params = torch .tensor (predt_params .values )
370+ dist_kwargs = {arg_name : param for arg_name , param in zip (self .distribution_arg_names , pred_params .T )}
371+ dist_pred = self .distribution (** dist_kwargs )
372+ pred_moments = pd .DataFrame ()
373+
374+ if inference != "none" :
375+ torch .manual_seed (seed )
376+ dist_samples = dist_pred .sample ((n_samples ,)).squeeze ().detach ().numpy ().T
377+
378+ if inference == "all" :
379+ pred_moments ["mean" ] = np .mean (dist_samples , axis = 1 )
380+ pred_moments ["variance" ] = np .var (dist_samples , axis = 1 )
381+ pred_moments ["mode" ], _ = stats .mode (dist_samples , axis = 1 , keepdims = True )
382+ return pred_moments
383+
384+ try :
385+ mean = dist_pred .mean
386+ except NotImplementedError :
387+ if inference == "missing" :
388+ pred_moments ["mean" ] = np .mean (dist_samples , axis = 1 )
389+ else :
390+ pred_moments ["mean" ] = mean .detach ().numpy ()
391+
392+ try :
393+ variance = dist_pred .variance
394+ except NotImplementedError :
395+ if inference == "missing" :
396+ pred_moments ["variance" ] = np .var (dist_samples , axis = 1 )
397+ pass
398+ else :
399+ pred_moments ["variance" ] = variance .detach ().numpy ()
400+ try :
401+ mode = dist_pred .mode
402+ except NotImplementedError :
403+ if inference == "missing" :
404+ pred_moments ["mode" ], _ = stats .mode (dist_samples , axis = 1 )
405+ else :
406+ pred_moments ["mode" ] = mode .detach ().numpy ()
407+
408+ if pred_moments .shape [1 ] == 0 :
409+ return None
410+ else :
411+ return pred_moments
412+ else :
413+ return None
337414
338415 def predict_dist (self ,
339416 booster : lgb .Booster ,
340417 data : pd .DataFrame ,
341418 start_values : np .ndarray ,
342419 pred_type : str = "parameters" ,
420+ moments_inference : str = "none" ,
343421 n_samples : int = 1000 ,
344422 quantiles : list = [0.1 , 0.5 , 0.9 ],
345423 seed : str = 123
@@ -361,6 +439,12 @@ def predict_dist(self,
361439 - "quantiles" calculates the quantiles from the predicted distribution.
362440 - "parameters" returns the predicted distributional parameters.
363441 - "expectiles" returns the predicted expectiles.
442+ - "moments" returns the mean, variance, and (if implemented) mode.
443+ moments_inference: str
444+ Type of inference to use if the prediction type is "moments":
445+ - "none" (default) Will return only the exact, implemented moments.
446+ - "missing" Will infer moments for missing implementations by drawing samples.
447+ - "all" Will infer all moments by drawing samples.
364448 n_samples : int
365449 Number of samples to draw from the predicted distribution.
366450 quantiles : List[float]
@@ -398,18 +482,24 @@ def predict_dist(self,
398482 dist_params_predt = pd .DataFrame (dist_params_predt )
399483 dist_params_predt .columns = self .param_dict .keys ()
400484
401- # Draw samples from predicted response distribution
402- pred_samples_df = self .draw_samples (predt_params = dist_params_predt ,
403- n_samples = n_samples ,
404- seed = seed )
405-
406485 if pred_type == "parameters" :
407486 return dist_params_predt
408487
409488 elif pred_type == "expectiles" :
410489 return dist_params_predt
490+
491+ elif pred_type == "moments" :
492+ return self .get_moments (predt_params = dist_params_predt ,
493+ inference = moments_inference ,
494+ n_samples = n_samples ,
495+ seed = seed )
496+
497+ # Draw samples from predicted response distribution
498+ pred_samples_df = self .draw_samples (predt_params = dist_params_predt ,
499+ n_samples = n_samples ,
500+ seed = seed )
411501
412- elif pred_type == "samples" :
502+ if pred_type == "samples" :
413503 return pred_samples_df
414504
415505 elif pred_type == "quantiles" :
0 commit comments