Skip to content

Commit 3536700

Browse files
author
Jan Ludwiczak
committed
Add helper plotting function, update requirements
1 parent 01307ae commit 3536700

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

deepcoil/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .utils import corr_seq, is_fasta, sharpen_preds
1+
from .utils import corr_seq, is_fasta, sharpen_preds, plot_preds

deepcoil/utils/utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import seaborn as sns
24
from itertools import groupby
35
from Bio import SeqIO
46
from scipy.signal import find_peaks
@@ -54,3 +56,76 @@ def sharpen_preds(probs):
5456
sharp_probs[beg:end] = max(sharp_probs[beg:end])
5557

5658
return sharp_probs
59+
60+
61+
def plot_preds(results, beg=0, end=-1, out_file=None):
62+
"""
63+
Helper function for plotting DeepCoil results
64+
:param results: results for given entry returned by DeepCoil
65+
:param beg: (optional) beginning aa of the range to use for plotting (useful for long sequences to see only subset of results)
66+
:param end: (optional) end aa of the range to use for plotting (useful for long sequences to see only subset of results)
67+
:param out_file: (optional) if specified results will also be dumped to file
68+
:return:
69+
"""
70+
71+
sharp_probs = sharpen_preds(results['cc'])[beg:end]
72+
probs, hept = results['cc'][beg:end], results['hept'][beg:end, :]
73+
74+
fig, (ax2, ax1) = plt.subplots(2, gridspec_kw={'height_ratios': [1, 9]}, figsize=(9, 7))
75+
76+
# Plot probs and sharpened probs
77+
ax1.plot(probs, linewidth=1, c='gray', linestyle='dashed')
78+
ax1.plot(sharp_probs, linewidth=3, c='black')
79+
80+
# Set axis limits
81+
ax1.set_ylim(0.01, 1)
82+
ax1.set_xlim(0, len(probs))
83+
84+
# Show grid
85+
ax1.grid(linestyle="dashed", color='gray', linewidth=0.5)
86+
87+
# Find and set appropriate spacing of xticks given the sequence length
88+
spacings = [10, 25] + list(range(50, 1000, 50))
89+
spacing = spacings[np.argmin([abs(10 - len([i for i in range(0, len(probs), spacing)])) for spacing in spacings])]
90+
ticks = [i for i in range(0, len(probs), spacing)]
91+
labels = [i + beg for i in range(0, len(probs), spacing)]
92+
labels[0] = ''
93+
ax1.set_xticks(ticks)
94+
ax1.set_xticklabels(labels)
95+
ax1.set_yticks([i / 10 for i in range(1, 10, 1)])
96+
ax1.xaxis.set_ticks_position('none')
97+
ax1.set_xlabel('Sequence position', fontsize=16)
98+
99+
# Parse a, d heptad annotations and plot
100+
a, d = [], []
101+
for i, pr in enumerate(hept):
102+
if pr[1] > 0.2 or pr[2] > 0.2:
103+
if pr[1] > pr[2]:
104+
a.append(pr[1])
105+
d.append(0)
106+
else:
107+
d.append(-pr[2])
108+
a.append(0)
109+
else:
110+
a.append(0)
111+
d.append(0)
112+
kk = np.vstack((a, d))
113+
sns.heatmap(np.asarray(kk), cmap='bwr', vmin=-1, vmax=1, cbar=None, ax=ax2)
114+
115+
# Hide bottom panel of the subplot
116+
for name, spine in ax2.spines.items():
117+
if name != 'bottom':
118+
spine.set_visible(True)
119+
120+
# Show the ticks corresponding to the bottom panel
121+
ax2.xaxis.grid(linestyle="dashed", color='gray', linewidth=0.5)
122+
ax2.set_xticks([i for i in range(0, len(probs), 50)])
123+
ax2.set_yticklabels(['a', 'd'], rotation=0, fontsize=12)
124+
ax2.set_xticks(ticks)
125+
ax2.set_xticklabels(labels)
126+
ax2.xaxis.set_ticks_position('none')
127+
128+
plt.tight_layout()
129+
plt.subplots_adjust(hspace=0)
130+
if out_file:
131+
plt.savefig(out_file, dpi=300)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ numpy<1.19.0,>=1.16.0
44
tensorflow>=2.3.0
55
allennlp>=0.9.0,<0.10.0
66
torch>=1.2,<2.0
7+
seaborn>=0.10

0 commit comments

Comments
 (0)