@@ -19,6 +19,7 @@ class DotPlot(object):
1919 MIN_FIGURE_HEIGHT = 3
2020 DEFAULT_BAND_ITEM_LENGTH = DEFAULT_ITEM_HEIGHT
2121
22+ # TODO implement annotation band
2223 def __init__ (self , df_size : pd .DataFrame ,
2324 df_color : Union [pd .DataFrame , None ] = None ,
2425 df_circle : Union [pd .DataFrame , None ] = None ,
@@ -55,11 +56,11 @@ def __get_figure(self):
5556 mainplot_width = (
5657 (_text_max + self .width_item ) * self .DEFAULT_ITEM_WIDTH
5758 )
58- if self .annotation_data is not None :
59- pass
60-
6159 figure_height = max ([self .MIN_FIGURE_HEIGHT , mainplot_height ])
6260 figure_width = mainplot_width + self .DEFAULT_LEGENDS_WIDTH
61+ if self .annotation_data is not None :
62+ # figure_width = figure_width + self.DEFAULT_BAND_ITEM_LENGTH * self.annotation_data.shape[1]
63+ ...
6364 plt .style .use ('seaborn-white' )
6465 fig = plt .figure (figsize = (figure_width , figure_height ))
6566 gs = gridspec .GridSpec (nrows = 3 , ncols = 2 , wspace = 0.15 , hspace = 0.15 ,
@@ -68,6 +69,9 @@ def __get_figure(self):
6869 ax_cbar = fig .add_subplot (gs [2 , 1 ])
6970 ax_sizes = fig .add_subplot (gs [0 , 1 ])
7071 ax_circles = fig .add_subplot (gs [1 , 1 ])
72+ if self .color_data is None :
73+ ax_cbar .axis ('off' )
74+ ax_circles .axis ('off' )
7175 return ax , ax_cbar , ax_sizes , ax_circles , fig
7276
7377 @classmethod
@@ -129,18 +133,24 @@ def __get_coordinates(self, size_factor):
129133 self .resized_circle_data = self .circle_data .applymap (func = lambda x : x * size_factor )
130134 return X , Y
131135
132- def __draw_dotplot (self , ax , size_factor , cmap , vmin , vmax ):
136+ def __draw_dotplot (self , ax , size_factor , cmap , vmin , vmax , ** kws ):
137+ dot_color = kws .get ('dot_color' , '#58000C' )
138+ circle_color = kws .get ('circle_color' , '#000000' )
139+ kws = kws .copy ()
140+ for _value in ['dot_title' , 'circle_title' , 'colorbar_title' , 'dot_color' , 'circle_color' ]:
141+ _ = kws .pop (_value , None )
142+
133143 X , Y = self .__get_coordinates (size_factor )
134144 if self .color_data is None :
135- sct = ax .scatter (X , Y , c = 'r' , cmap = cmap , s = self .resized_size_data .values .flatten (),
136- edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax )
145+ sct = ax .scatter (X , Y , c = dot_color , s = self .resized_size_data .values .flatten (),
146+ edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap , ** kws )
137147 else :
138148 sct = ax .scatter (X , Y , c = self .color_data .values .flatten (), s = self .resized_size_data .values .flatten (),
139- edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap )
149+ edgecolors = 'none' , linewidths = 0 , vmin = vmin , vmax = vmax , cmap = cmap , ** kws )
140150 sct_circle = None
141151 if self .circle_data is not None :
142- sct_circle = ax .scatter (X , Y , c = '' , edgecolors = 'k' , marker = 'o' , linestyle = '--' ,
143- s = self . resized_circle_data . values . flatten () )
152+ sct_circle = ax .scatter (X , Y , c = 'none ' , s = self . resized_circle_data . values . flatten () ,
153+ edgecolors = circle_color , marker = 'o' , vmin = vmin , vmax = vmax , linestyle = '--' )
144154 width , height = self .width_item , self .height_item
145155 ax .set_xlim ([0.5 , width + 0.5 ])
146156 ax .set_ylim ([0.6 , height + 0.6 ])
@@ -153,7 +163,7 @@ def __draw_dotplot(self, ax, size_factor, cmap, vmin, vmax):
153163 return sct , sct_circle
154164
155165 @staticmethod
156- def __draw_color_bar (ax , sct : mpl .collections .PathCollection , cmap , vmin , vmax ):
166+ def __draw_color_bar (ax , sct : mpl .collections .PathCollection , cmap , vmin , vmax , ylabel ):
157167 gradient = np .linspace (1 , 0 , 500 )
158168 gradient = gradient [:, np .newaxis ]
159169 _ = ax .imshow (gradient , aspect = 'auto' , cmap = cmap , origin = 'upper' , extent = [.2 , 0.3 , 0.5 , - 0.5 ])
@@ -166,7 +176,7 @@ def __draw_color_bar(ax, sct: mpl.collections.PathCollection, cmap, vmin, vmax):
166176 if vmin is None :
167177 vmin = math .floor (sct .get_array ().min ())
168178 _ = ax_cbar2 .set_yticklabels ([vmin , vmax ])
169- _ = ax_cbar2 .set_ylabel ('-log10(pvalue)' )
179+ _ = ax_cbar2 .set_ylabel (ylabel )
170180
171181 @staticmethod
172182 def __draw_legend (ax , sct : mpl .collections .PathCollection , size_factor , title , circle = False , color = None ):
@@ -197,24 +207,33 @@ def __draw_legend(ax, sct: mpl.collections.PathCollection, size_factor, title, c
197207 def plot (self , size_factor : float = 15 ,
198208 vmin : float = 0 , vmax : float = None ,
199209 path : Union [PathLike , None ] = None ,
200- cmap : Union [str , mpl .colors .Colormap ] = 'Reds' ):
210+ cmap : Union [str , mpl .colors .Colormap ] = 'Reds' ,
211+ ** kwargs
212+ ):
201213 """
202214
203215 :param size_factor: `size factor` * `value` for the actually representation of scatter size in the final figure
204216 :param vmin: `vmin` in `matplotlib.pyplot.scatter`
205217 :param vmax: `vmax` in `matplotlib.pyplot.scatter`
206218 :param path: path to save the figure
207219 :param cmap: color map supported by matplotlib
220+ :param kwargs: dot_title, circle_title, colorbar_title, dot_color, circle_color
221+ other kwargs are passed to `matplotlib.Axes.scatter`
208222 :return:
209223 """
210224 ax , ax_cbar , ax_sizes , ax_circles , fig = self .__get_figure ()
211225 scatter , sct_circle = self .__draw_dotplot (ax , size_factor , cmap , vmin , vmax )
212- self .__draw_legend (ax_sizes , scatter , size_factor , title = 'Sizes' , color = '#58000C' )
226+ self .__draw_legend (ax_sizes , scatter , size_factor ,
227+ color = kwargs .get ('dot_color' , '#58000C' ), # dot legend color
228+ title = kwargs .get ('dot_title' , 'Sizes' ))
213229 if sct_circle is not None :
214- self .__draw_legend (ax_circles , sct_circle , size_factor , title = 'Circles' , circle = True , color = 'k' )
215- else :
216- ax_circles .axis ('off' )
217- self .__draw_color_bar (ax_cbar , scatter , cmap , vmin , vmax )
230+ self .__draw_legend (ax_circles , sct_circle , size_factor ,
231+ color = kwargs .get ('circle_color' , '#000000' ),
232+ title = kwargs .get ('circle_title' , 'Circles' ),
233+ circle = True )
234+ if self .color_data is not None :
235+ self .__draw_color_bar (ax_cbar , scatter , cmap , vmin , vmax ,
236+ ylabel = kwargs .get ('colorbar_title' , '-log10(pvalue)' ))
218237 if path :
219238 fig .savefig (path , dpi = 300 , bbox_inches = 'tight' ) #
220239 return scatter
0 commit comments