diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index aef4e4e7c642..489b69e2e533 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -285,9 +285,9 @@ message Deduplicate { // A relation that does not need to be qualified by name. message LocalRelation { - // (Optional) A list qualified attributes. - repeated Expression.QualifiedAttribute attributes = 1; - // TODO: support local data. + // Local collection data serialized into Arrow IPC streaming format which contains + // the schema of the data. + bytes data = 1; } // Relation of type [[Sample]] that samples a fraction of the dataset. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 96d0dbe35803..9c4299d652fd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import com.google.common.collect.{Lists, Maps} +import org.apache.spark.TaskContext import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.WriteOperation @@ -29,7 +30,7 @@ import org.apache.spark.sql.{Column, Dataset, SparkSession} import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression, UnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} @@ -37,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Interse import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types._ @@ -272,8 +274,12 @@ class SparkConnectPlanner(session: SparkSession) { } private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { - val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq - new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) + val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator( + Iterator(rel.getData.toByteArray), + TaskContext.get()) + val attributes = structType.toAttributes + val proj = UnsafeProjection.create(attributes, attributes) + new logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq) } private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 9e5fc41a0c68..072af389513e 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -19,12 +19,19 @@ package org.apache.spark.sql.connect.planner import scala.collection.JavaConverters._ +import com.google.protobuf.ByteString + import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.UnresolvedStar -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String /** * Testing trait for SparkConnect tests with some helper methods to make it easier to create new @@ -55,17 +62,26 @@ trait SparkConnectPlanTest extends SharedSparkSession { * equivalent in Catalyst and can be easily used for planner testing. * * @param attrs + * the attributes of LocalRelation + * @param data + * the data of LocalRelation * @return */ - def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { + def createLocalRelationProto( + attrs: Seq[AttributeReference], + data: Seq[InternalRow]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() - for (attr <- attrs) { - localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute - .newBuilder() - .setName(attr.name) - .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))) - } + + val bytes = ArrowConverters + .toBatchWithSchemaIterator( + data.iterator, + StructType.fromAttributes(attrs.map(_.toAttribute)), + Long.MaxValue, + Long.MaxValue, + null) + .next() + + localRelationBuilder.setData(ByteString.copyFrom(bytes)) proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } } @@ -96,7 +112,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { new SparkConnectPlanner(None.orNull) .transformRelation( proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build())) - } test("Simple Read") { @@ -199,7 +214,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("Simple Join") { - val incompleteJoin = proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build() intercept[AssertionError](transform(incompleteJoin)) @@ -267,7 +281,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res.nodeName == "Project") - } test("Simple Aggregation") { @@ -354,4 +367,61 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { transform(proto.Relation.newBuilder.setSetOp(intersect).build())) assert(e2.getMessage.contains("Intersect does not support union_by_name")) } + + test("transform LocalRelation") { + val rows = (0 until 10).map { i => + InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i)) + } + + val schema = StructType( + Seq( + StructField("int", IntegerType), + StructField("str", StringType), + StructField("struct", StructType(Seq(StructField("inner", IntegerType)))))) + val inputRows = rows.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + + val localRelation = createLocalRelationProto(schema.toAttributes, inputRows) + val df = Dataset.ofRows(spark, transform(localRelation)) + val array = df.collect() + assertResult(10)(array.length) + assert(schema == df.schema) + for (i <- 0 until 10) { + assert(i == array(i).getInt(0)) + assert(s"str-$i" == array(i).getString(1)) + assert(i == array(i).getStruct(2).getInt(0)) + } + } + + test("Empty ArrowBatch") { + val schema = StructType(Seq(StructField("int", IntegerType))) + val data = ArrowConverters.createEmptyArrowBatch(schema, null) + val localRelation = proto.Relation + .newBuilder() + .setLocalRelation( + proto.LocalRelation + .newBuilder() + .setData(ByteString.copyFrom(data)) + .build()) + .build() + val df = Dataset.ofRows(spark, transform(localRelation)) + assert(schema == df.schema) + assert(df.isEmpty) + } + + test("Illegal LocalRelation data") { + intercept[Exception] { + transform( + proto.Relation + .newBuilder() + .setLocalRelation( + proto.LocalRelation + .newBuilder() + .setData(ByteString.copyFrom("illegal".getBytes())) + .build()) + .build()) + } + } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 185548971ecb..4114175e5101 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.connect.planner import java.nio.file.{Files, Paths} +import com.google.protobuf.ByteString + import org.apache.spark.SparkClassNotFoundException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType @@ -31,6 +33,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.commands._ import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, MapType, Metadata, StringType, StructField, StructType} @@ -44,14 +47,18 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { lazy val connectTestRelation = createLocalRelationProto( - Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) + Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()), + Seq.empty) lazy val connectTestRelation2 = createLocalRelationProto( - Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) + Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()), + Seq.empty) lazy val connectTestRelationMap = - createLocalRelationProto(Seq(AttributeReference("id", MapType(StringType, StringType))())) + createLocalRelationProto( + Seq(AttributeReference("id", MapType(StringType, StringType))()), + Seq.empty) lazy val sparkTestRelation: DataFrame = spark.createDataFrame( @@ -68,7 +75,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { new java.util.ArrayList[Row](), StructType(Seq(StructField("id", MapType(StringType, StringType))))) - lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)())) + lazy val localRelation = + createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()), Seq.empty) test("Basic select") { val connectPlan = connectTestRelation.select("id".protoAttr) @@ -500,10 +508,21 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { private def createLocalRelationProtoByQualifiedAttributes( attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() - for (attr <- attrs) { - localRelationBuilder.addAttributes(attr) - } - proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() + + val attributes = attrs.map(exp => + AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))()) + val buffer = ArrowConverters + .toBatchWithSchemaIterator( + Iterator.empty, + StructType.fromAttributes(attributes), + Long.MaxValue, + Long.MaxValue, + null) + .next() + proto.Relation + .newBuilder() + .setLocalRelation(localRelationBuilder.setData(ByteString.copyFrom(buffer)).build()) + .build() } // This is a function for testing only. This is used when the plan is ready and it only waits diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 344caa3ea37e..0ba0d8917cc2 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xd1\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"d\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12-\n\x04\x63ols\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x04\x63ols"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xd1\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"d\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12-\n\x04\x63ols\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x04\x63ols"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"#\n\rLocalRelation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -465,27 +465,27 @@ _DEDUPLICATE._serialized_start = 4206 _DEDUPLICATE._serialized_end = 4377 _LOCALRELATION._serialized_start = 4379 - _LOCALRELATION._serialized_end = 4472 - _SAMPLE._serialized_start = 4475 - _SAMPLE._serialized_end = 4699 - _RANGE._serialized_start = 4702 - _RANGE._serialized_end = 4847 - _SUBQUERYALIAS._serialized_start = 4849 - _SUBQUERYALIAS._serialized_end = 4963 - _REPARTITION._serialized_start = 4966 - _REPARTITION._serialized_end = 5108 - _SHOWSTRING._serialized_start = 5111 - _SHOWSTRING._serialized_end = 5252 - _STATSUMMARY._serialized_start = 5254 - _STATSUMMARY._serialized_end = 5346 - _STATCROSSTAB._serialized_start = 5348 - _STATCROSSTAB._serialized_end = 5449 - _NAFILL._serialized_start = 5452 - _NAFILL._serialized_end = 5586 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5588 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5702 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5705 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5964 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5897 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5964 + _LOCALRELATION._serialized_end = 4414 + _SAMPLE._serialized_start = 4417 + _SAMPLE._serialized_end = 4641 + _RANGE._serialized_start = 4644 + _RANGE._serialized_end = 4789 + _SUBQUERYALIAS._serialized_start = 4791 + _SUBQUERYALIAS._serialized_end = 4905 + _REPARTITION._serialized_start = 4908 + _REPARTITION._serialized_end = 5050 + _SHOWSTRING._serialized_start = 5053 + _SHOWSTRING._serialized_end = 5194 + _STATSUMMARY._serialized_start = 5196 + _STATSUMMARY._serialized_end = 5288 + _STATCROSSTAB._serialized_start = 5290 + _STATCROSSTAB._serialized_end = 5391 + _NAFILL._serialized_start = 5394 + _NAFILL._serialized_end = 5528 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5530 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5644 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5647 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5906 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5839 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5906 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 30e61282baaf..a6a16b448c91 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -1075,27 +1075,17 @@ class LocalRelation(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - ATTRIBUTES_FIELD_NUMBER: builtins.int - @property - def attributes( - self, - ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - pyspark.sql.connect.proto.expressions_pb2.Expression.QualifiedAttribute - ]: - """(Optional) A list qualified attributes. - TODO: support local data. - """ + DATA_FIELD_NUMBER: builtins.int + data: builtins.bytes + """Local collection data serialized into Arrow IPC streaming format which contains + the schema of the data. + """ def __init__( self, *, - attributes: collections.abc.Iterable[ - pyspark.sql.connect.proto.expressions_pb2.Expression.QualifiedAttribute - ] - | None = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["attributes", b"attributes"] + data: builtins.bytes = ..., ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data", b"data"]) -> None: ... global___LocalRelation = LocalRelation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index a60f7b5970d1..40de117f6f67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -21,11 +21,12 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, Ou import java.nio.channels.{Channels, ReadableByteChannel} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer} import org.apache.spark.TaskContext @@ -214,57 +215,129 @@ private[sql] object ArrowConverters extends Logging { } /** - * Maps iterator from serialized ArrowRecordBatches to InternalRows. + * An InternalRow iterator which parse data from serialized ArrowRecordBatches, subclass should + * implement [[nextBatch]] to parse data from binary records. */ - private[sql] def fromBatchIterator( + private[sql] abstract class InternalRowIterator( arrowBatchIter: Iterator[Array[Byte]], - schema: StructType, - timeZoneId: String, - context: TaskContext): Iterator[InternalRow] = { - val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue) - - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - - new Iterator[InternalRow] { - private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty - - if (context != null) context.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() - } + context: TaskContext) + extends Iterator[InternalRow] { + // Keep all the resources we have opened in order, should be closed in reverse order finally. + val resources = new ArrayBuffer[AutoCloseable]() + protected val allocator: BufferAllocator = ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) + resources.append(allocator) + + private var rowIterAndSchema = + if (arrowBatchIter.hasNext) nextBatch() else (Iterator.empty, null) + // We will ensure schemas parsed from every batch are the same. + val schema: StructType = rowIterAndSchema._2 + + if (context != null) context.addTaskCompletionListener[Unit] { _ => + closeAll(resources.toSeq.reverse: _*) + } - override def hasNext: Boolean = rowIter.hasNext || { - if (arrowBatchIter.hasNext) { - rowIter = nextBatch() - true - } else { - root.close() - allocator.close() - false + override def hasNext: Boolean = rowIterAndSchema._1.hasNext || { + if (arrowBatchIter.hasNext) { + rowIterAndSchema = nextBatch() + if (schema != rowIterAndSchema._2) { + throw new IllegalArgumentException( + s"ArrowBatch iterator contain 2 batches with" + + s" different schema: $schema and ${rowIterAndSchema._2}") } + rowIterAndSchema._1.hasNext + } else { + closeAll(resources.toSeq.reverse: _*) + false } + } - override def next(): InternalRow = rowIter.next() + override def next(): InternalRow = rowIterAndSchema._1.next() - private def nextBatch(): Iterator[InternalRow] = { - val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) - val vectorLoader = new VectorLoader(root) - vectorLoader.load(arrowRecordBatch) - arrowRecordBatch.close() + def nextBatch(): (Iterator[InternalRow], StructType) + } - val columns = root.getFieldVectors.asScala.map { vector => - new ArrowColumnVector(vector).asInstanceOf[ColumnVector] - }.toArray + /** + * Parse data from serialized ArrowRecordBatches, the [[arrowBatchIter]] only contains serialized + * arrow batch records, the schema is passed in through [[schema]]. + */ + private[sql] class InternalRowIteratorWithoutSchema( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext) + extends InternalRowIterator(arrowBatchIter, context) { + + override def nextBatch(): (Iterator[InternalRow], StructType) = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + resources.append(root) + val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(arrowRecordBatch) + arrowRecordBatch.close() + (vectorSchemaRootToIter(root), schema) + } + } - val batch = new ColumnarBatch(columns) - batch.setNumRows(root.getRowCount) - batch.rowIterator().asScala + /** + * Parse data from serialized ArrowRecordBatches, the arrowBatch in [[arrowBatchIter]] starts with + * the schema so we should parse schema from it first. + */ + private[sql] class InternalRowIteratorWithSchema( + arrowBatchIter: Iterator[Array[Byte]], + context: TaskContext) + extends InternalRowIterator(arrowBatchIter, context) { + override def nextBatch(): (Iterator[InternalRow], StructType) = { + val reader = + new ArrowStreamReader(new ByteArrayInputStream(arrowBatchIter.next()), allocator) + val root = if (reader.loadNextBatch()) reader.getVectorSchemaRoot else null + resources.append(reader, root) + if (root == null) { + (Iterator.empty, null) + } else { + (vectorSchemaRootToIter(root), ArrowUtils.fromArrowSchema(root.getSchema)) } } } + /** + * Maps iterator from serialized ArrowRecordBatches to InternalRows. + */ + private[sql] def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = new InternalRowIteratorWithoutSchema( + arrowBatchIter, schema, timeZoneId, context + ) + + /** + * Maps iterator from serialized ArrowRecordBatches to InternalRows. Different from + * [[fromBatchIterator]], each input arrow batch starts with the schema. + */ + private[sql] def fromBatchWithSchemaIterator( + arrowBatchIter: Iterator[Array[Byte]], + context: TaskContext): (Iterator[InternalRow], StructType) = { + val iterator = new InternalRowIteratorWithSchema(arrowBatchIter, context) + (iterator, iterator.schema) + } + + /** + * Convert an arrow batch container into an iterator of InternalRow. + */ + private def vectorSchemaRootToIter(root: VectorSchemaRoot): Iterator[InternalRow] = { + val columns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector).asInstanceOf[ColumnVector] + }.toArray + + val batch = new ColumnarBatch(columns) + batch.setNumRows(root.getRowCount) + batch.rowIterator().asScala + } + /** * Load a serialized ArrowRecordBatch. */ @@ -391,4 +464,12 @@ private[sql] object ArrowConverters extends Logging { } } } + + private def closeAll(closeables: AutoCloseable*): Unit = { + for (closeable <- closeables) { + if (closeable != null) { + closeable.close() + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index e876e9d6ff20..eb33e2e47caf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -31,11 +31,13 @@ import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Valid import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{ArrayType, BinaryType, Decimal, IntegerType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, Decimal, IntegerType, NullType, StringType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -1426,6 +1428,73 @@ class ArrowConvertersSuite extends SharedSparkSession { assert(count == inputRows.length) } + test("roundtrip arrow batches with complex schema") { + val rows = (0 until 9).map { i => + InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i)) + } + + val schema = StructType(Seq( + StructField("int", IntegerType), + StructField("str", StringType), + StructField("struct", StructType(Seq(StructField("inner", IntegerType)))) + )) + val inputRows = rows.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + val ctx = TaskContext.empty() + val batchIter = + ArrowConverters.toBatchWithSchemaIterator(inputRows.iterator, schema, 5, 1024 * 1024, null) + val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) + + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + assert(row.getInt(0) == i) + assert(row.getString(1) == s"str-$i") + assert(row.getStruct(2, 1).getInt(0) == i) + count += 1 + } + + assert(count == inputRows.length) + assert(schema == outputType) + } + + test("roundtrip empty arrow batches") { + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val ctx = TaskContext.empty() + val batchIter = + ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5, 1024 * 1024, null) + val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) + + assert(0 == outputRowIter.length) + assert(outputType == null) + } + + test("two batches with different schema") { + val schema1 = StructType(Seq(StructField("field1", IntegerType, nullable = true))) + val inputRows1 = Array(InternalRow(1)).map { row => + val proj = UnsafeProjection.create(schema1) + proj(row).copy() + } + val batchIter1 = ArrowConverters.toBatchWithSchemaIterator( + inputRows1.iterator, schema1, 5, 1024 * 1024, null) + + val schema2 = StructType(Seq(StructField("field2", IntegerType, nullable = true))) + val inputRows2 = Array(InternalRow(1)).map { row => + val proj = UnsafeProjection.create(schema2) + proj(row).copy() + } + val batchIter2 = ArrowConverters.toBatchWithSchemaIterator( + inputRows2.iterator, schema2, 5, 1024 * 1024, null) + + val iter = batchIter1.toArray ++ batchIter2 + + val ctx = TaskContext.empty() + intercept[IllegalArgumentException] { + ArrowConverters.fromBatchWithSchemaIterator(iter.iterator, ctx)._1.toArray + } + } + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate( df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = {