2323from pymatgen .electronic_structure .core import Spin
2424
2525if TYPE_CHECKING :
26+ from numpy .typing import NDArray
2627 from typing_extensions import Any , Self
2728
2829 from pymatgen .core .structure import IStructure
@@ -62,8 +63,8 @@ def __init__(
6263 self ,
6364 structure : Structure | IStructure ,
6465 data : dict [str , np .ndarray ],
65- distance_matrix : np . ndarray | None = None ,
66- data_aug : np . ndarray | None = None ,
66+ distance_matrix : dict | None = None ,
67+ data_aug : dict [ str , NDArray ] | None = None ,
6768 ) -> None :
6869 """
6970 Typically, this constructor is not used directly and the static
@@ -85,11 +86,11 @@ def __init__(
8586 # convert data to numpy arrays in case they were jsanitized as lists
8687 self .data = {k : np .array (v ) for k , v in data .items ()}
8788 self .dim = self .data ["total" ].shape
88- self .data_aug = data_aug
89+ self .data_aug = data_aug or {}
8990 self .ngridpts = self .dim [0 ] * self .dim [1 ] * self .dim [2 ]
9091 # lazy init the spin data since this is not always needed.
9192 self ._spin_data : dict [Spin , float ] = {}
92- self ._distance_matrix = distance_matrix
93+ self ._distance_matrix = distance_matrix if distance_matrix is not None else {}
9394 self .xpoints = np .linspace (0.0 , 1.0 , num = self .dim [0 ])
9495 self .ypoints = np .linspace (0.0 , 1.0 , num = self .dim [1 ])
9596 self .zpoints = np .linspace (0.0 , 1.0 , num = self .dim [2 ])
@@ -168,7 +169,7 @@ def linear_add(self, other, scale_factor=1.0) -> VolumetricData:
168169
169170 new = deepcopy (self )
170171 new .data = data
171- new .data_aug = None
172+ new .data_aug = {}
172173 return new
173174
174175 def scale (self , factor ):
@@ -247,6 +248,7 @@ def get_integrated_diff(self, ind, radius, nbins=1):
247248
248249 struct = self .structure
249250 a = self .dim
251+ self ._distance_matrix = {} if self ._distance_matrix is None else self ._distance_matrix
250252 if ind not in self ._distance_matrix or self ._distance_matrix [ind ]["max_radius" ] < radius :
251253 coords = []
252254 for x , y , z in itertools .product (* (list (range (i )) for i in a )):
0 commit comments