Skip to content

Commit 01307ae

Browse files
author
Jan Ludwiczak
committed
Correct prediction processing function
1 parent 130352d commit 01307ae

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

deepcoil/utils/utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from itertools import groupby
23
from Bio import SeqIO
34
from scipy.signal import find_peaks
45

@@ -40,8 +41,16 @@ def sharpen_preds(probs):
4041
end = int(peaks[1]['right_ips'][i])
4142
prob = max(probs[beg:end])
4243

43-
for i in range(beg, end + 1):
44+
for j in range(beg, end + 1):
4445
if prob >= 0.1:
45-
sharp_probs[i] = prob
46+
sharp_probs[j] = prob
47+
48+
sharp_probs = sharp_probs.flatten()
49+
above_threshold = sharp_probs > 0
50+
for k, g in groupby(enumerate(above_threshold), key=lambda x: x[1]):
51+
if k:
52+
g = list(g)
53+
beg, end = g[0][0], g[-1][0]
54+
sharp_probs[beg:end] = max(sharp_probs[beg:end])
4655

4756
return sharp_probs

0 commit comments

Comments
 (0)