From 70a42a967193ef4b3f9c1ccf75fefccfbcde0f41 Mon Sep 17 00:00:00 2001 From: dengziming Date: Mon, 14 Nov 2022 15:52:25 +0800 Subject: [PATCH 1/5] [SPARK-41114][CONNECT] Support local data for LocalRelation && resolve comments --- .../protobuf/spark/connect/relations.proto | 4 +- .../connect/planner/SparkConnectPlanner.scala | 12 +- .../planner/SparkConnectPlannerSuite.scala | 49 ++++-- .../planner/SparkConnectProtoSuite.scala | 36 +++- .../scala/org/apache/spark/util/Utils.scala | 8 + .../sql/connect/proto/relations_pb2.py | 150 ++++++++--------- .../sql/connect/proto/relations_pb2.pyi | 66 +------- .../sql/execution/arrow/ArrowConverters.scala | 155 +++++++++++++----- .../arrow/ArrowConvertersSuite.scala | 67 ++++++-- 9 files changed, 320 insertions(+), 227 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index aef4e4e7c642..8030d5b888c1 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -285,9 +285,7 @@ message Deduplicate { // A relation that does not need to be qualified by name. message LocalRelation { - // (Optional) A list qualified attributes. - repeated Expression.QualifiedAttribute attributes = 1; - // TODO: support local data. + bytes data = 1; } // Relation of type [[Sample]] that samples a fraction of the dataset. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 96d0dbe35803..d21479352fea 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import com.google.common.collect.{Lists, Maps} +import org.apache.spark.TaskContext import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.WriteOperation @@ -29,7 +30,7 @@ import org.apache.spark.sql.{Column, Dataset, SparkSession} import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression, UnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.CombineUnions import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin} @@ -37,6 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Interse import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.command.CreateViewCommand import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.types._ @@ -272,8 +274,12 @@ class SparkConnectPlanner(session: SparkSession) { } private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { - val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq - new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes) + val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator( + Seq(rel.getData.toByteArray).iterator, + TaskContext.get()) + val attributes = structType.toAttributes + val proj = UnsafeProjection.create(attributes, attributes) + new logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq) } private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = { diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 9e5fc41a0c68..1d6cec7f7dcb 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -17,14 +17,19 @@ package org.apache.spark.sql.connect.planner -import scala.collection.JavaConverters._ +import com.google.protobuf.ByteString +import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.UnresolvedStar -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} /** * Testing trait for SparkConnect tests with some helper methods to make it easier to create new @@ -55,17 +60,26 @@ trait SparkConnectPlanTest extends SharedSparkSession { * equivalent in Catalyst and can be easily used for planner testing. * * @param attrs + * the attributes of LocalRelation + * @param data + * the data of LocalRelation * @return */ - def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = { + def createLocalRelationProto( + attrs: Seq[AttributeReference], + data: Seq[InternalRow]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() - for (attr <- attrs) { - localRelationBuilder.addAttributes( - proto.Expression.QualifiedAttribute - .newBuilder() - .setName(attr.name) - .setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType))) - } + + val bytes = ArrowConverters + .toBatchWithSchemaIterator( + data.iterator, + StructType.fromAttributes(attrs.map(_.toAttribute)), + Long.MaxValue, + Long.MaxValue, + null) + .next() + + localRelationBuilder.setData(ByteString.copyFrom(bytes)) proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() } } @@ -96,7 +110,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { new SparkConnectPlanner(None.orNull) .transformRelation( proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build())) - } test("Simple Read") { @@ -199,7 +212,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("Simple Join") { - val incompleteJoin = proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build() intercept[AssertionError](transform(incompleteJoin)) @@ -267,7 +279,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { val res = transform(proto.Relation.newBuilder.setProject(project).build()) assert(res.nodeName == "Project") - } test("Simple Aggregation") { @@ -354,4 +365,16 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { transform(proto.Relation.newBuilder.setSetOp(intersect).build())) assert(e2.getMessage.contains("Intersect does not support union_by_name")) } + + test("transform LocalRelation") { + val inputRows = (0 until 10).map(InternalRow(_)) + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val rows = inputRows.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + + val localRelation = createLocalRelationProto(schema.toAttributes, rows) + assertResult(10)(Dataset.ofRows(spark, transform(localRelation)).count()) + } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 185548971ecb..9e08c72a41aa 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -16,8 +16,9 @@ */ package org.apache.spark.sql.connect.planner -import java.nio.file.{Files, Paths} +import com.google.protobuf.ByteString +import java.nio.file.{Files, Paths} import org.apache.spark.SparkClassNotFoundException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType @@ -31,6 +32,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.commands._ import org.apache.spark.sql.connect.dsl.expressions._ import org.apache.spark.sql.connect.dsl.plans._ +import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, MapType, Metadata, StringType, StructField, StructType} @@ -44,14 +46,18 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { lazy val connectTestRelation = createLocalRelationProto( - Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) + Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()), + Seq()) lazy val connectTestRelation2 = createLocalRelationProto( - Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)())) + Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()), + Seq()) lazy val connectTestRelationMap = - createLocalRelationProto(Seq(AttributeReference("id", MapType(StringType, StringType))())) + createLocalRelationProto( + Seq(AttributeReference("id", MapType(StringType, StringType))()), + Seq()) lazy val sparkTestRelation: DataFrame = spark.createDataFrame( @@ -68,7 +74,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { new java.util.ArrayList[Row](), StructType(Seq(StructField("id", MapType(StringType, StringType))))) - lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)())) + lazy val localRelation = + createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()), Seq()) test("Basic select") { val connectPlan = connectTestRelation.select("id".protoAttr) @@ -500,10 +507,21 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { private def createLocalRelationProtoByQualifiedAttributes( attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = { val localRelationBuilder = proto.LocalRelation.newBuilder() - for (attr <- attrs) { - localRelationBuilder.addAttributes(attr) - } - proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build() + + val attributes = attrs.map(exp => + AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))()) + val buffer = ArrowConverters + .toBatchWithSchemaIterator( + Iterator.empty, + StructType.fromAttributes(attributes), + Long.MaxValue, + Long.MaxValue, + null) + .next() + proto.Relation + .newBuilder() + .setLocalRelation(localRelationBuilder.setData(ByteString.copyFrom(buffer)).build()) + .build() } // This is a function for testing only. This is used when the plan is ready and it only waits diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 70477a5c9c08..2b596ace78c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -3257,6 +3257,14 @@ private[spark] object Utils extends Logging { case _ => math.max(sortedSize(len / 2), 1) } } + + def closeAll(closeables: AutoCloseable*): Unit = { + for (closeable <- closeables) { + if (closeable != null) { + closeable.close() + } + } + } } private[util] object CallerContext extends Logging { diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 344caa3ea37e..787a055773b4 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xd1\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"d\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12-\n\x04\x63ols\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x04\x63ols"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa6\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"#\n\rLocalRelation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -54,7 +54,6 @@ _AGGREGATE = DESCRIPTOR.message_types_by_name["Aggregate"] _SORT = DESCRIPTOR.message_types_by_name["Sort"] _SORT_SORTFIELD = _SORT.nested_types_by_name["SortField"] -_DROP = DESCRIPTOR.message_types_by_name["Drop"] _DEDUPLICATE = DESCRIPTOR.message_types_by_name["Deduplicate"] _LOCALRELATION = DESCRIPTOR.message_types_by_name["LocalRelation"] _SAMPLE = DESCRIPTOR.message_types_by_name["Sample"] @@ -257,17 +256,6 @@ _sym_db.RegisterMessage(Sort) _sym_db.RegisterMessage(Sort.SortField) -Drop = _reflection.GeneratedProtocolMessageType( - "Drop", - (_message.Message,), - { - "DESCRIPTOR": _DROP, - "__module__": "spark.connect.relations_pb2" - # @@protoc_insertion_point(class_scope:spark.connect.Drop) - }, -) -_sym_db.RegisterMessage(Drop) - Deduplicate = _reflection.GeneratedProtocolMessageType( "Deduplicate", (_message.Message,), @@ -419,73 +407,71 @@ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 1571 - _UNKNOWN._serialized_start = 1573 - _UNKNOWN._serialized_end = 1582 - _RELATIONCOMMON._serialized_start = 1584 - _RELATIONCOMMON._serialized_end = 1633 - _SQL._serialized_start = 1635 - _SQL._serialized_end = 1662 - _READ._serialized_start = 1665 - _READ._serialized_end = 2091 - _READ_NAMEDTABLE._serialized_start = 1807 - _READ_NAMEDTABLE._serialized_end = 1868 - _READ_DATASOURCE._serialized_start = 1871 - _READ_DATASOURCE._serialized_end = 2078 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2009 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2067 - _PROJECT._serialized_start = 2093 - _PROJECT._serialized_end = 2210 - _FILTER._serialized_start = 2212 - _FILTER._serialized_end = 2324 - _JOIN._serialized_start = 2327 - _JOIN._serialized_end = 2777 - _JOIN_JOINTYPE._serialized_start = 2590 - _JOIN_JOINTYPE._serialized_end = 2777 - _SETOPERATION._serialized_start = 2780 - _SETOPERATION._serialized_end = 3176 - _SETOPERATION_SETOPTYPE._serialized_start = 3039 - _SETOPERATION_SETOPTYPE._serialized_end = 3153 - _LIMIT._serialized_start = 3178 - _LIMIT._serialized_end = 3254 - _OFFSET._serialized_start = 3256 - _OFFSET._serialized_end = 3335 - _AGGREGATE._serialized_start = 3338 - _AGGREGATE._serialized_end = 3548 - _SORT._serialized_start = 3551 - _SORT._serialized_end = 4101 - _SORT_SORTFIELD._serialized_start = 3705 - _SORT_SORTFIELD._serialized_end = 3893 - _SORT_SORTDIRECTION._serialized_start = 3895 - _SORT_SORTDIRECTION._serialized_end = 4003 - _SORT_SORTNULLS._serialized_start = 4005 - _SORT_SORTNULLS._serialized_end = 4087 - _DROP._serialized_start = 4103 - _DROP._serialized_end = 4203 - _DEDUPLICATE._serialized_start = 4206 - _DEDUPLICATE._serialized_end = 4377 - _LOCALRELATION._serialized_start = 4379 - _LOCALRELATION._serialized_end = 4472 - _SAMPLE._serialized_start = 4475 - _SAMPLE._serialized_end = 4699 - _RANGE._serialized_start = 4702 - _RANGE._serialized_end = 4847 - _SUBQUERYALIAS._serialized_start = 4849 - _SUBQUERYALIAS._serialized_end = 4963 - _REPARTITION._serialized_start = 4966 - _REPARTITION._serialized_end = 5108 - _SHOWSTRING._serialized_start = 5111 - _SHOWSTRING._serialized_end = 5252 - _STATSUMMARY._serialized_start = 5254 - _STATSUMMARY._serialized_end = 5346 - _STATCROSSTAB._serialized_start = 5348 - _STATCROSSTAB._serialized_end = 5449 - _NAFILL._serialized_start = 5452 - _NAFILL._serialized_end = 5586 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5588 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5702 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5705 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5964 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5897 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5964 + _RELATION._serialized_end = 1528 + _UNKNOWN._serialized_start = 1530 + _UNKNOWN._serialized_end = 1539 + _RELATIONCOMMON._serialized_start = 1541 + _RELATIONCOMMON._serialized_end = 1590 + _SQL._serialized_start = 1592 + _SQL._serialized_end = 1619 + _READ._serialized_start = 1622 + _READ._serialized_end = 2048 + _READ_NAMEDTABLE._serialized_start = 1764 + _READ_NAMEDTABLE._serialized_end = 1825 + _READ_DATASOURCE._serialized_start = 1828 + _READ_DATASOURCE._serialized_end = 2035 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1966 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2024 + _PROJECT._serialized_start = 2050 + _PROJECT._serialized_end = 2167 + _FILTER._serialized_start = 2169 + _FILTER._serialized_end = 2281 + _JOIN._serialized_start = 2284 + _JOIN._serialized_end = 2734 + _JOIN_JOINTYPE._serialized_start = 2547 + _JOIN_JOINTYPE._serialized_end = 2734 + _SETOPERATION._serialized_start = 2737 + _SETOPERATION._serialized_end = 3133 + _SETOPERATION_SETOPTYPE._serialized_start = 2996 + _SETOPERATION_SETOPTYPE._serialized_end = 3110 + _LIMIT._serialized_start = 3135 + _LIMIT._serialized_end = 3211 + _OFFSET._serialized_start = 3213 + _OFFSET._serialized_end = 3292 + _AGGREGATE._serialized_start = 3295 + _AGGREGATE._serialized_end = 3505 + _SORT._serialized_start = 3508 + _SORT._serialized_end = 4058 + _SORT_SORTFIELD._serialized_start = 3662 + _SORT_SORTFIELD._serialized_end = 3850 + _SORT_SORTDIRECTION._serialized_start = 3852 + _SORT_SORTDIRECTION._serialized_end = 3960 + _SORT_SORTNULLS._serialized_start = 3962 + _SORT_SORTNULLS._serialized_end = 4044 + _DEDUPLICATE._serialized_start = 4061 + _DEDUPLICATE._serialized_end = 4232 + _LOCALRELATION._serialized_start = 4234 + _LOCALRELATION._serialized_end = 4269 + _SAMPLE._serialized_start = 4272 + _SAMPLE._serialized_end = 4496 + _RANGE._serialized_start = 4499 + _RANGE._serialized_end = 4644 + _SUBQUERYALIAS._serialized_start = 4646 + _SUBQUERYALIAS._serialized_end = 4760 + _REPARTITION._serialized_start = 4763 + _REPARTITION._serialized_end = 4905 + _SHOWSTRING._serialized_start = 4908 + _SHOWSTRING._serialized_end = 5049 + _STATSUMMARY._serialized_start = 5051 + _STATSUMMARY._serialized_end = 5143 + _STATCROSSTAB._serialized_start = 5145 + _STATCROSSTAB._serialized_end = 5246 + _NAFILL._serialized_start = 5249 + _NAFILL._serialized_end = 5383 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5385 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5499 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5502 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5761 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5694 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5761 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 30e61282baaf..ef28e5567374 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -79,7 +79,6 @@ class Relation(google.protobuf.message.Message): RENAME_COLUMNS_BY_SAME_LENGTH_NAMES_FIELD_NUMBER: builtins.int RENAME_COLUMNS_BY_NAME_TO_NAME_MAP_FIELD_NUMBER: builtins.int SHOW_STRING_FIELD_NUMBER: builtins.int - DROP_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int SUMMARY_FIELD_NUMBER: builtins.int CROSSTAB_FIELD_NUMBER: builtins.int @@ -125,8 +124,6 @@ class Relation(google.protobuf.message.Message): @property def show_string(self) -> global___ShowString: ... @property - def drop(self) -> global___Drop: ... - @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -159,7 +156,6 @@ class Relation(google.protobuf.message.Message): rename_columns_by_same_length_names: global___RenameColumnsBySameLengthNames | None = ..., rename_columns_by_name_to_name_map: global___RenameColumnsByNameToNameMap | None = ..., show_string: global___ShowString | None = ..., - drop: global___Drop | None = ..., fill_na: global___NAFill | None = ..., summary: global___StatSummary | None = ..., crosstab: global___StatCrosstab | None = ..., @@ -176,8 +172,6 @@ class Relation(google.protobuf.message.Message): b"crosstab", "deduplicate", b"deduplicate", - "drop", - b"drop", "fill_na", b"fill_na", "filter", @@ -233,8 +227,6 @@ class Relation(google.protobuf.message.Message): b"crosstab", "deduplicate", b"deduplicate", - "drop", - b"drop", "fill_na", b"fill_na", "filter", @@ -301,7 +293,6 @@ class Relation(google.protobuf.message.Message): "rename_columns_by_same_length_names", "rename_columns_by_name_to_name_map", "show_string", - "drop", "fill_na", "summary", "crosstab", @@ -970,42 +961,6 @@ class Sort(google.protobuf.message.Message): global___Sort = Sort -class Drop(google.protobuf.message.Message): - """Drop specified columns.""" - - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - INPUT_FIELD_NUMBER: builtins.int - COLS_FIELD_NUMBER: builtins.int - @property - def input(self) -> global___Relation: - """(Required) The input relation.""" - @property - def cols( - self, - ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - pyspark.sql.connect.proto.expressions_pb2.Expression - ]: - """(Required) columns to drop. - - Should contain at least 1 item. - """ - def __init__( - self, - *, - input: global___Relation | None = ..., - cols: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] - | None = ..., - ) -> None: ... - def HasField( - self, field_name: typing_extensions.Literal["input", b"input"] - ) -> builtins.bool: ... - def ClearField( - self, field_name: typing_extensions.Literal["cols", b"cols", "input", b"input"] - ) -> None: ... - -global___Drop = Drop - class Deduplicate(google.protobuf.message.Message): """Relation of type [[Deduplicate]] which have duplicate rows removed, could consider either only the subset of columns or all the columns. @@ -1075,27 +1030,14 @@ class LocalRelation(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - ATTRIBUTES_FIELD_NUMBER: builtins.int - @property - def attributes( - self, - ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ - pyspark.sql.connect.proto.expressions_pb2.Expression.QualifiedAttribute - ]: - """(Optional) A list qualified attributes. - TODO: support local data. - """ + DATA_FIELD_NUMBER: builtins.int + data: builtins.bytes def __init__( self, *, - attributes: collections.abc.Iterable[ - pyspark.sql.connect.proto.expressions_pb2.Expression.QualifiedAttribute - ] - | None = ..., - ) -> None: ... - def ClearField( - self, field_name: typing_extensions.Literal["attributes", b"attributes"] + data: builtins.bytes = ..., ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["data", b"data"]) -> None: ... global___LocalRelation = LocalRelation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index a60f7b5970d1..20142c549c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -19,15 +19,12 @@ package org.apache.spark.sql.execution.arrow import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} import java.nio.channels.{Channels, ReadableByteChannel} - import scala.collection.JavaConverters._ - import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ -import org.apache.arrow.vector.ipc.{ArrowStreamWriter, ReadChannel, WriteChannel} +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer} - import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils @@ -37,10 +34,9 @@ import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} - /** * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ @@ -76,21 +72,26 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, maxRecordsPerBatch: Long, timeZoneId: String, - context: TaskContext) extends Iterator[Array[Byte]] { + context: TaskContext) + extends Iterator[Array[Byte]] { protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) private val root = VectorSchemaRoot.create(arrowSchema, allocator) protected val unloader = new VectorUnloader(root) protected val arrowWriter = ArrowWriter.create(root) - Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() - }} + Option(context).foreach { + _.addTaskCompletionListener[Unit] { _ => + root.close() + allocator.close() + } + } override def hasNext: Boolean = rowIter.hasNext || { root.close() @@ -128,8 +129,7 @@ private[sql] object ArrowConverters extends Logging { maxEstimatedBatchSize: Long, timeZoneId: String, context: TaskContext) - extends ArrowBatchIterator( - rowIter, schema, maxRecordsPerBatch, timeZoneId, context) { + extends ArrowBatchIterator(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) { private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) var rowCountInLastBatch: Long = 0 @@ -146,15 +146,15 @@ private[sql] object ArrowConverters extends Logging { // Always write the first row. while (rowIter.hasNext && ( - // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. - // If the size in bytes is positive (set properly), always write the first row. - rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || - // If the size in bytes of rows are 0 or negative, unlimit it. - estimatedBatchSize <= 0 || - estimatedBatchSize < maxEstimatedBatchSize || - // If the size of rows are 0 or negative, unlimit it. - maxRecordsPerBatch <= 0 || - rowCountInLastBatch < maxRecordsPerBatch)) { + // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. + // If the size in bytes is positive (set properly), always write the first row. + rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || + // If the size in bytes of rows are 0 or negative, unlimit it. + estimatedBatchSize <= 0 || + estimatedBatchSize < maxEstimatedBatchSize || + // If the size of rows are 0 or negative, unlimit it. + maxRecordsPerBatch <= 0 || + rowCountInLastBatch < maxRecordsPerBatch)) { val row = rowIter.next() arrowWriter.write(row) estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes @@ -186,13 +186,12 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, timeZoneId: String, context: TaskContext): ArrowBatchIterator = { - new ArrowBatchIterator( - rowIter, schema, maxRecordsPerBatch, timeZoneId, context) + new ArrowBatchIterator(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) } /** - * Convert the input rows into fully contained arrow batches. - * Different from [[toBatchIterator]], each output arrow batch starts with the schema. + * Convert the input rows into fully contained arrow batches. Different from + * [[toBatchIterator]], each output arrow batch starts with the schema. */ private[sql] def toBatchWithSchemaIterator( rowIter: Iterator[InternalRow], @@ -201,14 +200,22 @@ private[sql] object ArrowConverters extends Logging { maxEstimatedBatchSize: Long, timeZoneId: String): ArrowBatchWithSchemaIterator = { new ArrowBatchWithSchemaIterator( - rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, timeZoneId, TaskContext.get) + rowIter, + schema, + maxRecordsPerBatch, + maxEstimatedBatchSize, + timeZoneId, + TaskContext.get) } - private[sql] def createEmptyArrowBatch( - schema: StructType, - timeZoneId: String): Array[Byte] = { + private[sql] def createEmptyArrowBatch(schema: StructType, timeZoneId: String): Array[Byte] = { new ArrowBatchWithSchemaIterator( - Iterator.empty, schema, 0L, 0L, timeZoneId, TaskContext.get) { + Iterator.empty, + schema, + 0L, + 0L, + timeZoneId, + TaskContext.get) { override def hasNext: Boolean = true }.next() } @@ -253,16 +260,76 @@ private[sql] object ArrowConverters extends Logging { val vectorLoader = new VectorLoader(root) vectorLoader.load(arrowRecordBatch) arrowRecordBatch.close() + vectorSchemaRootToIter(root) + } + } + } + + /** + * // TODO docs + */ + private[sql] def fromBatchWithSchemaIterator( + arrowBatchIter: Iterator[Array[Byte]], + context: TaskContext): (Iterator[InternalRow], StructType) = { + var structType = new StructType() + val allocator = + ArrowUtils.rootAllocator.newChildAllocator("fromBatchWithSchemaIterator", 0, Long.MaxValue) + + val iter = new Iterator[InternalRow] { + private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty - val columns = root.getFieldVectors.asScala.map { vector => - new ArrowColumnVector(vector).asInstanceOf[ColumnVector] - }.toArray + if (context != null) context.addTaskCompletionListener[Unit] { _ => + allocator.close() + } + + override def hasNext: Boolean = rowIter.hasNext || { + if (arrowBatchIter.hasNext) { + rowIter = nextBatch() + rowIter.hasNext + } else { + // Utils.closeAll(allocator) TODO memory LEAK + false + } + } - val batch = new ColumnarBatch(columns) - batch.setNumRows(root.getRowCount) - batch.rowIterator().asScala + override def next(): InternalRow = { + rowIter.next() + } + + private def nextBatch(): Iterator[InternalRow] = { + val rowsAndType = fromBatchWithSchemaBuffer(arrowBatchIter.next(), allocator, context) + structType = structType.merge(rowsAndType._2) + rowsAndType._1 } } + (iter, structType) + } + + // TODO THREAD LEAK + private def fromBatchWithSchemaBuffer( + arrowBuffer: Array[Byte], + allocator: BufferAllocator, + context: TaskContext): (Iterator[InternalRow], StructType) = { + val reader = new ArrowStreamReader(new ByteArrayInputStream(arrowBuffer), allocator) + + val root = if (reader.loadNextBatch()) reader.getVectorSchemaRoot else null + val structType = + if (root == null) new StructType() else ArrowUtils.fromArrowSchema(root.getSchema) + + if (context != null) context.addTaskCompletionListener[Unit] { _ => + Utils.closeAll(root, reader) + } + (vectorSchemaRootToIter(root), structType) + } + + private def vectorSchemaRootToIter(root: VectorSchemaRoot): Iterator[InternalRow] = { + val columns = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector).asInstanceOf[ColumnVector] + }.toArray + + val batch = new ColumnarBatch(columns) + batch.setNumRows(root.getRowCount) + batch.rowIterator().asScala } /** @@ -273,7 +340,9 @@ private[sql] object ArrowConverters extends Logging { allocator: BufferAllocator): ArrowRecordBatch = { val in = new ByteArrayInputStream(batchBytes) MessageSerializer.deserializeRecordBatch( - new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException + new ReadChannel(Channels.newChannel(in)), + allocator + ) // throws IOException } /** @@ -289,13 +358,15 @@ private[sql] object ArrowConverters extends Logging { val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] val attrs = schema.toAttributes val batchesInDriver = arrowBatches.toArray - val shouldUseRDD = session.sessionState.conf - .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum + val shouldUseRDD = session.sessionState.conf.arrowLocalRelationThreshold < batchesInDriver + .map(_.length.toLong) + .sum if (shouldUseRDD) { logDebug("Using RDD-based createDataFrame with Arrow optimization.") val timezone = session.sessionState.conf.sessionLocalTimeZone - val rdd = session.sparkContext.parallelize(batchesInDriver, batchesInDriver.length) + val rdd = session.sparkContext + .parallelize(batchesInDriver, batchesInDriver.length) .mapPartitions { batchesInExecutors => ArrowConverters.fromBatchIterator( batchesInExecutors, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index e876e9d6ff20..56c0666e3d25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -21,16 +21,15 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale - import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.ipc.JsonFileReader import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator} - import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -38,7 +37,6 @@ import org.apache.spark.sql.types.{ArrayType, BinaryType, Decimal, IntegerType, import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils - class ArrowConvertersSuite extends SharedSparkSession { import testImplicits._ @@ -361,7 +359,13 @@ class ArrowConvertersSuite extends SharedSparkSession { """.stripMargin val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_)) - val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)), + val b_d = List( + Some(Decimal(1.1)), + None, + None, + Some(Decimal(2.2)), + None, + Some(Decimal(3.3)), Some(Decimal("123456789012345678901234567890"))) val df = a_d.zip(b_d).toDF("a_d", "b_d") @@ -676,8 +680,8 @@ class ArrowConvertersSuite extends SharedSparkSession { |} """.stripMargin - val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" - val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" + val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" + val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" val d3 = Date.valueOf("2015-04-08") val d4 = Date.valueOf("3017-07-18") @@ -768,7 +772,7 @@ class ArrowConvertersSuite extends SharedSparkSession { |} """.stripMargin - val fnan = Seq(1.2F, Float.NaN) + val fnan = Seq(1.2f, Float.NaN) val dnan = Seq(Double.NaN, 1.2) val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") @@ -915,9 +919,14 @@ class ArrowConvertersSuite extends SharedSparkSession { val cArr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5))) val dArr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5))) - val df = aArr.zip(bArr).zip(cArr).zip(dArr).map { - case (((a, b), c), d) => (a, b, c, d) - }.toDF("a_arr", "b_arr", "c_arr", "d_arr") + val df = aArr + .zip(bArr) + .zip(cArr) + .zip(dArr) + .map { case (((a, b), c), d) => + (a, b, c, d) + } + .toDF("a_arr", "b_arr", "c_arr", "d_arr") collectAndValidate(df, json, "arrayData.json") } @@ -1056,8 +1065,8 @@ class ArrowConvertersSuite extends SharedSparkSession { val bStruct = Seq(Row(1), null, Row(3)) val cStruct = Seq(Row(1), Row(null), Row(3)) val dStruct = Seq(Row(Row(1)), null, Row(null)) - val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map { - case (((a, b), c), d) => Row(a, b, c, d) + val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map { case (((a, b), c), d) => + Row(a, b, c, d) } val rdd = sparkContext.parallelize(data) @@ -1426,9 +1435,41 @@ class ArrowConvertersSuite extends SharedSparkSession { assert(count == inputRows.length) } + test("roundtrip arrow batches with schema") { + val rows = (0 until 9).map { i => + InternalRow(i) + } :+ InternalRow(null) + + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val inputRows = rows.map { row => + val proj = UnsafeProjection.create(schema) + proj(row).copy() + } + val ctx = TaskContext.empty() + val batchIter = + ArrowConverters.toBatchWithSchemaIterator(inputRows.iterator, schema, 5, 4 * 1024, null) + val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) + + var count = 0 + outputRowIter.zipWithIndex.foreach { case (row, i) => + if (i != 9) { + assert(row.getInt(0) == i) + } else { + assert(row.isNullAt(0)) + } + count += 1 + } + + assert(count == inputRows.length) + assert(schema == outputType) + } + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate( - df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { + df: DataFrame, + json: String, + file: String, + timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) From 2db0b5a906d55ce7b63a9d342db5d4e78acd2fdc Mon Sep 17 00:00:00 2001 From: dengziming Date: Mon, 21 Nov 2022 23:14:59 +0800 Subject: [PATCH 2/5] resolve comments && fix scalastyle --- .../protobuf/spark/connect/relations.proto | 2 + .../planner/SparkConnectPlannerSuite.scala | 61 ++++++++-- .../planner/SparkConnectProtoSuite.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 111 ++++++++++-------- .../arrow/ArrowConvertersSuite.scala | 98 ++++++++++------ 5 files changed, 180 insertions(+), 95 deletions(-) diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 8030d5b888c1..489b69e2e533 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -285,6 +285,8 @@ message Deduplicate { // A relation that does not need to be qualified by name. message LocalRelation { + // Local collection data serialized into Arrow IPC streaming format which contains + // the schema of the data. bytes data = 1; } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 1d6cec7f7dcb..072af389513e 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.connect.planner +import scala.collection.JavaConverters._ + import com.google.protobuf.ByteString -import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Expression.UnresolvedStar @@ -29,7 +30,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProj import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String /** * Testing trait for SparkConnect tests with some helper methods to make it easier to create new @@ -367,14 +369,59 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { } test("transform LocalRelation") { - val inputRows = (0 until 10).map(InternalRow(_)) - val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) - val rows = inputRows.map { row => + val rows = (0 until 10).map { i => + InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i)) + } + + val schema = StructType( + Seq( + StructField("int", IntegerType), + StructField("str", StringType), + StructField("struct", StructType(Seq(StructField("inner", IntegerType)))))) + val inputRows = rows.map { row => val proj = UnsafeProjection.create(schema) proj(row).copy() } - val localRelation = createLocalRelationProto(schema.toAttributes, rows) - assertResult(10)(Dataset.ofRows(spark, transform(localRelation)).count()) + val localRelation = createLocalRelationProto(schema.toAttributes, inputRows) + val df = Dataset.ofRows(spark, transform(localRelation)) + val array = df.collect() + assertResult(10)(array.length) + assert(schema == df.schema) + for (i <- 0 until 10) { + assert(i == array(i).getInt(0)) + assert(s"str-$i" == array(i).getString(1)) + assert(i == array(i).getStruct(2).getInt(0)) + } + } + + test("Empty ArrowBatch") { + val schema = StructType(Seq(StructField("int", IntegerType))) + val data = ArrowConverters.createEmptyArrowBatch(schema, null) + val localRelation = proto.Relation + .newBuilder() + .setLocalRelation( + proto.LocalRelation + .newBuilder() + .setData(ByteString.copyFrom(data)) + .build()) + .build() + val df = Dataset.ofRows(spark, transform(localRelation)) + assert(schema == df.schema) + assert(df.isEmpty) + } + + test("Illegal LocalRelation data") { + intercept[Exception] { + transform( + proto.Relation + .newBuilder() + .setLocalRelation( + proto.LocalRelation + .newBuilder() + .setData(ByteString.copyFrom("illegal".getBytes())) + .build()) + .build()) + } } } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 9e08c72a41aa..f924f3811c25 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -16,9 +16,10 @@ */ package org.apache.spark.sql.connect.planner +import java.nio.file.{Files, Paths} + import com.google.protobuf.ByteString -import java.nio.file.{Files, Paths} import org.apache.spark.SparkClassNotFoundException import org.apache.spark.connect.proto import org.apache.spark.connect.proto.Join.JoinType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 20142c549c92..223f9652e063 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.execution.arrow import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, OutputStream} import java.nio.channels.{Channels, ReadableByteChannel} + import scala.collection.JavaConverters._ + import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter, ReadChannel, WriteChannel} import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, IpcOption, MessageSerializer} + import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils @@ -34,9 +37,10 @@ import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.{ByteBufferOutputStream, SizeEstimator, Utils} + /** * Writes serialized ArrowRecordBatches to a DataOutputStream in the Arrow stream format. */ @@ -72,26 +76,21 @@ private[sql] object ArrowConverters extends Logging { schema: StructType, maxRecordsPerBatch: Long, timeZoneId: String, - context: TaskContext) - extends Iterator[Array[Byte]] { + context: TaskContext) extends Iterator[Array[Byte]] { protected val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"to${this.getClass.getSimpleName}", - 0, - Long.MaxValue) + s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) private val root = VectorSchemaRoot.create(arrowSchema, allocator) protected val unloader = new VectorUnloader(root) protected val arrowWriter = ArrowWriter.create(root) - Option(context).foreach { - _.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() - } - } + Option(context).foreach {_.addTaskCompletionListener[Unit] { _ => + root.close() + allocator.close() + }} override def hasNext: Boolean = rowIter.hasNext || { root.close() @@ -129,7 +128,8 @@ private[sql] object ArrowConverters extends Logging { maxEstimatedBatchSize: Long, timeZoneId: String, context: TaskContext) - extends ArrowBatchIterator(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) { + extends ArrowBatchIterator( + rowIter, schema, maxRecordsPerBatch, timeZoneId, context) { private val arrowSchemaSize = SizeEstimator.estimate(arrowSchema) var rowCountInLastBatch: Long = 0 @@ -146,15 +146,15 @@ private[sql] object ArrowConverters extends Logging { // Always write the first row. while (rowIter.hasNext && ( - // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. - // If the size in bytes is positive (set properly), always write the first row. - rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || - // If the size in bytes of rows are 0 or negative, unlimit it. - estimatedBatchSize <= 0 || - estimatedBatchSize < maxEstimatedBatchSize || - // If the size of rows are 0 or negative, unlimit it. - maxRecordsPerBatch <= 0 || - rowCountInLastBatch < maxRecordsPerBatch)) { + // For maxBatchSize and maxRecordsPerBatch, respect whatever smaller. + // If the size in bytes is positive (set properly), always write the first row. + rowCountInLastBatch == 0 && maxEstimatedBatchSize > 0 || + // If the size in bytes of rows are 0 or negative, unlimit it. + estimatedBatchSize <= 0 || + estimatedBatchSize < maxEstimatedBatchSize || + // If the size of rows are 0 or negative, unlimit it. + maxRecordsPerBatch <= 0 || + rowCountInLastBatch < maxRecordsPerBatch)) { val row = rowIter.next() arrowWriter.write(row) estimatedBatchSize += row.asInstanceOf[UnsafeRow].getSizeInBytes @@ -186,12 +186,13 @@ private[sql] object ArrowConverters extends Logging { maxRecordsPerBatch: Long, timeZoneId: String, context: TaskContext): ArrowBatchIterator = { - new ArrowBatchIterator(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) + new ArrowBatchIterator( + rowIter, schema, maxRecordsPerBatch, timeZoneId, context) } /** - * Convert the input rows into fully contained arrow batches. Different from - * [[toBatchIterator]], each output arrow batch starts with the schema. + * Convert the input rows into fully contained arrow batches. + * Different from [[toBatchIterator]], each output arrow batch starts with the schema. */ private[sql] def toBatchWithSchemaIterator( rowIter: Iterator[InternalRow], @@ -200,22 +201,14 @@ private[sql] object ArrowConverters extends Logging { maxEstimatedBatchSize: Long, timeZoneId: String): ArrowBatchWithSchemaIterator = { new ArrowBatchWithSchemaIterator( - rowIter, - schema, - maxRecordsPerBatch, - maxEstimatedBatchSize, - timeZoneId, - TaskContext.get) + rowIter, schema, maxRecordsPerBatch, maxEstimatedBatchSize, timeZoneId, TaskContext.get) } - private[sql] def createEmptyArrowBatch(schema: StructType, timeZoneId: String): Array[Byte] = { + private[sql] def createEmptyArrowBatch( + schema: StructType, + timeZoneId: String): Array[Byte] = { new ArrowBatchWithSchemaIterator( - Iterator.empty, - schema, - 0L, - 0L, - timeZoneId, - TaskContext.get) { + Iterator.empty, schema, 0L, 0L, timeZoneId, TaskContext.get) { override def hasNext: Boolean = true }.next() } @@ -266,12 +259,13 @@ private[sql] object ArrowConverters extends Logging { } /** - * // TODO docs + * Maps iterator from serialized ArrowRecordBatches to InternalRows. Different from + * [[fromBatchIterator]], each input arrow batch starts with the schema. */ private[sql] def fromBatchWithSchemaIterator( arrowBatchIter: Iterator[Array[Byte]], context: TaskContext): (Iterator[InternalRow], StructType) = { - var structType = new StructType() + var structType: StructType = null val allocator = ArrowUtils.rootAllocator.newChildAllocator("fromBatchWithSchemaIterator", 0, Long.MaxValue) @@ -287,7 +281,7 @@ private[sql] object ArrowConverters extends Logging { rowIter = nextBatch() rowIter.hasNext } else { - // Utils.closeAll(allocator) TODO memory LEAK + Utils.closeAll(allocator) false } } @@ -298,14 +292,18 @@ private[sql] object ArrowConverters extends Logging { private def nextBatch(): Iterator[InternalRow] = { val rowsAndType = fromBatchWithSchemaBuffer(arrowBatchIter.next(), allocator, context) - structType = structType.merge(rowsAndType._2) + if (structType == null) { + structType = rowsAndType._2 + } else if (structType != rowsAndType._2) { + throw new IllegalArgumentException(s"ArrowBatch iterator contain 2 batches with" + + s" different schema: $structType and ${rowsAndType._2}") + } rowsAndType._1 } } (iter, structType) } - // TODO THREAD LEAK private def fromBatchWithSchemaBuffer( arrowBuffer: Array[Byte], allocator: BufferAllocator, @@ -319,7 +317,20 @@ private[sql] object ArrowConverters extends Logging { if (context != null) context.addTaskCompletionListener[Unit] { _ => Utils.closeAll(root, reader) } - (vectorSchemaRootToIter(root), structType) + val inner = vectorSchemaRootToIter(root) + val iter = new Iterator[InternalRow] { + override def hasNext: Boolean = { + if (inner.hasNext) { + true + } else { + Utils.closeAll(root, reader) + false + } + } + + override def next(): InternalRow = inner.next() + } + (iter, structType) } private def vectorSchemaRootToIter(root: VectorSchemaRoot): Iterator[InternalRow] = { @@ -340,9 +351,7 @@ private[sql] object ArrowConverters extends Logging { allocator: BufferAllocator): ArrowRecordBatch = { val in = new ByteArrayInputStream(batchBytes) MessageSerializer.deserializeRecordBatch( - new ReadChannel(Channels.newChannel(in)), - allocator - ) // throws IOException + new ReadChannel(Channels.newChannel(in)), allocator) // throws IOException } /** @@ -358,15 +367,13 @@ private[sql] object ArrowConverters extends Logging { val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] val attrs = schema.toAttributes val batchesInDriver = arrowBatches.toArray - val shouldUseRDD = session.sessionState.conf.arrowLocalRelationThreshold < batchesInDriver - .map(_.length.toLong) - .sum + val shouldUseRDD = session.sessionState.conf + .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum if (shouldUseRDD) { logDebug("Using RDD-based createDataFrame with Arrow optimization.") val timezone = session.sessionState.conf.sessionLocalTimeZone - val rdd = session.sparkContext - .parallelize(batchesInDriver, batchesInDriver.length) + val rdd = session.sparkContext.parallelize(batchesInDriver, batchesInDriver.length) .mapPartitions { batchesInExecutors => ArrowConverters.fromBatchIterator( batchesInExecutors, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 56c0666e3d25..eb33e2e47caf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -21,11 +21,13 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import java.util.Locale + import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{VectorLoader, VectorSchemaRoot} import org.apache.arrow.vector.ipc.JsonFileReader import org.apache.arrow.vector.util.{ByteArrayReadableSeekableByteChannel, Validator} + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.InternalRow @@ -33,10 +35,12 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{ArrayType, BinaryType, Decimal, IntegerType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, Decimal, IntegerType, NullType, StringType, StructField, StructType} import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils + class ArrowConvertersSuite extends SharedSparkSession { import testImplicits._ @@ -359,13 +363,7 @@ class ArrowConvertersSuite extends SharedSparkSession { """.stripMargin val a_d = List(1.0, 2.0, 0.01, 200.0, 0.0001, 20000.0, 30.0).map(Decimal(_)) - val b_d = List( - Some(Decimal(1.1)), - None, - None, - Some(Decimal(2.2)), - None, - Some(Decimal(3.3)), + val b_d = List(Some(Decimal(1.1)), None, None, Some(Decimal(2.2)), None, Some(Decimal(3.3)), Some(Decimal("123456789012345678901234567890"))) val df = a_d.zip(b_d).toDF("a_d", "b_d") @@ -680,8 +678,8 @@ class ArrowConvertersSuite extends SharedSparkSession { |} """.stripMargin - val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" - val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" + val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" + val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" val d3 = Date.valueOf("2015-04-08") val d4 = Date.valueOf("3017-07-18") @@ -772,7 +770,7 @@ class ArrowConvertersSuite extends SharedSparkSession { |} """.stripMargin - val fnan = Seq(1.2f, Float.NaN) + val fnan = Seq(1.2F, Float.NaN) val dnan = Seq(Double.NaN, 1.2) val df = fnan.zip(dnan).toDF("NaN_f", "NaN_d") @@ -919,14 +917,9 @@ class ArrowConvertersSuite extends SharedSparkSession { val cArr = Seq(Seq(Some(1), Some(2)), Seq(Some(3), None), Seq(), Seq(Some(5))) val dArr = Seq(Seq(Seq(1, 2)), Seq(Seq(3), Seq()), Seq(), Seq(Seq(5))) - val df = aArr - .zip(bArr) - .zip(cArr) - .zip(dArr) - .map { case (((a, b), c), d) => - (a, b, c, d) - } - .toDF("a_arr", "b_arr", "c_arr", "d_arr") + val df = aArr.zip(bArr).zip(cArr).zip(dArr).map { + case (((a, b), c), d) => (a, b, c, d) + }.toDF("a_arr", "b_arr", "c_arr", "d_arr") collectAndValidate(df, json, "arrayData.json") } @@ -1065,8 +1058,8 @@ class ArrowConvertersSuite extends SharedSparkSession { val bStruct = Seq(Row(1), null, Row(3)) val cStruct = Seq(Row(1), Row(null), Row(3)) val dStruct = Seq(Row(Row(1)), null, Row(null)) - val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map { case (((a, b), c), d) => - Row(a, b, c, d) + val data = aStruct.zip(bStruct).zip(cStruct).zip(dStruct).map { + case (((a, b), c), d) => Row(a, b, c, d) } val rdd = sparkContext.parallelize(data) @@ -1435,28 +1428,30 @@ class ArrowConvertersSuite extends SharedSparkSession { assert(count == inputRows.length) } - test("roundtrip arrow batches with schema") { + test("roundtrip arrow batches with complex schema") { val rows = (0 until 9).map { i => - InternalRow(i) - } :+ InternalRow(null) + InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i)) + } - val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val schema = StructType(Seq( + StructField("int", IntegerType), + StructField("str", StringType), + StructField("struct", StructType(Seq(StructField("inner", IntegerType)))) + )) val inputRows = rows.map { row => val proj = UnsafeProjection.create(schema) proj(row).copy() } val ctx = TaskContext.empty() val batchIter = - ArrowConverters.toBatchWithSchemaIterator(inputRows.iterator, schema, 5, 4 * 1024, null) + ArrowConverters.toBatchWithSchemaIterator(inputRows.iterator, schema, 5, 1024 * 1024, null) val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) var count = 0 outputRowIter.zipWithIndex.foreach { case (row, i) => - if (i != 9) { - assert(row.getInt(0) == i) - } else { - assert(row.isNullAt(0)) - } + assert(row.getInt(0) == i) + assert(row.getString(1) == s"str-$i") + assert(row.getStruct(2, 1).getInt(0) == i) count += 1 } @@ -1464,12 +1459,45 @@ class ArrowConvertersSuite extends SharedSparkSession { assert(schema == outputType) } + test("roundtrip empty arrow batches") { + val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) + val ctx = TaskContext.empty() + val batchIter = + ArrowConverters.toBatchWithSchemaIterator(Iterator.empty, schema, 5, 1024 * 1024, null) + val (outputRowIter, outputType) = ArrowConverters.fromBatchWithSchemaIterator(batchIter, ctx) + + assert(0 == outputRowIter.length) + assert(outputType == null) + } + + test("two batches with different schema") { + val schema1 = StructType(Seq(StructField("field1", IntegerType, nullable = true))) + val inputRows1 = Array(InternalRow(1)).map { row => + val proj = UnsafeProjection.create(schema1) + proj(row).copy() + } + val batchIter1 = ArrowConverters.toBatchWithSchemaIterator( + inputRows1.iterator, schema1, 5, 1024 * 1024, null) + + val schema2 = StructType(Seq(StructField("field2", IntegerType, nullable = true))) + val inputRows2 = Array(InternalRow(1)).map { row => + val proj = UnsafeProjection.create(schema2) + proj(row).copy() + } + val batchIter2 = ArrowConverters.toBatchWithSchemaIterator( + inputRows2.iterator, schema2, 5, 1024 * 1024, null) + + val iter = batchIter1.toArray ++ batchIter2 + + val ctx = TaskContext.empty() + intercept[IllegalArgumentException] { + ArrowConverters.fromBatchWithSchemaIterator(iter.iterator, ctx)._1.toArray + } + } + /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate( - df: DataFrame, - json: String, - file: String, - timeZoneId: String = null): Unit = { + df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val batchBytes = df.coalesce(1).toArrowBatchRdd.collect().head val tempFile = new File(tempDataPath, file) From c288a251ca456ff34c0599a031c08e136b65f987 Mon Sep 17 00:00:00 2001 From: dengziming Date: Tue, 22 Nov 2022 20:41:09 +0800 Subject: [PATCH 3/5] resolve comments --- .../planner/SparkConnectProtoSuite.scala | 8 +- .../scala/org/apache/spark/util/Utils.scala | 8 - .../sql/connect/proto/relations_pb2.py | 150 +++++++------- .../sql/connect/proto/relations_pb2.pyi | 48 +++++ .../sql/execution/arrow/ArrowConverters.scala | 188 ++++++++---------- 5 files changed, 222 insertions(+), 180 deletions(-) diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index f924f3811c25..4114175e5101 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -48,17 +48,17 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { lazy val connectTestRelation = createLocalRelationProto( Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()), - Seq()) + Seq.empty) lazy val connectTestRelation2 = createLocalRelationProto( Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()), - Seq()) + Seq.empty) lazy val connectTestRelationMap = createLocalRelationProto( Seq(AttributeReference("id", MapType(StringType, StringType))()), - Seq()) + Seq.empty) lazy val sparkTestRelation: DataFrame = spark.createDataFrame( @@ -76,7 +76,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { StructType(Seq(StructField("id", MapType(StringType, StringType))))) lazy val localRelation = - createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()), Seq()) + createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()), Seq.empty) test("Basic select") { val connectPlan = connectTestRelation.select("id".protoAttr) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2b596ace78c6..70477a5c9c08 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -3257,14 +3257,6 @@ private[spark] object Utils extends Logging { case _ => math.max(sortedSize(len / 2), 1) } } - - def closeAll(closeables: AutoCloseable*): Unit = { - for (closeable <- closeables) { - if (closeable != null) { - closeable.close() - } - } - } } private[util] object CallerContext extends Logging { diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 787a055773b4..0ba0d8917cc2 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -33,7 +33,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xa6\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"#\n\rLocalRelation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xd1\x0b\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12|\n#rename_columns_by_same_length_names\x18\x12 \x01(\x0b\x32-.spark.connect.RenameColumnsBySameLengthNamesH\x00R\x1erenameColumnsBySameLengthNames\x12w\n"rename_columns_by_name_to_name_map\x18\x13 \x01(\x0b\x32+.spark.connect.RenameColumnsByNameToNameMapH\x00R\x1crenameColumnsByNameToNameMap\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"1\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\xaa\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xcf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x00R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\xc2\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\x8c\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_name"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xd2\x01\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12H\n\x12result_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x11resultExpressions"\xa6\x04\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02\x42\x0c\n\n_is_global"d\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12-\n\x04\x63ols\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x04\x63ols"\xab\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x42\x16\n\x14_all_columns_as_keys"#\n\rLocalRelation\x12\x12\n\x04\x64\x61ta\x18\x01 \x01(\x0cR\x04\x64\x61ta"\xe0\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x42\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8d\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x18\n\x07numRows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"r\n\x1eRenameColumnsBySameLengthNames\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\x83\x02\n\x1cRenameColumnsByNameToNameMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12o\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x41.spark.connect.RenameColumnsByNameToNameMap.RenameColumnsMapEntryR\x10renameColumnsMap\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) @@ -54,6 +54,7 @@ _AGGREGATE = DESCRIPTOR.message_types_by_name["Aggregate"] _SORT = DESCRIPTOR.message_types_by_name["Sort"] _SORT_SORTFIELD = _SORT.nested_types_by_name["SortField"] +_DROP = DESCRIPTOR.message_types_by_name["Drop"] _DEDUPLICATE = DESCRIPTOR.message_types_by_name["Deduplicate"] _LOCALRELATION = DESCRIPTOR.message_types_by_name["LocalRelation"] _SAMPLE = DESCRIPTOR.message_types_by_name["Sample"] @@ -256,6 +257,17 @@ _sym_db.RegisterMessage(Sort) _sym_db.RegisterMessage(Sort.SortField) +Drop = _reflection.GeneratedProtocolMessageType( + "Drop", + (_message.Message,), + { + "DESCRIPTOR": _DROP, + "__module__": "spark.connect.relations_pb2" + # @@protoc_insertion_point(class_scope:spark.connect.Drop) + }, +) +_sym_db.RegisterMessage(Drop) + Deduplicate = _reflection.GeneratedProtocolMessageType( "Deduplicate", (_message.Message,), @@ -407,71 +419,73 @@ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._options = None _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 - _RELATION._serialized_end = 1528 - _UNKNOWN._serialized_start = 1530 - _UNKNOWN._serialized_end = 1539 - _RELATIONCOMMON._serialized_start = 1541 - _RELATIONCOMMON._serialized_end = 1590 - _SQL._serialized_start = 1592 - _SQL._serialized_end = 1619 - _READ._serialized_start = 1622 - _READ._serialized_end = 2048 - _READ_NAMEDTABLE._serialized_start = 1764 - _READ_NAMEDTABLE._serialized_end = 1825 - _READ_DATASOURCE._serialized_start = 1828 - _READ_DATASOURCE._serialized_end = 2035 - _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1966 - _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2024 - _PROJECT._serialized_start = 2050 - _PROJECT._serialized_end = 2167 - _FILTER._serialized_start = 2169 - _FILTER._serialized_end = 2281 - _JOIN._serialized_start = 2284 - _JOIN._serialized_end = 2734 - _JOIN_JOINTYPE._serialized_start = 2547 - _JOIN_JOINTYPE._serialized_end = 2734 - _SETOPERATION._serialized_start = 2737 - _SETOPERATION._serialized_end = 3133 - _SETOPERATION_SETOPTYPE._serialized_start = 2996 - _SETOPERATION_SETOPTYPE._serialized_end = 3110 - _LIMIT._serialized_start = 3135 - _LIMIT._serialized_end = 3211 - _OFFSET._serialized_start = 3213 - _OFFSET._serialized_end = 3292 - _AGGREGATE._serialized_start = 3295 - _AGGREGATE._serialized_end = 3505 - _SORT._serialized_start = 3508 - _SORT._serialized_end = 4058 - _SORT_SORTFIELD._serialized_start = 3662 - _SORT_SORTFIELD._serialized_end = 3850 - _SORT_SORTDIRECTION._serialized_start = 3852 - _SORT_SORTDIRECTION._serialized_end = 3960 - _SORT_SORTNULLS._serialized_start = 3962 - _SORT_SORTNULLS._serialized_end = 4044 - _DEDUPLICATE._serialized_start = 4061 - _DEDUPLICATE._serialized_end = 4232 - _LOCALRELATION._serialized_start = 4234 - _LOCALRELATION._serialized_end = 4269 - _SAMPLE._serialized_start = 4272 - _SAMPLE._serialized_end = 4496 - _RANGE._serialized_start = 4499 - _RANGE._serialized_end = 4644 - _SUBQUERYALIAS._serialized_start = 4646 - _SUBQUERYALIAS._serialized_end = 4760 - _REPARTITION._serialized_start = 4763 - _REPARTITION._serialized_end = 4905 - _SHOWSTRING._serialized_start = 4908 - _SHOWSTRING._serialized_end = 5049 - _STATSUMMARY._serialized_start = 5051 - _STATSUMMARY._serialized_end = 5143 - _STATCROSSTAB._serialized_start = 5145 - _STATCROSSTAB._serialized_end = 5246 - _NAFILL._serialized_start = 5249 - _NAFILL._serialized_end = 5383 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5385 - _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5499 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5502 - _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5761 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5694 - _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5761 + _RELATION._serialized_end = 1571 + _UNKNOWN._serialized_start = 1573 + _UNKNOWN._serialized_end = 1582 + _RELATIONCOMMON._serialized_start = 1584 + _RELATIONCOMMON._serialized_end = 1633 + _SQL._serialized_start = 1635 + _SQL._serialized_end = 1662 + _READ._serialized_start = 1665 + _READ._serialized_end = 2091 + _READ_NAMEDTABLE._serialized_start = 1807 + _READ_NAMEDTABLE._serialized_end = 1868 + _READ_DATASOURCE._serialized_start = 1871 + _READ_DATASOURCE._serialized_end = 2078 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 2009 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 2067 + _PROJECT._serialized_start = 2093 + _PROJECT._serialized_end = 2210 + _FILTER._serialized_start = 2212 + _FILTER._serialized_end = 2324 + _JOIN._serialized_start = 2327 + _JOIN._serialized_end = 2777 + _JOIN_JOINTYPE._serialized_start = 2590 + _JOIN_JOINTYPE._serialized_end = 2777 + _SETOPERATION._serialized_start = 2780 + _SETOPERATION._serialized_end = 3176 + _SETOPERATION_SETOPTYPE._serialized_start = 3039 + _SETOPERATION_SETOPTYPE._serialized_end = 3153 + _LIMIT._serialized_start = 3178 + _LIMIT._serialized_end = 3254 + _OFFSET._serialized_start = 3256 + _OFFSET._serialized_end = 3335 + _AGGREGATE._serialized_start = 3338 + _AGGREGATE._serialized_end = 3548 + _SORT._serialized_start = 3551 + _SORT._serialized_end = 4101 + _SORT_SORTFIELD._serialized_start = 3705 + _SORT_SORTFIELD._serialized_end = 3893 + _SORT_SORTDIRECTION._serialized_start = 3895 + _SORT_SORTDIRECTION._serialized_end = 4003 + _SORT_SORTNULLS._serialized_start = 4005 + _SORT_SORTNULLS._serialized_end = 4087 + _DROP._serialized_start = 4103 + _DROP._serialized_end = 4203 + _DEDUPLICATE._serialized_start = 4206 + _DEDUPLICATE._serialized_end = 4377 + _LOCALRELATION._serialized_start = 4379 + _LOCALRELATION._serialized_end = 4414 + _SAMPLE._serialized_start = 4417 + _SAMPLE._serialized_end = 4641 + _RANGE._serialized_start = 4644 + _RANGE._serialized_end = 4789 + _SUBQUERYALIAS._serialized_start = 4791 + _SUBQUERYALIAS._serialized_end = 4905 + _REPARTITION._serialized_start = 4908 + _REPARTITION._serialized_end = 5050 + _SHOWSTRING._serialized_start = 5053 + _SHOWSTRING._serialized_end = 5194 + _STATSUMMARY._serialized_start = 5196 + _STATSUMMARY._serialized_end = 5288 + _STATCROSSTAB._serialized_start = 5290 + _STATCROSSTAB._serialized_end = 5391 + _NAFILL._serialized_start = 5394 + _NAFILL._serialized_end = 5528 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 5530 + _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 5644 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 5647 + _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 5906 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start = 5839 + _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 5906 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index ef28e5567374..a6a16b448c91 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -79,6 +79,7 @@ class Relation(google.protobuf.message.Message): RENAME_COLUMNS_BY_SAME_LENGTH_NAMES_FIELD_NUMBER: builtins.int RENAME_COLUMNS_BY_NAME_TO_NAME_MAP_FIELD_NUMBER: builtins.int SHOW_STRING_FIELD_NUMBER: builtins.int + DROP_FIELD_NUMBER: builtins.int FILL_NA_FIELD_NUMBER: builtins.int SUMMARY_FIELD_NUMBER: builtins.int CROSSTAB_FIELD_NUMBER: builtins.int @@ -124,6 +125,8 @@ class Relation(google.protobuf.message.Message): @property def show_string(self) -> global___ShowString: ... @property + def drop(self) -> global___Drop: ... + @property def fill_na(self) -> global___NAFill: """NA functions""" @property @@ -156,6 +159,7 @@ class Relation(google.protobuf.message.Message): rename_columns_by_same_length_names: global___RenameColumnsBySameLengthNames | None = ..., rename_columns_by_name_to_name_map: global___RenameColumnsByNameToNameMap | None = ..., show_string: global___ShowString | None = ..., + drop: global___Drop | None = ..., fill_na: global___NAFill | None = ..., summary: global___StatSummary | None = ..., crosstab: global___StatCrosstab | None = ..., @@ -172,6 +176,8 @@ class Relation(google.protobuf.message.Message): b"crosstab", "deduplicate", b"deduplicate", + "drop", + b"drop", "fill_na", b"fill_na", "filter", @@ -227,6 +233,8 @@ class Relation(google.protobuf.message.Message): b"crosstab", "deduplicate", b"deduplicate", + "drop", + b"drop", "fill_na", b"fill_na", "filter", @@ -293,6 +301,7 @@ class Relation(google.protobuf.message.Message): "rename_columns_by_same_length_names", "rename_columns_by_name_to_name_map", "show_string", + "drop", "fill_na", "summary", "crosstab", @@ -961,6 +970,42 @@ class Sort(google.protobuf.message.Message): global___Sort = Sort +class Drop(google.protobuf.message.Message): + """Drop specified columns.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + INPUT_FIELD_NUMBER: builtins.int + COLS_FIELD_NUMBER: builtins.int + @property + def input(self) -> global___Relation: + """(Required) The input relation.""" + @property + def cols( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + pyspark.sql.connect.proto.expressions_pb2.Expression + ]: + """(Required) columns to drop. + + Should contain at least 1 item. + """ + def __init__( + self, + *, + input: global___Relation | None = ..., + cols: collections.abc.Iterable[pyspark.sql.connect.proto.expressions_pb2.Expression] + | None = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["input", b"input"] + ) -> builtins.bool: ... + def ClearField( + self, field_name: typing_extensions.Literal["cols", b"cols", "input", b"input"] + ) -> None: ... + +global___Drop = Drop + class Deduplicate(google.protobuf.message.Message): """Relation of type [[Deduplicate]] which have duplicate rows removed, could consider either only the subset of columns or all the columns. @@ -1032,6 +1077,9 @@ class LocalRelation(google.protobuf.message.Message): DATA_FIELD_NUMBER: builtins.int data: builtins.bytes + """Local collection data serialized into Arrow IPC streaming format which contains + the schema of the data. + """ def __init__( self, *, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 223f9652e063..35302e315e8b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, Ou import java.nio.channels.{Channels, ReadableByteChannel} import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.arrow.flatbuf.MessageHeader import org.apache.arrow.memory.BufferAllocator @@ -213,51 +214,94 @@ private[sql] object ArrowConverters extends Logging { }.next() } - /** - * Maps iterator from serialized ArrowRecordBatches to InternalRows. - */ - private[sql] def fromBatchIterator( + private[sql] abstract class InternalRowIterator( arrowBatchIter: Iterator[Array[Byte]], - schema: StructType, - timeZoneId: String, - context: TaskContext): Iterator[InternalRow] = { - val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromBatchIterator", 0, Long.MaxValue) - - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - - new Iterator[InternalRow] { - private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty + context: TaskContext) + extends Iterator[InternalRow] { + // Keep all the resources we have opened in order, should be closed in reverse order finally. + val resources = new ArrayBuffer[AutoCloseable]() + protected val allocator: BufferAllocator = ArrowUtils.rootAllocator.newChildAllocator( + s"to${this.getClass.getSimpleName}", + 0, + Long.MaxValue) + resources.append(allocator) + + private var rowIterAndSchema = + if (arrowBatchIter.hasNext) nextBatch() else (Iterator.empty, null) + // We will ensure schemas parsed from every batch are the same + val schema: StructType = rowIterAndSchema._2 - if (context != null) context.addTaskCompletionListener[Unit] { _ => - root.close() - allocator.close() - } + if (context != null) context.addTaskCompletionListener[Unit] { _ => + closeAll(resources.reverse: _*) + } - override def hasNext: Boolean = rowIter.hasNext || { - if (arrowBatchIter.hasNext) { - rowIter = nextBatch() - true - } else { - root.close() - allocator.close() - false + override def hasNext: Boolean = rowIterAndSchema._1.hasNext || { + if (arrowBatchIter.hasNext) { + rowIterAndSchema = nextBatch() + if (schema != rowIterAndSchema._2) { + throw new IllegalArgumentException( + s"ArrowBatch iterator contain 2 batches with" + + s" different schema: $schema and ${rowIterAndSchema._2}") } + rowIterAndSchema._1.hasNext + } else { + closeAll(resources.reverse: _*) + false } + } - override def next(): InternalRow = rowIter.next() + override def next(): InternalRow = rowIterAndSchema._1.next() + + def nextBatch(): (Iterator[InternalRow], StructType) + } - private def nextBatch(): Iterator[InternalRow] = { - val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) - val vectorLoader = new VectorLoader(root) - vectorLoader.load(arrowRecordBatch) - arrowRecordBatch.close() - vectorSchemaRootToIter(root) + private[sql] class InternalRowIteratorWithoutSchema( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext) + extends InternalRowIterator(arrowBatchIter, context) { + + override def nextBatch(): (Iterator[InternalRow], StructType) = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + resources.append(root) + val arrowRecordBatch = ArrowConverters.loadBatch(arrowBatchIter.next(), allocator) + val vectorLoader = new VectorLoader(root) + vectorLoader.load(arrowRecordBatch) + arrowRecordBatch.close() + (vectorSchemaRootToIter(root), schema) + } + } + + private[sql] class InternalRowIteratorWithSchema( + arrowBatchIter: Iterator[Array[Byte]], + context: TaskContext) + extends InternalRowIterator(arrowBatchIter, context) { + override def nextBatch(): (Iterator[InternalRow], StructType) = { + val reader = + new ArrowStreamReader(new ByteArrayInputStream(arrowBatchIter.next()), allocator) + val root = if (reader.loadNextBatch()) reader.getVectorSchemaRoot else null + resources.append(reader, root) + if (root == null) { + (Iterator.empty, null) + } else { + (vectorSchemaRootToIter(root), ArrowUtils.fromArrowSchema(root.getSchema)) } } } + /** + * Maps iterator from serialized ArrowRecordBatches to InternalRows. + */ + private[sql] def fromBatchIterator( + arrowBatchIter: Iterator[Array[Byte]], + schema: StructType, + timeZoneId: String, + context: TaskContext): Iterator[InternalRow] = new InternalRowIteratorWithoutSchema( + arrowBatchIter, schema, timeZoneId, context + ) + /** * Maps iterator from serialized ArrowRecordBatches to InternalRows. Different from * [[fromBatchIterator]], each input arrow batch starts with the schema. @@ -265,72 +309,8 @@ private[sql] object ArrowConverters extends Logging { private[sql] def fromBatchWithSchemaIterator( arrowBatchIter: Iterator[Array[Byte]], context: TaskContext): (Iterator[InternalRow], StructType) = { - var structType: StructType = null - val allocator = - ArrowUtils.rootAllocator.newChildAllocator("fromBatchWithSchemaIterator", 0, Long.MaxValue) - - val iter = new Iterator[InternalRow] { - private var rowIter = if (arrowBatchIter.hasNext) nextBatch() else Iterator.empty - - if (context != null) context.addTaskCompletionListener[Unit] { _ => - allocator.close() - } - - override def hasNext: Boolean = rowIter.hasNext || { - if (arrowBatchIter.hasNext) { - rowIter = nextBatch() - rowIter.hasNext - } else { - Utils.closeAll(allocator) - false - } - } - - override def next(): InternalRow = { - rowIter.next() - } - - private def nextBatch(): Iterator[InternalRow] = { - val rowsAndType = fromBatchWithSchemaBuffer(arrowBatchIter.next(), allocator, context) - if (structType == null) { - structType = rowsAndType._2 - } else if (structType != rowsAndType._2) { - throw new IllegalArgumentException(s"ArrowBatch iterator contain 2 batches with" + - s" different schema: $structType and ${rowsAndType._2}") - } - rowsAndType._1 - } - } - (iter, structType) - } - - private def fromBatchWithSchemaBuffer( - arrowBuffer: Array[Byte], - allocator: BufferAllocator, - context: TaskContext): (Iterator[InternalRow], StructType) = { - val reader = new ArrowStreamReader(new ByteArrayInputStream(arrowBuffer), allocator) - - val root = if (reader.loadNextBatch()) reader.getVectorSchemaRoot else null - val structType = - if (root == null) new StructType() else ArrowUtils.fromArrowSchema(root.getSchema) - - if (context != null) context.addTaskCompletionListener[Unit] { _ => - Utils.closeAll(root, reader) - } - val inner = vectorSchemaRootToIter(root) - val iter = new Iterator[InternalRow] { - override def hasNext: Boolean = { - if (inner.hasNext) { - true - } else { - Utils.closeAll(root, reader) - false - } - } - - override def next(): InternalRow = inner.next() - } - (iter, structType) + val iterator = new InternalRowIteratorWithSchema(arrowBatchIter, context) + (iterator, iterator.schema) } private def vectorSchemaRootToIter(root: VectorSchemaRoot): Iterator[InternalRow] = { @@ -469,4 +449,12 @@ private[sql] object ArrowConverters extends Logging { } } } + + private def closeAll(closeables: AutoCloseable*): Unit = { + for (closeable <- closeables) { + if (closeable != null) { + closeable.close() + } + } + } } From 7c2c0a1c41e82167d209d159d826ade6973f5791 Mon Sep 17 00:00:00 2001 From: dengziming Date: Wed, 23 Nov 2022 15:35:41 +0800 Subject: [PATCH 4/5] comments on new classes and new methods && scala 2.13 fix --- .../sql/execution/arrow/ArrowConverters.scala | 21 ++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 35302e315e8b..40de117f6f67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -214,6 +214,10 @@ private[sql] object ArrowConverters extends Logging { }.next() } + /** + * An InternalRow iterator which parse data from serialized ArrowRecordBatches, subclass should + * implement [[nextBatch]] to parse data from binary records. + */ private[sql] abstract class InternalRowIterator( arrowBatchIter: Iterator[Array[Byte]], context: TaskContext) @@ -228,11 +232,11 @@ private[sql] object ArrowConverters extends Logging { private var rowIterAndSchema = if (arrowBatchIter.hasNext) nextBatch() else (Iterator.empty, null) - // We will ensure schemas parsed from every batch are the same + // We will ensure schemas parsed from every batch are the same. val schema: StructType = rowIterAndSchema._2 if (context != null) context.addTaskCompletionListener[Unit] { _ => - closeAll(resources.reverse: _*) + closeAll(resources.toSeq.reverse: _*) } override def hasNext: Boolean = rowIterAndSchema._1.hasNext || { @@ -245,7 +249,7 @@ private[sql] object ArrowConverters extends Logging { } rowIterAndSchema._1.hasNext } else { - closeAll(resources.reverse: _*) + closeAll(resources.toSeq.reverse: _*) false } } @@ -255,6 +259,10 @@ private[sql] object ArrowConverters extends Logging { def nextBatch(): (Iterator[InternalRow], StructType) } + /** + * Parse data from serialized ArrowRecordBatches, the [[arrowBatchIter]] only contains serialized + * arrow batch records, the schema is passed in through [[schema]]. + */ private[sql] class InternalRowIteratorWithoutSchema( arrowBatchIter: Iterator[Array[Byte]], schema: StructType, @@ -274,6 +282,10 @@ private[sql] object ArrowConverters extends Logging { } } + /** + * Parse data from serialized ArrowRecordBatches, the arrowBatch in [[arrowBatchIter]] starts with + * the schema so we should parse schema from it first. + */ private[sql] class InternalRowIteratorWithSchema( arrowBatchIter: Iterator[Array[Byte]], context: TaskContext) @@ -313,6 +325,9 @@ private[sql] object ArrowConverters extends Logging { (iterator, iterator.schema) } + /** + * Convert an arrow batch container into an iterator of InternalRow. + */ private def vectorSchemaRootToIter(root: VectorSchemaRoot): Iterator[InternalRow] = { val columns = root.getFieldVectors.asScala.map { vector => new ArrowColumnVector(vector).asInstanceOf[ColumnVector] From 21b6482656d6e36cd4e3e09ad4834bec3b9b52bc Mon Sep 17 00:00:00 2001 From: dengziming Date: Wed, 23 Nov 2022 23:20:19 +0800 Subject: [PATCH 5/5] MINOR: resolve comments --- .../apache/spark/sql/connect/planner/SparkConnectPlanner.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index d21479352fea..9c4299d652fd 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -275,7 +275,7 @@ class SparkConnectPlanner(session: SparkSession) { private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = { val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator( - Seq(rel.getData.toByteArray).iterator, + Iterator(rel.getData.toByteArray), TaskContext.get()) val attributes = structType.toAttributes val proj = UnsafeProjection.create(attributes, attributes)