Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions python/pyspark/pipelines/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _validate_stored_dataset_args(
name: Optional[str],
table_properties: Optional[Dict[str, str]],
partition_cols: Optional[List[str]],
cluster_by: Optional[List[str]],
) -> None:
if name is not None and type(name) is not str:
raise PySparkTypeError(
Expand All @@ -91,6 +92,7 @@ def _validate_stored_dataset_args(
},
)
validate_optional_list_of_str_arg(arg_name="partition_cols", arg_value=partition_cols)
validate_optional_list_of_str_arg(arg_name="cluster_by", arg_value=cluster_by)


@overload
Expand All @@ -107,6 +109,7 @@ def table(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
) -> Callable[[QueryFunction], None]:
...
Expand All @@ -120,6 +123,7 @@ def table(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> Union[Callable[[QueryFunction], None], None]:
Expand All @@ -142,11 +146,12 @@ def table(
:param table_properties: A dict where the keys are the property names and the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition columns.
:param cluster_by: A list containing the column names of the cluster columns.
:param schema: Explicit Spark SQL schema to materialize this table with. Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
"""
_validate_stored_dataset_args(name, table_properties, partition_cols)
_validate_stored_dataset_args(name, table_properties, partition_cols, cluster_by)

source_code_location = get_caller_source_code_location(stacklevel=1)

Expand All @@ -163,6 +168,7 @@ def outer(
name=resolved_name,
table_properties=table_properties or {},
partition_cols=partition_cols,
cluster_by=cluster_by,
schema=schema,
source_code_location=source_code_location,
format=format,
Expand Down Expand Up @@ -209,6 +215,7 @@ def materialized_view(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
) -> Callable[[QueryFunction], None]:
...
Expand All @@ -222,6 +229,7 @@ def materialized_view(
spark_conf: Optional[Dict[str, str]] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> Union[Callable[[QueryFunction], None], None]:
Expand All @@ -244,11 +252,12 @@ def materialized_view(
:param table_properties: A dict where the keys are the property names and the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition columns.
:param cluster_by: A list containing the column names of the cluster columns.
:param schema: Explicit Spark SQL schema to materialize this table with. Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
"""
_validate_stored_dataset_args(name, table_properties, partition_cols)
_validate_stored_dataset_args(name, table_properties, partition_cols, cluster_by)

source_code_location = get_caller_source_code_location(stacklevel=1)

Expand All @@ -265,6 +274,7 @@ def outer(
name=resolved_name,
table_properties=table_properties or {},
partition_cols=partition_cols,
cluster_by=cluster_by,
schema=schema,
source_code_location=source_code_location,
format=format,
Expand Down Expand Up @@ -403,6 +413,7 @@ def create_streaming_table(
comment: Optional[str] = None,
table_properties: Optional[Dict[str, str]] = None,
partition_cols: Optional[List[str]] = None,
cluster_by: Optional[List[str]] = None,
schema: Optional[Union[StructType, str]] = None,
format: Optional[str] = None,
) -> None:
Expand All @@ -417,6 +428,7 @@ def create_streaming_table(
:param table_properties: A dict where the keys are the property names and the values are the \
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition columns.
:param cluster_by: A list containing the column names of the cluster columns.
:param schema Explicit Spark SQL schema to materialize this table with. Supports either a \
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
Expand All @@ -435,6 +447,7 @@ def create_streaming_table(
},
)
validate_optional_list_of_str_arg(arg_name="partition_cols", arg_value=partition_cols)
validate_optional_list_of_str_arg(arg_name="cluster_by", arg_value=cluster_by)

source_code_location = get_caller_source_code_location(stacklevel=1)

Expand All @@ -444,6 +457,7 @@ def create_streaming_table(
source_code_location=source_code_location,
table_properties=table_properties or {},
partition_cols=partition_cols,
cluster_by=cluster_by,
schema=schema,
format=format,
)
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/pipelines/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ class Table(Output):
:param table_properties: A dict where the keys are the property names and the values are the
property values. These properties will be set on the table.
:param partition_cols: A list containing the column names of the partition columns.
:param cluster_by: A list containing the column names of the cluster columns.
:param schema Explicit Spark SQL schema to materialize this table with. Supports either a
Pyspark StructType or a SQL DDL string, such as "a INT, b STRING".
:param format: The format of the table, e.g. "parquet".
"""

table_properties: Mapping[str, str]
partition_cols: Optional[Sequence[str]]
cluster_by: Optional[Sequence[str]]
schema: Optional[Union[StructType, str]]
format: Optional[str]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def register_output(self, output: Output) -> None:
table_details = pb2.PipelineCommand.DefineOutput.TableDetails(
table_properties=output.table_properties,
partition_cols=output.partition_cols,
clustering_columns=output.cluster_by,
format=output.format,
# Even though schema_string is not required, the generated Python code seems to
# erroneously think it is required.
Expand Down
88 changes: 44 additions & 44 deletions python/pyspark/sql/connect/proto/pipelines_pb2.py

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions python/pyspark/sql/connect/proto/pipelines_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class PipelineCommand(google.protobuf.message.Message):
FORMAT_FIELD_NUMBER: builtins.int
SCHEMA_DATA_TYPE_FIELD_NUMBER: builtins.int
SCHEMA_STRING_FIELD_NUMBER: builtins.int
CLUSTERING_COLUMNS_FIELD_NUMBER: builtins.int
@property
def table_properties(
self,
Expand All @@ -255,6 +256,11 @@ class PipelineCommand(google.protobuf.message.Message):
@property
def schema_data_type(self) -> pyspark.sql.connect.proto.types_pb2.DataType: ...
schema_string: builtins.str
@property
def clustering_columns(
self,
) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
"""Optional cluster columns for the table."""
def __init__(
self,
*,
Expand All @@ -263,6 +269,7 @@ class PipelineCommand(google.protobuf.message.Message):
format: builtins.str | None = ...,
schema_data_type: pyspark.sql.connect.proto.types_pb2.DataType | None = ...,
schema_string: builtins.str = ...,
clustering_columns: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
self,
Expand All @@ -284,6 +291,8 @@ class PipelineCommand(google.protobuf.message.Message):
field_name: typing_extensions.Literal[
"_format",
b"_format",
"clustering_columns",
b"clustering_columns",
"format",
b"format",
"partition_cols",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ message PipelineCommand {
spark.connect.DataType schema_data_type = 4;
string schema_string = 5;
}

// Optional cluster columns for the table.
repeated string clustering_columns = 6;
}

// Metadata that's only applicable to sinks.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ private[connect] object PipelinesHandler extends Logging {
},
partitionCols = Option(tableDetails.getPartitionColsList.asScala.toSeq)
.filter(_.nonEmpty),
clusterCols = Option(tableDetails.getClusteringColumnsList.asScala.toSeq)
.filter(_.nonEmpty),
properties = tableDetails.getTablePropertiesMap.asScala.toMap,
origin = QueryOrigin(
filePath = Option.when(output.getSourceCodeLocation.hasFileName)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -865,4 +865,41 @@ class PythonPipelineSuite

(exitCode, output.toSeq)
}

test("empty cluster_by list should work and create table with no clustering") {
withTable("mv", "st") {
val graph = buildGraph("""
|from pyspark.sql.functions import col
|
|@dp.materialized_view(cluster_by = [])
|def mv():
| return spark.range(5).withColumn("id_mod", col("id") % 2)
|
|@dp.table(cluster_by = [])
|def st():
| return spark.readStream.table("mv")
|""".stripMargin)
val updateContext =
new PipelineUpdateContextImpl(graph, eventCallback = _ => (), storageRoot = storageRoot)
updateContext.pipelineExecution.runPipeline()
updateContext.pipelineExecution.awaitCompletion()

// Check tables are created with no clustering transforms
val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]

val mvIdentifier = Identifier.of(Array("default"), "mv")
val mvTable = catalog.loadTable(mvIdentifier)
val mvTransforms = mvTable.partitioning()
assert(
mvTransforms.isEmpty,
s"MaterializedView should have no transforms, but got: ${mvTransforms.mkString(", ")}")

val stIdentifier = Identifier.of(Array("default"), "st")
val stTable = catalog.loadTable(stIdentifier)
val stTransforms = stTable.partitioning()
assert(
stTransforms.isEmpty,
s"Table should have no transforms, but got: ${stTransforms.mkString(", ")}")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@ class TestPipelineDefinition(graphId: String) {
// TODO: Add support for specifiedSchema
// specifiedSchema: Option[StructType] = None,
partitionCols: Option[Seq[String]] = None,
clusterCols: Option[Seq[String]] = None,
properties: Map[String, String] = Map.empty): Unit = {
val tableDetails = sc.PipelineCommand.DefineOutput.TableDetails
.newBuilder()
.addAllPartitionCols(partitionCols.getOrElse(Seq()).asJava)
.addAllClusteringColumns(clusterCols.getOrElse(Seq()).asJava)
.putAllTableProperties(properties.asJava)
.build()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.connector.catalog.{
TableInfo
}
import org.apache.spark.sql.connector.catalog.CatalogV2Util.v2ColumnsToStructType
import org.apache.spark.sql.connector.expressions.Expressions
import org.apache.spark.sql.connector.expressions.{ClusterByTransform, Expressions}
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.pipelines.graph.QueryOrigin.ExceptionHelpers
import org.apache.spark.sql.pipelines.util.SchemaInferenceUtils.diffSchemas
Expand Down Expand Up @@ -266,22 +266,35 @@ object DatasetManager extends Logging {
)
val mergedProperties = resolveTableProperties(table, identifier)
val partitioning = table.partitionCols.toSeq.flatten.map(Expressions.identity)
val clustering = table.clusterCols.map(cols =>
ClusterByTransform(cols.map(col => Expressions.column(col)))
).toSeq

// Validate that partition and cluster columns don't coexist
if (partitioning.nonEmpty && clustering.nonEmpty) {
throw new AnalysisException(
errorClass = "SPECIFY_CLUSTER_BY_WITH_PARTITIONED_BY_IS_NOT_ALLOWED",
messageParameters = Map.empty
)
}

val allTransforms = partitioning ++ clustering

val existingTableOpt = if (catalog.tableExists(identifier)) {
Some(catalog.loadTable(identifier))
} else {
None
}

// Error if partitioning doesn't match
// Error if partitioning/clustering doesn't match
if (existingTableOpt.isDefined) {
val existingPartitioning = existingTableOpt.get.partitioning().toSeq
if (existingPartitioning != partitioning) {
val existingTransforms = existingTableOpt.get.partitioning().toSeq
if (existingTransforms != allTransforms) {
throw new AnalysisException(
errorClass = "CANNOT_UPDATE_PARTITION_COLUMNS",
messageParameters = Map(
"existingPartitionColumns" -> existingPartitioning.mkString(", "),
"requestedPartitionColumns" -> partitioning.mkString(", ")
"existingPartitionColumns" -> existingTransforms.mkString(", "),
"requestedPartitionColumns" -> allTransforms.mkString(", ")
)
)
}
Expand Down Expand Up @@ -314,7 +327,7 @@ object DatasetManager extends Logging {
new TableInfo.Builder()
.withProperties(mergedProperties.asJava)
.withColumns(CatalogV2Util.structTypeToV2Columns(outputSchema))
.withPartitions(partitioning.toArray)
.withPartitions(allTransforms.toArray)
.build()
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,20 @@ class BatchTableWrite(
if (destination.format.isDefined) {
dataFrameWriter.format(destination.format.get)
}

// In "append" mode with saveAsTable, partition/cluster columns must be specified in query
// because the format and options of the existing table is used, and the table could
// have been created with partition columns.
if (destination.clusterCols.isDefined) {
val clusterCols = destination.clusterCols.get
dataFrameWriter.clusterBy(clusterCols.head, clusterCols.tail: _*)
}
if (destination.partitionCols.isDefined) {
dataFrameWriter.partitionBy(destination.partitionCols.get: _*)
}

dataFrameWriter
.mode("append")
// In "append" mode with saveAsTable, partition columns must be specified in query
// because the format and options of the existing table is used, and the table could
// have been created with partition columns.
.partitionBy(destination.partitionCols.getOrElse(Seq.empty): _*)
.saveAsTable(destination.identifier.unquotedString)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class SqlGraphRegistrationContext(
specifiedSchema =
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
partitionCols = Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
clusterCols = None,
properties = cst.tableSpec.properties,
origin = queryOrigin.copy(
objectName = Option(stIdentifier.unquotedString),
Expand Down Expand Up @@ -223,6 +224,7 @@ class SqlGraphRegistrationContext(
specifiedSchema =
Option.when(cst.columns.nonEmpty)(StructType(cst.columns.map(_.toV1Column))),
partitionCols = Option(PartitionHelper.applyPartitioning(cst.partitioning, queryOrigin)),
clusterCols = None,
properties = cst.tableSpec.properties,
origin = queryOrigin.copy(
objectName = Option(stIdentifier.unquotedString),
Expand Down Expand Up @@ -273,6 +275,7 @@ class SqlGraphRegistrationContext(
specifiedSchema =
Option.when(cmv.columns.nonEmpty)(StructType(cmv.columns.map(_.toV1Column))),
partitionCols = Option(PartitionHelper.applyPartitioning(cmv.partitioning, queryOrigin)),
clusterCols = None,
properties = cmv.tableSpec.properties,
origin = queryOrigin.copy(
objectName = Option(mvIdentifier.unquotedString),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ sealed trait TableInput extends Input {
* @param identifier The identifier of this table within the graph.
* @param specifiedSchema The user-specified schema for this table.
* @param partitionCols What columns the table should be partitioned by when materialized.
* @param clusterCols What columns the table should be clustered by when materialized.
* @param normalizedPath Normalized storage location for the table based on the user-specified table
* path (if not defined, we will normalize a managed storage path for it).
* @param properties Table Properties to set in table metadata.
Expand All @@ -124,6 +125,7 @@ case class Table(
identifier: TableIdentifier,
specifiedSchema: Option[StructType],
partitionCols: Option[Seq[String]],
clusterCols: Option[Seq[String]],
normalizedPath: Option[String],
properties: Map[String, String] = Map.empty,
comment: Option[String],
Expand Down
Loading