1212from deepcoil .utils import corr_seq
1313from tensorflow .keras .models import model_from_json
1414from tensorflow .keras .utils import get_file
15+ import tensorflow .keras .backend as K
1516from zipfile import ZipFile
1617
1718
@@ -126,15 +127,15 @@ def predict(self, data):
126127 # Encode SeqVec (#TODO: handle large inputs which may cause OOM errors)
127128 embeddings = self ._seqvec .encode (data , to_file = False )
128129
129- if self .use_gpu :
130- # If GPU is used free mem used by SeqVec
131- torch .cuda .empty_cache ()
132-
133130 # Setup generator that'll be evaluated
134131 seqvec_enc = SeqVecMemEncoder (embeddings , pad_length = 500 )
135132 gen = SeqChunker (data , batch_size = 64 , W_size = 500 , shuffle = False ,
136133 data_encoders = [seqvec_enc ], data_cols = ['sequence' ])
137134
135+ if self .use_gpu :
136+ # If GPU is used free mem used by SeqVec
137+ torch .cuda .empty_cache ()
138+
138139 # Predict with each of N predictors, depad predictions and average out for final output
139140 for i in range (1 , self ._n_weights + 1 ):
140141 self .model .load_weights (f'{ self ._weights_prefix } _{ i } .h5' )
@@ -154,6 +155,8 @@ def predict(self, data):
154155 for key , pred in hept_preds_depadded .items ():
155156 hept_preds_per_fold [key ].append (pred )
156157
158+ K .clear_session ()
159+
157160 # Average predictions between 5 predictors from CV training
158161 cc_preds_avg = {key : np .average (value , axis = 0 ).flatten () for key , value in cc_preds_per_fold .items ()}
159162 hept_preds_avg = {key : np .average (value , axis = 0 ) for key , value in hept_preds_per_fold .items ()}
0 commit comments