1010import warnings
1111from collections import defaultdict
1212from functools import lru_cache
13- from typing import TYPE_CHECKING
13+ from typing import TYPE_CHECKING , Any
1414
1515import matplotlib .pyplot as plt
1616import numpy as np
4343 from numpy .typing import ArrayLike
4444 from typing_extensions import Self
4545
46+ from pymatgen .util .typing import CompositionLike , EntryLike
47+
4648logger = logging .getLogger (__name__ )
4749
4850with open (
@@ -69,14 +71,14 @@ class PDEntry(Entry):
6971
7072 def __init__ (
7173 self ,
72- composition : Composition ,
74+ composition : CompositionLike ,
7375 energy : float ,
7476 name : str | None = None ,
7577 attribute : object = None ,
7678 ):
7779 """
7880 Args:
79- composition (Composition ): Composition
81+ composition (CompositionLike ): Composition
8082 energy (float): Energy for composition.
8183 name (str): Optional parameter to name the entry. Defaults
8284 to the reduced chemical formula.
@@ -341,8 +343,8 @@ class PhaseDiagram(MSONable):
341343
342344 def __init__ (
343345 self ,
344- entries : Sequence [ PDEntry ] | set [ PDEntry ],
345- elements : Sequence [Element ] = (),
346+ entries : Collection [ EntryLike ],
347+ elements : Collection [Element ] = (),
346348 * ,
347349 computed_data : dict [str , Any ] | None = None ,
348350 ) -> None :
@@ -435,9 +437,9 @@ def _compute(self) -> dict[str, Any]:
435437
436438 entries = sorted (self .entries , key = lambda e : e .composition .reduced_composition )
437439
438- el_refs : dict [Element , PDEntry ] = {}
439- min_entries : list [PDEntry ] = []
440- all_entries : list [PDEntry ] = []
440+ el_refs : dict [Element , EntryLike ] = {}
441+ min_entries : list [EntryLike ] = []
442+ all_entries : list [EntryLike ] = []
441443 for composition , group_iter in itertools .groupby (entries , key = lambda e : e .composition .reduced_composition ):
442444 group = list (group_iter )
443445 min_entry = min (group , key = lambda e : e .energy_per_atom )
@@ -980,7 +982,7 @@ def get_phase_separation_energy(self, entry, **kwargs):
980982 """
981983 return self .get_decomp_and_phase_separation_energy (entry , ** kwargs )[1 ]
982984
983- def get_composition_chempots (self , comp ) :
985+ def get_composition_chempots (self , comp : Composition ) -> dict [ Element , float ] :
984986 """Get the chemical potentials for all elements at a given composition.
985987
986988 Args:
@@ -992,7 +994,7 @@ def get_composition_chempots(self, comp):
992994 facet = self ._get_facet_and_simplex (comp )[0 ]
993995 return self ._get_facet_chempots (facet )
994996
995- def get_all_chempots (self , comp ) :
997+ def get_all_chempots (self , comp : Composition ) -> dict [ str , dict [ Element , float ]] :
996998 """Get chemical potentials at a given composition.
997999
9981000 Args:
@@ -1010,7 +1012,7 @@ def get_all_chempots(self, comp):
10101012
10111013 return chempots
10121014
1013- def get_transition_chempots (self , element ) :
1015+ def get_transition_chempots (self , element : Element ) -> tuple [ float , ...] :
10141016 """Get the critical chemical potentials for an element in the Phase
10151017 Diagram.
10161018
@@ -1029,7 +1031,7 @@ def get_transition_chempots(self, element):
10291031 chempots = self ._get_facet_chempots (facet )
10301032 critical_chempots .append (chempots [element ])
10311033
1032- clean_pots = []
1034+ clean_pots : list [ float ] = []
10331035 for c in sorted (critical_chempots ):
10341036 if len (clean_pots ) == 0 or not math .isclose (
10351037 c , clean_pots [- 1 ], abs_tol = PhaseDiagram .numerical_tol , rel_tol = 0
@@ -1038,7 +1040,7 @@ def get_transition_chempots(self, element):
10381040 clean_pots .reverse ()
10391041 return tuple (clean_pots )
10401042
1041- def get_critical_compositions (self , comp1 , comp2 ) :
1043+ def get_critical_compositions (self , comp1 : Composition , comp2 : Composition ) -> list [ Composition ] :
10421044 """Get the critical compositions along the tieline between two
10431045 compositions. I.e. where the decomposition products change.
10441046 The endpoints are also returned.
@@ -1098,7 +1100,7 @@ def get_critical_compositions(self, comp1, comp2):
10981100
10991101 return [Composition ((elem , val ) for elem , val in zip (pd_els , m , strict = True )) for m in cs ]
11001102
1101- def get_element_profile (self , element , comp , comp_tol = 1e-5 ):
1103+ def get_element_profile (self , element : Element , comp : Composition , comp_tol : float = 1e-5 ) -> list [ dict [ str , Any ]] :
11021104 """
11031105 Provides the element evolution data for a composition. For example, can be used
11041106 to analyze Li conversion voltages by varying mu_Li and looking at the phases
@@ -1199,7 +1201,9 @@ def get_chempot_range_map(
11991201
12001202 return chempot_ranges
12011203
1202- def getmu_vertices_stability_phase (self , target_comp , dep_elt , tol_en = 1e-2 ):
1204+ def getmu_vertices_stability_phase (
1205+ self , target_comp : Composition , dep_elt : Element , tol_en : float = 1e-2
1206+ ) -> list [dict [Element , float ]] | None :
12031207 """Get a set of chemical potentials corresponding to the vertices of
12041208 the simplex in the chemical potential phase diagram.
12051209 The simplex is built using all elements in the target_composition
@@ -1233,11 +1237,11 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
12331237 if elem .composition .reduced_composition == target_comp .reduced_composition :
12341238 multiplier = elem .composition [dep_elt ] / target_comp [dep_elt ]
12351239 ef = elem .energy / multiplier
1236- all_coords = []
1240+ all_coords : list [ dict [ Element , float ]] = []
12371241 for simplex in chempots :
12381242 for v in simplex ._coords :
12391243 elements = [elem for elem in self .elements if elem != dep_elt ]
1240- res = {}
1244+ res : dict [ Element , float ] = {}
12411245 for idx , el in enumerate (elements ):
12421246 res [el ] = v [idx ] + mu_ref [idx ]
12431247 res [dep_elt ] = (np .dot (v + mu_ref , coeff ) + ef ) / target_comp [dep_elt ]
@@ -1257,7 +1261,9 @@ def getmu_vertices_stability_phase(self, target_comp, dep_elt, tol_en=1e-2):
12571261 return all_coords
12581262 return None
12591263
1260- def get_chempot_range_stability_phase (self , target_comp , open_elt ):
1264+ def get_chempot_range_stability_phase (
1265+ self , target_comp : Composition , open_elt : Element
1266+ ) -> dict [Element , tuple [float , float ]]:
12611267 """Get a set of chemical potentials corresponding to the max and min
12621268 chemical potential of the open element for a given composition. It is
12631269 quite common to have for instance a ternary oxide (e.g., ABO3) for
@@ -1408,18 +1414,25 @@ class GrandPotentialPhaseDiagram(PhaseDiagram):
14081414 doi:10.1016/j.elecom.2010.01.010
14091415 """
14101416
1411- def __init__ (self , entries , chempots , elements = None , * , computed_data = None ):
1417+ def __init__ (
1418+ self ,
1419+ entries : Collection [EntryLike ],
1420+ chempots : dict [Element , float ],
1421+ elements : Collection [Element ] | None = None ,
1422+ * ,
1423+ computed_data : dict [str , Any ] | None = None ,
1424+ ):
14121425 """Standard constructor for grand potential phase diagram.
14131426
14141427 Args:
1415- entries ([PDEntry ]): A list of PDEntry-like objects having an
1428+ entries (Sequence[EntryLike ]): A list of EntryLike objects having an
14161429 energy, energy_per_atom and composition.
1417- chempots ({ Element: float} ): Specify the chemical potentials
1430+ chempots (dict[ Element, float] ): Specify the chemical potentials
14181431 of the open elements.
1419- elements ([Element]): Optional list of elements in the phase
1432+ elements (Sequence [Element]): Optional list of elements in the phase
14201433 diagram. If set to None, the elements are determined from
14211434 the entries themselves.
1422- computed_data (dict): A dict containing pre-computed data. This allows
1435+ computed_data (dict[str, Any] ): A dict containing pre-computed data. This allows
14231436 PhaseDiagram object to be reconstituted without performing the
14241437 expensive convex hull computation. The dict is the output from the
14251438 PhaseDiagram._compute() method and is stored in PhaseDiagram.computed_data
@@ -1481,7 +1494,12 @@ class CompoundPhaseDiagram(PhaseDiagram):
14811494 # Tolerance for determining if amount of a composition is positive.
14821495 amount_tol = 1e-5
14831496
1484- def __init__ (self , entries , terminal_compositions , normalize_terminal_compositions = True ):
1497+ def __init__ (
1498+ self ,
1499+ entries : Sequence [EntryLike ],
1500+ terminal_compositions : Sequence [Composition ],
1501+ normalize_terminal_compositions : bool = True ,
1502+ ):
14851503 """Initialize a CompoundPhaseDiagram.
14861504
14871505 Args:
@@ -1532,7 +1550,9 @@ def num2str(num):
15321550
15331551 return ret
15341552
1535- def transform_entries (self , entries , terminal_compositions ):
1553+ def transform_entries (
1554+ self , entries : Sequence [EntryLike ], terminal_compositions : Sequence [Composition ]
1555+ ) -> tuple [list [TransformedPDEntry ], dict [Composition , DummySpecies ]]:
15361556 """
15371557 Method to transform all entries to the composition coordinate in the
15381558 terminal compositions. If the entry does not fall within the space
@@ -1624,16 +1644,16 @@ class PatchedPhaseDiagram(PhaseDiagram):
16241644
16251645 def __init__ (
16261646 self ,
1627- entries : Sequence [PDEntry ] | set [PDEntry ],
1647+ entries : Sequence [EntryLike ] | set [EntryLike ],
16281648 elements : Sequence [Element ] | None = None ,
16291649 keep_all_spaces : bool = False ,
16301650 verbose : bool = False ,
16311651 ) -> None :
16321652 """
16331653 Args:
1634- entries (list[PDEntry] ): A list of PDEntry-like objects having an
1654+ entries (Sequence[EntryLike] | set[EntryLike] ): A list of EntryLike objects having an
16351655 energy, energy_per_atom and composition.
1636- elements (list [Element], optional): Optional list of elements in the phase
1656+ elements (Sequence [Element], optional): Optional list of elements in the phase
16371657 diagram. If set to None, the elements are determined from
16381658 the entries themselves and are sorted alphabetically.
16391659 If specified, element ordering (e.g. for pd coordinates)
@@ -1649,9 +1669,9 @@ def __init__(
16491669
16501670 entries = sorted (entries , key = lambda e : e .composition .reduced_composition )
16511671
1652- el_refs : dict [Element , PDEntry ] = {}
1653- min_entries = []
1654- all_entries : list [PDEntry ] = []
1672+ el_refs : dict [Element , EntryLike ] = {}
1673+ min_entries : list [ EntryLike ] = []
1674+ all_entries : list [EntryLike ] = []
16551675 for composition , group_iter in itertools .groupby (entries , key = lambda e : e .composition .reduced_composition ):
16561676 group = list (group_iter )
16571677 min_entry = min (group , key = lambda e : e .energy_per_atom )
@@ -1781,17 +1801,19 @@ def from_dict(cls, dct: dict) -> Self:
17811801 return cls (entries , elements )
17821802
17831803 @staticmethod
1784- def remove_redundant_spaces (spaces , keep_all_spaces = False ):
1804+ def remove_redundant_spaces (
1805+ spaces : set [frozenset [Element ]], keep_all_spaces : bool = False
1806+ ) -> set [frozenset [Element ]]:
17851807 if keep_all_spaces or len (spaces ) <= 1 :
17861808 return spaces
17871809
17881810 # Sort spaces by size in descending order and pre-compute lengths
17891811 sorted_spaces = sorted (spaces , key = len , reverse = True )
17901812
1791- result = []
1813+ result = set ()
17921814 for idx , space_i in enumerate (sorted_spaces ):
17931815 if not any (space_i .issubset (larger_space ) for larger_space in sorted_spaces [:idx ]):
1794- result .append (space_i )
1816+ result .add (space_i )
17951817
17961818 return result
17971819
@@ -1808,7 +1830,7 @@ def remove_redundant_spaces(spaces, keep_all_spaces=False):
18081830 # get_decomp_and_phase_separation_energy(),
18091831 # get_phase_separation_energy()
18101832
1811- def get_pd_for_entry (self , entry : Entry | Composition ) -> PhaseDiagram :
1833+ def get_pd_for_entry (self , entry : EntryLike | Composition ) -> PhaseDiagram :
18121834 """Get the possible phase diagrams for an entry.
18131835
18141836 Args:
0 commit comments