Skip to content
This repository has been archived by the owner on May 27, 2020. It is now read-only.

Commit

Permalink
fix connection leaks
Browse files Browse the repository at this point in the history
  • Loading branch information
darroyo-stratio committed May 10, 2016
1 parent 1f8e399 commit 064013d
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package com.stratio.datasource.mongodb

import com.stratio.datasource.mongodb.schema.MongodbRowConverter
import com.stratio.datasource.mongodb.writer.{MongodbSimpleWriter, MongodbBatchWriter}
import com.stratio.datasource.mongodb.writer.{MongodbBatchWriter, MongodbSimpleWriter}
import com.stratio.datasource.util.Config
import org.apache.spark.sql.DataFrame

Expand All @@ -39,10 +39,13 @@ class MongodbDataFrame(dataFrame: DataFrame) extends Serializable {
val writer =
if (batch) new MongodbBatchWriter(config)
else new MongodbSimpleWriter(config)
writer.saveWithPk(it.map(row =>
MongodbRowConverter.rowAsDBObject(row, schema)))
writer.freeConnection()

writer.saveWithPk(
it.map(row => MongodbRowConverter.rowAsDBObject(row, schema)))

})
}



}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

package com.stratio.datasource.mongodb

import com.stratio.datasource.mongodb.config.MongodbConfig
import com.mongodb.casbah.Imports._
import com.stratio.datasource.mongodb.client.MongodbClientFactory
import com.stratio.datasource.mongodb.config.{MongodbConfig, MongodbConfigReader}
import com.stratio.datasource.mongodb.partitioner.MongodbPartitioner
import com.stratio.datasource.mongodb.rdd.MongodbRDD
import com.stratio.datasource.mongodb.schema.{MongodbRowConverter, MongodbSchema}
import com.stratio.datasource.mongodb.writer.MongodbSimpleWriter
import com.stratio.datasource.mongodb.util.usingMongoClient
import com.stratio.datasource.util.Config
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan}
Expand All @@ -45,6 +47,7 @@ with PrunedFilteredScan with InsertableRelation {

implicit val _: Config = config

import MongodbConfigReader._
import MongodbRelation._

private val rddPartitioner: MongodbPartitioner =
Expand Down Expand Up @@ -80,7 +83,14 @@ with PrunedFilteredScan with InsertableRelation {
* Indicates if a collection is empty.
* @return Boolean
*/
def isEmptyCollection: Boolean = new MongodbSimpleWriter(config).isEmpty
def isEmptyCollection: Boolean =
usingMongoClient(MongodbClientFactory.getClient(config.hosts, config.credentials, config.sslOptions, config.clientOptions).clientConnection) { mongoClient =>
dbCollection(mongoClient).isEmpty
}





/**
* Insert data into the specified DataSource.
Expand All @@ -89,7 +99,9 @@ with PrunedFilteredScan with InsertableRelation {
*/
def insert(data: DataFrame, overwrite: Boolean): Unit = {
if (overwrite) {
new MongodbSimpleWriter(config).dropCollection
usingMongoClient(MongodbClientFactory.getClient(config.hosts, config.credentials, config.sslOptions, config.clientOptions).clientConnection) { mongoClient =>
dbCollection(mongoClient).dropCollection()
}
}

data.saveToMongodb(config)
Expand All @@ -111,6 +123,12 @@ with PrunedFilteredScan with InsertableRelation {
val state = Seq(schema, config)
state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
}

/**
* A MongoDB collection created from the specified database and collection.
*/
private def dbCollection(mongoClient: MongoClient): MongoCollection =
mongoClient(config(MongodbConfig.Database))(config(MongodbConfig.Collection))
}

object MongodbRelation {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,12 @@ import scala.annotation.tailrec
import scala.util.Try


// TODO Refactor - "The MongoClient class is designed to be thread safe and shared among threads"
class MongodbClientActor extends Actor {

private val KeySeparator = "-"

private val CloseSleepTime = 1000
private val CloseSleepTime = 100

private val mongoClient: scala.collection.mutable.Map[String, MongodbConnection] =
scala.collection.mutable.Map.empty[String, MongodbConnection]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import com.mongodb.ServerAddress
import com.mongodb.casbah.Imports._
import com.mongodb.casbah.MongoClient
import com.stratio.datasource.mongodb.client.MongodbClientActor._
import com.stratio.datasource.mongodb.config.MongodbSSLOptions
import com.stratio.datasource.mongodb.config.{MongodbConfig, MongodbSSLOptions}
import com.typesafe.config.ConfigFactory

import scala.concurrent.Await
Expand All @@ -35,6 +35,7 @@ import scala.concurrent.duration._
/**
* Different client configurations to Mongodb database
*/
// TODO Refactor - MongodbClientFactory should be used internally and should not delegate to other when closing/freeing connections
object MongodbClientFactory {

type Client = MongoClient
Expand All @@ -46,7 +47,7 @@ object MongodbClientFactory {
*/
private val actorSystem = ActorSystem("mongodbClientFactory", ConfigFactory.load(ConfigFactory.parseString("akka.daemonic=on")))
private val scheduler = actorSystem.scheduler
private val SecondsToCheckConnections = 60
private val SecondsToCheckConnections = MongodbConfig.DefaultConnectionsTime
private val mongoConnectionsActor = actorSystem.actorOf(Props(new MongodbClientActor), "mongoConnectionActor")

private implicit val executor = actorSystem.dispatcher
Expand All @@ -63,7 +64,7 @@ object MongodbClientFactory {
* @param host Ip or Dns to connect
* @return Client connection with identifier
*/
def getClient(host: String): ClientResponse = {
private[mongodb] def getClient(host: String): ClientResponse = {
val futureResult = mongoConnectionsActor ? GetClient(host)
Await.result(futureResult, timeout.duration) match {
case ClientResponse(key, clientConnection) => ClientResponse(key, clientConnection)
Expand All @@ -79,7 +80,7 @@ object MongodbClientFactory {
* @param password Password for credentials
* @return Client connection with identifier
*/
def getClient(host: String, port: Int, user: String, database: String, password: String): ClientResponse = {
private[mongodb] def getClient(host: String, port: Int, user: String, database: String, password: String): ClientResponse = {
val futureResult = mongoConnectionsActor ? GetClientWithUser(host, port, user, database, password)
Await.result(futureResult, timeout.duration) match {
case ClientResponse(key, clientConnection) => ClientResponse(key, clientConnection)
Expand All @@ -94,7 +95,7 @@ object MongodbClientFactory {
* @param clientOptions All options for the client connections
* @return Client connection with identifier
*/
def getClient(hostPort: List[ServerAddress],
private[mongodb] def getClient(hostPort: List[ServerAddress],
credentials: List[MongoCredential] = List(),
optionSSLOptions: Option[MongodbSSLOptions] = None,
clientOptions: Map[String, Any] = Map()): ClientResponse = {
Expand All @@ -110,7 +111,7 @@ object MongodbClientFactory {
* Close all client connections on the concurrent map
* @param gracefully Close the connections if is free
*/
def closeAll(gracefully: Boolean = true, attempts: Int = CloseAttempts): Unit = {
private[mongodb] def closeAll(gracefully: Boolean = true, attempts: Int = CloseAttempts): Unit = {
mongoConnectionsActor ! CloseAll(gracefully, attempts)
}

Expand All @@ -119,7 +120,7 @@ object MongodbClientFactory {
* @param client client value for connect to MongoDb
* @param gracefully Close the connection if is free
*/
def closeByClient(client: Client, gracefully: Boolean = true): Unit = {
private[mongodb] def closeByClient(client: Client, gracefully: Boolean = true): Unit = {
mongoConnectionsActor ! CloseByClient(client, gracefully)
}

Expand All @@ -128,27 +129,27 @@ object MongodbClientFactory {
* @param clientKey key pre calculated with the connection options
* @param gracefully Close the connection if is free
*/
def closeByKey(clientKey: String, gracefully: Boolean = true): Unit = {
private[mongodb] def closeByKey(clientKey: String, gracefully: Boolean = true): Unit = {
mongoConnectionsActor ! CloseByKey(clientKey, gracefully)
}

/**
* Set Free the connection that have the same client as the client param
* @param client client value for connect to MongoDb
*/
def setFreeConnectionByClient(client: Client, extendedTime: Option[Long] = None): Unit = {
private[mongodb] def setFreeConnectionByClient(client: Client, extendedTime: Option[Long] = None): Unit = {
mongoConnectionsActor ! SetFreeConnectionsByClient(client, extendedTime)
}

/**
* Set Free the connection that have the same key as the clientKey param
* @param clientKey key pre calculated with the connection options
*/
def setFreeConnectionByKey(clientKey: String, extendedTime: Option[Long] = None): Unit = {
private[mongodb] def setFreeConnectionByKey(clientKey: String, extendedTime: Option[Long] = None): Unit = {
mongoConnectionsActor ! SetFreeConnectionByKey(clientKey, extendedTime)
}

def getClientPoolSize: Int = {
private[client] def getClientPoolSize: Int = {
val futureResult = mongoConnectionsActor ? GetSize
Await.result(futureResult, timeout.duration) match {
case size: Int => size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ object MongodbConfig {
val DefaultSamplingRatio = 1.0
val DefaultSplitSize = 10
val DefaultSplitKey = "_id"
val DefaultConnectionsTime = 120000L
val DefaultConnectionsTime = 10000L
val DefaultCursorBatchSize = 101
val DefaultBulkBatchSize = 1000
val DefaultIdAsObjectId = "true"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.stratio.datasource.mongodb.config

import com.mongodb.casbah.Imports._
import com.mongodb.{MongoCredential, ServerAddress}
import com.stratio.datasource.mongodb.config.MongodbConfig._
import com.stratio.datasource.util.Config

object MongodbConfigReader {

implicit class MongodbConfigFunctions(config: Config) {
@transient protected[mongodb] val hosts : List[ServerAddress] =
config[List[String]](MongodbConfig.Host)
.map(add => new ServerAddress(add))

@transient protected[mongodb] val credentials: List[MongoCredential] =
config.getOrElse[List[MongodbCredentials]](MongodbConfig.Credentials, MongodbConfig.DefaultCredentials).map{
case MongodbCredentials(user,database,password) =>
MongoCredential.createCredential(user,database,password)
}

@transient protected[mongodb] val sslOptions: Option[MongodbSSLOptions] =
config.get[MongodbSSLOptions](MongodbConfig.SSLOptions)

@transient protected[mongodb] val writeConcern: WriteConcern = config.get[String](MongodbConfig.WriteConcern) match {
case Some(wConcern) => parseWriteConcern(wConcern)
case None => DefaultWriteConcern
}

protected[mongodb] val clientOptions = config.properties.filterKeys(_.contains(MongodbConfig.ListMongoClientOptions))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
package com.stratio.datasource.mongodb.partitioner

import java.text.SimpleDateFormat

import com.mongodb.casbah.Imports._
import com.mongodb.{MongoCredential, ServerAddress}
import com.stratio.datasource.mongodb.client.MongodbClientFactory
import com.stratio.datasource.mongodb.client.MongodbClientFactory.Client
import com.stratio.datasource.mongodb.config.{MongodbSSLOptions, MongodbCredentials, MongodbConfig}
import com.stratio.datasource.mongodb.client.MongodbClientFactory
import com.stratio.datasource.mongodb.config.{MongodbConfig, MongodbCredentials, MongodbSSLOptions}
import com.stratio.datasource.mongodb.partitioner.MongodbPartitioner._
import com.stratio.datasource.mongodb.util.usingMongoClient
import com.stratio.datasource.partitioner.{PartitionRange, Partitioner}
import com.stratio.datasource.util.Config

import scala.util.Try

/**
Expand Down Expand Up @@ -60,16 +63,13 @@ class MongodbPartitioner(config: Config) extends Partitioner[MongodbPartition] {

private val cursorBatchSize = config.getOrElse[Int](MongodbConfig.CursorBatchSize, MongodbConfig.DefaultCursorBatchSize)

override def computePartitions(): Array[MongodbPartition] = {
val mongoClient = MongodbClientFactory.getClient(hosts, credentials, ssloptions, clientOptions)

val result = if (isShardedCollection(mongoClient.clientConnection))
computeShardedChunkPartitions(mongoClient.clientConnection)
else
computeNotShardedPartitions(mongoClient.clientConnection)

result
}
override def computePartitions(): Array[MongodbPartition] =
usingMongoClient(MongodbClientFactory.getClient(hosts, credentials, ssloptions, clientOptions).clientConnection) { mongoClient =>
if (isShardedCollection(mongoClient))
computeShardedChunkPartitions(mongoClient)
else
computeNotShardedPartitions(mongoClient)
}

/**
* @return Whether this is a sharded collection or not
Expand Down Expand Up @@ -203,16 +203,16 @@ class MongodbPartitioner(config: Config) extends Partitioner[MongodbPartition] {
.find(MongoDBObject("_id" -> stats.getString("primary"))).batchSize(cursorBatchSize)
val shard = shards.next()
val shardHost: String = shard.as[String]("host").replace(shard.get("_id") + "/", "")
val shardClient = MongodbClientFactory.getClient(shardHost)
val data = shardClient.clientConnection.getDB("admin").command(cmd)
val splitKeys = data.as[List[DBObject]]("splitKeys").map(Option(_))
val ranges = (splitKeyMin +: splitKeys) zip (splitKeys :+ splitKeyMax )

shards.close()
usingMongoClient(MongodbClientFactory.getClient(shardHost).clientConnection){ mongoClient =>
val data = mongoClient.getDB("admin").command(cmd)
val splitKeys = data.as[List[DBObject]]("splitKeys").map(Option(_))
val ranges = (splitKeyMin +: splitKeys) zip (splitKeys :+ splitKeyMax )

MongodbClientFactory.setFreeConnectionByKey(shardClient.key, connectionsTime)
shards.close()
ranges.toSeq
}

ranges.toSeq
}.getOrElse(Seq((None, None)))

ranges
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
*/
package com.stratio.datasource.mongodb.reader

import java.util.regex.Pattern

import com.mongodb.QueryBuilder
import com.mongodb.casbah.Imports._
import com.mongodb.casbah.MongoCursorBase
import com.stratio.datasource.mongodb.query.FilterSection
import com.stratio.datasource.mongodb.client.MongodbClientFactory
import com.stratio.datasource.mongodb.config.{MongodbSSLOptions, MongodbCredentials, MongodbConfig}
import com.stratio.datasource.mongodb.config.{MongodbConfig, MongodbCredentials, MongodbSSLOptions}
import com.stratio.datasource.mongodb.partitioner.MongodbPartition
import com.stratio.datasource.mongodb.query.FilterSection
import com.stratio.datasource.util.Config
import org.apache.spark.Partition

Expand Down Expand Up @@ -58,10 +55,8 @@ class MongodbReader(config: Config,

mongoClient.fold(ifEmpty = ()) { client =>
mongoClientKey.fold({
MongodbClientFactory.setFreeConnectionByClient(client, connectionsTime)
MongodbClientFactory.closeByClient(client)
}) {key =>
MongodbClientFactory.setFreeConnectionByKey(key, connectionsTime)
MongodbClientFactory.closeByKey(key)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.stratio.datasource.mongodb.util

import com.mongodb.casbah.MongoClient
import com.stratio.datasource.mongodb.client.MongodbClientFactory

import scala.util.Try

object usingMongoClient {

def apply[A](mongoClient: MongoClient)(code: MongoClient => A): A =
try {
code(mongoClient)
} finally {
Try(MongodbClientFactory.closeByClient(mongoClient))
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ import com.stratio.datasource.util.Config
*
* @param config Configuration parameters (host,database,collection,...)
*/
class MongodbBatchWriter(config: Config) extends MongodbWriter(config) {
private[mongodb] class MongodbBatchWriter(config: Config) extends MongodbWriter(config) {

private val IdKey = "_id"

private val bulkBatchSize = config.getOrElse[Int](MongodbConfig.BulkBatchSize, MongodbConfig.DefaultBulkBatchSize)

private val pkConfig: Option[Array[String]] = config.get[Array[String]](MongodbConfig.UpdateFields)

override def save(it: Iterator[DBObject]): Unit = {
override def save(it: Iterator[DBObject], mongoClient: MongoClient): Unit = {
it.grouped(bulkBatchSize).foreach { group =>
val bulkOperation = dbCollection.initializeUnorderedBulkOperation
val bulkOperation = dbCollection(mongoClient).initializeUnorderedBulkOperation
group.foreach { element =>
val query = getUpdateQuery(element)
if (query.isEmpty) bulkOperation.insert(element)
Expand Down
Loading

0 comments on commit 064013d

Please sign in to comment.