diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 43ca359b5173..fb41c54c3430 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -83,7 +83,8 @@ private[sql] object CatalystConverter { protected[parquet] def createConverter( field: FieldType, fieldIndex: Int, - parent: CatalystConverter): Converter = { + parent: CatalystConverter, + fromProtobuf:Boolean = false): Converter = { val fieldType: DataType = field.dataType fieldType match { case udt: UserDefinedType[_] => { @@ -91,11 +92,26 @@ private[sql] object CatalystConverter { } // For native JVM types we use a converter with native arrays case ArrayType(elementType: NativeType, false) => { - new CatalystNativeArrayConverter(elementType, fieldIndex, parent) + if(fromProtobuf) { + new CatalystProtobufNativeArrayConverter(field.name,elementType,fieldIndex,parent) + } else { + new CatalystNativeArrayConverter(elementType, fieldIndex, parent) + } } // This is for other types of arrays, including those with nested fields case ArrayType(elementType: DataType, false) => { - new CatalystArrayConverter(elementType, fieldIndex, parent) + if(fromProtobuf){ + elementType match { + case StructType(fields: Array[StructField]) => { + new CatalystProtobufStructArrayConverter(fields, fieldIndex, parent) + } + case _ => throw new RuntimeException( + s"unable to convert datatype ${field.dataType.toString} in CatalystConverter") + + } + } else { + new CatalystArrayConverter(elementType, fieldIndex, parent) + } } case ArrayType(elementType: DataType, true) => { new CatalystArrayContainsNullConverter(elementType, fieldIndex, parent) @@ -156,12 +172,13 @@ private[sql] object CatalystConverter { protected[parquet] def createRootConverter( parquetSchema: MessageType, - attributes: Seq[Attribute]): CatalystConverter = { + attributes: Seq[Attribute], + fromProtobuf:Boolean=false): CatalystConverter = { // For non-nested types we use the optimized Row converter if (attributes.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))) { new CatalystPrimitiveRowConverter(attributes.toArray) } else { - new CatalystGroupConverter(attributes.toArray) + new CatalystGroupConverter(attributes.toArray, fromProtobuf) } } } @@ -279,27 +296,29 @@ private[parquet] class CatalystGroupConverter( protected[parquet] val index: Int, protected[parquet] val parent: CatalystConverter, protected[parquet] var current: ArrayBuffer[Any], - protected[parquet] var buffer: ArrayBuffer[Row]) + protected[parquet] var buffer: ArrayBuffer[Row], + protected[parquet] val fromProtobuf: Boolean = false) extends CatalystConverter { - def this(schema: Array[FieldType], index: Int, parent: CatalystConverter) = + def this(schema: Array[FieldType], index: Int, parent: CatalystConverter, fromProtobuf:Boolean=false) = this( schema, index, parent, current = null, buffer = new ArrayBuffer[Row]( - CatalystArrayConverter.INITIAL_ARRAY_SIZE)) + CatalystArrayConverter.INITIAL_ARRAY_SIZE), + fromProtobuf) /** * This constructor is used for the root converter only! */ - def this(attributes: Array[Attribute]) = - this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null) + def this(attributes: Array[Attribute],fromProtobuf:Boolean = false) = + this(attributes.map(a => new FieldType(a.name, a.dataType, a.nullable)), 0, null, fromProtobuf) protected [parquet] val converters: Array[Converter] = schema.zipWithIndex.map { - case (field, idx) => CatalystConverter.createConverter(field, idx, this) + case (field, idx) => CatalystConverter.createConverter(field, idx, this, fromProtobuf) }.toArray override val size = schema.size @@ -746,6 +765,8 @@ private[parquet] class CatalystNativeArrayConverter( } } + + /** * A `parquet.io.api.GroupConverter` that converts a single-element groups that * match the characteristics of an array contains null (see @@ -825,6 +846,82 @@ private[parquet] class CatalystArrayContainsNullConverter( } } +private[parquet] class CatalystProtobufNativeArrayConverter( + val name:String, + val elementType: NativeType, + val fieldIndex: Int, + val parent: CatalystConverter, + var capacity: Int = CatalystArrayConverter.INITIAL_ARRAY_SIZE) + extends PrimitiveConverter { + + type NativeType = elementType.JvmType + + private var buffer: Array[NativeType] = elementType.classTag.newArray(capacity) + + private var elements: Int = 0 + + + private def addValue(value: NativeType): Unit = { + checkGrowBuffer() + buffer(elements) = value + elements += 1 + parent.updateField( + fieldIndex, + buffer.slice(0, elements).toSeq) + } + + override def addBinary(value: Binary): Unit = addValue(value.getBytes.asInstanceOf[NativeType]) + + + override def addBoolean(value: Boolean): Unit = addValue(value.asInstanceOf[NativeType]) + + override def addDouble(value: Double): Unit = addValue(value.asInstanceOf[NativeType]) + + override def addFloat(value: Float): Unit = addValue(value.asInstanceOf[NativeType]) + + + override def addInt(value: Int): Unit = addValue(value.asInstanceOf[NativeType]) + + + override def addLong(value: Long): Unit = addValue(value.asInstanceOf[NativeType]) + + + private def checkGrowBuffer(): Unit = { + if (elements >= capacity) { + val newCapacity = 2 * capacity + val tmp: Array[NativeType] = elementType.classTag.newArray(newCapacity) + Array.copy(buffer, 0, tmp, 0, capacity) + buffer = tmp + capacity = newCapacity + } + } +} + +class CatalystProtobufStructArrayConverter(fields: Array[FieldType], myFieldIndex: Int, parent: CatalystConverter) + extends CatalystGroupConverter(fields, myFieldIndex, parent, fromProtobuf = true) { + val rowBuffer: ArrayBuffer[GenericRow] = new ArrayBuffer[GenericRow](CatalystArrayConverter.INITIAL_ARRAY_SIZE) + var currentRow:Array[Any] = new Array[Any](fields.length) + var elements = 0 + override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = { + currentRow(fieldIndex) = value + } + + override def end(): Unit = { + rowBuffer += new GenericRow(currentRow) + currentRow = new Array[Any](fields.length) + elements +=1 + parent.updateField(myFieldIndex, rowBuffer.slice(0, elements)) + } + + override def start(): Unit = { + super.start() + } + + override protected[parquet] def clearBuffer(): Unit = { + super.clearBuffer() + rowBuffer.clear() + } +} /** * This converter is for multi-element groups of primitive or complex types * that have repetition level optional or required (so struct fields). @@ -923,3 +1020,5 @@ private[parquet] class CatalystMapConverter( override protected[parquet] def updateField(fieldIndex: Int, value: Any): Unit = throw new UnsupportedOperationException } + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 5a1b15490d27..83e732dfc06d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.parquet +import java.util import java.util.{HashMap => JHashMap} import org.apache.hadoop.conf.Configuration @@ -39,8 +40,8 @@ import org.apache.spark.sql.types._ private[parquet] class RowRecordMaterializer(root: CatalystConverter) extends RecordMaterializer[Row] { - def this(parquetSchema: MessageType, attributes: Seq[Attribute]) = - this(CatalystConverter.createRootConverter(parquetSchema, attributes)) + def this(parquetSchema: MessageType, attributes: Seq[Attribute], fromProtobuf: Boolean = false) = + this(CatalystConverter.createRootConverter(parquetSchema, attributes, fromProtobuf)) override def getCurrentRecord: Row = root.getCurrentRecord @@ -87,7 +88,8 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { parquetSchema, false, true) } log.debug(s"list of attributes that will be read: $schema") - new RowRecordMaterializer(parquetSchema, schema) + val isProtobuf = "true".equals(readContext.getReadSupportMetadata.get(RowReadSupport.FROM_PROTOBUF)) + new RowRecordMaterializer(parquetSchema, schema, fromProtobuf = isProtobuf) } override def init( @@ -99,14 +101,18 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { val requestedAttributes = RowReadSupport.getRequestedSchema(configuration) if (requestedAttributes != null) { + val keySet: util.Set[String] = keyValueMetaData.keySet() // If the parquet file is thrift derived, there is a good chance that // it will have the thrift class in metadata. - val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") + val isThriftDerived = keySet.contains("thrift.class") + val isProto = keySet.contains("parquet.proto.class") + metadata.put(RowReadSupport.FROM_PROTOBUF, isProto.toString) parquetSchema = ParquetTypesConverter - .convertFromAttributes(requestedAttributes, isThriftDerived) + .convertFromAttributes(requestedAttributes, isThriftDerived,isProto) + val converter: String = ParquetTypesConverter.convertToString(requestedAttributes) metadata.put( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, - ParquetTypesConverter.convertToString(requestedAttributes)) + converter) } val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) @@ -121,6 +127,7 @@ private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { private[parquet] object RowReadSupport { val SPARK_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" val SPARK_METADATA_KEY = "org.apache.spark.sql.parquet.row.metadata" + val FROM_PROTOBUF = "org.apache.spark.sql.parquet.row.protobuf" private def getRequestedSchema(configuration: Configuration): Seq[Attribute] = { val schemaString = configuration.get(RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index da668f068613..d001a6147d49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -31,13 +31,14 @@ import parquet.hadoop.util.ContextUtil import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition -import parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import parquet.schema.{GroupType => ParquetGroupType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes, _} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ import org.apache.spark.{Logging, SparkException} // Implicits + import scala.collection.JavaConversions._ /** A class representing Parquet info fields we care about, for passing back to Parquet */ @@ -129,8 +130,21 @@ private[parquet] object ParquetTypesConverter extends Logging { groupType.getFields.apply(0).getRepetition == Repetition.REPEATED } - if (parquetType.isPrimitive) { + if (parquetType.isPrimitive && !parquetType.isRepetition(parquet.schema.Type.Repetition.REPEATED)) { toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) + } else if (parquetType.isRepetition(parquet.schema.Type.Repetition.REPEATED)) { + if (parquetType.isPrimitive) { + ArrayType(toPrimitiveDataType(parquetType.asPrimitiveType(), isBinaryAsString, isInt96AsTimestamp), containsNull = false) + } else { + val fields = parquetType.asGroupType() + .getFields + .map(ptype => new StructField( + ptype.getName, + toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), + ptype.getRepetition != Repetition.REQUIRED)) + StructType(fields) + ArrayType(StructType(fields), containsNull = false) + } } else { val groupType = parquetType.asGroupType() parquetType.getOriginalType match { @@ -285,11 +299,12 @@ private[parquet] object ParquetTypesConverter extends Logging { * @return The corresponding Parquet type. */ def fromDataType( - ctype: DataType, - name: String, - nullable: Boolean = true, - inArray: Boolean = false, - toThriftSchemaNames: Boolean = false): ParquetType = { + ctype: DataType, + name: String, + nullable: Boolean = true, + inArray: Boolean = false, + toThriftSchemaNames: Boolean = false, + isProtobufSchema: Boolean = false): ParquetType = { val repetition = if (inArray) { Repetition.REPEATED @@ -323,8 +338,27 @@ private[parquet] object ParquetTypesConverter extends Logging { arraySchemaName, nullable = false, inArray = true, - toThriftSchemaNames) - ConversionPatterns.listType(repetition, name, parquetElementType) + toThriftSchemaNames, + isProtobufSchema) + if (isProtobufSchema) { + if (parquetElementType.isPrimitive) { + new parquet.schema.PrimitiveType(parquet.schema.Type.Repetition.REPEATED, parquetElementType.asPrimitiveType().getPrimitiveTypeName, name) + } else { + elementType match { + case StructType(structFields) => { + val fields = structFields.map { + field => fromDataType(field.dataType, field.name, field.nullable, + inArray = false, toThriftSchemaNames, isProtobufSchema) + } + new ParquetGroupType(parquet.schema.Type.Repetition.REPEATED,name,fields.toSeq) + } + case _ => sys.error(s"Unsupported datatype $ctype") + } + + } + } else { + ConversionPatterns.listType(repetition, name, parquetElementType) + } } case ArrayType(elementType, true) => { val parquetElementType = fromDataType( @@ -332,19 +366,23 @@ private[parquet] object ParquetTypesConverter extends Logging { arraySchemaName, nullable = true, inArray = false, - toThriftSchemaNames) - ConversionPatterns.listType( - repetition, - name, - new ParquetGroupType( - Repetition.REPEATED, - CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, - parquetElementType)) + toThriftSchemaNames, isProtobufSchema) + if (isProtobufSchema) { + new parquet.schema.PrimitiveType(parquet.schema.Type.Repetition.REPEATED, parquetElementType.asPrimitiveType().getPrimitiveTypeName, name) + } else { + ConversionPatterns.listType( + repetition, + name, + new ParquetGroupType( + Repetition.REPEATED, + CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, + parquetElementType)) + } } case StructType(structFields) => { val fields = structFields.map { field => fromDataType(field.dataType, field.name, field.nullable, - inArray = false, toThriftSchemaNames) + inArray = false, toThriftSchemaNames, isProtobufSchema) } new ParquetGroupType(repetition, name, fields.toSeq) } @@ -355,14 +393,16 @@ private[parquet] object ParquetTypesConverter extends Logging { CatalystConverter.MAP_KEY_SCHEMA_NAME, nullable = false, inArray = false, - toThriftSchemaNames) + toThriftSchemaNames, + isProtobufSchema) val parquetValueType = fromDataType( valueType, CatalystConverter.MAP_VALUE_SCHEMA_NAME, nullable = valueContainsNull, inArray = false, - toThriftSchemaNames) + toThriftSchemaNames, + isProtobufSchema) ConversionPatterns.mapType( repetition, name, @@ -389,11 +429,11 @@ private[parquet] object ParquetTypesConverter extends Logging { } def convertFromAttributes(attributes: Seq[Attribute], - toThriftSchemaNames: Boolean = false): MessageType = { + toThriftSchemaNames: Boolean = false, isProtobufSchema: Boolean = false): MessageType = { val fields = attributes.map( attribute => fromDataType(attribute.dataType, attribute.name, attribute.nullable, - toThriftSchemaNames = toThriftSchemaNames)) + toThriftSchemaNames = toThriftSchemaNames, isProtobufSchema = isProtobufSchema)) new MessageType("root", fields) } diff --git a/sql/core/src/test/resources/nested-array-struct.parquet b/sql/core/src/test/resources/nested-array-struct.parquet new file mode 100755 index 000000000000..41a43fa35d39 Binary files /dev/null and b/sql/core/src/test/resources/nested-array-struct.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-int.parquet b/sql/core/src/test/resources/old-repeated-int.parquet new file mode 100755 index 000000000000..520922f73ebb Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-int.parquet differ diff --git a/sql/core/src/test/resources/old-repeated-message.parquet b/sql/core/src/test/resources/old-repeated-message.parquet new file mode 100755 index 000000000000..548db9916277 Binary files /dev/null and b/sql/core/src/test/resources/old-repeated-message.parquet differ diff --git a/sql/core/src/test/resources/proto-repeated-struct.parquet b/sql/core/src/test/resources/proto-repeated-struct.parquet new file mode 100755 index 000000000000..c29eee35c350 Binary files /dev/null and b/sql/core/src/test/resources/proto-repeated-struct.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array-many.parquet b/sql/core/src/test/resources/proto-struct-with-array-many.parquet new file mode 100755 index 000000000000..ff9809675fc0 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array-many.parquet differ diff --git a/sql/core/src/test/resources/proto-struct-with-array.parquet b/sql/core/src/test/resources/proto-struct-with-array.parquet new file mode 100755 index 000000000000..325a8370ad20 Binary files /dev/null and b/sql/core/src/test/resources/proto-struct-with-array.parquet differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ProtoParquetTypesConverterTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ProtoParquetTypesConverterTest.scala new file mode 100644 index 000000000000..efa5d4046118 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ProtoParquetTypesConverterTest.scala @@ -0,0 +1,170 @@ +package org.apache.spark.sql.parquet + +import java.net.URL + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{Row, QueryTest, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.FunSuite +import parquet.schema.{GroupType, PrimitiveType, MessageType} + +import scala.collection.mutable.ArrayBuffer + +class ProtoParquetTypesConverterTest extends QueryTest with ParquetTest { + override val sqlContext: SQLContext = TestSQLContext + + test("parquet-schema conversion retains repeated primitive type") { + val actualSchema: MessageType = new MessageType("root", new PrimitiveType(parquet.schema.Type.Repetition.REPEATED,PrimitiveType.PrimitiveTypeName.INT32,"repeated_field")) + val attributes: Seq[Attribute] = ParquetTypesConverter.convertToAttributes(actualSchema,isBinaryAsString = false,isInt96AsTimestamp = true) + val convertedSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes, isProtobufSchema = true) + assert(actualSchema === convertedSchema) + } + + test("parquet-schema conversion retains repeated group type") { + val actualSchema: MessageType = new MessageType("root", new GroupType(parquet.schema.Type.Repetition.REPEATED,"inner",new PrimitiveType(parquet.schema.Type.Repetition.OPTIONAL,PrimitiveType.PrimitiveTypeName.DOUBLE,"something"))) + val attributes: Seq[Attribute] = ParquetTypesConverter.convertToAttributes(actualSchema,isBinaryAsString = false,isInt96AsTimestamp = true) + val convertedSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes, isProtobufSchema = true) + assert(actualSchema === convertedSchema) + } + + test("paquet-schema conversion retains arrays nested within groups") { + val field1: PrimitiveType = new PrimitiveType(parquet.schema.Type.Repetition.OPTIONAL, PrimitiveType.PrimitiveTypeName.DOUBLE, "something") + val repeated: PrimitiveType = new PrimitiveType(parquet.schema.Type.Repetition.REPEATED, PrimitiveType.PrimitiveTypeName.INT96, "an_int") + val actualSchema: MessageType = new MessageType("root", + new GroupType(parquet.schema.Type.Repetition.REQUIRED,"my_struct",field1, repeated)) + val attributes: Seq[Attribute] = ParquetTypesConverter.convertToAttributes(actualSchema,isBinaryAsString = false,isInt96AsTimestamp = true) + val convertedSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes, isProtobufSchema = true) + assert(actualSchema === convertedSchema) + } + + test("paquet-schema conversion retains multiple nesting") { + + val actualSchema: MessageType = new MessageType("root", + new GroupType(parquet.schema.Type.Repetition.REPEATED,"outer", + new GroupType(parquet.schema.Type.Repetition.OPTIONAL, "inner", + new PrimitiveType(parquet.schema.Type.Repetition.OPTIONAL,PrimitiveType.PrimitiveTypeName.DOUBLE,"something"), + new GroupType(parquet.schema.Type.Repetition.REPEATED,"inner_inner", + new PrimitiveType(parquet.schema.Type.Repetition.OPTIONAL,PrimitiveType.PrimitiveTypeName.DOUBLE,"something_else"))))) + + val attributes: Seq[Attribute] = ParquetTypesConverter.convertToAttributes(actualSchema,isBinaryAsString = false,isInt96AsTimestamp = true) + val convertedSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes, isProtobufSchema = true) + assert(actualSchema === convertedSchema) + + } + + test("should work with repeated primitive") { + val resource: URL = getClass.getResource("/old-repeated-int.parquet") + val pf: DataFrame = sqlContext.parquetFile(resource.toURI.toString) + pf.registerTempTable("my_test_table") + val rows: Array[Row] = sqlContext.sql("select * from my_test_table").collect() + val ints: ArrayBuffer[Int] = rows(0)(0).asInstanceOf[ArrayBuffer[Int]] + assert(ints(0) === 1) + assert(ints(1) === 2) + assert(ints(2) === 3) + assert(ints.length === 3) + } + + test("should work with repeated complex") { + val resource: URL = getClass.getResource("/old-repeated-message.parquet") + val pf: DataFrame = sqlContext.parquetFile(resource.toURI.toString) + pf.registerTempTable("my_complex_table") + val rows: Array[Row] = sqlContext.sql("select * from my_complex_table").collect() + val array: ArrayBuffer[GenericRow] = rows(0)(0).asInstanceOf[ArrayBuffer[GenericRow]] + assert(array.length === 3) + } + + test("should work with repeated complex2") { + val resource: URL = getClass.getResource("/proto-repeated-struct.parquet") + val pf: DataFrame = sqlContext.parquetFile(resource.toURI.toString) + pf.registerTempTable("my_complex_table") + val rows: Array[Row] = sqlContext.sql("select * from my_complex_table").collect() + assert(rows.length === 1) + val array: ArrayBuffer[GenericRow] = rows(0)(0).asInstanceOf[ArrayBuffer[GenericRow]] + assert(array.length === 2) + assert(array(0)(0) === "0 - 1") + assert(array(0)(1) === "0 - 2") + assert(array(0)(2) === "0 - 3") + assert(array(1)(0) === "1 - 1") + assert(array(1)(1) === "1 - 2") + assert(array(1)(2) === "1 - 3") + } + + test("should work with repeated complex with many rows") { + val resource: URL = getClass.getResource("/proto-struct-with-array-many.parquet") + val pf: DataFrame = sqlContext.parquetFile(resource.toURI.toString) + pf.registerTempTable("my_complex_table") + val rows: Array[Row] = sqlContext.sql("select * from my_complex_table").collect() + assert(rows.length === 3) + val row0: ArrayBuffer[GenericRow] = rows(0)(0).asInstanceOf[ArrayBuffer[GenericRow]] + val row1: ArrayBuffer[GenericRow] = rows(1)(0).asInstanceOf[ArrayBuffer[GenericRow]] + val row2: ArrayBuffer[GenericRow] = rows(2)(0).asInstanceOf[ArrayBuffer[GenericRow]] + assert(row0(0)(0) === "0 - 0 - 1") + assert(row0(0)(1) === "0 - 0 - 2") + assert(row0(0)(2) === "0 - 0 - 3") + assert(row0(1)(0) === "0 - 1 - 1") + assert(row0(1)(1) === "0 - 1 - 2") + assert(row0(1)(2) === "0 - 1 - 3") + assert(row1(0)(0) === "1 - 0 - 1") + assert(row1(0)(1) === "1 - 0 - 2") + assert(row1(0)(2) === "1 - 0 - 3") + assert(row1(1)(0) === "1 - 1 - 1") + assert(row1(1)(1) === "1 - 1 - 2") + assert(row1(1)(2) === "1 - 1 - 3") + assert(row2(0)(0) === "2 - 0 - 1") + assert(row2(0)(1) === "2 - 0 - 2") + assert(row2(0)(2) === "2 - 0 - 3") + assert(row2(1)(0) === "2 - 1 - 1") + assert(row2(1)(1) === "2 - 1 - 2") + assert(row2(1)(2) === "2 - 1 - 3") + } + + test("should work with complex type containing array") { + val resource: URL = getClass.getResource("/proto-struct-with-array.parquet") + val pf: DataFrame = sqlContext.parquetFile(resource.toURI.toString) + pf.registerTempTable("my_complex_struct") + val rows: Array[Row] = sqlContext.sql("select * from my_complex_struct").collect() + assert(rows.length === 1) + val theRow: GenericRow = rows(0).asInstanceOf[GenericRow] + val optionalStruct = theRow(3).asInstanceOf[GenericRow] + val requiredStruct = theRow(4).asInstanceOf[GenericRow] + val arrayOfStruct = theRow(5).asInstanceOf[ArrayBuffer[GenericRow]] + assert(theRow.length === 6) + assert(theRow(0) === 10) + assert(theRow(1) === 9) + assert(theRow(2) == null) + assert(optionalStruct === null) + assert(requiredStruct(0) === 9) + assert(arrayOfStruct(0)(0) === 9) + assert(arrayOfStruct(1)(0) === 10) + } + + test("should work with mulitple levels of nesting") { + val resource: URL = getClass.getResource("/nested-array-struct.parquet") + val pf: DataFrame = sqlContext.parquetFile(resource.toURI.toString) + pf.registerTempTable("my_nested_struct") + val rows: Array[Row] = sqlContext.sql("select * from my_nested_struct").collect() + assert(rows.length === 3) + val row0: GenericRow = rows(0).asInstanceOf[GenericRow] + val row1: GenericRow = rows(1).asInstanceOf[GenericRow] + val row2: GenericRow = rows(2).asInstanceOf[GenericRow] + val nestedR0: ArrayBuffer[GenericRow] = row0(1).asInstanceOf[ArrayBuffer[GenericRow]] + val nestedR1: ArrayBuffer[GenericRow] = row1(1).asInstanceOf[ArrayBuffer[GenericRow]] + val nestedR2: ArrayBuffer[GenericRow] = row2(1).asInstanceOf[ArrayBuffer[GenericRow]] + val nestedR0Array: ArrayBuffer[GenericRow] = nestedR0(0)(1).asInstanceOf[ArrayBuffer[GenericRow]] + val nestedR1Array: ArrayBuffer[GenericRow] = nestedR1(0)(1).asInstanceOf[ArrayBuffer[GenericRow]] + val nestedR2Array: ArrayBuffer[GenericRow] = nestedR2(0)(1).asInstanceOf[ArrayBuffer[GenericRow]] + assert(row0(0) === 2) + assert(row1(0) === 5) + assert(row2(0) === 8) + assert(nestedR0.length ===1) + assert(nestedR1.length ===1) + assert(nestedR2.length ===1) + assert(nestedR0(0)(0) === 1) + assert(nestedR1(0)(0) === 4) + assert(nestedR2(0)(0) === 7) + assert(nestedR0Array(0)(0) === 3) + assert(nestedR1Array(0)(0) === 6) + assert(nestedR2Array(0)(0) === 9) + } +}