diff --git a/connector/connect/src/main/protobuf/spark/connect/relations.proto b/connector/connect/src/main/protobuf/spark/connect/relations.proto index 353fbebd0460..eadedf495d3e 100644 --- a/connector/connect/src/main/protobuf/spark/connect/relations.proto +++ b/connector/connect/src/main/protobuf/spark/connect/relations.proto @@ -67,11 +67,21 @@ message SQL { message Read { oneof read_type { NamedTable named_table = 1; + DataSource data_source = 2; } message NamedTable { string unparsed_identifier = 1; } + + message DataSource { + // Required. Supported formats include: parquet, orc, text, json, parquet, csv, avro. + string format = 1; + // Optional. If not set, Spark will infer the schema. + string schema = 2; + // The key is case insensitive. + map options = 3; + } } // Projection of a bag of expressions for a given input relation. diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 6a6b5a15a087..450283a9b81f 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeRef import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.{logical, FullOuter, Inner, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sample, SubqueryAlias} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.types._ final case class InvalidPlanInput( @@ -112,7 +113,19 @@ class SparkConnectPlanner(plan: proto.Relation, session: SparkSession) { } else { child } - case _ => throw InvalidPlanInput() + case proto.Read.ReadTypeCase.DATA_SOURCE => + if (rel.getDataSource.getFormat == "") { + throw InvalidPlanInput("DataSource requires a format") + } + val localMap = CaseInsensitiveMap[String](rel.getDataSource.getOptionsMap.asScala.toMap) + val reader = session.read + reader.format(rel.getDataSource.getFormat) + localMap.foreach { case (key, value) => reader.option(key, value) } + if (rel.getDataSource.getSchema != null) { + reader.schema(rel.getDataSource.getSchema) + } + reader.load().queryExecution.analyzed + case _ => throw InvalidPlanInput("Does not support " + rel.getReadTypeCase.name()) } baseRelation } diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index fc3d219ec6ba..83bf76efce1d 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -255,4 +255,15 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { assert(res.nodeName == "Aggregate") } + test("Invalid DataSource") { + val dataSource = proto.Read.DataSource.newBuilder() + + val e = intercept[InvalidPlanInput]( + transform( + proto.Relation + .newBuilder() + .setRead(proto.Read.newBuilder().setDataSource(dataSource)) + .build())) + assert(e.getMessage.contains("DataSource requires a format")) + } } diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 5fcd468924d8..c564b71cdba6 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -23,6 +23,7 @@ Union, cast, TYPE_CHECKING, + Mapping, ) import pyspark.sql.connect.proto as proto @@ -111,6 +112,46 @@ def _child_repr_(self) -> str: return self._child._repr_html_() if self._child is not None else "" +class DataSource(LogicalPlan): + """A datasource with a format and optional a schema from which Spark reads data""" + + def __init__( + self, + format: str = "", + schema: Optional[str] = None, + options: Optional[Mapping[str, str]] = None, + ) -> None: + super().__init__(None) + self.format = format + self.schema = schema + self.options = options + + def plan(self, session: Optional["RemoteSparkSession"]) -> proto.Relation: + plan = proto.Relation() + if self.format is not None: + plan.read.data_source.format = self.format + if self.schema is not None: + plan.read.data_source.schema = self.schema + if self.options is not None: + for k in self.options.keys(): + v = self.options.get(k) + if v is not None: + plan.read.data_source.options[k] = v + return plan + + def _repr_html_(self) -> str: + return f""" + + """ + + class Read(LogicalPlan): def __init__(self, table_name: str) -> None: super().__init__(None) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index c3b7b7ec2eaf..b244cdf8dcb9 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -32,7 +32,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcf\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x95\x01\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifierB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xb8\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12\x12\n\x04seed\x18\x05 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x1fspark/connect/expressions.proto"\xcf\x05\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12,\n\x05union\x18\x06 \x01(\x0b\x32\x14.spark.connect.UnionH\x00R\x05union\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\t\n\x07Unknown"G\n\x0eRelationCommon\x12\x1f\n\x0bsource_info\x18\x01 \x01(\tR\nsourceInfo\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias"\x1b\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query"\x9a\x03\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x1a=\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x1a\xbf\x01\n\nDataSource\x12\x16\n\x06\x66ormat\x18\x01 \x01(\tR\x06\x66ormat\x12\x16\n\x06schema\x18\x02 \x01(\tR\x06schema\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x9d\x03\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType"\xbb\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06"\xcd\x01\n\x05Union\x12/\n\x06inputs\x18\x01 \x03(\x0b\x32\x17.spark.connect.RelationR\x06inputs\x12=\n\nunion_type\x18\x02 \x01(\x0e\x32\x1e.spark.connect.Union.UnionTypeR\tunionType"T\n\tUnionType\x12\x1a\n\x16UNION_TYPE_UNSPECIFIED\x10\x00\x12\x17\n\x13UNION_TYPE_DISTINCT\x10\x01\x12\x12\n\x0eUNION_TYPE_ALL\x10\x02"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"\xc5\x02\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12Y\n\x12result_expressions\x18\x03 \x03(\x0b\x32*.spark.connect.Aggregate.AggregateFunctionR\x11resultExpressions\x1a`\n\x11\x41ggregateFunction\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"\xf6\x03\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\x0bsort_fields\x18\x02 \x03(\x0b\x32\x1d.spark.connect.Sort.SortFieldR\nsortFields\x1a\xbc\x01\n\tSortField\x12\x39\n\nexpression\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\nexpression\x12?\n\tdirection\x18\x02 \x01(\x0e\x32!.spark.connect.Sort.SortDirectionR\tdirection\x12\x33\n\x05nulls\x18\x03 \x01(\x0e\x32\x1d.spark.connect.Sort.SortNullsR\x05nulls"l\n\rSortDirection\x12\x1e\n\x1aSORT_DIRECTION_UNSPECIFIED\x10\x00\x12\x1c\n\x18SORT_DIRECTION_ASCENDING\x10\x01\x12\x1d\n\x19SORT_DIRECTION_DESCENDING\x10\x02"R\n\tSortNulls\x12\x1a\n\x16SORT_NULLS_UNSPECIFIED\x10\x00\x12\x14\n\x10SORT_NULLS_FIRST\x10\x01\x12\x13\n\x0fSORT_NULLS_LAST\x10\x02"]\n\rLocalRelation\x12L\n\nattributes\x18\x01 \x03(\x0b\x32,.spark.connect.Expression.QualifiedAttributeR\nattributes"\xb8\x01\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12)\n\x10with_replacement\x18\x04 \x01(\x08R\x0fwithReplacement\x12\x12\n\x04seed\x18\x05 \x01(\x03R\x04seedB"\n\x1eorg.apache.spark.connect.protoP\x01\x62\x06proto3' ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -41,6 +41,8 @@ DESCRIPTOR._options = None DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001" + _READ_DATASOURCE_OPTIONSENTRY._options = None + _READ_DATASOURCE_OPTIONSENTRY._serialized_options = b"8\001" _RELATION._serialized_start = 82 _RELATION._serialized_end = 801 _UNKNOWN._serialized_start = 803 @@ -50,39 +52,43 @@ _SQL._serialized_start = 887 _SQL._serialized_end = 914 _READ._serialized_start = 917 - _READ._serialized_end = 1066 - _READ_NAMEDTABLE._serialized_start = 992 - _READ_NAMEDTABLE._serialized_end = 1053 - _PROJECT._serialized_start = 1068 - _PROJECT._serialized_end = 1185 - _FILTER._serialized_start = 1187 - _FILTER._serialized_end = 1299 - _JOIN._serialized_start = 1302 - _JOIN._serialized_end = 1715 - _JOIN_JOINTYPE._serialized_start = 1528 - _JOIN_JOINTYPE._serialized_end = 1715 - _UNION._serialized_start = 1718 - _UNION._serialized_end = 1923 - _UNION_UNIONTYPE._serialized_start = 1839 - _UNION_UNIONTYPE._serialized_end = 1923 - _LIMIT._serialized_start = 1925 - _LIMIT._serialized_end = 2001 - _OFFSET._serialized_start = 2003 - _OFFSET._serialized_end = 2082 - _AGGREGATE._serialized_start = 2085 - _AGGREGATE._serialized_end = 2410 - _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2314 - _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2410 - _SORT._serialized_start = 2413 - _SORT._serialized_end = 2915 - _SORT_SORTFIELD._serialized_start = 2533 - _SORT_SORTFIELD._serialized_end = 2721 - _SORT_SORTDIRECTION._serialized_start = 2723 - _SORT_SORTDIRECTION._serialized_end = 2831 - _SORT_SORTNULLS._serialized_start = 2833 - _SORT_SORTNULLS._serialized_end = 2915 - _LOCALRELATION._serialized_start = 2917 - _LOCALRELATION._serialized_end = 3010 - _SAMPLE._serialized_start = 3013 - _SAMPLE._serialized_end = 3197 + _READ._serialized_end = 1327 + _READ_NAMEDTABLE._serialized_start = 1059 + _READ_NAMEDTABLE._serialized_end = 1120 + _READ_DATASOURCE._serialized_start = 1123 + _READ_DATASOURCE._serialized_end = 1314 + _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 1256 + _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 1314 + _PROJECT._serialized_start = 1329 + _PROJECT._serialized_end = 1446 + _FILTER._serialized_start = 1448 + _FILTER._serialized_end = 1560 + _JOIN._serialized_start = 1563 + _JOIN._serialized_end = 1976 + _JOIN_JOINTYPE._serialized_start = 1789 + _JOIN_JOINTYPE._serialized_end = 1976 + _UNION._serialized_start = 1979 + _UNION._serialized_end = 2184 + _UNION_UNIONTYPE._serialized_start = 2100 + _UNION_UNIONTYPE._serialized_end = 2184 + _LIMIT._serialized_start = 2186 + _LIMIT._serialized_end = 2262 + _OFFSET._serialized_start = 2264 + _OFFSET._serialized_end = 2343 + _AGGREGATE._serialized_start = 2346 + _AGGREGATE._serialized_end = 2671 + _AGGREGATE_AGGREGATEFUNCTION._serialized_start = 2575 + _AGGREGATE_AGGREGATEFUNCTION._serialized_end = 2671 + _SORT._serialized_start = 2674 + _SORT._serialized_end = 3176 + _SORT_SORTFIELD._serialized_start = 2794 + _SORT_SORTFIELD._serialized_end = 2982 + _SORT_SORTDIRECTION._serialized_start = 2984 + _SORT_SORTDIRECTION._serialized_end = 3092 + _SORT_SORTNULLS._serialized_start = 3094 + _SORT_SORTNULLS._serialized_end = 3176 + _LOCALRELATION._serialized_start = 3178 + _LOCALRELATION._serialized_end = 3271 + _SAMPLE._serialized_start = 3274 + _SAMPLE._serialized_end = 3458 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index 3354fc86f45d..f0a8b6412b51 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -280,29 +280,79 @@ class Read(google.protobuf.message.Message): field_name: typing_extensions.Literal["unparsed_identifier", b"unparsed_identifier"], ) -> None: ... + class DataSource(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class OptionsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField( + self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"] + ) -> None: ... + + FORMAT_FIELD_NUMBER: builtins.int + SCHEMA_FIELD_NUMBER: builtins.int + OPTIONS_FIELD_NUMBER: builtins.int + format: builtins.str + """Required. Supported formats include: parquet, orc, text, json, parquet, csv, avro.""" + schema: builtins.str + """Optional. If not set, Spark will infer the schema.""" + @property + def options( + self, + ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """The key is case insensitive.""" + def __init__( + self, + *, + format: builtins.str = ..., + schema: builtins.str = ..., + options: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "format", b"format", "options", b"options", "schema", b"schema" + ], + ) -> None: ... + NAMED_TABLE_FIELD_NUMBER: builtins.int + DATA_SOURCE_FIELD_NUMBER: builtins.int @property def named_table(self) -> global___Read.NamedTable: ... + @property + def data_source(self) -> global___Read.DataSource: ... def __init__( self, *, named_table: global___Read.NamedTable | None = ..., + data_source: global___Read.DataSource | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "named_table", b"named_table", "read_type", b"read_type" + "data_source", b"data_source", "named_table", b"named_table", "read_type", b"read_type" ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "named_table", b"named_table", "read_type", b"read_type" + "data_source", b"data_source", "named_table", b"named_table", "read_type", b"read_type" ], ) -> None: ... def WhichOneof( self, oneof_group: typing_extensions.Literal["read_type", b"read_type"] - ) -> typing_extensions.Literal["named_table"] | None: ... + ) -> typing_extensions.Literal["named_table", "data_source"] | None: ... global___Read = Read diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 285e78e59ae9..66e48eeab76b 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -15,8 +15,16 @@ # limitations under the License. # + +from typing import Dict, Optional + +from pyspark.sql.connect.column import PrimitiveType from pyspark.sql.connect.dataframe import DataFrame -from pyspark.sql.connect.plan import Read +from pyspark.sql.connect.plan import Read, DataSource +from pyspark.sql.utils import to_str + + +OptionalPrimitiveType = Optional[PrimitiveType] from typing import TYPE_CHECKING @@ -29,8 +37,114 @@ class DataFrameReader: TODO(SPARK-40539) Achieve parity with PySpark. """ - def __init__(self, client: "RemoteSparkSession") -> None: + def __init__(self, client: "RemoteSparkSession"): self._client = client + self._format = "" + self._schema = "" + self._options: Dict[str, str] = {} + + def format(self, source: str) -> "DataFrameReader": + """ + Specifies the input data source format. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + source : str + string, name of the data source, e.g. 'json', 'parquet'. + + """ + self._format = source + return self + + # TODO(SPARK-40539): support StructType in python client and support schema as StructType. + def schema(self, schema: str) -> "DataFrameReader": + """ + Specifies the input schema. + + Some data sources (e.g. JSON) can infer the input schema automatically from data. + By specifying the schema here, the underlying data source can skip the schema + inference step, and thus speed up data loading. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + schema : str + a DDL-formatted string + (For example ``col0 INT, col1 DOUBLE``). + + """ + self._schema = schema + return self + + def option(self, key: str, value: "OptionalPrimitiveType") -> "DataFrameReader": + """ + Adds an input option for the underlying data source. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + key : str + The key for the option to set. key string is case-insensitive. + value + The value for the option to set. + + """ + self._options[key] = str(value) + return self + + def options(self, **options: "OptionalPrimitiveType") -> "DataFrameReader": + """ + Adds input options for the underlying data source. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + **options : dict + The dictionary of string keys and prmitive-type values. + """ + for k in options: + self.option(k, to_str(options[k])) + return self + + def load( + self, + path: Optional[str] = None, + format: Optional[str] = None, + schema: Optional[str] = None, + **options: "OptionalPrimitiveType", + ) -> "DataFrame": + """ + Loads data from a data source and returns it as a :class:`DataFrame`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + path : str or list, optional + optional string or a list of string for file-system backed data sources. + format : str, optional + optional string for format of the data source. + schema : str, optional + optional DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). + **options : dict + all other string options + """ + if format is not None: + self.format(format) + if schema is not None: + self.schema(schema) + self.options(**options) + if path is not None: + self.option("path", path) + + plan = DataSource(format=self._format, schema=self._schema, options=self._options) + df = DataFrame.withPlan(plan, self._client) + return df def table(self, tableName: str) -> "DataFrame": df = DataFrame.withPlan(Read(tableName), self._client) diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 1a59e7d596ee..de300946932f 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -16,6 +16,7 @@ # from typing import Any import unittest +import shutil import tempfile import pandas @@ -24,6 +25,7 @@ from pyspark.sql.connect.client import RemoteSparkSession from pyspark.sql.connect.function_builder import udf from pyspark.sql.connect.functions import lit +from pyspark.sql.dataframe import DataFrame from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase @@ -35,6 +37,7 @@ class SparkConnectSQLTestCase(ReusedPySparkTestCase): connect: RemoteSparkSession tbl_name: str + df_text: "DataFrame" @classmethod def setUpClass(cls: Any): @@ -44,7 +47,9 @@ def setUpClass(cls: Any): # Create the new Spark Session cls.spark = SparkSession(cls.sc) cls.testData = [Row(key=i, value=str(i)) for i in range(100)] + cls.testDataStr = [Row(key=str(i)) for i in range(100)] cls.df = cls.sc.parallelize(cls.testData).toDF() + cls.df_text = cls.sc.parallelize(cls.testDataStr).toDF() cls.tbl_name = "test_connect_basic_table_1" @@ -101,6 +106,21 @@ def test_simple_binary_expressions(self): res = pandas.DataFrame(data={"id": [0, 30, 60, 90]}) self.assert_(pd.equals(res), f"{pd.to_string()} != {res.to_string()}") + def test_simple_datasource_read(self) -> None: + writeDf = self.df_text + tmpPath = tempfile.mkdtemp() + shutil.rmtree(tmpPath) + writeDf.write.text(tmpPath) + + readDf = self.connect.read.format("text").schema("id STRING").load(path=tmpPath) + expectResult = writeDf.collect() + pandasResult = readDf.toPandas() + if pandasResult is None: + self.assertTrue(False, "Empty pandas dataframe") + else: + actualResult = pandasResult.values.tolist() + self.assertEqual(len(expectResult), len(actualResult)) + if __name__ == "__main__": from pyspark.sql.tests.connect.test_connect_basic import * # noqa: F401 diff --git a/python/pyspark/sql/tests/connect/test_connect_plan_only.py b/python/pyspark/sql/tests/connect/test_connect_plan_only.py index c547000bdcf7..96bbb8aa8347 100644 --- a/python/pyspark/sql/tests/connect/test_connect_plan_only.py +++ b/python/pyspark/sql/tests/connect/test_connect_plan_only.py @@ -18,6 +18,7 @@ from pyspark.testing.connectutils import PlanOnlyTestFixture import pyspark.sql.connect.proto as proto +from pyspark.sql.connect.readwriter import DataFrameReader from pyspark.sql.connect.function_builder import UserDefinedFunction, udf from pyspark.sql.types import StringType @@ -48,6 +49,18 @@ def test_relation_alias(self): plan = df.alias("table_alias")._plan.to_proto(self.connect) self.assertEqual(plan.root.common.alias, "table_alias") + def test_datasource_read(self): + reader = DataFrameReader(self.connect) + df = reader.load(path="test_path", format="text", schema="id INT", op1="opv", op2="opv2") + plan = df._plan.to_proto(self.connect) + data_source = plan.root.read.data_source + self.assertEqual(data_source.format, "text") + self.assertEqual(data_source.schema, "id INT") + self.assertEqual(len(data_source.options), 3) + self.assertEqual(data_source.options.get("path"), "test_path") + self.assertEqual(data_source.options.get("op1"), "opv") + self.assertEqual(data_source.options.get("op2"), "opv2") + def test_simple_udf(self): u = udf(lambda x: "Martin", StringType()) self.assertIsNotNone(u)