1- import keras
1+ from collections .abc import Mapping
2+
23import numpy as np
4+
5+ import keras
36from keras .saving import (
47 register_keras_serializable as serializable ,
58)
@@ -21,7 +24,7 @@ class PointApproximator(ContinuousApproximator):
2124
2225 def estimate (
2326 self ,
24- conditions : dict [str , np .ndarray ],
27+ conditions : Mapping [str , np .ndarray ],
2528 split : bool = False ,
2629 ** kwargs ,
2730 ) -> dict [str , dict [str , np .ndarray | dict [str , np .ndarray ]]]:
@@ -33,7 +36,7 @@ def estimate(
3336
3437 Parameters
3538 ----------
36- conditions : dict [str, np.ndarray]
39+ conditions : Mapping [str, np.ndarray]
3740 A dictionary mapping variable names to arrays representing the conditions
3841 for the estimation process.
3942 split : bool, optional
@@ -71,7 +74,7 @@ def sample(
7174 self ,
7275 * ,
7376 num_samples : int ,
74- conditions : dict [str , np .ndarray ],
77+ conditions : Mapping [str , np .ndarray ],
7578 split : bool = False ,
7679 ** kwargs ,
7780 ) -> dict [str , dict [str , np .ndarray ]]:
@@ -111,7 +114,7 @@ def sample(
111114 # Optionally split the arrays along the last axis.
112115 if split :
113116 raise NotImplementedError ("split=True is currently not supported for `PointApproximator`." )
114- samples = split_arrays ( samples , axis = - 1 )
117+
115118 # Squeeze sample dictionary if there's only one key-value pair.
116119 samples = self ._squeeze_parametric_score_major_dict (samples )
117120
@@ -120,7 +123,7 @@ def sample(
120123 def log_prob (
121124 self ,
122125 * ,
123- data : dict [str , np .ndarray ],
126+ data : Mapping [str , np .ndarray ],
124127 ** kwargs ,
125128 ) -> np .ndarray | dict [str , np .ndarray ]:
126129 """
@@ -152,14 +155,14 @@ def log_prob(
152155
153156 return log_prob
154157
155- def _prepare_conditions (self , conditions : dict [str , np .ndarray ], ** kwargs ) -> dict [str , Tensor ]:
158+ def _prepare_conditions (self , conditions : Mapping [str , np .ndarray ], ** kwargs ) -> dict [str , Tensor ]:
156159 """Adapts and converts the conditions to tensors."""
157160 conditions = self .adapter (conditions , strict = False , stage = "inference" , ** kwargs )
158161 conditions .pop ("inference_variables" , None )
159162 return keras .tree .map_structure (keras .ops .convert_to_tensor , conditions )
160163
161164 def _apply_inverse_adapter_to_estimates (
162- self , estimates : dict [str , dict [str , Tensor ]], ** kwargs
165+ self , estimates : Mapping [str , Mapping [str , Tensor ]], ** kwargs
163166 ) -> dict [str , dict [str , dict [str , np .ndarray ]]]:
164167 """Applies the inverse adapter on each inner element of the _estimate output dictionary."""
165168 estimates = keras .tree .map_structure (keras .ops .convert_to_numpy , estimates )
@@ -183,7 +186,7 @@ def _apply_inverse_adapter_to_estimates(
183186 return processed
184187
185188 def _apply_inverse_adapter_to_samples (
186- self , samples : dict [str , Tensor ], ** kwargs
189+ self , samples : Mapping [str , Tensor ], ** kwargs
187190 ) -> dict [str , dict [str , np .ndarray ]]:
188191 """Applies the inverse adapter to a dictionary of samples."""
189192 samples = keras .tree .map_structure (keras .ops .convert_to_numpy , samples )
@@ -198,7 +201,7 @@ def _apply_inverse_adapter_to_samples(
198201 return processed
199202
200203 def _reorder_estimates (
201- self , estimates : dict [str , dict [str , dict [str , np .ndarray ]]]
204+ self , estimates : Mapping [str , Mapping [str , Mapping [str , np .ndarray ]]]
202205 ) -> dict [str , dict [str , dict [str , np .ndarray ]]]:
203206 """Reorders the nested dictionary so that the inference variable names become the top-level keys."""
204207 # Grab the variable names from one sample inner dictionary.
@@ -212,7 +215,7 @@ def _reorder_estimates(
212215 return reordered
213216
214217 def _squeeze_estimates (
215- self , estimates : dict [str , dict [str , dict [str , np .ndarray ]]]
218+ self , estimates : Mapping [str , Mapping [str , Mapping [str , np .ndarray ]]]
216219 ) -> dict [str , dict [str , np .ndarray ]]:
217220 """Squeezes each inner estimate dictionary to remove unnecessary nesting."""
218221 squeezed = {}
@@ -224,7 +227,7 @@ def _squeeze_estimates(
224227 return squeezed
225228
226229 def _squeeze_parametric_score_major_dict (
227- self , samples : dict [str , np .ndarray ]
230+ self , samples : Mapping [str , np .ndarray ]
228231 ) -> np .ndarray or dict [str , np .ndarray ]:
229232 """Squeezes the dictionary to just the value if there is only one key-value pair."""
230233 if len (samples ) == 1 :
0 commit comments