diff --git a/lightgbmlss/distributions/distribution_utils.py b/lightgbmlss/distributions/distribution_utils.py index 96a4173..c32b161 100644 --- a/lightgbmlss/distributions/distribution_utils.py +++ b/lightgbmlss/distributions/distribution_utils.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd from tqdm import tqdm +from scipy import stats from typing import Any, Dict, Optional, List, Tuple import matplotlib.pyplot as plt @@ -334,12 +335,89 @@ def draw_samples(self, dist_samples = dist_samples.astype(int) return dist_samples + + def get_moments(self, + predt_params: pd.DataFrame, + inference: str = "none", + n_samples: int = 1000, + seed: int = 123 + ) -> pd.DataFrame: + """ + Function that returns moments (mean, variance, mode) of a predicted distribution. + + Arguments + --------- + predt_params: pd.DataFrame + pd.DataFrame with predicted distributional parameters. + inference: str + Type of inference from drawn samples: + - "none" (default) Will return only the exact, implemented moments. + - "missing" Will infer moments for missing implementations by drawing samples. + - "all" Will infer all moments by drawing samples. + n_samples: int + Number of sample to draw from predicted response distribution. + seed: int + Manual seed. + + Returns + ------- + pred_dist: pd.DataFrame + DataFrame with mean, variance, and mode of predicted response distribution. + + """ + if self.tau is None: + pred_params = torch.tensor(predt_params.values) + dist_kwargs = {arg_name: param for arg_name, param in zip(self.distribution_arg_names, pred_params.T)} + dist_pred = self.distribution(**dist_kwargs) + pred_moments = pd.DataFrame() + + if inference != "none": + torch.manual_seed(seed) + dist_samples = dist_pred.sample((n_samples,)).squeeze().detach().numpy().T + + if inference == "all": + pred_moments["mean"] = np.mean(dist_samples, axis=1) + pred_moments["variance"] = np.var(dist_samples, axis=1) + pred_moments["mode"], _ = stats.mode(dist_samples, axis=1, keepdims=True) + return pred_moments + + try: + mean = dist_pred.mean + except NotImplementedError: + if inference == "missing": + pred_moments["mean"] = np.mean(dist_samples, axis=1) + else: + pred_moments["mean"] = mean.detach().numpy() + + try: + variance = dist_pred.variance + except NotImplementedError: + if inference == "missing": + pred_moments["variance"] = np.var(dist_samples, axis=1) + pass + else: + pred_moments["variance"] = variance.detach().numpy() + try: + mode = dist_pred.mode + except NotImplementedError: + if inference == "missing": + pred_moments["mode"], _ = stats.mode(dist_samples, axis=1) + else: + pred_moments["mode"] = mode.detach().numpy() + + if pred_moments.shape[1] == 0: + return None + else: + return pred_moments + else: + return None def predict_dist(self, booster: lgb.Booster, data: pd.DataFrame, start_values: np.ndarray, pred_type: str = "parameters", + moments_inference: str = "none", n_samples: int = 1000, quantiles: list = [0.1, 0.5, 0.9], seed: str = 123 @@ -361,6 +439,12 @@ def predict_dist(self, - "quantiles" calculates the quantiles from the predicted distribution. - "parameters" returns the predicted distributional parameters. - "expectiles" returns the predicted expectiles. + - "moments" returns the mean, variance, and (if implemented) mode. + moments_inference: str + Type of inference to use if the prediction type is "moments": + - "none" (default) Will return only the exact, implemented moments. + - "missing" Will infer moments for missing implementations by drawing samples. + - "all" Will infer all moments by drawing samples. n_samples : int Number of samples to draw from the predicted distribution. quantiles : List[float] @@ -398,18 +482,24 @@ def predict_dist(self, dist_params_predt = pd.DataFrame(dist_params_predt) dist_params_predt.columns = self.param_dict.keys() - # Draw samples from predicted response distribution - pred_samples_df = self.draw_samples(predt_params=dist_params_predt, - n_samples=n_samples, - seed=seed) - if pred_type == "parameters": return dist_params_predt elif pred_type == "expectiles": return dist_params_predt + + elif pred_type == "moments": + return self.get_moments(predt_params=dist_params_predt, + inference=moments_inference, + n_samples=n_samples, + seed=seed) + + # Draw samples from predicted response distribution + pred_samples_df = self.draw_samples(predt_params=dist_params_predt, + n_samples=n_samples, + seed=seed) - elif pred_type == "samples": + if pred_type == "samples": return pred_samples_df elif pred_type == "quantiles": diff --git a/lightgbmlss/model.py b/lightgbmlss/model.py index 60f9aaf..1276da9 100644 --- a/lightgbmlss/model.py +++ b/lightgbmlss/model.py @@ -436,6 +436,7 @@ def predict(self, pred_type: str = "parameters", n_samples: int = 1000, quantiles: list = [0.1, 0.5, 0.9], + moments_inference: str = "none", seed: str = 123): """ Function that predicts from the trained model. @@ -450,6 +451,7 @@ def predict(self, - "quantiles" calculates the quantiles from the predicted distribution. - "parameters" returns the predicted distributional parameters. - "expectiles" returns the predicted expectiles. + - "moments" returns the mean, variance, and (if implemented) mode. n_samples : int Number of samples to draw from the predicted distribution. quantiles : List[float] @@ -470,6 +472,7 @@ def predict(self, pred_type=pred_type, n_samples=n_samples, quantiles=quantiles, + moments_inference=moments_inference, seed=seed) return predt_df