@@ -2328,10 +2328,14 @@ def get_plot(
23282328 Returns:
23292329 go.Figure | plt.Axes: Plotly figure or matplotlib axes object depending on backend.
23302330 """
2331- fig = None
2332- data = []
2331+ if self ._dim not in {1 , 2 , 3 , 4 }:
2332+ raise ValueError (
2333+ f"Plotting is only supported for unary/binary/ternary/quaternary phase diagrams — got { self ._dim } D "
2334+ )
23332335
23342336 if self .backend == "plotly" :
2337+ data : list = []
2338+
23352339 if self ._dim != 1 :
23362340 data .append (self ._create_plotly_lines ())
23372341
@@ -2349,7 +2353,7 @@ def get_plot(
23492353 if self ._dim != 1 and not (self ._dim == 3 and self .ternary_style == "2d" ):
23502354 data .append (self ._create_plotly_stable_labels (label_stable ))
23512355
2352- if fill and self ._dim in [ 3 , 4 ] :
2356+ if fill and self ._dim in { 3 , 4 } :
23532357 data .extend (self ._create_plotly_fill ())
23542358
23552359 data .extend ([stable_marker_plot , unstable_marker_plot ])
@@ -2361,20 +2365,22 @@ def get_plot(
23612365 fig .layout = self ._create_plotly_figure_layout ()
23622366 fig .update_layout (coloraxis_colorbar = {"yanchor" : "top" , "y" : 0.05 , "x" : 1 })
23632367
2364- elif self .backend == "matplotlib" :
2365- if self ._dim <= 3 :
2366- fig = self ._get_matplotlib_2d_plot (
2368+ return fig
2369+
2370+ if self .backend == "matplotlib" :
2371+ if self ._dim in {1 , 2 , 3 }:
2372+ return self ._get_matplotlib_2d_plot (
23672373 label_stable ,
23682374 label_unstable ,
23692375 ordering ,
23702376 energy_colormap ,
23712377 ax = ax ,
23722378 process_attributes = process_attributes ,
23732379 )
2374- elif self ._dim == 4 :
2375- fig = self ._get_matplotlib_3d_plot (label_stable , ax = ax )
2380+ if self ._dim == 4 :
2381+ return self ._get_matplotlib_3d_plot (label_stable , ax = ax )
23762382
2377- return fig
2383+ return None
23782384
23792385 def show (self , * args , ** kwargs ) -> None :
23802386 """
0 commit comments