diff --git a/doc/dataframe.md b/doc/dataframe.md index 6287f08f..bd4e4778 100644 --- a/doc/dataframe.md +++ b/doc/dataframe.md @@ -93,6 +93,40 @@ The keys in Redis: 2) "person:Peter" ``` +The keys will not be persisted in Redis hashes + +```bash +127.0.0.1:6379> hgetall person:John +1) "age" +2) "30" +``` + +In order to load the keys back, you also need to specify +the key column parameter while reading + +```scala +val df = spark.read + .format("org.apache.spark.sql.redis") + .option("table", "person") + .option("key.column", "name") + .load() +``` + +Otherwise, a field with name `_id` of type `String` will be populated + +```bash +root + |-- _id: string (nullable = true) + |-- age: integer (nullable = false) + ++-----+---+ +| _id|age| ++-----+---+ +| John| 30| +|Peter| 45| ++-----+---+ +``` + ### Save Modes Spark-redis supports all DataFrame [SaveMode](https://spark.apache.org/docs/latest/sql-programming-guide.html#save-modes)'s: `Append`, @@ -213,7 +247,7 @@ root +-----+---+ | John| 30| |Peter| 45| -+-----+---+ ++-----+---+ ``` To read with a Spark SQL: @@ -262,8 +296,42 @@ The output is: root |-- name: string (nullable = true) |-- age: string (nullable = true) + |-- _id: string (nullable = true) ``` +Note: If your schema has a field named `_id` or it was inferred. The +Redis key will be stored in that field. Spark Redis will also try to +extract the key based on your pattern. (you can also change the name +of key column, please refer to [Specifying Redis key](#specifying-redis-key)) +- if the pattern ends with `*` and it's the only wildcard, all the +trailing value will be extracted, e.g. + ```scala + df.show() + ``` + ```bash + +-----+---+-----+ + | name|age| _id| + +-----+---+-----+ + | John| 30| John| + |Peter| 45|Peter| + +-----+---+-----+ + ``` +- otherwise, all Redis key will be kept as is, e.g. + ```scala + val df = // code ommitted... + .option("keys.pattern", "p*:*") + .load() + df.show() + ``` + ```bash + +-----+---+------------+ + | name|age| _id| + +-----+---+------------+ + | John| 30| person:John| + |Peter| 45|person:Peter| + +-----+---+------------+ + ``` + ## DataFrame options | Name | Description | Type | Default | @@ -279,4 +347,5 @@ root ## Known limitations - - Nested DataFrame fields are not currently supported with Hash model. Consider making DataFrame schema flat or using Binary persistence model. \ No newline at end of file + - Nested DataFrame fields are not currently supported with Hash model. Consider making DataFrame schema flat or using Binary persistence model. + - Key column deserialization relies on pattern prefix, e.g. keysPattern:*, tableName:$key diff --git a/src/test/scala/com/redislabs/provider/redis/util/ConnectionUtils.scala b/src/main/scala/com/redislabs/provider/redis/util/ConnectionUtils.scala similarity index 55% rename from src/test/scala/com/redislabs/provider/redis/util/ConnectionUtils.scala rename to src/main/scala/com/redislabs/provider/redis/util/ConnectionUtils.scala index 3facb657..eebdf679 100644 --- a/src/test/scala/com/redislabs/provider/redis/util/ConnectionUtils.scala +++ b/src/main/scala/com/redislabs/provider/redis/util/ConnectionUtils.scala @@ -1,6 +1,5 @@ package com.redislabs.provider.redis.util -import com.redislabs.provider.redis.RedisEndpoint import redis.clients.jedis.Jedis /** @@ -8,8 +7,7 @@ import redis.clients.jedis.Jedis */ object ConnectionUtils { - def withConnection[A](endpoint: RedisEndpoint)(body: Jedis => A): A = { - val conn = endpoint.connect() + def withConnection[A](conn: Jedis)(body: Jedis => A): A = { val res = body(conn) conn.close() res diff --git a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala index ccd7cafc..c9b0a981 100644 --- a/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/BinaryRedisPersistence.scala @@ -25,13 +25,14 @@ class BinaryRedisPersistence extends RedisPersistence[Array[Byte]] { override def load(pipeline: Pipeline, key: String, requiredColumns: Seq[String]): Unit = pipeline.get(key.getBytes(UTF_8)) - override def encodeRow(value: Row): Array[Byte] = { + override def encodeRow(keyName: String, value: Row): Array[Byte] = { val fields = value.schema.fields.map(_.name) val valuesArray = fields.map(f => value.getAs[Any](f)) SerializationUtils.serialize(valuesArray) } - override def decodeRow(value: Array[Byte], schema: => StructType, inferSchema: Boolean): Row = { + override def decodeRow(keyMap: (String, String), value: Array[Byte], schema: StructType, + requiredColumns: Seq[String]): Row = { val valuesArray: Array[Any] = SerializationUtils.deserialize(value) new GenericRowWithSchema(valuesArray, schema) } diff --git a/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala index badb2fe5..d76a5f13 100644 --- a/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/HashRedisPersistence.scala @@ -1,6 +1,7 @@ package org.apache.spark.sql.redis import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float => JFloat, Long => JLong, Short => JShort} +import java.util.{List => JList} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema @@ -12,10 +13,10 @@ import scala.collection.JavaConverters._ /** * @author The Viet Nguyen */ -class HashRedisPersistence extends RedisPersistence[Map[String, String]] { +class HashRedisPersistence extends RedisPersistence[Any] { - override def save(pipeline: Pipeline, key: String, value: Map[String, String], ttl: Int): Unit = { - val javaValue = value.asJava + override def save(pipeline: Pipeline, key: String, value: Any, ttl: Int): Unit = { + val javaValue = value.asInstanceOf[Map[String, String]].asJava pipeline.hmset(key, javaValue) if (ttl > 0) { pipeline.expire(key, ttl) @@ -23,36 +24,33 @@ class HashRedisPersistence extends RedisPersistence[Map[String, String]] { } override def load(pipeline: Pipeline, key: String, requiredColumns: Seq[String]): Unit = { - if (requiredColumns.isEmpty) { - pipeline.hgetAll(key) - } else { - pipeline.hmget(key, requiredColumns: _*) - } + pipeline.hmget(key, requiredColumns: _*) } - override def encodeRow(value: Row): Map[String, String] = { + override def encodeRow(keyName: String, value: Row): Map[String, String] = { val fields = value.schema.fields.map(_.name) val kvMap = value.getValuesMap[Any](fields) kvMap - .filter { case (k, v) => + .filter { case (_, v) => // don't store null values v != null } + .filter { case (k, _) => + // don't store key values + k != keyName + } .map { case (k, v) => k -> String.valueOf(v) } } - override def decodeRow(value: Map[String, String], schema: => StructType, - inferSchema: Boolean): Row = { - val actualSchema = if (!inferSchema) schema else { - val fields = value.keys - .map(StructField(_, StringType)) - .toArray - StructType(fields) - } - val fieldsValue = parseFields(value, actualSchema) - new GenericRowWithSchema(fieldsValue, actualSchema) + override def decodeRow(keyMap: (String, String), value: Any, schema: StructType, + requiredColumns: Seq[String]): Row = { + val scalaValue = value.asInstanceOf[JList[String]].asScala + val values = requiredColumns.zip(scalaValue) + val results = values :+ keyMap + val fieldsValue = parseFields(results.toMap, schema) + new GenericRowWithSchema(fieldsValue, schema) } private def parseFields(value: Map[String, String], schema: StructType): Array[Any] = diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala b/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala index ec9efaef..d69eef66 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisPersistence.scala @@ -13,9 +13,26 @@ trait RedisPersistence[T] extends Serializable { def load(pipeline: Pipeline, key: String, requiredColumns: Seq[String]): Unit - def encodeRow(value: Row): T - - def decodeRow(value: T, schema: => StructType, inferSchema: Boolean): Row + /** + * Encode dataframe row before storing it in Redis. + * + * @param keyName field name that should be encoded in special way, e.g. in Redis keys. + * @param value row to encode. + * @return encoded row + */ + def encodeRow(keyName: String, value: Row): T + + /** + * Decode dataframe row stored in Redis. + * + * @param keyMap extracted name/value of key column from Redis key + * @param value encoded row + * @param schema row schema + * @param requiredColumns required columns to decode + * @return decoded row + */ + def decodeRow(keyMap: (String, String), value: T, schema: StructType, + requiredColumns: Seq[String]): Row } object RedisPersistence { diff --git a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala index d4487002..353dd339 100644 --- a/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala +++ b/src/main/scala/org/apache/spark/sql/redis/RedisSourceRelation.scala @@ -1,8 +1,9 @@ package org.apache.spark.sql.redis -import java.util.{UUID, List => JList, Map => JMap} +import java.util.UUID import com.redislabs.provider.redis.rdd.Keys +import com.redislabs.provider.redis.util.ConnectionUtils.withConnection import com.redislabs.provider.redis.util.Logging import com.redislabs.provider.redis.util.PipelineUtils._ import com.redislabs.provider.redis.{ReadWriteConfig, RedisConfig, RedisEndpoint, RedisNode, toRedisContext} @@ -11,11 +12,12 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.redis.RedisSourceRelation._ import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SQLContext} import redis.clients.jedis.{PipelineBase, Protocol} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ class RedisSourceRelation(override val sqlContext: SQLContext, parameters: Map[String, String], @@ -54,12 +56,17 @@ class RedisSourceRelation(override val sqlContext: SQLContext, logInfo(s"Redis config initial host: ${redisConfig.initialHost}") @transient private val sc = sqlContext.sparkContext + + /** + * Will be filled while saving data to Redis or reading from Redis. + */ @volatile private var currentSchema: StructType = _ /** parameters **/ private val tableNameOpt: Option[String] = parameters.get(SqlOptionTableName) private val keysPatternOpt: Option[String] = parameters.get(SqlOptionKeysPattern) private val keyColumn = parameters.get(SqlOptionKeyColumn) + private val keyName = keyColumn.getOrElse("_id") private val numPartitions = parameters.get(SqlOptionNumPartitions).map(_.toInt) .getOrElse(SqlOptionNumPartitionsDefault) private val inferSchemaEnabled = parameters.get(SqlOptionInferSchema).exists(_.toBoolean) @@ -67,6 +74,27 @@ class RedisSourceRelation(override val sqlContext: SQLContext, private val persistence = RedisPersistence(persistenceModel) private val ttl = parameters.get(SqlOptionTTL).map(_.toInt).getOrElse(0) + /** + * redis key pattern for rows, based either on the 'keys.pattern' or 'table' parameter + */ + private val dataKeyPattern = keysPatternOpt + .orElse(tableNameOpt.map(tableName => tableDataKeyPattern(tableName))) + .getOrElse { + val msg = s"Neither '$SqlOptionKeysPattern' or '$SqlOptionTableName' option is set." + throw new IllegalArgumentException(msg) + } + + /** + * Support key column extraction from Redis prefix pattern. Otherwise, + * return Redis key unmodified. + */ + private val keysPrefixPattern = + if (dataKeyPattern.endsWith("*") && dataKeyPattern.count(_ == '*') == 1) { + dataKeyPattern + } else { + "" + } + // check specified parameters if (tableNameOpt.isDefined && keysPatternOpt.isDefined) { throw new IllegalArgumentException(s"Both options '$SqlOptionTableName' and '$SqlOptionTableName' are set. " + @@ -75,14 +103,9 @@ class RedisSourceRelation(override val sqlContext: SQLContext, override def schema: StructType = { if (currentSchema == null) { - currentSchema = userSpecifiedSchema - .getOrElse { - if (inferSchemaEnabled) { - inferSchema() - } else { - loadSchema() - } - } + currentSchema = userSpecifiedSchema.getOrElse { + if (inferSchemaEnabled) inferSchema() else loadSchema() + } } currentSchema } @@ -93,7 +116,7 @@ class RedisSourceRelation(override val sqlContext: SQLContext, currentSchema = saveSchema(schema) if (overwrite) { // truncate the table - sc.fromRedisKeyPattern(dataKeyPattern()).foreachPartition { partition => + sc.fromRedisKeyPattern(dataKeyPattern).foreachPartition { partition => groupKeysByNode(redisConfig.hosts, partition).foreach { case (node, keys) => val conn = node.connect() foreachWithPipeline(conn, keys) { (pipeline, key) => @@ -111,7 +134,7 @@ class RedisSourceRelation(override val sqlContext: SQLContext, val conn = node.connect() foreachWithPipeline(conn, keys) { (pipeline, key) => val row = rowsWithKey(key) - val encodedRow = persistence.encodeRow(row) + val encodedRow = persistence.encodeRow(keyName, row) persistence.save(pipeline, key, encodedRow, ttl) } conn.close() @@ -121,16 +144,24 @@ class RedisSourceRelation(override val sqlContext: SQLContext, override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { logInfo("build scan") - val keysRdd = sc.fromRedisKeyPattern(dataKeyPattern(), partitionNum = numPartitions) + val keysRdd = sc.fromRedisKeyPattern(dataKeyPattern, partitionNum = numPartitions) if (requiredColumns.isEmpty) { keysRdd.map { _ => new GenericRow(Array[Any]()) } } else { + val filteredSchema = { + val requiredColumnsSet = Set(requiredColumns: _*) + val filteredFields = schema.fields + .filter { f => + requiredColumnsSet.contains(f.name) + } + StructType(filteredFields) + } keysRdd.mapPartitions { partition => groupKeysByNode(redisConfig.hosts, partition) .flatMap { case (node, keys) => - scanRows(node, keys, requiredColumns) + scanRows(node, keys, filteredSchema, requiredColumns) } .iterator } @@ -144,7 +175,7 @@ class RedisSourceRelation(override val sqlContext: SQLContext, * @return true if data exists in redis */ def isEmpty: Boolean = { - sc.fromRedisKeyPattern(dataKeyPattern()).isEmpty() + sc.fromRedisKeyPattern(dataKeyPattern).isEmpty() } /** @@ -169,36 +200,25 @@ class RedisSourceRelation(override val sqlContext: SQLContext, dataKey(tableName(), id) } - /** - * redis key pattern for rows, based either on the 'keys.pattern' or 'table' parameter - */ - private def dataKeyPattern(): String = { - keysPatternOpt - .orElse( - tableNameOpt.map(tableName => tableDataKeyPattern(tableName)) - ) - .getOrElse(throw new IllegalArgumentException(s"Neither '$SqlOptionKeysPattern' or '$SqlOptionTableName' option is set.")) - } - /** * infer schema from a random redis row */ private def inferSchema(): StructType = { - val keys = sc.fromRedisKeyPattern(dataKeyPattern()) + if (persistenceModel != SqlOptionModelHash) { + throw new IllegalArgumentException(s"Cannot infer schema from model '$persistenceModel'. " + + s"Currently, only '$SqlOptionModelHash' is supported") + } + val keys = sc.fromRedisKeyPattern(dataKeyPattern) if (keys.isEmpty()) { throw new IllegalStateException("No key is available") } else { val firstKey = keys.first() val node = getMasterNode(redisConfig.hosts, firstKey) - scanRows(node, Seq(firstKey), Seq()) - .collectFirst { - case r: Row => - logDebug(s"Row for schema inference: $r") - r.schema - } - .getOrElse { - throw new IllegalStateException("No row is available") - } + withConnection(node.connect()) { conn => + val results = conn.hgetAll(firstKey).asScala.toSeq :+ keyName -> firstKey + val fields = results.map(kv => StructField(kv._1, StringType)).toArray + StructType(fields) + } } } @@ -237,42 +257,18 @@ class RedisSourceRelation(override val sqlContext: SQLContext, /** * read rows from redis */ - private def scanRows(node: RedisNode, keys: Seq[String], requiredColumns: Seq[String]): Seq[Row] = { - def filteredSchema(): StructType = { - val requiredColumnsSet = Set(requiredColumns: _*) - val filteredFields = schema.fields - .filter { f => - requiredColumnsSet.contains(f.name) - } - StructType(filteredFields) - } - - val conn = node.connect() - - val pipelineValues = mapWithPipeline(conn, keys) { (pipeline, key) => - persistence.load(pipeline, key, requiredColumns) - } - - val rows = - if (requiredColumns.isEmpty || persistenceModel == SqlOptionModelBinary) { - pipelineValues - .map { - case jmap: JMap[_, _] => jmap.toMap - case value: Any => value - } - .map { value => - persistence.decodeRow(value, schema, inferSchemaEnabled) - } - } else { - pipelineValues.map { case values: JList[_] => - val value = requiredColumns.zip(values.asInstanceOf[JList[String]]).toMap - persistence.decodeRow(value, filteredSchema(), inferSchemaEnabled) - } + private def scanRows(node: RedisNode, keys: Seq[String], schema: StructType, + requiredColumns: Seq[String]): Seq[Row] = { + withConnection(node.connect()) { conn => + val pipelineValues = mapWithPipeline(conn, keys) { (pipeline, key) => + persistence.load(pipeline, key, requiredColumns) } - conn.close() - rows + keys.zip(pipelineValues).map { case (key, value) => + val keyMap = keyName -> tableKey(keysPrefixPattern, key) + persistence.decodeRow(keyMap, value, schema, requiredColumns) + } + } } - } object RedisSourceRelation { @@ -286,4 +282,13 @@ object RedisSourceRelation { def uuid(): String = UUID.randomUUID().toString.replace("-", "") def tableDataKeyPattern(tableName: String): String = s"$tableName:*" + + def tableKey(keysPrefixPattern: String, redisKey: String): String = { + if (keysPrefixPattern.endsWith("*")) { + // keysPattern* + redisKey.substring(keysPrefixPattern.length - 1) + } else { + redisKey + } + } } diff --git a/src/test/scala/com/redislabs/provider/redis/df/CsvDataframeSuite.scala b/src/test/scala/com/redislabs/provider/redis/df/CsvDataframeSuite.scala index 33d75e5f..7e6a09f7 100644 --- a/src/test/scala/com/redislabs/provider/redis/df/CsvDataframeSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/df/CsvDataframeSuite.scala @@ -25,6 +25,7 @@ trait CsvDataframeSuite extends RedisDataframeSuite with Matchers { val loadedDf = spark.read.format(RedisFormat) .option(SqlOptionTableName, tableName) + .option(SqlOptionKeyColumn, "id") .load() .cache() diff --git a/src/test/scala/com/redislabs/provider/redis/df/DataframeSuite.scala b/src/test/scala/com/redislabs/provider/redis/df/DataframeSuite.scala index 3b718dcc..e5ee70b9 100644 --- a/src/test/scala/com/redislabs/provider/redis/df/DataframeSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/df/DataframeSuite.scala @@ -2,10 +2,10 @@ package com.redislabs.provider.redis.df import com.redislabs.provider.redis.util.Person import com.redislabs.provider.redis.util.Person._ +import com.redislabs.provider.redis.util.TestUtils._ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.redis._ import org.scalatest.Matchers -import com.redislabs.provider.redis.util.TestUtils._ trait DataframeSuite extends RedisDataframeSuite with Matchers { @@ -176,10 +176,11 @@ trait DataframeSuite extends RedisDataframeSuite with Matchers { val df = spark.createDataFrame(data) df.write.format(RedisFormat) .option(SqlOptionTableName, tableName) - .option(SqlOptionKeyColumn, "name") + .option(SqlOptionKeyColumn, KeyName) .save() val loadedDf = spark.read.format(RedisFormat) .option(SqlOptionTableName, tableName) + .option(SqlOptionKeyColumn, KeyName) .load() .cache() verifyDf(loadedDf) @@ -188,19 +189,20 @@ trait DataframeSuite extends RedisDataframeSuite with Matchers { test("user defined key column append") { val tableName = generateTableName(TableNamePrefix) spark.createDataFrame(data).write.format(RedisFormat) - .option(SqlOptionKeyColumn, "name") .option(SqlOptionTableName, tableName) + .option(SqlOptionKeyColumn, KeyName) .save() val head = data.head val appendData = Seq(head.copy(name = "Jack"), head.copy(age = 31)) val df = spark.createDataFrame(appendData) df.write.format(RedisFormat) .mode(SaveMode.Append) - .option(SqlOptionKeyColumn, "name") .option(SqlOptionTableName, tableName) + .option(SqlOptionKeyColumn, KeyName) .save() val loadedDf = spark.read.format(RedisFormat) .option(SqlOptionTableName, tableName) + .option(SqlOptionKeyColumn, KeyName) .load() .cache() loadedDf.show() @@ -221,10 +223,11 @@ trait DataframeSuite extends RedisDataframeSuite with Matchers { .write.format(RedisFormat) .mode(SaveMode.Overwrite) .option(SqlOptionTableName, tableName) - .option(SqlOptionKeyColumn, "name") + .option(SqlOptionKeyColumn, KeyName) .save() val loadedDf = spark.read.format(RedisFormat) .option(SqlOptionTableName, tableName) + .option(SqlOptionKeyColumn, KeyName) .load() .cache() verifyDf(loadedDf, overrideData) diff --git a/src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala b/src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala index 014d92b6..cdd85033 100644 --- a/src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/df/HashDataframeSuite.scala @@ -2,13 +2,14 @@ package com.redislabs.provider.redis.df import java.sql.{Date, Timestamp} -import com.redislabs.provider.redis.util.Person import com.redislabs.provider.redis.util.Person.{data, _} +import com.redislabs.provider.redis.util.TestUtils._ +import com.redislabs.provider.redis.util.{EntityId, Person} import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.redis.RedisSourceRelation.tableDataKeyPattern import org.apache.spark.sql.redis._ import org.apache.spark.sql.types._ import org.scalatest.Matchers -import com.redislabs.provider.redis.util.TestUtils._ /** * @author The Viet Nguyen @@ -198,6 +199,59 @@ trait HashDataframeSuite extends RedisDataframeSuite with Matchers { row.getAs[java.sql.Timestamp]("_10") should be(Timestamp.valueOf("2017-12-02 03:04:00")) } + test("read key column from Redis keys") { + val tableName = generateTableName("person") + saveHash(tableName, "John", + Map("age" -> "30", "address" -> "60 Wall Street", "salary" -> "150.5")) + val loadedPersons = spark.read.format(RedisFormat) + .option(SqlOptionTableName, tableName) + .option(SqlOptionKeyColumn, "name") + .schema(Person.schema) + .load() + .as[Person] + .collect() + loadedPersons should contain(Person.data.head) + } + + test("read key column from Redis keys with prefix pattern") { + val tableName = generateTableName("person") + saveHash(tableName, "John", + Map("age" -> "30", "address" -> "60 Wall Street", "salary" -> "150.5")) + val loadedPersons = spark.read.format(RedisFormat) + .option(SqlOptionKeysPattern, tableDataKeyPattern(tableName)) + .option(SqlOptionKeyColumn, "name") + .schema(Person.schema) + .load() + .as[Person] + .collect() + loadedPersons should contain(Person.data.head) + } + + test("read key column from Redis keys (when _id field does not exist)") { + val tableName = generateTableName("person") + saveHash(tableName, "John", + Map("name" -> "John", "age" -> "30", "address" -> "60 Wall Street", "salary" -> "150.5")) + val loadedPersons = spark.read.format(RedisFormat) + .option(SqlOptionTableName, tableName) + .schema(Person.schema) + .load() + .as[Person] + .collect() + loadedPersons should contain(Person.data.head) + } + + test("read default key column from Redis keys") { + val tableName = generateTableName("entityId") + saveHash(tableName, "id", Map("name" -> "name")) + val loadedPersons = spark.read.format(RedisFormat) + .option(SqlOptionTableName, tableName) + .schema(EntityId.schema) + .load() + .as[EntityId] + .collect() + loadedPersons should contain(EntityId("id", "name")) + } + def saveMap(tableName: String): Unit = { val data = Seq( Map("name" -> "John", "age" -> "30", "address" -> "60 Wall Street", "salary" -> "150.5"), diff --git a/src/test/scala/com/redislabs/provider/redis/df/cluster/HashDataframeClusterSuite.scala b/src/test/scala/com/redislabs/provider/redis/df/cluster/HashDataframeClusterSuite.scala index a8717890..821bae2e 100644 --- a/src/test/scala/com/redislabs/provider/redis/df/cluster/HashDataframeClusterSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/df/cluster/HashDataframeClusterSuite.scala @@ -15,7 +15,7 @@ class HashDataframeClusterSuite extends HashDataframeSuite with RedisClusterEnv val host = redisConfig.initialHost val hostAndPort = new HostAndPort(host.host, host.port) val conn = new JedisCluster(hostAndPort) - conn.hmset(tableName + ":" + value("name"), value.asJava) + conn.hmset(tableName + ":" + key, value.asJava) conn.close() } } diff --git a/src/test/scala/com/redislabs/provider/redis/df/standalone/HashDataframeStandaloneSuite.scala b/src/test/scala/com/redislabs/provider/redis/df/standalone/HashDataframeStandaloneSuite.scala index 94ee6618..9da4f580 100644 --- a/src/test/scala/com/redislabs/provider/redis/df/standalone/HashDataframeStandaloneSuite.scala +++ b/src/test/scala/com/redislabs/provider/redis/df/standalone/HashDataframeStandaloneSuite.scala @@ -2,7 +2,7 @@ package com.redislabs.provider.redis.df.standalone import com.redislabs.provider.redis.df.HashDataframeSuite import com.redislabs.provider.redis.env.RedisStandaloneEnv -import com.redislabs.provider.redis.util.ConnectionUtils +import com.redislabs.provider.redis.util.ConnectionUtils.withConnection import scala.collection.JavaConverters._ @@ -12,8 +12,9 @@ import scala.collection.JavaConverters._ class HashDataframeStandaloneSuite extends HashDataframeSuite with RedisStandaloneEnv { override def saveHash(tableName: String, key: String, value: Map[String, String]): Unit = { - ConnectionUtils.withConnection(redisConfig.initialHost) { conn => - conn.hmset(tableName + ":" + value("name"), value.asJava) + val host = redisConfig.initialHost + withConnection(host.connect()) { conn => + conn.hmset(tableName + ":" + key, value.asJava) } } } diff --git a/src/test/scala/com/redislabs/provider/redis/util/EntityId.scala b/src/test/scala/com/redislabs/provider/redis/util/EntityId.scala new file mode 100644 index 00000000..a28a44a7 --- /dev/null +++ b/src/test/scala/com/redislabs/provider/redis/util/EntityId.scala @@ -0,0 +1,16 @@ +package com.redislabs.provider.redis.util + +import org.apache.spark.sql.types._ + +/** + * @author The Viet Nguyen + */ +case class EntityId(_id: String, name: String) + +object EntityId { + + val schema = StructType(Array( + StructField("_id", StringType), + StructField("name", StringType) + )) +} diff --git a/src/test/scala/com/redislabs/provider/redis/util/Person.scala b/src/test/scala/com/redislabs/provider/redis/util/Person.scala index 36ac0017..fa15635f 100644 --- a/src/test/scala/com/redislabs/provider/redis/util/Person.scala +++ b/src/test/scala/com/redislabs/provider/redis/util/Person.scala @@ -1,6 +1,7 @@ package com.redislabs.provider.redis.util import com.redislabs.provider.redis.util.TestUtils._ +import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, SparkSession} /** @@ -11,12 +12,20 @@ case class Person(name: String, age: Int, address: String, salary: Double) object Person { val TableNamePrefix = "person" + val KeyName = "name" val data = Seq( Person("John", 30, "60 Wall Street", 150.5), Person("Peter", 35, "110 Wall Street", 200.3) ) + val schema = StructType(Array( + StructField("name", StringType), + StructField("age", IntegerType), + StructField("address", StringType), + StructField("salary", DoubleType) + )) + def df(spark: SparkSession): DataFrame = spark.createDataFrame(data) def generatePersonTableName(): String = generateTableName(TableNamePrefix) diff --git a/src/test/scala/org/apache/spark/sql/redis/RedisSourceRelationTest.scala b/src/test/scala/org/apache/spark/sql/redis/RedisSourceRelationTest.scala new file mode 100644 index 00000000..c9223a1f --- /dev/null +++ b/src/test/scala/org/apache/spark/sql/redis/RedisSourceRelationTest.scala @@ -0,0 +1,19 @@ +package org.apache.spark.sql.redis + +import org.scalatest.{FunSuite, Matchers} + +/** + * @author The Viet Nguyen + */ +class RedisSourceRelationTest extends FunSuite with Matchers { + + test("redis key extractor with prefix pattern") { + val key = RedisSourceRelation.tableKey("table*", "tablekey") + key shouldBe "key" + } + + test("redis key extractor with other patterns") { + val key = RedisSourceRelation.tableKey("*table", "key") + key shouldBe "key" + } +}