Skip to content

Commit 5b7c0e0

Browse files
committed
typing: add some typing to phase diagram code
1 parent 20b68ff commit 5b7c0e0

File tree

6 files changed

+89
-56
lines changed

6 files changed

+89
-56
lines changed

src/pymatgen/analysis/phase_diagram.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111
from collections import defaultdict
1212
from functools import lru_cache
13-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, Any
1414

1515
import matplotlib.pyplot as plt
1616
import numpy as np
@@ -43,6 +43,8 @@
4343
from numpy.typing import ArrayLike
4444
from typing_extensions import Self
4545

46+
from pymatgen.util.typing import CompositionLike, EntryLike
47+
4648
logger = logging.getLogger(__name__)
4749

4850
with open(
@@ -69,14 +71,14 @@ class PDEntry(Entry):
6971

7072
def __init__(
7173
self,
72-
composition: Composition,
74+
composition: CompositionLike,
7375
energy: float,
7476
name: str | None = None,
7577
attribute: object = None,
7678
):
7779
"""
7880
Args:
79-
composition (Composition): Composition
81+
composition (CompositionLike): Composition
8082
energy (float): Energy for composition.
8183
name (str): Optional parameter to name the entry. Defaults
8284
to the reduced chemical formula.
@@ -341,8 +343,8 @@ class PhaseDiagram(MSONable):
341343

342344
def __init__(
343345
self,
344-
entries: Sequence[PDEntry] | set[PDEntry],
345-
elements: Sequence[Element] = (),
346+
entries: Collection[EntryLike],
347+
elements: Collection[Element] = (),
346348
*,
347349
computed_data: dict[str, Any] | None = None,
348350
) -> None:
@@ -435,9 +437,9 @@ def _compute(self) -> dict[str, Any]:
435437

436438
entries = sorted(self.entries, key=lambda e: e.composition.reduced_composition)
437439

438-
el_refs: dict[Element, PDEntry] = {}
439-
min_entries: list[PDEntry] = []
440-
all_entries: list[PDEntry] = []
440+
el_refs: dict[Element, EntryLike] = {}
441+
min_entries: list[EntryLike] = []
442+
all_entries: list[EntryLike] = []
441443
for composition, group_iter in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
442444
group = list(group_iter)
443445
min_entry = min(group, key=lambda e: e.energy_per_atom)
@@ -980,7 +982,7 @@ def get_phase_separation_energy(self, entry, **kwargs):
980982
"""
981983
return self.get_decomp_and_phase_separation_energy(entry, **kwargs)[1]
982984

983-
def get_composition_chempots(self, comp):
985+
def get_composition_chempots(self, comp: Composition) -> dict[Element, float]:
984986
"""Get the chemical potentials for all elements at a given composition.
985987
986988
Args:
@@ -992,7 +994,7 @@ def get_composition_chempots(self, comp):
992994
facet = self._get_facet_and_simplex(comp)[0]
993995
return self._get_facet_chempots(facet)
994996

995-
def get_all_chempots(self, comp):
997+
def get_all_chempots(self, comp: Composition) -> dict[str, dict[Element, float]]:
996998
"""Get chemical potentials at a given composition.
997999
9981000
Args:
@@ -1010,7 +1012,7 @@ def get_all_chempots(self, comp):
10101012

10111013
return chempots
10121014

1013-
def get_transition_chempots(self, element):
1015+
def get_transition_chempots(self, element: Element) -> tuple[float, ...]:
10141016
"""Get the critical chemical potentials for an element in the Phase
10151017
Diagram.
10161018
@@ -1029,7 +1031,7 @@ def get_transition_chempots(self, element):
10291031
chempots = self._get_facet_chempots(facet)
10301032
critical_chempots.append(chempots[element])
10311033

1032-
clean_pots = []
1034+
clean_pots: list[float] = []
10331035
for c in sorted(critical_chempots):
10341036
if len(clean_pots) == 0 or not math.isclose(
10351037
c, clean_pots[-1], abs_tol=PhaseDiagram.numerical_tol, rel_tol=0
@@ -1038,7 +1040,7 @@ def get_transition_chempots(self, element):
10381040
clean_pots.reverse()
10391041
return tuple(clean_pots)
10401042

1041-
def get_critical_compositions(self, comp1, comp2):
1043+
def get_critical_compositions(self, comp1: Composition, comp2: Composition) -> list[Composition]:
10421044
"""Get the critical compositions along the tieline between two
10431045
compositions. I.e. where the decomposition products change.
10441046
The endpoints are also returned.
@@ -1098,7 +1100,7 @@ def get_critical_compositions(self, comp1, comp2):
10981100

10991101
return [Composition((elem, val) for elem, val in zip(pd_els, m, strict=True)) for m in cs]
11001102

1101-
def get_element_profile(self, element, comp, comp_tol=1e-5):
1103+
def get_element_profile(self, element: Element, comp: Composition, comp_tol: float = 1e-5) -> list[dict[str, Any]]:
11021104
"""
11031105
Provides the element evolution data for a composition. For example, can be used
11041106
to analyze Li conversion voltages by varying mu_Li and looking at the phases
@@ -1199,7 +1201,9 @@ def get_chempot_range_map(
11991201

12001202
return chempot_ranges
12011203

1202-
def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
1204+
def getmu_vertices_stability_phase(
1205+
self, target_comp: Composition, dep_elt: Element, tol_en: float = 1e-2
1206+
) -> list[dict[Element, float]] | None:
12031207
"""Get a set of chemical potentials corresponding to the vertices of
12041208
the simplex in the chemical potential phase diagram.
12051209
The simplex is built using all elements in the target_composition
@@ -1233,11 +1237,11 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
12331237
if elem.composition.reduced_composition == target_comp.reduced_composition:
12341238
multiplier = elem.composition[dep_elt] / target_comp[dep_elt]
12351239
ef = elem.energy / multiplier
1236-
all_coords = []
1240+
all_coords: list[dict[Element, float]] = []
12371241
for simplex in chempots:
12381242
for v in simplex._coords:
12391243
elements = [elem for elem in self.elements if elem != dep_elt]
1240-
res = {}
1244+
res: dict[Element, float] = {}
12411245
for idx, el in enumerate(elements):
12421246
res[el] = v[idx] + mu_ref[idx]
12431247
res[dep_elt] = (np.dot(v + mu_ref, coeff) + ef) / target_comp[dep_elt]
@@ -1257,7 +1261,9 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
12571261
return all_coords
12581262
return None
12591263

1260-
def get_chempot_range_stability_phase(self, target_comp, open_elt):
1264+
def get_chempot_range_stability_phase(
1265+
self, target_comp: Composition, open_elt: Element
1266+
) -> dict[Element, tuple[float, float]]:
12611267
"""Get a set of chemical potentials corresponding to the max and min
12621268
chemical potential of the open element for a given composition. It is
12631269
quite common to have for instance a ternary oxide (e.g., ABO3) for
@@ -1408,18 +1414,25 @@ class GrandPotentialPhaseDiagram(PhaseDiagram):
14081414
doi:10.1016/j.elecom.2010.01.010
14091415
"""
14101416

1411-
def __init__(self, entries, chempots, elements=None, *, computed_data=None):
1417+
def __init__(
1418+
self,
1419+
entries: Collection[EntryLike],
1420+
chempots: dict[Element, float],
1421+
elements: Collection[Element] | None = None,
1422+
*,
1423+
computed_data: dict[str, Any] | None = None,
1424+
):
14121425
"""Standard constructor for grand potential phase diagram.
14131426
14141427
Args:
1415-
entries ([PDEntry]): A list of PDEntry-like objects having an
1428+
entries (Sequence[EntryLike]): A list of EntryLike objects having an
14161429
energy, energy_per_atom and composition.
1417-
chempots ({Element: float}): Specify the chemical potentials
1430+
chempots (dict[Element, float]): Specify the chemical potentials
14181431
of the open elements.
1419-
elements ([Element]): Optional list of elements in the phase
1432+
elements (Sequence[Element]): Optional list of elements in the phase
14201433
diagram. If set to None, the elements are determined from
14211434
the entries themselves.
1422-
computed_data (dict): A dict containing pre-computed data. This allows
1435+
computed_data (dict[str, Any]): A dict containing pre-computed data. This allows
14231436
PhaseDiagram object to be reconstituted without performing the
14241437
expensive convex hull computation. The dict is the output from the
14251438
PhaseDiagram._compute() method and is stored in PhaseDiagram.computed_data
@@ -1481,7 +1494,12 @@ class CompoundPhaseDiagram(PhaseDiagram):
14811494
# Tolerance for determining if amount of a composition is positive.
14821495
amount_tol = 1e-5
14831496

1484-
def __init__(self, entries, terminal_compositions, normalize_terminal_compositions=True):
1497+
def __init__(
1498+
self,
1499+
entries: Sequence[EntryLike],
1500+
terminal_compositions: Sequence[Composition],
1501+
normalize_terminal_compositions: bool = True,
1502+
):
14851503
"""Initialize a CompoundPhaseDiagram.
14861504
14871505
Args:
@@ -1532,7 +1550,9 @@ def num2str(num):
15321550

15331551
return ret
15341552

1535-
def transform_entries(self, entries, terminal_compositions):
1553+
def transform_entries(
1554+
self, entries: Sequence[EntryLike], terminal_compositions: Sequence[Composition]
1555+
) -> tuple[list[TransformedPDEntry], dict[Composition, DummySpecies]]:
15361556
"""
15371557
Method to transform all entries to the composition coordinate in the
15381558
terminal compositions. If the entry does not fall within the space
@@ -1624,16 +1644,16 @@ class PatchedPhaseDiagram(PhaseDiagram):
16241644

16251645
def __init__(
16261646
self,
1627-
entries: Sequence[PDEntry] | set[PDEntry],
1647+
entries: Sequence[EntryLike] | set[EntryLike],
16281648
elements: Sequence[Element] | None = None,
16291649
keep_all_spaces: bool = False,
16301650
verbose: bool = False,
16311651
) -> None:
16321652
"""
16331653
Args:
1634-
entries (list[PDEntry]): A list of PDEntry-like objects having an
1654+
entries (Sequence[EntryLike] | set[EntryLike]): A list of EntryLike objects having an
16351655
energy, energy_per_atom and composition.
1636-
elements (list[Element], optional): Optional list of elements in the phase
1656+
elements (Sequence[Element], optional): Optional list of elements in the phase
16371657
diagram. If set to None, the elements are determined from
16381658
the entries themselves and are sorted alphabetically.
16391659
If specified, element ordering (e.g. for pd coordinates)
@@ -1649,9 +1669,9 @@ def __init__(
16491669

16501670
entries = sorted(entries, key=lambda e: e.composition.reduced_composition)
16511671

1652-
el_refs: dict[Element, PDEntry] = {}
1653-
min_entries = []
1654-
all_entries: list[PDEntry] = []
1672+
el_refs: dict[Element, EntryLike] = {}
1673+
min_entries: list[EntryLike] = []
1674+
all_entries: list[EntryLike] = []
16551675
for composition, group_iter in itertools.groupby(entries, key=lambda e: e.composition.reduced_composition):
16561676
group = list(group_iter)
16571677
min_entry = min(group, key=lambda e: e.energy_per_atom)
@@ -1781,17 +1801,19 @@ def from_dict(cls, dct: dict) -> Self:
17811801
return cls(entries, elements)
17821802

17831803
@staticmethod
1784-
def remove_redundant_spaces(spaces, keep_all_spaces=False):
1804+
def remove_redundant_spaces(
1805+
spaces: set[frozenset[Element]], keep_all_spaces: bool = False
1806+
) -> set[frozenset[Element]]:
17851807
if keep_all_spaces or len(spaces) <= 1:
17861808
return spaces
17871809

17881810
# Sort spaces by size in descending order and pre-compute lengths
17891811
sorted_spaces = sorted(spaces, key=len, reverse=True)
17901812

1791-
result = []
1813+
result = set()
17921814
for idx, space_i in enumerate(sorted_spaces):
17931815
if not any(space_i.issubset(larger_space) for larger_space in sorted_spaces[:idx]):
1794-
result.append(space_i)
1816+
result.add(space_i)
17951817

17961818
return result
17971819

@@ -1808,7 +1830,7 @@ def remove_redundant_spaces(spaces, keep_all_spaces=False):
18081830
# get_decomp_and_phase_separation_energy(),
18091831
# get_phase_separation_energy()
18101832

1811-
def get_pd_for_entry(self, entry: Entry | Composition) -> PhaseDiagram:
1833+
def get_pd_for_entry(self, entry: EntryLike | Composition) -> PhaseDiagram:
18121834
"""Get the possible phase diagrams for an entry.
18131835
18141836
Args:

src/pymatgen/core/sites.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,10 +269,11 @@ def as_dict(self) -> dict:
269269
@classmethod
270270
def from_dict(cls, dct: dict) -> Self:
271271
"""Create Site from dict representation."""
272-
atoms_n_occu = {}
272+
atoms_n_occu: dict[SpeciesLike, float] = {}
273273
for sp_occu in dct["species"]:
274+
sp: SpeciesLike
274275
if "oxidation_state" in sp_occu and Element.is_valid_symbol(sp_occu["element"]):
275-
sp: Species | DummySpecies | Element = Species.from_dict(sp_occu)
276+
sp = Species.from_dict(sp_occu)
276277
elif "oxidation_state" in sp_occu:
277278
sp = DummySpecies.from_dict(sp_occu)
278279
else:
@@ -636,10 +637,11 @@ def from_dict(cls, dct: dict, lattice: Lattice | None = None) -> Self:
636637
Returns:
637638
PeriodicSite
638639
"""
639-
species = {}
640+
species: dict[SpeciesLike, float] = {}
640641
for sp_occu in dct["species"]:
642+
sp: SpeciesLike
641643
if "oxidation_state" in sp_occu and Element.is_valid_symbol(sp_occu["element"]):
642-
sp: Species | DummySpecies | Element = Species.from_dict(sp_occu)
644+
sp = Species.from_dict(sp_occu)
643645
elif "oxidation_state" in sp_occu:
644646
sp = DummySpecies.from_dict(sp_occu)
645647
else:

src/pymatgen/entries/__init__.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import Literal
2020

2121
from pymatgen.core import DummySpecies, Element, Species
22+
from pymatgen.util.typing import CompositionLike
2223

2324

2425
__author__ = "Shyue Ping Ong, Anubhav Jain, Ayush Gupta"
@@ -37,20 +38,19 @@ class Entry(MSONable, ABC):
3738
which inherit from Entry must define a .energy property.
3839
"""
3940

40-
def __init__(self, composition: Composition | str | dict[str, float], energy: float) -> None:
41+
def __init__(self, composition: CompositionLike, energy: float) -> None:
4142
"""Initialize an Entry.
4243
4344
Args:
44-
composition (Composition): Composition of the entry. For
45+
composition (CompositionLike): Composition of the entry. For
4546
flexibility, this can take the form of all the typical input taken by a
4647
Composition, including a {symbol: amt} dict, a string formula, and others.
4748
energy (float): Energy of the entry.
4849
"""
49-
if isinstance(composition, Composition):
50-
self._composition = composition
51-
else:
52-
self._composition = Composition(composition)
53-
# self._composition = Composition(composition)
50+
if not isinstance(composition, Composition):
51+
composition = Composition(composition)
52+
53+
self._composition = composition
5454
self._energy = energy
5555

5656
@property

src/pymatgen/entries/computed_entries.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
from pymatgen.analysis.phase_diagram import PhaseDiagram
3434
from pymatgen.core import Structure
35+
from pymatgen.util.typing import CompositionLike
3536

3637
__author__ = "Ryan Kingsbury, Matt McDermott, Shyue Ping Ong, Anubhav Jain"
3738
__copyright__ = "Copyright 2011-2020, The Materials Project"
@@ -292,7 +293,7 @@ class ComputedEntry(Entry):
292293

293294
def __init__(
294295
self,
295-
composition: Composition | str | dict[str, float],
296+
composition: CompositionLike,
296297
energy: float,
297298
correction: float = 0.0,
298299
energy_adjustments: list | None = None,
@@ -558,7 +559,7 @@ def __init__(
558559
structure: Structure,
559560
energy: float,
560561
correction: float = 0.0,
561-
composition: Composition | str | dict[str, float] | None = None,
562+
composition: CompositionLike | None = None,
562563
energy_adjustments: list | None = None,
563564
parameters: dict | None = None,
564565
data: dict | None = None,
@@ -585,12 +586,10 @@ def __init__(
585586
with the entry. Defaults to None.
586587
entry_id: An optional id to uniquely identify the entry.
587588
"""
588-
if composition:
589-
if isinstance(composition, Composition):
590-
pass
591-
else:
589+
if composition is not None:
590+
if not isinstance(composition, Composition):
592591
composition = Composition(composition)
593-
# composition = Composition(composition)
592+
594593
if (
595594
composition.get_integer_formula_and_factor()[0]
596595
!= structure.composition.get_integer_formula_and_factor()[0]
@@ -706,7 +705,7 @@ def __init__(
706705
formation_enthalpy_per_atom: float,
707706
temp: float = 300,
708707
gibbs_model: Literal["SISSO"] = "SISSO",
709-
composition: Composition | None = None,
708+
composition: CompositionLike | None = None,
710709
correction: float = 0.0,
711710
energy_adjustments: list | None = None,
712711
parameters: dict | None = None,

0 commit comments

Comments
 (0)