From 55d9cbed7f99d0047a8da9ff0d4ec9ebef9244d5 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Thu, 28 Apr 2016 17:28:38 -0700 Subject: [PATCH] fix memory leak during broadcasting --- src/main/scala/apps/CifarApp.scala | 3 +++ src/main/scala/apps/FeaturizerApp.scala | 3 +++ src/main/scala/apps/ImageNetApp.scala | 3 +++ src/main/scala/apps/MnistApp.scala | 3 +++ src/main/scala/apps/TFImageNetApp.scala | 3 +++ 5 files changed, 15 insertions(+) diff --git a/src/main/scala/apps/CifarApp.scala b/src/main/scala/apps/CifarApp.scala index b49a115..5c00797 100644 --- a/src/main/scala/apps/CifarApp.scala +++ b/src/main/scala/apps/CifarApp.scala @@ -103,6 +103,9 @@ object CifarApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers", i) workers.foreach(_ => workerStore.get[CaffeSolver]("solver").trainNet.setWeights(broadcastWeights.value)) + // avoiding a memory leak: + broadcastWeights.unpersist() + broadcastWeights.destroy() if (i % 5 == 0) { logger.log("testing", i) diff --git a/src/main/scala/apps/FeaturizerApp.scala b/src/main/scala/apps/FeaturizerApp.scala index be26c79..fb34665 100644 --- a/src/main/scala/apps/FeaturizerApp.scala +++ b/src/main/scala/apps/FeaturizerApp.scala @@ -77,6 +77,9 @@ object FeaturizerApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers") workers.foreach(_ => workerStore.get[CaffeNet]("net").setWeights(broadcastWeights.value)) + // avoiding a memory leak: + broadcastWeights.unpersist() + broadcastWeights.destroy() // featurize the images val featurizedDF = trainDF.mapPartitions( it => { diff --git a/src/main/scala/apps/ImageNetApp.scala b/src/main/scala/apps/ImageNetApp.scala index 04506ad..906f6ed 100644 --- a/src/main/scala/apps/ImageNetApp.scala +++ b/src/main/scala/apps/ImageNetApp.scala @@ -103,6 +103,9 @@ object ImageNetApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers", i) workers.foreach(_ => workerStore.get[CaffeSolver]("solver").trainNet.setWeights(broadcastWeights.value)) + // avoiding a memory leak: + broadcastWeights.unpersist() + broadcastWeights.destroy() if (i % 10 == 0) { logger.log("testing", i) diff --git a/src/main/scala/apps/MnistApp.scala b/src/main/scala/apps/MnistApp.scala index 5796868..873e1e2 100644 --- a/src/main/scala/apps/MnistApp.scala +++ b/src/main/scala/apps/MnistApp.scala @@ -94,6 +94,9 @@ object MnistApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers", i) workers.foreach(_ => workerStore.get[TensorFlowNet]("net").setWeights(broadcastWeights.value)) + // avoiding a memory leak: + broadcastWeights.unpersist() + broadcastWeights.destroy() if (i % 5 == 0) { logger.log("testing", i) diff --git a/src/main/scala/apps/TFImageNetApp.scala b/src/main/scala/apps/TFImageNetApp.scala index f6195aa..9b287c4 100644 --- a/src/main/scala/apps/TFImageNetApp.scala +++ b/src/main/scala/apps/TFImageNetApp.scala @@ -95,6 +95,9 @@ object TFImageNetApp { val broadcastWeights = sc.broadcast(netWeights) logger.log("setting weights on workers", i) workers.foreach(_ => workerStore.get[TensorFlowNet]("net").setWeights(broadcastWeights.value)) + // avoiding a memory leak + broadcastWeights.unpersist() + broadcastWeights.destroy() if (i % 5 == 0) { logger.log("testing", i)