Skip to content

Commit b63ca3e

Browse files
committed
add types for Simplex
cast np.bool to bool
1 parent b1ea39b commit b63ca3e

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

src/pymatgen/util/coord.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -367,27 +367,40 @@ class Simplex(MSONable):
367367
simplex_dim (int): Dimension of the simplex coordinate space.
368368
"""
369369

370-
def __init__(self, coords) -> None:
370+
def __init__(self, coords: Sequence[Sequence[float]]) -> None:
371371
"""Initialize a Simplex from vertex coordinates.
372372
373373
Args:
374374
coords ([[float]]): Coords of the vertices of the simplex. e.g.
375375
[[1, 2, 3], [2, 4, 5], [6, 7, 8], [8, 9, 10].
376376
"""
377-
self._coords = np.array(coords)
377+
self._coords = np.asarray(coords)
378378
self.space_dim, self.simplex_dim = self._coords.shape
379379
self.origin = self._coords[-1]
380380
if self.space_dim == self.simplex_dim + 1:
381381
# pre-compute augmented matrix for calculating bary_coords
382382
self._aug = np.concatenate([coords, np.ones((self.space_dim, 1))], axis=-1)
383383
self._aug_inv = np.linalg.inv(self._aug)
384384

385+
def __eq__(self, other: object) -> bool:
386+
if not isinstance(other, Simplex):
387+
return NotImplemented
388+
return any(np.allclose(p, other.coords) for p in itertools.permutations(self._coords))
389+
390+
def __hash__(self) -> int:
391+
return len(self._coords)
392+
393+
def __repr__(self) -> str:
394+
output = [f"{self.simplex_dim}-simplex in {self.space_dim}D space\nVertices:"]
395+
output += [f"\t({', '.join(map(str, coord))})" for coord in self._coords]
396+
return "\n".join(output)
397+
385398
@property
386399
def volume(self) -> float:
387400
"""Volume of the simplex."""
388401
return abs(np.linalg.det(self._aug)) / math.factorial(self.simplex_dim)
389402

390-
def bary_coords(self, point):
403+
def bary_coords(self, point: ArrayLike) -> np.ndarray:
391404
"""
392405
Args:
393406
point (ArrayLike): Point coordinates.
@@ -400,7 +413,7 @@ def bary_coords(self, point):
400413
except AttributeError as exc:
401414
raise ValueError("Simplex is not full-dimensional") from exc
402415

403-
def point_from_bary_coords(self, bary_coords: ArrayLike):
416+
def point_from_bary_coords(self, bary_coords: ArrayLike) -> np.ndarray:
404417
"""
405418
Args:
406419
bary_coords (ArrayLike): Barycentric coordinates (d+1, d).
@@ -428,9 +441,14 @@ def in_simplex(self, point: Sequence[float], tolerance: float = 1e-8) -> bool:
428441
point (list[float]): Point to test
429442
tolerance (float): Tolerance to test if point is in simplex.
430443
"""
431-
return (self.bary_coords(point) >= -tolerance).all()
432-
433-
def line_intersection(self, point1: Sequence[float], point2: Sequence[float], tolerance: float = 1e-8):
444+
return bool((self.bary_coords(point) >= -tolerance).all())
445+
446+
def line_intersection(
447+
self,
448+
point1: Sequence[float],
449+
point2: Sequence[float],
450+
tolerance: float = 1e-8,
451+
) -> list[np.ndarray]:
434452
"""Compute the intersection points of a line with a simplex.
435453
436454
Args:
@@ -465,19 +483,6 @@ def line_intersection(self, point1: Sequence[float], point2: Sequence[float], to
465483
raise ValueError("More than 2 intersections found")
466484
return [self.point_from_bary_coords(b) for b in barys]
467485

468-
def __eq__(self, other: object) -> bool:
469-
if not isinstance(other, Simplex):
470-
return NotImplemented
471-
return any(np.allclose(p, other.coords) for p in itertools.permutations(self._coords))
472-
473-
def __hash__(self) -> int:
474-
return len(self._coords)
475-
476-
def __repr__(self) -> str:
477-
output = [f"{self.simplex_dim}-simplex in {self.space_dim}D space\nVertices:"]
478-
output += [f"\t({', '.join(map(str, coord))})" for coord in self._coords]
479-
return "\n".join(output)
480-
481486
@property
482487
def coords(self) -> np.ndarray:
483488
"""A copy of the vertex coordinates in the simplex."""

0 commit comments

Comments
 (0)