@@ -367,27 +367,40 @@ class Simplex(MSONable):
367367 simplex_dim (int): Dimension of the simplex coordinate space.
368368 """
369369
370- def __init__ (self , coords ) -> None :
370+ def __init__ (self , coords : Sequence [ Sequence [ float ]] ) -> None :
371371 """Initialize a Simplex from vertex coordinates.
372372
373373 Args:
374374 coords ([[float]]): Coords of the vertices of the simplex. e.g.
375375 [[1, 2, 3], [2, 4, 5], [6, 7, 8], [8, 9, 10].
376376 """
377- self ._coords = np .array (coords )
377+ self ._coords = np .asarray (coords )
378378 self .space_dim , self .simplex_dim = self ._coords .shape
379379 self .origin = self ._coords [- 1 ]
380380 if self .space_dim == self .simplex_dim + 1 :
381381 # pre-compute augmented matrix for calculating bary_coords
382382 self ._aug = np .concatenate ([coords , np .ones ((self .space_dim , 1 ))], axis = - 1 )
383383 self ._aug_inv = np .linalg .inv (self ._aug )
384384
385+ def __eq__ (self , other : object ) -> bool :
386+ if not isinstance (other , Simplex ):
387+ return NotImplemented
388+ return any (np .allclose (p , other .coords ) for p in itertools .permutations (self ._coords ))
389+
390+ def __hash__ (self ) -> int :
391+ return len (self ._coords )
392+
393+ def __repr__ (self ) -> str :
394+ output = [f"{ self .simplex_dim } -simplex in { self .space_dim } D space\n Vertices:" ]
395+ output += [f"\t ({ ', ' .join (map (str , coord ))} )" for coord in self ._coords ]
396+ return "\n " .join (output )
397+
385398 @property
386399 def volume (self ) -> float :
387400 """Volume of the simplex."""
388401 return abs (np .linalg .det (self ._aug )) / math .factorial (self .simplex_dim )
389402
390- def bary_coords (self , point ) :
403+ def bary_coords (self , point : ArrayLike ) -> np . ndarray :
391404 """
392405 Args:
393406 point (ArrayLike): Point coordinates.
@@ -400,7 +413,7 @@ def bary_coords(self, point):
400413 except AttributeError as exc :
401414 raise ValueError ("Simplex is not full-dimensional" ) from exc
402415
403- def point_from_bary_coords (self , bary_coords : ArrayLike ):
416+ def point_from_bary_coords (self , bary_coords : ArrayLike ) -> np . ndarray :
404417 """
405418 Args:
406419 bary_coords (ArrayLike): Barycentric coordinates (d+1, d).
@@ -428,9 +441,14 @@ def in_simplex(self, point: Sequence[float], tolerance: float = 1e-8) -> bool:
428441 point (list[float]): Point to test
429442 tolerance (float): Tolerance to test if point is in simplex.
430443 """
431- return (self .bary_coords (point ) >= - tolerance ).all ()
432-
433- def line_intersection (self , point1 : Sequence [float ], point2 : Sequence [float ], tolerance : float = 1e-8 ):
444+ return bool ((self .bary_coords (point ) >= - tolerance ).all ())
445+
446+ def line_intersection (
447+ self ,
448+ point1 : Sequence [float ],
449+ point2 : Sequence [float ],
450+ tolerance : float = 1e-8 ,
451+ ) -> list [np .ndarray ]:
434452 """Compute the intersection points of a line with a simplex.
435453
436454 Args:
@@ -465,19 +483,6 @@ def line_intersection(self, point1: Sequence[float], point2: Sequence[float], to
465483 raise ValueError ("More than 2 intersections found" )
466484 return [self .point_from_bary_coords (b ) for b in barys ]
467485
468- def __eq__ (self , other : object ) -> bool :
469- if not isinstance (other , Simplex ):
470- return NotImplemented
471- return any (np .allclose (p , other .coords ) for p in itertools .permutations (self ._coords ))
472-
473- def __hash__ (self ) -> int :
474- return len (self ._coords )
475-
476- def __repr__ (self ) -> str :
477- output = [f"{ self .simplex_dim } -simplex in { self .space_dim } D space\n Vertices:" ]
478- output += [f"\t ({ ', ' .join (map (str , coord ))} )" for coord in self ._coords ]
479- return "\n " .join (output )
480-
481486 @property
482487 def coords (self ) -> np .ndarray :
483488 """A copy of the vertex coordinates in the simplex."""
0 commit comments