11import numpy as np
2+ import matplotlib .pyplot as plt
3+ import seaborn as sns
24from itertools import groupby
35from Bio import SeqIO
46from scipy .signal import find_peaks
@@ -54,3 +56,76 @@ def sharpen_preds(probs):
5456 sharp_probs [beg :end ] = max (sharp_probs [beg :end ])
5557
5658 return sharp_probs
59+
60+
61+ def plot_preds (results , beg = 0 , end = - 1 , out_file = None ):
62+ """
63+ Helper function for plotting DeepCoil results
64+ :param results: results for given entry returned by DeepCoil
65+ :param beg: (optional) beginning aa of the range to use for plotting (useful for long sequences to see only subset of results)
66+ :param end: (optional) end aa of the range to use for plotting (useful for long sequences to see only subset of results)
67+ :param out_file: (optional) if specified results will also be dumped to file
68+ :return:
69+ """
70+
71+ sharp_probs = sharpen_preds (results ['cc' ])[beg :end ]
72+ probs , hept = results ['cc' ][beg :end ], results ['hept' ][beg :end , :]
73+
74+ fig , (ax2 , ax1 ) = plt .subplots (2 , gridspec_kw = {'height_ratios' : [1 , 9 ]}, figsize = (9 , 7 ))
75+
76+ # Plot probs and sharpened probs
77+ ax1 .plot (probs , linewidth = 1 , c = 'gray' , linestyle = 'dashed' )
78+ ax1 .plot (sharp_probs , linewidth = 3 , c = 'black' )
79+
80+ # Set axis limits
81+ ax1 .set_ylim (0.01 , 1 )
82+ ax1 .set_xlim (0 , len (probs ))
83+
84+ # Show grid
85+ ax1 .grid (linestyle = "dashed" , color = 'gray' , linewidth = 0.5 )
86+
87+ # Find and set appropriate spacing of xticks given the sequence length
88+ spacings = [10 , 25 ] + list (range (50 , 1000 , 50 ))
89+ spacing = spacings [np .argmin ([abs (10 - len ([i for i in range (0 , len (probs ), spacing )])) for spacing in spacings ])]
90+ ticks = [i for i in range (0 , len (probs ), spacing )]
91+ labels = [i + beg for i in range (0 , len (probs ), spacing )]
92+ labels [0 ] = ''
93+ ax1 .set_xticks (ticks )
94+ ax1 .set_xticklabels (labels )
95+ ax1 .set_yticks ([i / 10 for i in range (1 , 10 , 1 )])
96+ ax1 .xaxis .set_ticks_position ('none' )
97+ ax1 .set_xlabel ('Sequence position' , fontsize = 16 )
98+
99+ # Parse a, d heptad annotations and plot
100+ a , d = [], []
101+ for i , pr in enumerate (hept ):
102+ if pr [1 ] > 0.2 or pr [2 ] > 0.2 :
103+ if pr [1 ] > pr [2 ]:
104+ a .append (pr [1 ])
105+ d .append (0 )
106+ else :
107+ d .append (- pr [2 ])
108+ a .append (0 )
109+ else :
110+ a .append (0 )
111+ d .append (0 )
112+ kk = np .vstack ((a , d ))
113+ sns .heatmap (np .asarray (kk ), cmap = 'bwr' , vmin = - 1 , vmax = 1 , cbar = None , ax = ax2 )
114+
115+ # Hide bottom panel of the subplot
116+ for name , spine in ax2 .spines .items ():
117+ if name != 'bottom' :
118+ spine .set_visible (True )
119+
120+ # Show the ticks corresponding to the bottom panel
121+ ax2 .xaxis .grid (linestyle = "dashed" , color = 'gray' , linewidth = 0.5 )
122+ ax2 .set_xticks ([i for i in range (0 , len (probs ), 50 )])
123+ ax2 .set_yticklabels (['a' , 'd' ], rotation = 0 , fontsize = 12 )
124+ ax2 .set_xticks (ticks )
125+ ax2 .set_xticklabels (labels )
126+ ax2 .xaxis .set_ticks_position ('none' )
127+
128+ plt .tight_layout ()
129+ plt .subplots_adjust (hspace = 0 )
130+ if out_file :
131+ plt .savefig (out_file , dpi = 300 )
0 commit comments