|
1 | 1 | import argparse |
2 | | -from Bio import SeqIO |
3 | 2 | import os |
4 | 3 | import sys |
5 | | -import numpy as np |
6 | | -from utils import enc_seq_onehot, enc_pssm, is_fasta, get_pssm_sequence, DeepCoil_Model, decode |
7 | | -import keras.backend as K |
| 4 | + |
8 | 5 | import h5py |
| 6 | +import keras.backend as K |
| 7 | +import numpy as np |
| 8 | +from Bio import SeqIO |
9 | 9 |
|
| 10 | +from utils import enc_seq_onehot, enc_pssm, is_fasta, get_pssm_sequence, DeepCoil_Model, decode, SegmentResultFilter, \ |
| 11 | + ScoreResultFilter |
10 | 12 |
|
11 | 13 | # cx_freeze specific |
12 | 14 | if getattr(sys, 'frozen', False): |
|
41 | 43 | parser.add_argument('-skip_checks', |
42 | 44 | action='store_true', |
43 | 45 | help='Skips input verification saving some time. Use only if entirely sure or in the re-runs') |
| 46 | +parser.add_argument('-min_residue_score', |
| 47 | + default=None, |
| 48 | + help="minimum score to assign residue as part of coiled coil") |
| 49 | +parser.add_argument('-min_segment_length', |
| 50 | + default=None, |
| 51 | + help="minimum number of consecutive residues to ") |
44 | 52 | args = parser.parse_args() |
45 | 53 |
|
46 | 54 | # Verify whether weights files are present |
47 | 55 |
|
48 | 56 | for i in range(1, 6): |
49 | 57 | if not os.path.isfile('%s/weights/final_seq_%s.h5' % (my_loc, i)) and not os.path.isfile( |
50 | | - '%s/weights/final_seq_pssm_%s.h5' % (my_loc, i)): |
| 58 | + '%s/weights/final_seq_pssm_%s.h5' % (my_loc, i)): |
51 | 59 | print("Weight files for the DeepCoil model are not available.") |
52 | 60 | exit() |
53 | 61 |
|
|
111 | 119 | try: |
112 | 120 | parsed_pssm = np.genfromtxt(pssm_fn, skip_header=3, skip_footer=5, usecols=(i for i in range(2, 22))) |
113 | 121 | if not parsed_pssm.shape[0] == len(seq): |
114 | | - parsed_pssm = np.genfromtxt(pssm_fn, skip_header=3, skip_footer=3, usecols=(i for i in range(2, 22))) |
115 | | - if not parsed_pssm.shape[0] == len(seq): |
116 | | - raise ValueError |
| 122 | + parsed_pssm = np.genfromtxt(pssm_fn, skip_header=3, skip_footer=3, |
| 123 | + usecols=(i for i in range(2, 22))) |
| 124 | + if not parsed_pssm.shape[0] == len(seq): |
| 125 | + raise ValueError |
117 | 126 | except ValueError: |
118 | 127 | print("ERROR: Malformed PSSM file for entry %s!" % entry) |
119 | 128 | exit() |
|
156 | 165 | predictions = model.predict(enc_sequences, verbose=1) |
157 | 166 | print() |
158 | 167 | decoded_predictions = [decode(pred, encoded_seq) for pred, encoded_seq in |
159 | | - zip(predictions, enc_sequences)] |
| 168 | + zip(predictions, enc_sequences)] |
160 | 169 | for decoded_prediction, entry in zip(decoded_predictions, entries): |
161 | 170 | if i == 1: |
162 | 171 | ensemble_results[entry] = decoded_prediction |
|
168 | 177 | for entry, seq in zip(entries, sequences): |
169 | 178 | f = open('%s/%s.out' % (args.out_path, entry), 'w') |
170 | 179 | final_results = np.average(ensemble_results[entry], axis=0) |
| 180 | + res_filter = None |
| 181 | + if args.min_residue_score: |
| 182 | + res_filter = ScoreResultFilter(final_results, args.min_residue_score) |
| 183 | + res_filter.write_results(entry, seq, |
| 184 | + os.path.join(args.out_path, 'residue_filter_{}'.format(args.min_residue_score))) |
| 185 | + if args.min_segment_length: |
| 186 | + seg_filter = SegmentResultFilter(final_results, args.min_segment_length, other_filter=res_filter) |
| 187 | + seg_filter.write_results(entry, seq, |
| 188 | + os.path.join(args.out_path, 'segment_filter_{}'.format(args.min_segment_length))) |
171 | 189 | for aa, prob in zip(seq, final_results): |
172 | 190 | f.write("%s %s\n" % (aa, "% .3f" % prob)) |
173 | 191 | f.close() |
174 | 192 | elif args.out_type == 'h5': |
175 | 193 | f = h5py.File(args.out_filename, 'w') |
176 | 194 | for entry, seq in zip(entries, sequences): |
177 | | - f.create_dataset(data=np.average(ensemble_results[entry], axis=0), name=entry) |
| 195 | + final_results = np.average(ensemble_results[entry], axis=0) |
| 196 | + can_pass = ScoreResultFilter(args.min_residue_score, final_results).is_correct and SegmentResultFilter( |
| 197 | + args.min_segment_length, final_results).is_correct |
| 198 | + if can_pass: |
| 199 | + f.create_dataset(data=final_results, axis=0, name=entry) |
178 | 200 | f.close() |
179 | 201 | print() |
180 | 202 | print("Done!") |
0 commit comments