Skip to content

Commit 206152f

Browse files
committed
more explicit handling of dimension
1 parent c33c731 commit 206152f

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/pymatgen/analysis/phase_diagram.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,10 +2328,14 @@ def get_plot(
23282328
Returns:
23292329
go.Figure | plt.Axes: Plotly figure or matplotlib axes object depending on backend.
23302330
"""
2331-
fig = None
2332-
data = []
2331+
if self._dim not in {1, 2, 3, 4}:
2332+
raise ValueError(
2333+
f"Plotting is only supported for unary/binary/ternary/quaternary phase diagrams — got {self._dim}D "
2334+
)
23332335

23342336
if self.backend == "plotly":
2337+
data: list = []
2338+
23352339
if self._dim != 1:
23362340
data.append(self._create_plotly_lines())
23372341

@@ -2349,7 +2353,7 @@ def get_plot(
23492353
if self._dim != 1 and not (self._dim == 3 and self.ternary_style == "2d"):
23502354
data.append(self._create_plotly_stable_labels(label_stable))
23512355

2352-
if fill and self._dim in [3, 4]:
2356+
if fill and self._dim in {3, 4}:
23532357
data.extend(self._create_plotly_fill())
23542358

23552359
data.extend([stable_marker_plot, unstable_marker_plot])
@@ -2361,20 +2365,22 @@ def get_plot(
23612365
fig.layout = self._create_plotly_figure_layout()
23622366
fig.update_layout(coloraxis_colorbar={"yanchor": "top", "y": 0.05, "x": 1})
23632367

2364-
elif self.backend == "matplotlib":
2365-
if self._dim <= 3:
2366-
fig = self._get_matplotlib_2d_plot(
2368+
return fig
2369+
2370+
if self.backend == "matplotlib":
2371+
if self._dim in {1, 2, 3}:
2372+
return self._get_matplotlib_2d_plot(
23672373
label_stable,
23682374
label_unstable,
23692375
ordering,
23702376
energy_colormap,
23712377
ax=ax,
23722378
process_attributes=process_attributes,
23732379
)
2374-
elif self._dim == 4:
2375-
fig = self._get_matplotlib_3d_plot(label_stable, ax=ax)
2380+
if self._dim == 4:
2381+
return self._get_matplotlib_3d_plot(label_stable, ax=ax)
23762382

2377-
return fig
2383+
return None
23782384

23792385
def show(self, *args, **kwargs) -> None:
23802386
"""

0 commit comments

Comments
 (0)