diff --git a/src/pymatgen/analysis/phase_diagram.py b/src/pymatgen/analysis/phase_diagram.py
index 2610d96049b..a18b4202923 100644
--- a/src/pymatgen/analysis/phase_diagram.py
+++ b/src/pymatgen/analysis/phase_diagram.py
@@ -38,11 +38,14 @@
if TYPE_CHECKING:
from collections.abc import Collection, Iterator, Sequence
from io import StringIO
- from typing import Any, Literal
+ from typing import Any, ClassVar, Literal
- from numpy.typing import ArrayLike
+ from matplotlib.colors import Colormap
+ from numpy.typing import ArrayLike, NDArray
from typing_extensions import Self
+ from pymatgen.entries.computed_entries import ComputedEntry
+
logger = logging.getLogger(__name__)
with open(
@@ -73,7 +76,7 @@ def __init__(
energy: float,
name: str | None = None,
attribute: object = None,
- ):
+ ) -> None:
"""
Args:
composition (Composition): Composition
@@ -86,7 +89,7 @@ def __init__(
self.name = name or self.reduced_formula
self.attribute = attribute
- def __repr__(self):
+ def __repr__(self) -> str:
name = ""
if self.name != self.reduced_formula:
name = f" ({self.name})"
@@ -97,12 +100,12 @@ def energy(self) -> float:
"""The entry's energy."""
return self._energy
- def as_dict(self):
+ def as_dict(self) -> dict[str, Any]:
"""Get MSONable dict representation of PDEntry."""
return super().as_dict() | {"name": self.name, "attribute": self.attribute}
@classmethod
- def from_dict(cls, dct: dict) -> Self:
+ def from_dict(cls, dct: dict[str, Any]) -> Self:
"""
Args:
dct (dict): dictionary representation of PDEntry.
@@ -125,7 +128,7 @@ class GrandPotPDEntry(PDEntry):
dict.
"""
- def __init__(self, entry, chempots, name=None):
+ def __init__(self, entry: PDEntry, chempots: dict[Element, float], name: str | None = None) -> None:
"""
Args:
entry: A PDEntry-like object.
@@ -146,6 +149,16 @@ def __init__(self, entry, chempots, name=None):
self.original_comp = self._composition
self.chempots = chempots
+ def __repr__(self):
+ output = [
+ (
+ f"GrandPotPDEntry with original composition {self.original_entry.composition}, "
+ f"energy = {self.original_entry.energy:.4f}, "
+ ),
+ "chempots = " + ", ".join(f"mu_{el} = {mu:.4f}" for el, mu in self.chempots.items()),
+ ]
+ return "".join(output)
+
@property
def composition(self) -> Composition:
"""The composition after removing free species.
@@ -156,7 +169,7 @@ def composition(self) -> Composition:
return Composition({el: self._composition[el] for el in self._composition.elements if el not in self.chempots})
@property
- def chemical_energy(self):
+ def chemical_energy(self) -> float:
"""The chemical energy term mu*N in the grand potential.
Returns:
@@ -169,17 +182,7 @@ def energy(self) -> float:
"""Grand potential energy."""
return self._energy - self.chemical_energy
- def __repr__(self):
- output = [
- (
- f"GrandPotPDEntry with original composition {self.original_entry.composition}, "
- f"energy = {self.original_entry.energy:.4f}, "
- ),
- "chempots = " + ", ".join(f"mu_{el} = {mu:.4f}" for el, mu in self.chempots.items()),
- ]
- return "".join(output)
-
- def as_dict(self):
+ def as_dict(self) -> dict[str, Any]:
"""Get MSONable dict representation of GrandPotPDEntry."""
return {
"@module": type(self).__module__,
@@ -212,7 +215,7 @@ class TransformedPDEntry(PDEntry):
"""
# Tolerance for determining if amount of a composition is positive.
- amount_tol = 1e-5
+ amount_tol: ClassVar[float] = 1e-5
def __init__(self, entry, sp_mapping, name=None):
"""
@@ -236,6 +239,14 @@ def __init__(self, entry, sp_mapping, name=None):
if not all(self.rxn.get_coeff(comp) <= TransformedPDEntry.amount_tol for comp in self.sp_mapping):
raise TransformedPDEntryError("Only reactions with positive amounts of reactants allowed")
+ def __repr__(self):
+ output = [
+ f"TransformedPDEntry {self.composition}",
+ f" with original composition {self.original_entry.composition}",
+ f", energy = {self.original_entry.energy:.4f}",
+ ]
+ return "".join(output)
+
@property
def composition(self) -> Composition:
"""The composition in the dummy species space.
@@ -254,14 +265,6 @@ def composition(self) -> Composition:
return Composition(trans_comp)
- def __repr__(self):
- output = [
- f"TransformedPDEntry {self.composition}",
- f" with original composition {self.original_entry.composition}",
- f", energy = {self.original_entry.energy:.4f}",
- ]
- return "".join(output)
-
def as_dict(self):
"""Get MSONable dict representation of TransformedPDEntry."""
return {
@@ -336,8 +339,8 @@ class PhaseDiagram(MSONable):
"""
# Tolerance for determining if formation energy is positive.
- formation_energy_tol = 1e-11
- numerical_tol = 1e-8
+ formation_energy_tol: ClassVar[float] = 1e-11
+ numerical_tol: ClassVar[float] = 1e-8
def __init__(
self,
@@ -366,6 +369,7 @@ def __init__(
self.elements = elements
self.entries = entries
+
if computed_data is None:
computed_data = self._compute()
else:
@@ -375,19 +379,32 @@ def __init__(
# Update keys to be Element objects in case they are strings in pre-computed data
computed_data["el_refs"] = [(Element(el_str), entry) for el_str, entry in computed_data["el_refs"]]
+
self.computed_data = computed_data
- self.facets = computed_data["facets"]
- self.simplexes = computed_data["simplexes"]
- self.all_entries = computed_data["all_entries"]
- self.qhull_data = computed_data["qhull_data"]
- self.dim = computed_data["dim"]
- self.el_refs = dict(computed_data["el_refs"])
- self.qhull_entries = tuple(computed_data["qhull_entries"])
- self._qhull_spaces = tuple(frozenset(e.elements) for e in self.qhull_entries)
- self._stable_entries = tuple({self.qhull_entries[idx] for idx in set(itertools.chain(*self.facets))})
+
+ self.facets: list[NDArray[int]] = computed_data["facets"]
+ self.simplexes: list[Simplex] = computed_data["simplexes"]
+ self.all_entries: list[PDEntry] = computed_data["all_entries"]
+ self.qhull_data: np.ndarray = computed_data["qhull_data"]
+ self.dim: int = computed_data["dim"]
+ self.el_refs: dict[Element, PDEntry] = dict(computed_data["el_refs"])
+ self.qhull_entries: tuple[PDEntry, ...] = tuple(computed_data["qhull_entries"])
+ self._qhull_spaces: tuple = tuple(frozenset(e.elements) for e in self.qhull_entries)
+ self._stable_entries: tuple[PDEntry, ...] = tuple(
+ {self.qhull_entries[idx] for idx in set(itertools.chain(*self.facets))}
+ )
self._stable_spaces = tuple(frozenset(e.elements) for e in self._stable_entries)
- def as_dict(self):
+ def __repr__(self) -> str:
+ symbols = [el.symbol for el in self.elements]
+ output = [
+ f"{'-'.join(symbols)} phase diagram",
+ f"{len(self.stable_entries)} stable phases: ",
+ ", ".join(entry.name for entry in sorted(self.stable_entries, key=str)),
+ ]
+ return "\n".join(output)
+
+ def as_dict(self) -> dict[str, Any]:
"""Get MSONable dict representation of PhaseDiagram."""
qhull_entry_indices = [self.all_entries.index(e) for e in self.qhull_entries]
@@ -517,13 +534,13 @@ def pd_coords(self, comp: Composition) -> np.ndarray:
return np.array([comp.get_atomic_fraction(el) for el in self.elements[1:]])
@property
- def all_entries_hulldata(self):
+ def all_entries_hulldata(self) -> np.ndarray:
"""The ndarray used to construct the convex hull."""
data = [
[e.composition.get_atomic_fraction(el) for el in self.elements] + [e.energy_per_atom]
for e in self.all_entries
]
- return np.array(data)[:, 1:]
+ return np.asarray(data)[:, 1:]
@property
def unstable_entries(self) -> set[Entry]:
@@ -542,7 +559,7 @@ def stable_entries(self) -> set[Entry]:
return set(self._stable_entries)
@lru_cache(1) # noqa: B019
- def _get_stable_entries_in_space(self, space) -> list[Entry]:
+ def _get_stable_entries_in_space(self, space: set[Element]) -> list[Entry]:
"""
Args:
space (set[Element]): set of Element objects.
@@ -599,15 +616,6 @@ def get_form_energy_per_atom(self, entry: PDEntry) -> float:
"""
return self.get_form_energy(entry) / entry.composition.num_atoms
- def __repr__(self) -> str:
- symbols = [el.symbol for el in self.elements]
- output = [
- f"{'-'.join(symbols)} phase diagram",
- f"{len(self.stable_entries)} stable phases: ",
- ", ".join(entry.name for entry in sorted(self.stable_entries, key=str)),
- ]
- return "\n".join(output)
-
@lru_cache(1) # noqa: B019
def _get_facet_and_simplex(self, comp: Composition) -> tuple[Simplex, Simplex]:
"""Get any facet that a composition falls into. Cached so successive
@@ -623,7 +631,7 @@ def _get_facet_and_simplex(self, comp: Composition) -> tuple[Simplex, Simplex]:
raise RuntimeError(f"No facet found for {comp = }")
- def _get_all_facets_and_simplexes(self, comp):
+ def _get_all_facets_and_simplexes(self, comp: Composition) -> list:
"""Get all facets that a composition falls into.
Args:
@@ -675,7 +683,7 @@ def _get_simplex_intersections(self, c1, c2):
for sc in self.simplexes:
intersections.extend(sc.line_intersection(c1, c2))
- return np.array(intersections)
+ return np.asarray(intersections)
def get_decomposition(self, comp: Composition) -> dict[PDEntry, float]:
"""
@@ -915,7 +923,7 @@ def get_decomp_and_phase_separation_energy(
return self.get_decomp_and_e_above_hull(entry, allow_negative=True, **kwargs)
# take entries with negative e_form and different compositions as competing entries
- competing_entries = {c for c in compare_entries if id(c) not in same_comp_mem_ids}
+ competing_entries: set[PDEntry] = {c for c in compare_entries if id(c) not in same_comp_mem_ids}
# NOTE SLSQP optimizer doesn't scale well for > 300 competing entries.
if len(competing_entries) > space_limit and not stable_only:
@@ -942,7 +950,7 @@ def get_decomp_and_phase_separation_energy(
stacklevel=2,
)
- decomp = _get_slsqp_decomp(entry.composition, competing_entries, tols, maxiter)
+ decomp = _get_slsqp_decomp(entry.composition, list(competing_entries), tols, maxiter)
# find the minimum alternative formation energy for the decomposition
decomp_enthalpy = np.sum([c.energy_per_atom * amt for c, amt in decomp.items()])
@@ -951,7 +959,7 @@ def get_decomp_and_phase_separation_energy(
return decomp, decomp_enthalpy
- def get_phase_separation_energy(self, entry, **kwargs):
+ def get_phase_separation_energy(self, entry: PDEntry, **kwargs):
"""
Provides the energy to the convex hull for the given entry. For stable entries
already in the phase diagram the algorithm provides the phase separation energy
@@ -980,7 +988,7 @@ def get_phase_separation_energy(self, entry, **kwargs):
"""
return self.get_decomp_and_phase_separation_energy(entry, **kwargs)[1]
- def get_composition_chempots(self, comp):
+ def get_composition_chempots(self, comp: Composition) -> dict[Element, float]:
"""Get the chemical potentials for all elements at a given composition.
Args:
@@ -992,7 +1000,7 @@ def get_composition_chempots(self, comp):
facet = self._get_facet_and_simplex(comp)[0]
return self._get_facet_chempots(facet)
- def get_all_chempots(self, comp):
+ def get_all_chempots(self, comp: Composition) -> dict[str, dict[Element, float]]:
"""Get chemical potentials at a given composition.
Args:
@@ -1010,7 +1018,7 @@ def get_all_chempots(self, comp):
return chempots
- def get_transition_chempots(self, element):
+ def get_transition_chempots(self, element: Element) -> tuple[float, ...]:
"""Get the critical chemical potentials for an element in the Phase
Diagram.
@@ -1029,7 +1037,7 @@ def get_transition_chempots(self, element):
chempots = self._get_facet_chempots(facet)
critical_chempots.append(chempots[element])
- clean_pots = []
+ clean_pots: list[float] = []
for c in sorted(critical_chempots):
if len(clean_pots) == 0 or not math.isclose(
c, clean_pots[-1], abs_tol=PhaseDiagram.numerical_tol, rel_tol=0
@@ -1038,7 +1046,11 @@ def get_transition_chempots(self, element):
clean_pots.reverse()
return tuple(clean_pots)
- def get_critical_compositions(self, comp1, comp2):
+ def get_critical_compositions(
+ self,
+ comp1: Composition,
+ comp2: Composition,
+ ) -> list[Composition]:
"""Get the critical compositions along the tieline between two
compositions. I.e. where the decomposition products change.
The endpoints are also returned.
@@ -1098,7 +1110,12 @@ def get_critical_compositions(self, comp1, comp2):
return [Composition((elem, val) for elem, val in zip(pd_els, m, strict=True)) for m in cs]
- def get_element_profile(self, element, comp, comp_tol=1e-5):
+ def get_element_profile(
+ self,
+ element: Element,
+ comp: Composition,
+ comp_tol: float = 1e-5,
+ ) -> list[dict[str, Any]]:
"""
Provides the element evolution data for a composition. For example, can be used
to analyze Li conversion voltages by varying mu_Li and looking at the phases
@@ -1146,7 +1163,10 @@ def get_element_profile(self, element, comp, comp_tol=1e-5):
return evolution
def get_chempot_range_map(
- self, elements: Sequence[Element], referenced: bool = True, joggle: bool = True
+ self,
+ elements: Sequence[Element],
+ referenced: bool = True,
+ joggle: bool = True,
) -> dict[Element, list[Simplex]]:
"""Get a chemical potential range map for each stable entry.
@@ -1199,7 +1219,12 @@ def get_chempot_range_map(
return chempot_ranges
- def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
+ def getmu_vertices_stability_phase(
+ self,
+ target_comp: Composition,
+ dep_elt: Element,
+ tol_en: float = 1e-2,
+ ) -> list[dict[Element, float]] | None:
"""Get a set of chemical potentials corresponding to the vertices of
the simplex in the chemical potential phase diagram.
The simplex is built using all elements in the target_composition
@@ -1233,7 +1258,7 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
if elem.composition.reduced_composition == target_comp.reduced_composition:
multiplier = elem.composition[dep_elt] / target_comp[dep_elt]
ef = elem.energy / multiplier
- all_coords = []
+ all_coords: list[dict] = []
for simplex in chempots:
for v in simplex._coords:
elements = [elem for elem in self.elements if elem != dep_elt]
@@ -1257,7 +1282,11 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
return all_coords
return None
- def get_chempot_range_stability_phase(self, target_comp, open_elt):
+ def get_chempot_range_stability_phase(
+ self,
+ target_comp: Composition,
+ open_elt: Element,
+ ) -> dict[Element, tuple[float, float]]:
"""Get a set of chemical potentials corresponding to the max and min
chemical potential of the open element for a given composition. It is
quite common to have for instance a ternary oxide (e.g., ABO3) for
@@ -1317,14 +1346,14 @@ def get_plot(
ternary_style: Literal["2d", "3d"] = "2d",
label_stable: bool = True,
label_unstable: bool = True,
- ordering: Sequence[str] | None = None,
- energy_colormap=None,
+ ordering: Sequence[Literal["Up", "Left", "Right"]] | None = None,
+ energy_colormap: str | Colormap | None = None,
process_attributes: bool = False,
- ax: plt.Axes = None,
+ ax: plt.Axes | None = None,
label_uncertainties: bool = False,
fill: bool = True,
**kwargs,
- ):
+ ) -> go.Figure | plt.Axes:
"""
Convenient wrapper for PDPlotter. Initializes a PDPlotter object and calls
get_plot() with provided combined arguments.
@@ -1479,7 +1508,7 @@ class CompoundPhaseDiagram(PhaseDiagram):
"""
# Tolerance for determining if amount of a composition is positive.
- amount_tol = 1e-5
+ amount_tol: ClassVar[float] = 1e-5
def __init__(self, entries, terminal_compositions, normalize_terminal_compositions=True):
"""Initialize a CompoundPhaseDiagram.
@@ -1620,6 +1649,19 @@ class PatchedPhaseDiagram(PhaseDiagram):
These are entries corresponding to the lowest energy element entries for
simple compositional phase diagrams.
elements (list[Element]): List of elements in the phase diagram.
+
+ NOTE following methods are inherited unchanged from `PhaseDiagram`:
+ - __repr__
+ - all_entries_hulldata
+ - unstable_entries
+ - stable_entries
+ - get_form_energy
+ - get_form_energy_per_atom
+ - get_hull_energy
+ - get_e_above_hull
+ - get_decomp_and_e_above_hull
+ - get_decomp_and_phase_separation_energy
+ - get_phase_separation_energ
"""
def __init__(
@@ -1795,19 +1837,6 @@ def remove_redundant_spaces(spaces, keep_all_spaces=False):
return result
- # NOTE following methods are inherited unchanged from PhaseDiagram:
- # __repr__,
- # all_entries_hulldata,
- # unstable_entries,
- # stable_entries,
- # get_form_energy(),
- # get_form_energy_per_atom(),
- # get_hull_energy(),
- # get_e_above_hull(),
- # get_decomp_and_e_above_hull(),
- # get_decomp_and_phase_separation_energy(),
- # get_phase_separation_energy()
-
def get_pd_for_entry(self, entry: Entry | Composition) -> PhaseDiagram:
"""Get the possible phase diagrams for an entry.
@@ -1952,7 +1981,14 @@ class ReactionDiagram:
an electrolyte and an electrode.
"""
- def __init__(self, entry1, entry2, all_entries, tol: float = 1e-4, float_fmt="%.4f"):
+ def __init__(
+ self,
+ entry1: ComputedEntry,
+ entry2: ComputedEntry,
+ all_entries: list[ComputedEntry],
+ tol: float = 1e-4,
+ float_fmt: str = "%.4f",
+ ) -> None:
"""
Args:
entry1 (ComputedEntry): Entry for 1st component. Note that
@@ -2008,7 +2044,7 @@ def fmt(fl):
try:
mat = [[entry.composition.get_atomic_fraction(el) for el in elements] for entry in face_entries]
mat.append(comp_vec2 - comp_vec1)
- matrix = np.array(mat).T
+ matrix = np.asarray(mat).T
coeffs = np.linalg.solve(matrix, comp_vec2)
x = coeffs[-1]
@@ -2071,7 +2107,7 @@ def fmt(fl):
self.all_entries = all_entries
self.pd = pd
- def get_compound_pd(self):
+ def get_compound_pd(self) -> CompoundPhaseDiagram:
"""Get the CompoundPhaseDiagram object, which can then be used for
plotting.
@@ -2117,11 +2153,11 @@ def get_facets(qhull_data: ArrayLike, joggle: bool = False) -> ConvexHull:
def _get_slsqp_decomp(
- comp,
- competing_entries,
- tols=(1e-8,),
- maxiter=1000,
-):
+ comp: Composition,
+ competing_entries: Sequence[PDEntry],
+ tols: Sequence[float] = (1e-8,),
+ maxiter: int = 1000,
+) -> dict:
"""Find the amounts of competing compositions that minimize the energy of a
given composition.
@@ -2260,20 +2296,22 @@ def get_plot(
self,
label_stable: bool = True,
label_unstable: bool = True,
- ordering: Sequence[str] | None = None,
- energy_colormap=None,
+ # `matplotlib` only
+ ordering: Sequence[Literal["Up", "Left", "Right"]] | None = None,
+ energy_colormap: str | Colormap | None = None,
process_attributes: bool = False,
- ax: plt.Axes = None,
+ ax: plt.Axes | None = None,
+ # `plotly` only
label_uncertainties: bool = False,
fill: bool = True,
highlight_entries: Collection[PDEntry] | None = None,
- ) -> go.Figure | plt.Axes:
+ ) -> go.Figure | plt.Axes | None:
"""
Args:
label_stable: Whether to label stable compounds.
label_unstable: Whether to label unstable compounds.
- ordering: Ordering of vertices, given as a list ['Up',
- 'Left','Right'] (matplotlib only).
+ ordering: Ordering of vertices, given as a list ["Up",
+ "Left", "Right"] (matplotlib only).
energy_colormap: Colormap for coloring energy (matplotlib only).
process_attributes: Whether to process the attributes (matplotlib only).
ax: Existing matplotlib Axes object if plotting multiple phase diagrams
@@ -2290,10 +2328,14 @@ def get_plot(
Returns:
go.Figure | plt.Axes: Plotly figure or matplotlib axes object depending on backend.
"""
- fig = None
- data = []
+ if self._dim not in {1, 2, 3, 4}:
+ raise ValueError(
+ f"Plotting is only supported for unary/binary/ternary/quaternary phase diagrams — got {self._dim}D "
+ )
if self.backend == "plotly":
+ data: list = []
+
if self._dim != 1:
data.append(self._create_plotly_lines())
@@ -2311,7 +2353,7 @@ def get_plot(
if self._dim != 1 and not (self._dim == 3 and self.ternary_style == "2d"):
data.append(self._create_plotly_stable_labels(label_stable))
- if fill and self._dim in [3, 4]:
+ if fill and self._dim in {3, 4}:
data.extend(self._create_plotly_fill())
data.extend([stable_marker_plot, unstable_marker_plot])
@@ -2323,9 +2365,11 @@ def get_plot(
fig.layout = self._create_plotly_figure_layout()
fig.update_layout(coloraxis_colorbar={"yanchor": "top", "y": 0.05, "x": 1})
- elif self.backend == "matplotlib":
- if self._dim <= 3:
- fig = self._get_matplotlib_2d_plot(
+ return fig
+
+ if self.backend == "matplotlib":
+ if self._dim in {1, 2, 3}:
+ return self._get_matplotlib_2d_plot(
label_stable,
label_unstable,
ordering,
@@ -2333,10 +2377,10 @@ def get_plot(
ax=ax,
process_attributes=process_attributes,
)
- elif self._dim == 4:
- fig = self._get_matplotlib_3d_plot(label_stable, ax=ax)
+ if self._dim == 4:
+ return self._get_matplotlib_3d_plot(label_stable, ax=ax)
- return fig
+ return None
def show(self, *args, **kwargs) -> None:
"""
@@ -2373,7 +2417,13 @@ def write_image(self, stream: str | StringIO, image_format: str = "svg", **kwarg
fig = self.get_plot(**kwargs)
fig.write_image(stream, format=image_format)
- def plot_element_profile(self, element, comp, show_label_index=None, xlim=5):
+ def plot_element_profile(
+ self,
+ element: Element,
+ comp: Composition,
+ show_label_index: list[int] | None = None,
+ xlim: float = 5,
+ ) -> plt.Axes:
"""
Draw the element profile plot for a composition varying different
chemical potential of an element.
@@ -2438,7 +2488,7 @@ def plot_element_profile(self, element, comp, show_label_index=None, xlim=5):
return ax
- def plot_chempot_range_map(self, elements, referenced=True) -> None:
+ def plot_chempot_range_map(self, elements: Sequence[Element], referenced: bool = True) -> None:
"""
Plot the chemical potential range _map using matplotlib. Currently works only for
3-component PDs. This shows the plot but does not return it.
@@ -2456,7 +2506,7 @@ class (pymatgen.analysis.chempot_diagram).
"""
self.get_chempot_range_map_plot(elements, referenced=referenced).show()
- def get_chempot_range_map_plot(self, elements, referenced=True):
+ def get_chempot_range_map_plot(self, elements: Sequence[Element], referenced: bool = True) -> plt.Axes:
"""Get a plot of the chemical potential range _map. Currently works
only for 3-component PDs.
@@ -2482,7 +2532,7 @@ class (pymatgen.analysis.chempot_diagram).
for entry, lines in chempot_ranges.items():
comp = entry.composition
center_x = center_y = 0
- coords = []
+ coords: list[list] = []
contain_zero = any(comp.get_atomic_fraction(el) == 0 for el in elements)
is_boundary = (not contain_zero) and sum(comp.get_atomic_fraction(el) for el in elements) == 1
for line in lines:
@@ -2560,7 +2610,7 @@ class (pymatgen.analysis.chempot_diagram).
plt.tight_layout()
return ax
- def get_contour_pd_plot(self):
+ def get_contour_pd_plot(self) -> plt.Axes:
"""
Plot a contour phase diagram plot, where phase triangles are colored
according to degree of instability by interpolation. Currently only
@@ -2571,7 +2621,7 @@ def get_contour_pd_plot(self):
"""
pd = self._pd
entries = pd.qhull_entries
- data = np.array(pd.qhull_data)
+ data = np.asarray(pd.qhull_data)
ax = self._get_matplotlib_2d_plot()
data[:, 0:2] = triangular_coord(data[:, 0:2]).transpose()
@@ -2596,7 +2646,7 @@ def get_contour_pd_plot(self):
@property
@lru_cache(1) # noqa: B019
- def pd_plot_data(self):
+ def pd_plot_data(self) -> tuple[list, dict, dict]:
"""
Plotting data for phase diagram. Cached for repetitive calls.
@@ -2605,18 +2655,18 @@ def pd_plot_data(self):
Returns:
A tuple containing three objects (lines, stable_entries, unstable_entries):
- - lines is a list of list of coordinates for lines in the PD.
- - stable_entries is a dict of {coordinates : entry} for each stable node
- in the phase diagram. (Each coordinate can only have one
- stable phase)
- - unstable_entries is a dict of {entry: coordinates} for all unstable
- nodes in the phase diagram.
+ - lines: a list of list of coordinates for lines in the PD.
+ - stable_entries: a dict of {coordinates: entry} for each stable node
+ in the phase diagram. (Each coordinate can only have one
+ stable phase)
+ - unstable_entries: a dict of {entry: coordinates} for all unstable
+ nodes in the phase diagram.
"""
pd = self._pd
entries = pd.qhull_entries
- data = np.array(pd.qhull_data)
- lines = []
- stable_entries = {}
+ data = np.asarray(pd.qhull_data)
+ lines: list = []
+ stable_entries: dict = {}
for line in self.lines:
entry1 = entries[line[0]]
@@ -2638,7 +2688,7 @@ def pd_plot_data(self):
stable_entries[label_coord[1]] = entry2
all_entries = pd.all_entries
- all_data = np.array(pd.all_entries_hulldata)
+ all_data = np.asarray(pd.all_entries_hulldata)
unstable_entries = {}
stable = pd.stable_entries
@@ -2660,7 +2710,7 @@ def pd_plot_data(self):
return lines, stable_entries, unstable_entries
- def _create_plotly_figure_layout(self, label_stable=True):
+ def _create_plotly_figure_layout(self, label_stable: bool = True) -> dict[str, Any]:
"""
Creates layout for plotly phase diagram figure and updates with
figure annotations.
@@ -2704,7 +2754,7 @@ def _create_plotly_figure_layout(self, label_stable=True):
return layout
- def _create_plotly_lines(self):
+ def _create_plotly_lines(self) -> go.Scatter | go.Scatterternary | go.Scatter3d | None:
"""
Create Plotly scatter plots containing line traces of phase diagram facets.
@@ -2712,12 +2762,12 @@ def _create_plotly_lines(self):
Either a go.Scatter (binary), go.Scatterternary (ternary_2d), or
go.Scatter3d plot (ternary_3d, quaternary)
"""
- line_plot = None
+
x, y, z, energies = [], [], [], []
pd = self._pd
- plot_args = {
+ plot_args: dict[str, Any] = {
"mode": "lines",
"hoverinfo": "none",
"line": {"color": "black", "width": 4.0},
@@ -2755,18 +2805,21 @@ def _create_plotly_lines(self):
z += [*line[2], None]
if self._dim == 2:
- line_plot = go.Scatter(x=x, y=y, **plot_args)
- elif self._dim == 3 and self.ternary_style == "2d":
- line_plot = go.Scatterternary(a=x, b=y, c=z, **plot_args)
- elif self._dim == 3 and self.ternary_style == "3d":
- line_plot = go.Scatter3d(x=y, y=x, z=z, **plot_args)
- elif self._dim == 4:
+ return go.Scatter(x=x, y=y, **plot_args)
+
+ if self._dim == 3:
+ if self.ternary_style == "2d":
+ return go.Scatterternary(a=x, b=y, c=z, **plot_args)
+ if self.ternary_style == "3d":
+ return go.Scatter3d(x=y, y=x, z=z, **plot_args)
+
+ if self._dim == 4:
plot_args["line"]["width"] = 1.5
- line_plot = go.Scatter3d(x=x, y=y, z=z, **plot_args)
+ return go.Scatter3d(x=x, y=y, z=z, **plot_args)
- return line_plot
+ return None
- def _create_plotly_fill(self):
+ def _create_plotly_fill(self) -> list[go.Mesh3d]:
"""
Creates shaded mesh traces for coloring the hull.
@@ -2819,7 +2872,7 @@ def _create_plotly_fill(self):
)
]
elif self._dim == 3 and self.ternary_style == "3d":
- facets = np.array(self._pd.facets)
+ facets = np.asarray(self._pd.facets)
coords = np.array(
[
triangular_coord(c)
@@ -2862,7 +2915,7 @@ def _create_plotly_fill(self):
)
)
elif self._dim == 4:
- all_data = np.array(pd.qhull_data)
+ all_data = np.asarray(pd.qhull_data)
fillcolors = itertools.cycle(plotly_layouts["default_fill_colors"])
for _idx, facet in enumerate(pd.facets):
xs, ys, zs = [], [], []
@@ -2905,7 +2958,7 @@ def _create_plotly_fill(self):
return traces
- def _create_plotly_stable_labels(self, label_stable=True):
+ def _create_plotly_stable_labels(self, label_stable: bool = True) -> go.Scatter | go.Scatter3d:
"""
Creates a (hidable) scatter trace containing labels of stable phases.
Contains some functionality for creating sensible label positions. This method
@@ -2968,7 +3021,7 @@ def _create_plotly_stable_labels(self, label_stable=True):
formula = comp.reduced_formula
text.append(htmlify(formula))
- visible = True
+ visible: str | bool = True
if not label_stable or self._dim == 4:
visible = "legendonly"
@@ -2992,7 +3045,7 @@ def _create_plotly_stable_labels(self, label_stable=True):
return stable_labels_plot
- def _create_plotly_element_annotations(self):
+ def _create_plotly_element_annotations(self) -> list[dict] | None:
"""
Creates terminal element annotations for Plotly phase diagrams. This method does
not apply to ternary_2d plots.
@@ -3069,20 +3122,29 @@ def _create_plotly_element_annotations(self):
return annotations_list
- def _create_plotly_markers(self, highlight_entries=None, label_uncertainties=False):
+ def _create_plotly_markers(
+ self,
+ highlight_entries: Collection[PDEntry] | None = None,
+ label_uncertainties: bool = False,
+ ) -> tuple:
"""
- Creates stable and unstable marker plots for overlaying on the phase diagram.
+ Creates stable, unstable and highlight marker plots for overlaying on the phase diagram.
Returns:
tuple[go.Scatter]: Plotly Scatter objects (unary, binary), go.Scatterternary(ternary_2d),
- or go.Scatter3d (ternary_3d, quaternary) objects in order: (stable markers, unstable markers)
+ or go.Scatter3d (ternary_3d, quaternary) objects in order: (stable, unstable and highlight markers)
"""
- def get_marker_props(coords, entries):
+ def get_marker_props(coords, entries) -> dict[str, Any]:
"""Get marker locations, hovertext, and error bars from pd_plot_data."""
- x, y, z, texts, energies, uncertainties = [], [], [], [], [], []
-
- is_stable = [entry in self._pd.stable_entries for entry in entries]
+ x: list[float] = []
+ y: list[float] = []
+ z: list[float] = []
+ texts: list[str] = []
+ energies: list[float] = []
+ uncertainties: list[float] = []
+
+ is_stable: list[bool] = [entry in self._pd.stable_entries for entry in entries]
for coord, entry, stable in zip(coords, entries, is_stable, strict=True):
energy = round(self._pd.get_form_energy_per_atom(entry), 3)
@@ -3097,10 +3159,12 @@ def get_marker_props(coords, entries):
formula = comp.reduced_formula
clean_formula = htmlify(formula)
label = f"{clean_formula} ({entry_id})
Formation energy: {energy} eV/atom
"
+
if not stable:
- e_above_hull = round(self._pd.get_e_above_hull(entry), 3)
- if e_above_hull > self.show_unstable:
+ e_above_hull = self._pd.get_e_above_hull(entry)
+ if e_above_hull is None or e_above_hull > self.show_unstable:
continue
+ e_above_hull = round(e_above_hull, 3)
label += f" Energy Above Hull: ({e_above_hull:+} eV/atom)"
energies.append(e_above_hull)
else:
@@ -3121,6 +3185,7 @@ def get_marker_props(coords, entries):
_cartesian_positions = [x, y, z]
_cartesian_positions[axis].append(entry.composition[el])
label += f"
{el}: {round(entry.composition[el] / total_sum_el, 6)}"
+
elif self._dim == 3 and self.ternary_style == "3d":
x.append(coord[0])
y.append(coord[1])
@@ -3132,6 +3197,7 @@ def get_marker_props(coords, entries):
)
for el, _axis in zip(self._pd.elements, range(self._dim), strict=True):
label += f"
{el}: {round(entry.composition[el] / total_sum_el, 6)}"
+
elif self._dim == 4:
x.append(coord[0])
y.append(coord[1])
@@ -3143,6 +3209,7 @@ def get_marker_props(coords, entries):
)
for el, _axis in zip(self._pd.elements, range(self._dim), strict=True):
label += f"
{el}: {round(entry.composition[el] / total_sum_el, 6)}"
+
else:
x.append(coord[0])
y.append(coord[1])
@@ -3161,10 +3228,13 @@ def get_marker_props(coords, entries):
if highlight_entries is None:
highlight_entries = []
- stable_coords, stable_entries = [], []
- unstable_coords, unstable_entries = [], []
- highlight_coords, highlight_ents = [], []
+ stable_coords: list[Sequence[float]] = []
+ highlight_coords: list[Sequence[float]] = []
+ stable_entries: list[PDEntry] = []
+ highlight_ents: list[PDEntry] = []
+
+ # Stable entries
for coord, entry in zip(self.pd_plot_data[1], self.pd_plot_data[1].values(), strict=True):
if entry in highlight_entries:
highlight_coords.append(coord)
@@ -3173,13 +3243,23 @@ def get_marker_props(coords, entries):
stable_coords.append(coord)
stable_entries.append(entry)
+ # Unstable entries (lowest energy only per composition)
+ min_unstable: dict[str, tuple[Sequence[float], PDEntry]] = {}
+
for coord, entry in zip(self.pd_plot_data[2].values(), self.pd_plot_data[2], strict=True):
if entry in highlight_entries:
highlight_coords.append(coord)
highlight_ents.append(entry)
- else:
- unstable_coords.append(coord)
- unstable_entries.append(entry)
+ continue
+
+ formula = entry.composition.reduced_formula
+ e_above_hull = self._pd.get_e_above_hull(entry)
+
+ if formula not in min_unstable or e_above_hull < self._pd.get_e_above_hull(min_unstable[formula][1]):
+ min_unstable[formula] = (coord, entry)
+
+ unstable_coords = [coord for coord, _ in min_unstable.values()]
+ unstable_entries = [entry for _, entry in min_unstable.values()]
stable_props = get_marker_props(stable_coords, stable_entries)
unstable_props = get_marker_props(unstable_coords, unstable_entries)
@@ -3502,29 +3582,31 @@ def get_marker_props(coords, entries):
highlight_marker_plot = None
- if self._dim in [1, 2]:
+ if self._dim in {1, 2}:
stable_marker_plot, unstable_marker_plot = (
- go.Scatter(**markers) for markers in [stable_markers, unstable_markers]
+ go.Scatter(**markers) for markers in (stable_markers, unstable_markers)
)
if highlight_entries:
highlight_marker_plot = go.Scatter(**highlight_markers)
+
elif self._dim == 3 and self.ternary_style == "2d":
stable_marker_plot, unstable_marker_plot = (
- go.Scatterternary(**markers) for markers in [stable_markers, unstable_markers]
+ go.Scatterternary(**markers) for markers in (stable_markers, unstable_markers)
)
if highlight_entries:
highlight_marker_plot = go.Scatterternary(**highlight_markers)
+
else:
stable_marker_plot, unstable_marker_plot = (
- go.Scatter3d(**markers) for markers in [stable_markers, unstable_markers]
+ go.Scatter3d(**markers) for markers in (stable_markers, unstable_markers)
)
if highlight_entries:
highlight_marker_plot = go.Scatter3d(**highlight_markers)
return stable_marker_plot, unstable_marker_plot, highlight_marker_plot
- def _create_plotly_uncertainty_shading(self, stable_marker_plot):
+ def _create_plotly_uncertainty_shading(self, stable_marker_plot: go.Scatter) -> go.Scatter:
"""
Creates shaded uncertainty region for stable entries. Currently only works
for binary (dim=2) phase diagrams.
@@ -3555,7 +3637,7 @@ def _create_plotly_uncertainty_shading(self, stable_marker_plot):
outline = points[:, :2].copy()
outline[:, 1] += points[:, 2]
- last = -1
+ last: int | None = -1
if transformed:
last = None # allows for uncertainty in terminal compounds
@@ -3577,7 +3659,7 @@ def _create_plotly_uncertainty_shading(self, stable_marker_plot):
return uncertainty_plot
- def _create_plotly_ternary_support_lines(self):
+ def _create_plotly_ternary_support_lines(self) -> go.Scatter3d:
"""
Creates support lines which aid in seeing the ternary hull in three
dimensions.
@@ -3585,7 +3667,7 @@ def _create_plotly_ternary_support_lines(self):
Returns:
go.Scatter3d plot of support lines for ternary phase diagram.
"""
- stable_entry_coords = dict(map(reversed, self.pd_plot_data[1].items()))
+ stable_entry_coords: dict = {v: k for k, v in self.pd_plot_data[1].items()}
elem_coords = [stable_entry_coords[entry] for entry in self._pd.el_refs.values()]
@@ -3614,20 +3696,17 @@ def _create_plotly_ternary_support_lines(self):
def _get_matplotlib_2d_plot(
self,
- label_stable=True,
- label_unstable=True,
- ordering=None,
- energy_colormap=None,
- vmin_mev=-60.0,
- vmax_mev=60.0,
- show_colorbar=True,
- process_attributes=False,
- ax: plt.Axes = None,
- ):
- """Show the plot using matplotlib.
-
- Imports are done within the function as matplotlib is no longer the default.
- """
+ label_stable: bool = True,
+ label_unstable: bool = True,
+ ordering: Sequence[Literal["Up", "Left", "Right"]] | None = None,
+ energy_colormap: str | Colormap | None = None,
+ vmin_mev: float = -60.0,
+ vmax_mev: float = 60.0,
+ show_colorbar: bool = True,
+ process_attributes: bool = False,
+ ax: plt.Axes | None = None,
+ ) -> plt.Axes:
+ """Show the plot using matplotlib."""
ax = ax or pretty_plot(8, 6)
if ordering is None:
@@ -3727,7 +3806,7 @@ def _get_matplotlib_2d_plot(
# The follow defines an offset for the annotation text emanating
# from the center of the PD. Results in fairly nice layouts for the
# most part.
- vec = np.array(coords) - center
+ vec = np.asarray(coords) - center
vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 else vec
valign = "bottom" if vec[1] > 0 else "top"
if vec[0] < -0.01:
@@ -3773,7 +3852,7 @@ def _get_matplotlib_2d_plot(
for entry, coords in unstable.items():
ehull = self._pd.get_e_above_hull(entry)
if ehull is not None and ehull < self.show_unstable:
- vec = np.array(coords) - center
+ vec = np.asarray(coords) - center
vec = vec / np.linalg.norm(vec) * 10 if np.linalg.norm(vec) != 0 else vec
label = entry.name
if energy_colormap is None:
@@ -3823,7 +3902,11 @@ def _get_matplotlib_2d_plot(
plt.subplots_adjust(left=0.09, right=0.98, top=0.98, bottom=0.07)
return ax
- def _get_matplotlib_3d_plot(self, label_stable=True, ax: plt.Axes = None):
+ def _get_matplotlib_3d_plot(
+ self,
+ label_stable: bool = True,
+ ax: plt.Axes | None = None,
+ ) -> plt.Axes:
"""Show the plot using matplotlib.
Args:
@@ -3866,23 +3949,23 @@ def _get_matplotlib_3d_plot(self, label_stable=True, ax: plt.Axes = None):
return ax
-def uniquelines(q):
+def uniquelines(q: list[NDArray[int]]) -> set[tuple[int, int]]:
"""
Given all the facets, convert it into a set of unique lines. Specifically
used for converting convex hull facets into line pairs of coordinates.
Args:
q: A 2-dim sequence, where each row represents a facet. e.g.
- [[1,2,3],[3,6,7],...]
+ [[1, 2, 3], [3, 6, 7], ...]
Returns:
setoflines:
- A set of tuple of lines. e.g. ((1,2), (1,3), (2,3), ....)
+ A set of tuple of lines. e.g. ((1, 2), (1, 3), ...)
"""
return {tuple(sorted(line)) for facets in q for line in itertools.combinations(facets, 2)}
-def triangular_coord(coord):
+def triangular_coord(coord: ArrayLike) -> np.ndarray:
"""
Convert a 2D coordinate into a triangle-based coordinate system for a
prettier phase diagram.
@@ -3895,11 +3978,11 @@ def triangular_coord(coord):
"""
unit_vec = np.array([[1, 0], [0.5, math.sqrt(3) / 2]])
- result = np.dot(np.array(coord), unit_vec)
+ result = np.dot(np.asarray(coord), unit_vec)
return result.transpose()
-def tet_coord(coord):
+def tet_coord(coord: ArrayLike) -> np.ndarray:
"""
Convert a 3D coordinate into a tetrahedron based coordinate system for a
prettier phase diagram.
@@ -3917,11 +4000,16 @@ def tet_coord(coord):
[0.5, 1 / 3 * math.sqrt(3) / 2, math.sqrt(6) / 3],
]
)
- result = np.dot(np.array(coord), unitvec)
+ result = np.dot(np.asarray(coord), unitvec)
return result.transpose()
-def order_phase_diagram(lines, stable_entries, unstable_entries, ordering):
+def order_phase_diagram(
+ lines: list,
+ stable_entries: dict[Any, PDEntry],
+ unstable_entries: dict[PDEntry, Any],
+ ordering: Sequence[Literal["Up", "Left", "Right"]],
+) -> tuple[list, dict[Any, PDEntry], dict[PDEntry, Any]]:
"""
Orders the entries (their coordinates) in a phase diagram plot according
to the user specified ordering.
@@ -3940,12 +4028,11 @@ def order_phase_diagram(lines, stable_entries, unstable_entries, ordering):
Returns:
tuple[list, dict, dict]:
- - new_lines is a list of list of coordinates for lines in the PD.
- - new_stable_entries is a {coordinate: entry} for each stable node
+ - a list of list of coordinates for lines in the PD.
+ - a {coordinate: entry} for each stable node
in the phase diagram. (Each coordinate can only have one
stable phase)
- - new_unstable_entries is a {entry: coordinates} for all unstable
- nodes in the phase diagram.
+ - a {entry: coordinates} for all unstable nodes in the phase diagram.
"""
yup = -1000.0
xleft = 1000.0
@@ -3954,16 +4041,16 @@ def order_phase_diagram(lines, stable_entries, unstable_entries, ordering):
nameup = ""
nameleft = ""
nameright = ""
- for coord in stable_entries:
+ for coord, entry in stable_entries.items():
if coord[0] > xright:
xright = coord[0]
- nameright = stable_entries[coord].name
+ nameright = entry.name
if coord[0] < xleft:
xleft = coord[0]
- nameleft = stable_entries[coord].name
+ nameleft = entry.name
if coord[1] > yup:
yup = coord[1]
- nameup = stable_entries[coord].name
+ nameup = entry.name
if (nameup not in ordering) or (nameright not in ordering) or (nameleft not in ordering):
raise ValueError(
diff --git a/src/pymatgen/util/coord.py b/src/pymatgen/util/coord.py
index 2186eec84e3..038e3f490e6 100644
--- a/src/pymatgen/util/coord.py
+++ b/src/pymatgen/util/coord.py
@@ -367,14 +367,14 @@ class Simplex(MSONable):
simplex_dim (int): Dimension of the simplex coordinate space.
"""
- def __init__(self, coords) -> None:
+ def __init__(self, coords: Sequence[Sequence[float]]) -> None:
"""Initialize a Simplex from vertex coordinates.
Args:
coords ([[float]]): Coords of the vertices of the simplex. e.g.
[[1, 2, 3], [2, 4, 5], [6, 7, 8], [8, 9, 10].
"""
- self._coords = np.array(coords)
+ self._coords = np.asarray(coords)
self.space_dim, self.simplex_dim = self._coords.shape
self.origin = self._coords[-1]
if self.space_dim == self.simplex_dim + 1:
@@ -382,12 +382,25 @@ def __init__(self, coords) -> None:
self._aug = np.concatenate([coords, np.ones((self.space_dim, 1))], axis=-1)
self._aug_inv = np.linalg.inv(self._aug)
+ def __eq__(self, other: object) -> bool:
+ if not isinstance(other, Simplex):
+ return NotImplemented
+ return any(np.allclose(p, other.coords) for p in itertools.permutations(self._coords))
+
+ def __hash__(self) -> int:
+ return len(self._coords)
+
+ def __repr__(self) -> str:
+ output = [f"{self.simplex_dim}-simplex in {self.space_dim}D space\nVertices:"]
+ output += [f"\t({', '.join(map(str, coord))})" for coord in self._coords]
+ return "\n".join(output)
+
@property
def volume(self) -> float:
"""Volume of the simplex."""
return abs(np.linalg.det(self._aug)) / math.factorial(self.simplex_dim)
- def bary_coords(self, point):
+ def bary_coords(self, point: ArrayLike) -> np.ndarray:
"""
Args:
point (ArrayLike): Point coordinates.
@@ -400,7 +413,7 @@ def bary_coords(self, point):
except AttributeError as exc:
raise ValueError("Simplex is not full-dimensional") from exc
- def point_from_bary_coords(self, bary_coords: ArrayLike):
+ def point_from_bary_coords(self, bary_coords: ArrayLike) -> np.ndarray:
"""
Args:
bary_coords (ArrayLike): Barycentric coordinates (d+1, d).
@@ -428,9 +441,14 @@ def in_simplex(self, point: Sequence[float], tolerance: float = 1e-8) -> bool:
point (list[float]): Point to test
tolerance (float): Tolerance to test if point is in simplex.
"""
- return (self.bary_coords(point) >= -tolerance).all()
-
- def line_intersection(self, point1: Sequence[float], point2: Sequence[float], tolerance: float = 1e-8):
+ return bool((self.bary_coords(point) >= -tolerance).all())
+
+ def line_intersection(
+ self,
+ point1: Sequence[float],
+ point2: Sequence[float],
+ tolerance: float = 1e-8,
+ ) -> list[np.ndarray]:
"""Compute the intersection points of a line with a simplex.
Args:
@@ -465,19 +483,6 @@ def line_intersection(self, point1: Sequence[float], point2: Sequence[float], to
raise ValueError("More than 2 intersections found")
return [self.point_from_bary_coords(b) for b in barys]
- def __eq__(self, other: object) -> bool:
- if not isinstance(other, Simplex):
- return NotImplemented
- return any(np.allclose(p, other.coords) for p in itertools.permutations(self._coords))
-
- def __hash__(self) -> int:
- return len(self._coords)
-
- def __repr__(self) -> str:
- output = [f"{self.simplex_dim}-simplex in {self.space_dim}D space\nVertices:"]
- output += [f"\t({', '.join(map(str, coord))})" for coord in self._coords]
- return "\n".join(output)
-
@property
def coords(self) -> np.ndarray:
"""A copy of the vertex coordinates in the simplex."""