Skip to content

Commit 6145ae3

Browse files
author
vidmantas zemleris
committed
[SPARK-6994][SQL] Allow to fetch field values by name on Row
- add fieldIndex(name: String) - add getAs[T](fieldName: String) - add getValuesMap[T] returning a map of values for the requested fieldNames
1 parent 9564ebb commit 6145ae3

File tree

4 files changed

+115
-0
lines changed

4 files changed

+115
-0
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,38 @@ trait Row extends Serializable {
306306
*/
307307
def getAs[T](i: Int): T = apply(i).asInstanceOf[T]
308308

309+
/**
310+
* Returns the value of a given fieldName.
311+
*
312+
* @throws UnsupportedOperationException when schema is not defined.
313+
* @throws IllegalArgumentException when fieldName do not exist.
314+
* @throws ClassCastException when data type does not match.
315+
*/
316+
def getAs[T](fieldName: String): T = getAs[T](fieldIndex(fieldName))
317+
318+
/**
319+
* Returns the index of a given field name.
320+
*
321+
* @throws UnsupportedOperationException when schema is not defined.
322+
* @throws IllegalArgumentException when fieldName do not exist.
323+
*/
324+
def fieldIndex(name: String): Int = {
325+
throw new UnsupportedOperationException("fieldIndex on a Row without schema is undefined.")
326+
}
327+
328+
/**
329+
* Returns a Map(name -> value) for the requested fieldNames
330+
*
331+
* @throws UnsupportedOperationException when schema is not defined.
332+
* @throws IllegalArgumentException when fieldName do not exist.
333+
* @throws ClassCastException when data type does not match.
334+
*/
335+
def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = {
336+
fieldNames.map { name =>
337+
name -> getAs[T](name)
338+
}.toMap
339+
}
340+
309341
override def toString(): String = s"[${this.mkString(",")}]"
310342

311343
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
181181

182182
/** No-arg constructor for serialization. */
183183
protected def this() = this(null, null)
184+
185+
override def fieldIndex(name: String): Int = schema.fieldIndex(name)
184186
}
185187

186188
class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema}
21+
import org.apache.spark.sql.types._
22+
import org.scalatest.{Matchers, FunSpec}
23+
24+
class RowTest extends FunSpec with Matchers {
25+
26+
val schema = StructType(
27+
StructField("col1", StringType) ::
28+
StructField("col2", StringType) ::
29+
StructField("col3", IntegerType) :: Nil)
30+
val values = Array("value1", "value2", 1)
31+
32+
val sampleRow: Row = new GenericRowWithSchema(values, schema)
33+
val noSchemaRow: Row = new GenericRow(values)
34+
35+
describe("Row (without schema)") {
36+
it("throws an exception when accessing by fieldName") {
37+
intercept[UnsupportedOperationException] {
38+
noSchemaRow.fieldIndex("col1")
39+
}
40+
intercept[UnsupportedOperationException] {
41+
noSchemaRow.getAs("col1")
42+
}
43+
}
44+
}
45+
46+
describe("Row (with schema)") {
47+
it("fieldIndex(name) returns field index") {
48+
sampleRow.fieldIndex("col1") shouldBe 0
49+
sampleRow.fieldIndex("col3") shouldBe 2
50+
}
51+
52+
it("getAs[T] retrieves a value by fieldname") {
53+
sampleRow.getAs[String]("col1") shouldBe "value1"
54+
sampleRow.getAs[Int]("col3") shouldBe 1
55+
}
56+
57+
it("Accessing non existent field throws an exception") {
58+
intercept[IllegalArgumentException] {
59+
sampleRow.getAs[String]("non_existent")
60+
}
61+
}
62+
63+
it("getValuesMap() retrieves values of multiple fields as a Map(field -> value)") {
64+
val expected = Map(
65+
"col1" -> "value1",
66+
"col2" -> "value2"
67+
)
68+
sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
69+
}
70+
}
71+
}

sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,14 @@ class RowSuite extends FunSuite {
6262
val de = instance.deserialize(ser).asInstanceOf[Row]
6363
assert(de === row)
6464
}
65+
66+
test("get values by field name on Row created via .toDF") {
67+
val row = Seq((1, Seq(1))).toDF("a", "b").first()
68+
assert(row.getAs[Int]("a") === 1)
69+
assert(row.getAs[Seq[Int]]("b") === Seq(1))
70+
71+
intercept[IllegalArgumentException]{
72+
row.getAs[Int]("c")
73+
}
74+
}
6575
}

0 commit comments

Comments
 (0)