Skip to content

Commit 4ad6655

Browse files
author
Jan Ludwiczak
committed
Fix potential memory issues when running multiple inputs in the same session
1 parent be602ea commit 4ad6655

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

deepcoil/deepcoil.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from deepcoil.utils import corr_seq
1313
from tensorflow.keras.models import model_from_json
1414
from tensorflow.keras.utils import get_file
15+
import tensorflow.keras.backend as K
1516
from zipfile import ZipFile
1617

1718

@@ -126,15 +127,15 @@ def predict(self, data):
126127
# Encode SeqVec (#TODO: handle large inputs which may cause OOM errors)
127128
embeddings = self._seqvec.encode(data, to_file=False)
128129

129-
if self.use_gpu:
130-
# If GPU is used free mem used by SeqVec
131-
torch.cuda.empty_cache()
132-
133130
# Setup generator that'll be evaluated
134131
seqvec_enc = SeqVecMemEncoder(embeddings, pad_length=500)
135132
gen = SeqChunker(data, batch_size=64, W_size=500, shuffle=False,
136133
data_encoders=[seqvec_enc], data_cols=['sequence'])
137134

135+
if self.use_gpu:
136+
# If GPU is used free mem used by SeqVec
137+
torch.cuda.empty_cache()
138+
138139
# Predict with each of N predictors, depad predictions and average out for final output
139140
for i in range(1, self._n_weights+1):
140141
self.model.load_weights(f'{self._weights_prefix}_{i}.h5')
@@ -154,6 +155,8 @@ def predict(self, data):
154155
for key, pred in hept_preds_depadded.items():
155156
hept_preds_per_fold[key].append(pred)
156157

158+
K.clear_session()
159+
157160
# Average predictions between 5 predictors from CV training
158161
cc_preds_avg = {key: np.average(value, axis=0).flatten() for key, value in cc_preds_per_fold.items()}
159162
hept_preds_avg = {key: np.average(value, axis=0) for key, value in hept_preds_per_fold.items()}

0 commit comments

Comments
 (0)