diff --git a/src/main/scala/apps/CifarApp.scala b/src/main/scala/apps/CifarApp.scala index b49a115..5c00797 100644 --- a/src/main/scala/apps/CifarApp.scala +++ b/src/main/scala/apps/CifarApp.scala @@ -103,6 +103,9 @@ object CifarApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers", i) workers.foreach(_ => workerStore.get[CaffeSolver]("solver").trainNet.setWeights(broadcastWeights.value)) + // avoiding a memory leak: + broadcastWeights.unpersist() + broadcastWeights.destroy() if (i % 5 == 0) { logger.log("testing", i) diff --git a/src/main/scala/apps/FeaturizerApp.scala b/src/main/scala/apps/FeaturizerApp.scala index be26c79..fb34665 100644 --- a/src/main/scala/apps/FeaturizerApp.scala +++ b/src/main/scala/apps/FeaturizerApp.scala @@ -77,6 +77,9 @@ object FeaturizerApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers") workers.foreach(_ => workerStore.get[CaffeNet]("net").setWeights(broadcastWeights.value)) + // avoiding a memory leak: + broadcastWeights.unpersist() + broadcastWeights.destroy() // featurize the images val featurizedDF = trainDF.mapPartitions( it => { diff --git a/src/main/scala/apps/ImageNetApp.scala b/src/main/scala/apps/ImageNetApp.scala index 04506ad..906f6ed 100644 --- a/src/main/scala/apps/ImageNetApp.scala +++ b/src/main/scala/apps/ImageNetApp.scala @@ -103,6 +103,9 @@ object ImageNetApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers", i) workers.foreach(_ => workerStore.get[CaffeSolver]("solver").trainNet.setWeights(broadcastWeights.value)) + // avoiding a memory leak: + broadcastWeights.unpersist() + broadcastWeights.destroy() if (i % 10 == 0) { logger.log("testing", i) diff --git a/src/main/scala/apps/MnistApp.scala b/src/main/scala/apps/MnistApp.scala index 5796868..873e1e2 100644 --- a/src/main/scala/apps/MnistApp.scala +++ b/src/main/scala/apps/MnistApp.scala @@ -94,6 +94,9 @@ object MnistApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers", i) workers.foreach(_ => workerStore.get[TensorFlowNet]("net").setWeights(broadcastWeights.value)) + // avoiding a memory leak: + broadcastWeights.unpersist() + broadcastWeights.destroy() if (i % 5 == 0) { logger.log("testing", i) diff --git a/src/main/scala/apps/TFImageNetApp.scala b/src/main/scala/apps/TFImageNetApp.scala index f6195aa..9b287c4 100644 --- a/src/main/scala/apps/TFImageNetApp.scala +++ b/src/main/scala/apps/TFImageNetApp.scala @@ -95,6 +95,9 @@ object TFImageNetApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers", i) workers.foreach(_ => workerStore.get[TensorFlowNet]("net").setWeights(broadcastWeights.value)) + // avoiding a memory leak + broadcastWeights.unpersist() + broadcastWeights.destroy() if (i % 5 == 0) { logger.log("testing", i)