Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ message Deduplicate {

// A relation that does not need to be qualified by name.
message LocalRelation {
// (Optional) A list qualified attributes.
repeated Expression.QualifiedAttribute attributes = 1;
// TODO: support local data.
// Local collection data serialized into Arrow IPC streaming format which contains
// the schema of the data.
bytes data = 1;
}

// Relation of type [[Sample]] that samples a fraction of the dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,23 @@ import scala.collection.mutable

import com.google.common.collect.{Lists, Maps}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.WriteOperation
import org.apache.spark.sql.{Column, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Expression, NamedExpression, UnsafeProjection}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException}
import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter, UsingJoin}
import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Except, Intersect, LocalRelation, LogicalPlan, Sample, SubqueryAlias, Union}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -272,8 +274,12 @@ class SparkConnectPlanner(session: SparkSession) {
}

private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
val attributes = rel.getAttributesList.asScala.map(transformAttribute(_)).toSeq
new org.apache.spark.sql.catalyst.plans.logical.LocalRelation(attributes)
val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
Iterator(rel.getData.toByteArray),
TaskContext.get())
val attributes = structType.toAttributes
val proj = UnsafeProjection.create(attributes, attributes)
new logical.LocalRelation(attributes, rows.map(r => proj(r).copy()).toSeq)
}

private def transformAttribute(exp: proto.Expression.QualifiedAttribute): Attribute = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@ package org.apache.spark.sql.connect.planner

import scala.collection.JavaConverters._

import com.google.protobuf.ByteString

import org.apache.spark.SparkFunSuite
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Expression.UnresolvedStar
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String

/**
* Testing trait for SparkConnect tests with some helper methods to make it easier to create new
Expand Down Expand Up @@ -55,17 +62,26 @@ trait SparkConnectPlanTest extends SharedSparkSession {
* equivalent in Catalyst and can be easily used for planner testing.
*
* @param attrs
* the attributes of LocalRelation
* @param data
* the data of LocalRelation
* @return
*/
def createLocalRelationProto(attrs: Seq[AttributeReference]): proto.Relation = {
def createLocalRelationProto(
attrs: Seq[AttributeReference],
data: Seq[InternalRow]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
for (attr <- attrs) {
localRelationBuilder.addAttributes(
proto.Expression.QualifiedAttribute
.newBuilder()
.setName(attr.name)
.setType(DataTypeProtoConverter.toConnectProtoType(attr.dataType)))
}

val bytes = ArrowConverters
.toBatchWithSchemaIterator(
data.iterator,
StructType.fromAttributes(attrs.map(_.toAttribute)),
Long.MaxValue,
Long.MaxValue,
null)
.next()

localRelationBuilder.setData(ByteString.copyFrom(bytes))
proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()
}
}
Expand Down Expand Up @@ -96,7 +112,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
new SparkConnectPlanner(None.orNull)
.transformRelation(
proto.Relation.newBuilder.setUnknown(proto.Unknown.newBuilder().build()).build()))

}

test("Simple Read") {
Expand Down Expand Up @@ -199,7 +214,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
}

test("Simple Join") {

val incompleteJoin =
proto.Relation.newBuilder.setJoin(proto.Join.newBuilder.setLeft(readRel)).build()
intercept[AssertionError](transform(incompleteJoin))
Expand Down Expand Up @@ -267,7 +281,6 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {

val res = transform(proto.Relation.newBuilder.setProject(project).build())
assert(res.nodeName == "Project")

}

test("Simple Aggregation") {
Expand Down Expand Up @@ -354,4 +367,61 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest {
transform(proto.Relation.newBuilder.setSetOp(intersect).build()))
assert(e2.getMessage.contains("Intersect does not support union_by_name"))
}

test("transform LocalRelation") {
val rows = (0 until 10).map { i =>
InternalRow(i, UTF8String.fromString(s"str-$i"), InternalRow(i))
}

val schema = StructType(
Seq(
StructField("int", IntegerType),
StructField("str", StringType),
StructField("struct", StructType(Seq(StructField("inner", IntegerType))))))
val inputRows = rows.map { row =>
val proj = UnsafeProjection.create(schema)
proj(row).copy()
}

val localRelation = createLocalRelationProto(schema.toAttributes, inputRows)
val df = Dataset.ofRows(spark, transform(localRelation))
val array = df.collect()
assertResult(10)(array.length)
assert(schema == df.schema)
for (i <- 0 until 10) {
assert(i == array(i).getInt(0))
assert(s"str-$i" == array(i).getString(1))
assert(i == array(i).getStruct(2).getInt(0))
}
}

test("Empty ArrowBatch") {
val schema = StructType(Seq(StructField("int", IntegerType)))
val data = ArrowConverters.createEmptyArrowBatch(schema, null)
val localRelation = proto.Relation
.newBuilder()
.setLocalRelation(
proto.LocalRelation
.newBuilder()
.setData(ByteString.copyFrom(data))
.build())
.build()
val df = Dataset.ofRows(spark, transform(localRelation))
assert(schema == df.schema)
assert(df.isEmpty)
}

test("Illegal LocalRelation data") {
intercept[Exception] {
transform(
proto.Relation
.newBuilder()
.setLocalRelation(
proto.LocalRelation
.newBuilder()
.setData(ByteString.copyFrom("illegal".getBytes()))
.build())
.build())
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package org.apache.spark.sql.connect.planner

import java.nio.file.{Files, Paths}

import com.google.protobuf.ByteString

import org.apache.spark.SparkClassNotFoundException
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.Join.JoinType
Expand All @@ -31,6 +33,7 @@ import org.apache.spark.sql.connect.dsl.MockRemoteSession
import org.apache.spark.sql.connect.dsl.commands._
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{IntegerType, MapType, Metadata, StringType, StructField, StructType}

Expand All @@ -44,14 +47,18 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {

lazy val connectTestRelation =
createLocalRelationProto(
Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()))
Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()),
Seq.empty)

lazy val connectTestRelation2 =
createLocalRelationProto(
Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()))
Seq(AttributeReference("id", IntegerType)(), AttributeReference("name", StringType)()),
Seq.empty)

lazy val connectTestRelationMap =
createLocalRelationProto(Seq(AttributeReference("id", MapType(StringType, StringType))()))
createLocalRelationProto(
Seq(AttributeReference("id", MapType(StringType, StringType))()),
Seq.empty)

lazy val sparkTestRelation: DataFrame =
spark.createDataFrame(
Expand All @@ -68,7 +75,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
new java.util.ArrayList[Row](),
StructType(Seq(StructField("id", MapType(StringType, StringType)))))

lazy val localRelation = createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()))
lazy val localRelation =
createLocalRelationProto(Seq(AttributeReference("id", IntegerType)()), Seq.empty)

test("Basic select") {
val connectPlan = connectTestRelation.select("id".protoAttr)
Expand Down Expand Up @@ -500,10 +508,21 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
private def createLocalRelationProtoByQualifiedAttributes(
attrs: Seq[proto.Expression.QualifiedAttribute]): proto.Relation = {
val localRelationBuilder = proto.LocalRelation.newBuilder()
for (attr <- attrs) {
localRelationBuilder.addAttributes(attr)
}
proto.Relation.newBuilder().setLocalRelation(localRelationBuilder.build()).build()

val attributes = attrs.map(exp =>
AttributeReference(exp.getName, DataTypeProtoConverter.toCatalystType(exp.getType))())
val buffer = ArrowConverters
.toBatchWithSchemaIterator(
Iterator.empty,
StructType.fromAttributes(attributes),
Long.MaxValue,
Long.MaxValue,
null)
.next()
proto.Relation
.newBuilder()
.setLocalRelation(localRelationBuilder.setData(ByteString.copyFrom(buffer)).build())
.build()
}

// This is a function for testing only. This is used when the plan is ready and it only waits
Expand Down
Loading