|
2 | 2 | # -*- coding: utf-8 -*- |
3 | 3 | import numpy as np |
4 | 4 | from tqdm import tqdm |
| 5 | +import threading |
| 6 | +from joblib import Parallel, delayed |
| 7 | +from sklearn.utils.fixes import _joblib_parallel_args |
5 | 8 | from graphdot.linalg.cholesky import CholSolver |
6 | 9 |
|
7 | 10 |
|
8 | 11 | class NaiveLocalExpertGP: |
9 | 12 | """Transductive Naive Local Experts of Gaussian process regression. |
10 | 13 |
|
11 | 14 | """ |
12 | | - def __init__(self, kernel, alpha=1e-8, n_local=500, normalize_y=False, kernel_options={}): |
| 15 | + def __init__(self, kernel, alpha=1e-8, n_local=500, normalize_y=False, |
| 16 | + n_jobs=1, kernel_options={}): |
13 | 17 | self.kernel = kernel |
14 | 18 | self.alpha = alpha |
15 | 19 | self.n_local = n_local |
16 | 20 | self.normalize_y = normalize_y |
| 21 | + self.n_jobs = n_jobs |
17 | 22 | self.kernel_options = kernel_options |
18 | 23 |
|
19 | 24 | @property |
@@ -79,27 +84,48 @@ def fit(self, X, y): |
79 | 84 | self.X = X |
80 | 85 | self.y = y |
81 | 86 |
|
| 87 | + def predict_(self, Z, return_std=False): |
| 88 | + Ks = self._gramian(Z, self.X) |
| 89 | + local_idx = np.argsort(-Ks)[:, :min(self.n_local, Ks.shape[1])][0] |
| 90 | + Ks_local = Ks[:, local_idx] |
| 91 | + X_local = self.X[local_idx] |
| 92 | + y_local = self.y[local_idx] |
| 93 | + K_local = self._gramian(X_local) |
| 94 | + Kinv_local, _ = self._invert(K_local) |
| 95 | + Ky_local = Kinv_local @ y_local |
| 96 | + y_mean = (Ks_local @ Ky_local) * self.y_std + self.y_mean |
| 97 | + if return_std: |
| 98 | + Kss = self._gramian(Z, diag=True) |
| 99 | + y_std = np.sqrt( |
| 100 | + np.maximum(0, Kss - (Ks_local @ (Kinv_local @ Ks_local.T)).diagonal()) |
| 101 | + ) |
| 102 | + return y_mean, y_std |
| 103 | + else: |
| 104 | + return y_mean |
| 105 | + |
| 106 | + def _accumulate_prediction(self, Z, y_hat, u_hat, lock, return_std=False): |
| 107 | + if return_std: |
| 108 | + prediction, uncertainty = self.predict_(Z, return_std=True) |
| 109 | + with lock: |
| 110 | + y_hat.append(prediction) |
| 111 | + u_hat.append(uncertainty) |
| 112 | + else: |
| 113 | + prediction = self.predict_(Z, return_std=False) |
| 114 | + with lock: |
| 115 | + y_hat.append(prediction) |
| 116 | + |
82 | 117 | def predict(self, Z, return_std=False): |
83 | | - y_mean = [] |
84 | | - y_std = [] |
85 | | - for z in tqdm(Z, total=len(Z)): |
86 | | - Z_ = z.reshape(1, -1) |
87 | | - Ks = self._gramian(Z_, self.X) |
88 | | - local_idx = np.argsort(-Ks)[:, :min(self.n_local, Ks.shape[1])][0] |
89 | | - Ks_local = Ks[:, local_idx] |
90 | | - X_local = self.X[local_idx] |
91 | | - y_local = self.y[local_idx] |
92 | | - K_local = self._gramian(X_local) |
93 | | - Kinv_local, _ = self._invert(K_local) |
94 | | - Ky_local = Kinv_local @ y_local |
95 | | - y_mean.append((Ks_local @ Ky_local) * self.y_std + self.y_mean) |
96 | | - if return_std is True: |
97 | | - Kss = self._gramian(Z_, diag=True) |
98 | | - std = np.sqrt( |
99 | | - np.maximum(0, Kss - (Ks_local @ (Kinv_local @ Ks_local.T)).diagonal()) |
100 | | - ) |
101 | | - y_std.append(std) |
| 118 | + results = Parallel( |
| 119 | + n_jobs=self.n_jobs, verbose=True, |
| 120 | + **_joblib_parallel_args(prefer='processes'))( |
| 121 | + delayed(self.predict_)( |
| 122 | + z.reshape(1, -1), |
| 123 | + return_std |
| 124 | + ) |
| 125 | + for z in Z) |
| 126 | + y_mean = np.asarray([result[0][0] for result in results]) |
102 | 127 | if return_std: |
103 | | - return np.concatenate(y_mean), np.concatenate(y_std) |
| 128 | + y_std = np.asarray([result[1][0] for result in results]) |
| 129 | + return y_mean, y_std |
104 | 130 | else: |
105 | | - return np.concatenate(y_mean) |
| 131 | + return y_mean |
0 commit comments