diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 0e54e7628fcb..2a638bc7ec36 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -174,6 +174,13 @@ "`` does not allow a Column in a list." ] }, + "CONFLICTING_PIPELINE_REFRESH_OPTIONS" : { + "message" : [ + "--full-refresh-all option conflicts with ", + "The --full-refresh-all option performs a full refresh of all datasets, ", + "so specifying individual datasets with is not allowed." + ] + }, "CONNECT_URL_ALREADY_DEFINED": { "message": [ "Only one Spark Connect client URL can be set; however, got a different URL [] from the existing []." diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 395f7e9b8374..2a0cf880d10c 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -28,7 +28,7 @@ import yaml from dataclasses import dataclass from pathlib import Path -from typing import Any, Generator, Mapping, Optional, Sequence +from typing import Any, Generator, List, Mapping, Optional, Sequence from pyspark.errors import PySparkException, PySparkTypeError from pyspark.sql import SparkSession @@ -217,8 +217,36 @@ def change_dir(path: Path) -> Generator[None, None, None]: os.chdir(prev) -def run(spec_path: Path) -> None: - """Run the pipeline defined with the given spec.""" +def run( + spec_path: Path, + full_refresh: Sequence[str], + full_refresh_all: bool, + refresh: Sequence[str], +) -> None: + """Run the pipeline defined with the given spec. + + :param spec_path: Path to the pipeline specification file. + :param full_refresh: List of datasets to reset and recompute. + :param full_refresh_all: Perform a full graph reset and recompute. + :param refresh: List of datasets to update. + """ + # Validate conflicting arguments + if full_refresh_all: + if full_refresh: + raise PySparkException( + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", + messageParameters={ + "conflicting_option": "--full_refresh", + }, + ) + if refresh: + raise PySparkException( + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", + messageParameters={ + "conflicting_option": "--refresh", + }, + ) + log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...") spec = load_pipeline_spec(spec_path) @@ -242,20 +270,52 @@ def run(spec_path: Path) -> None: register_definitions(spec_path, registry, spec) log_with_curr_timestamp("Starting run...") - result_iter = start_run(spark, dataflow_graph_id) + result_iter = start_run( + spark, + dataflow_graph_id, + full_refresh=full_refresh, + full_refresh_all=full_refresh_all, + refresh=refresh, + ) try: handle_pipeline_events(result_iter) finally: spark.stop() +def parse_table_list(value: str) -> List[str]: + """Parse a comma-separated list of table names, handling whitespace.""" + return [table.strip() for table in value.split(",") if table.strip()] + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Pipeline CLI") subparsers = parser.add_subparsers(dest="command", required=True) # "run" subcommand - run_parser = subparsers.add_parser("run", help="Run a pipeline.") + run_parser = subparsers.add_parser( + "run", + help="Run a pipeline. If no refresh options specified, " + "a default incremental update is performed.", + ) run_parser.add_argument("--spec", help="Path to the pipeline spec.") + run_parser.add_argument( + "--full-refresh", + type=parse_table_list, + action="extend", + help="List of datasets to reset and recompute (comma-separated).", + default=[], + ) + run_parser.add_argument( + "--full-refresh-all", action="store_true", help="Perform a full graph reset and recompute." + ) + run_parser.add_argument( + "--refresh", + type=parse_table_list, + action="extend", + help="List of datasets to update (comma-separated).", + default=[], + ) # "init" subcommand init_parser = subparsers.add_parser( @@ -283,6 +343,11 @@ def run(spec_path: Path) -> None: else: spec_path = find_pipeline_spec(Path.cwd()) - run(spec_path=spec_path) + run( + spec_path=spec_path, + full_refresh=args.full_refresh, + full_refresh_all=args.full_refresh_all, + refresh=args.refresh, + ) elif args.command == "init": init(args.name) diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py b/python/pyspark/pipelines/spark_connect_pipeline.py index 12f43a236c28..f430d33be4a1 100644 --- a/python/pyspark/pipelines/spark_connect_pipeline.py +++ b/python/pyspark/pipelines/spark_connect_pipeline.py @@ -15,7 +15,7 @@ # limitations under the License. # from datetime import timezone -from typing import Any, Dict, Mapping, Iterator, Optional, cast +from typing import Any, Dict, Mapping, Iterator, Optional, cast, Sequence import pyspark.sql.connect.proto as pb2 from pyspark.sql import SparkSession @@ -65,12 +65,26 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None: log_with_provided_timestamp(event.message, dt) -def start_run(spark: SparkSession, dataflow_graph_id: str) -> Iterator[Dict[str, Any]]: +def start_run( + spark: SparkSession, + dataflow_graph_id: str, + full_refresh: Optional[Sequence[str]] = None, + full_refresh_all: bool = False, + refresh: Optional[Sequence[str]] = None, +) -> Iterator[Dict[str, Any]]: """Start a run of the dataflow graph in the Spark Connect server. :param dataflow_graph_id: The ID of the dataflow graph to start. + :param full_refresh: List of datasets to reset and recompute. + :param full_refresh_all: Perform a full graph reset and recompute. + :param refresh: List of datasets to update. """ - inner_command = pb2.PipelineCommand.StartRun(dataflow_graph_id=dataflow_graph_id) + inner_command = pb2.PipelineCommand.StartRun( + dataflow_graph_id=dataflow_graph_id, + full_refresh_selection=full_refresh or [], + full_refresh_all=full_refresh_all, + refresh_selection=refresh or [], + ) command = pb2.Command() command.pipeline_command.start_run.CopyFrom(inner_command) # Cast because mypy seems to think `spark`` is a function, not an object. Likely related to diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index 66303e567c52..ded00e691db4 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -36,6 +36,7 @@ unpack_pipeline_spec, DefinitionsGlob, PipelineSpec, + run, ) from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry @@ -358,6 +359,95 @@ def test_python_import_current_directory(self): ), ) + def test_full_refresh_all_conflicts_with_full_refresh(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a minimal pipeline spec + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing both --full-refresh-all and --full-refresh raises an exception + with self.assertRaises(PySparkException) as context: + run( + spec_path=spec_path, + full_refresh=["table1", "table2"], + full_refresh_all=True, + refresh=[], + ) + + self.assertEqual( + context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" + ) + self.assertEqual( + context.exception.getMessageParameters(), {"conflicting_option": "--full_refresh"} + ) + + def test_full_refresh_all_conflicts_with_refresh(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a minimal pipeline spec + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing both --full-refresh-all and --refresh raises an exception + with self.assertRaises(PySparkException) as context: + run( + spec_path=spec_path, + full_refresh=[], + full_refresh_all=True, + refresh=["table1", "table2"], + ) + + self.assertEqual( + context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" + ) + self.assertEqual( + context.exception.getMessageParameters(), + {"conflicting_option": "--refresh"}, + ) + + def test_full_refresh_all_conflicts_with_both(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a minimal pipeline spec + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing --full-refresh-all with both other options raises an exception + # (it should catch the first conflict - full_refresh) + with self.assertRaises(PySparkException) as context: + run( + spec_path=spec_path, + full_refresh=["table1"], + full_refresh_all=True, + refresh=["table2"], + ) + + self.assertEqual( + context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" + ) + + def test_parse_table_list_single_table(self): + """Test parsing a single table name.""" + from pyspark.pipelines.cli import parse_table_list + + result = parse_table_list("table1") + self.assertEqual(result, ["table1"]) + + def test_parse_table_list_multiple_tables(self): + """Test parsing multiple table names.""" + from pyspark.pipelines.cli import parse_table_list + + result = parse_table_list("table1,table2,table3") + self.assertEqual(result, ["table1", "table2", "table3"]) + + def test_parse_table_list_with_spaces(self): + """Test parsing table names with spaces.""" + from pyspark.pipelines.cli import parse_table_list + + result = parse_table_list("table1, table2 , table3") + self.assertEqual(result, ["table1", "table2", "table3"]) + if __name__ == "__main__": try: diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py b/python/pyspark/sql/connect/proto/pipelines_pb2.py index 413e1fbe12b5..e13877d05a60 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.py +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xf2\x12\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x1a\x87\x03\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aQ\n\x08Response\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xd1\x04\n\rDefineDataset\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12&\n\x0c\x64\x61taset_name\x18\x02 \x01(\tH\x01R\x0b\x64\x61tasetName\x88\x01\x01\x12\x42\n\x0c\x64\x61taset_type\x18\x03 \x01(\x0e\x32\x1a.spark.connect.DatasetTypeH\x02R\x0b\x64\x61tasetType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x03R\x07\x63omment\x88\x01\x01\x12l\n\x10table_properties\x18\x05 \x03(\x0b\x32\x41.spark.connect.PipelineCommand.DefineDataset.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x06 \x03(\tR\rpartitionCols\x12\x34\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x04R\x06schema\x88\x01\x01\x12\x1b\n\x06\x66ormat\x18\x08 \x01(\tH\x05R\x06\x66ormat\x88\x01\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0f\n\r_dataset_nameB\x0f\n\r_dataset_typeB\n\n\x08_commentB\t\n\x07_schemaB\t\n\x07_format\x1a\xc8\x03\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x01R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x02R\x11targetDatasetName\x88\x01\x01\x12\x38\n\x08relation\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x03R\x08relation\x88\x01\x01\x12Q\n\x08sql_conf\x18\x05 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12\x17\n\x04once\x18\x06 \x01(\x08H\x04R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0b\n\t_relationB\x07\n\x05_once\x1aQ\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_textB\x0e\n\x0c\x63ommand_type"\x8e\x02\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message*a\n\x0b\x44\x61tasetType\x12\x1c\n\x18\x44\x41TASET_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9a\x14\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x1a\x87\x03\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aQ\n\x08Response\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xd1\x04\n\rDefineDataset\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12&\n\x0c\x64\x61taset_name\x18\x02 \x01(\tH\x01R\x0b\x64\x61tasetName\x88\x01\x01\x12\x42\n\x0c\x64\x61taset_type\x18\x03 \x01(\x0e\x32\x1a.spark.connect.DatasetTypeH\x02R\x0b\x64\x61tasetType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x03R\x07\x63omment\x88\x01\x01\x12l\n\x10table_properties\x18\x05 \x03(\x0b\x32\x41.spark.connect.PipelineCommand.DefineDataset.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x06 \x03(\tR\rpartitionCols\x12\x34\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x04R\x06schema\x88\x01\x01\x12\x1b\n\x06\x66ormat\x18\x08 \x01(\tH\x05R\x06\x66ormat\x88\x01\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0f\n\r_dataset_nameB\x0f\n\r_dataset_typeB\n\n\x08_commentB\t\n\x07_schemaB\t\n\x07_format\x1a\xc8\x03\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x01R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x02R\x11targetDatasetName\x88\x01\x01\x12\x38\n\x08relation\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x03R\x08relation\x88\x01\x01\x12Q\n\x08sql_conf\x18\x05 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12\x17\n\x04once\x18\x06 \x01(\x08H\x04R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0b\n\t_relationB\x07\n\x05_once\x1a\xf8\x01\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelectionB\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_all\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_textB\x0e\n\x0c\x63ommand_type"\x8e\x02\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message*a\n\x0b\x44\x61tasetType\x12\x1c\n\x18\x44\x41TASET_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -59,10 +59,10 @@ _globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options = b"8\001" _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = b"8\001" - _globals["_DATASETTYPE"]._serialized_start = 3026 - _globals["_DATASETTYPE"]._serialized_end = 3123 + _globals["_DATASETTYPE"]._serialized_start = 3194 + _globals["_DATASETTYPE"]._serialized_end = 3291 _globals["_PIPELINECOMMAND"]._serialized_start = 140 - _globals["_PIPELINECOMMAND"]._serialized_end = 2558 + _globals["_PIPELINECOMMAND"]._serialized_end = 2726 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 719 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1110 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start = 928 @@ -79,16 +79,16 @@ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 2257 _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start = 928 _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 986 - _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2259 - _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2340 - _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 2343 - _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2542 - _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2561 - _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2831 - _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 2718 - _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 2816 - _globals["_PIPELINEEVENTRESULT"]._serialized_start = 2833 - _globals["_PIPELINEEVENTRESULT"]._serialized_end = 2906 - _globals["_PIPELINEEVENT"]._serialized_start = 2908 - _globals["_PIPELINEEVENT"]._serialized_end = 3024 + _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2260 + _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2508 + _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 2511 + _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2710 + _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2729 + _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2999 + _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 2886 + _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 2984 + _globals["_PIPELINEEVENTRESULT"]._serialized_start = 3001 + _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3074 + _globals["_PIPELINEEVENT"]._serialized_start = 3076 + _globals["_PIPELINEEVENT"]._serialized_end = 3192 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi index 36fb73f06906..fff130ac3b93 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi @@ -530,20 +530,42 @@ class PipelineCommand(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor DATAFLOW_GRAPH_ID_FIELD_NUMBER: builtins.int + FULL_REFRESH_SELECTION_FIELD_NUMBER: builtins.int + FULL_REFRESH_ALL_FIELD_NUMBER: builtins.int + REFRESH_SELECTION_FIELD_NUMBER: builtins.int dataflow_graph_id: builtins.str """The graph to start.""" + @property + def full_refresh_selection( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """List of dataset to reset and recompute.""" + full_refresh_all: builtins.bool + """Perform a full graph reset and recompute.""" + @property + def refresh_selection( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """List of dataset to update.""" def __init__( self, *, dataflow_graph_id: builtins.str | None = ..., + full_refresh_selection: collections.abc.Iterable[builtins.str] | None = ..., + full_refresh_all: builtins.bool | None = ..., + refresh_selection: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ "_dataflow_graph_id", b"_dataflow_graph_id", + "_full_refresh_all", + b"_full_refresh_all", "dataflow_graph_id", b"dataflow_graph_id", + "full_refresh_all", + b"full_refresh_all", ], ) -> builtins.bool: ... def ClearField( @@ -551,14 +573,27 @@ class PipelineCommand(google.protobuf.message.Message): field_name: typing_extensions.Literal[ "_dataflow_graph_id", b"_dataflow_graph_id", + "_full_refresh_all", + b"_full_refresh_all", "dataflow_graph_id", b"dataflow_graph_id", + "full_refresh_all", + b"full_refresh_all", + "full_refresh_selection", + b"full_refresh_selection", + "refresh_selection", + b"refresh_selection", ], ) -> None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_dataflow_graph_id", b"_dataflow_graph_id"], ) -> typing_extensions.Literal["dataflow_graph_id"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_full_refresh_all", b"_full_refresh_all"] + ) -> typing_extensions.Literal["full_refresh_all"] | None: ... class DefineSqlGraphElements(google.protobuf.message.Message): """Parses the SQL file and registers all datasets and flows.""" diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto index c5a631264590..751b14abe478 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto @@ -116,6 +116,15 @@ message PipelineCommand { message StartRun { // The graph to start. optional string dataflow_graph_id = 1; + + // List of dataset to reset and recompute. + repeated string full_refresh_selection = 2; + + // Perform a full graph reset and recompute. + optional bool full_refresh_all = 3; + + // List of dataset to update. + repeated string refresh_selection = 4; } // Parses the SQL file and registers all datasets and flows. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 7bb1d7358557..7f92aa13944c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -26,6 +26,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineCommandResult, Relation} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.common.DataTypeProtoConverter @@ -33,7 +34,7 @@ import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.QueryOriginType import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED} -import org.apache.spark.sql.pipelines.graph.{FlowAnalysis, GraphIdentifierManager, IdentifierHelper, PipelineUpdateContextImpl, QueryContext, QueryOrigin, SqlGraphRegistrationContext, Table, TemporaryView, UnresolvedFlow} +import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress} import org.apache.spark.sql.types.StructType @@ -224,6 +225,8 @@ private[connect] object PipelinesHandler extends Logging { sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder) + // We will use this variable to store the run failure event if it occurs. This will be set // by the event callback. @volatile var runFailureEvent = Option.empty[PipelineEvent] @@ -279,8 +282,11 @@ private[connect] object PipelinesHandler extends Logging { .build()) } } - val pipelineUpdateContext = - new PipelineUpdateContextImpl(graphElementRegistry.toDataflowGraph, eventCallback) + val pipelineUpdateContext = new PipelineUpdateContextImpl( + graphElementRegistry.toDataflowGraph, + eventCallback, + tableFiltersResult.refresh, + tableFiltersResult.fullRefresh) sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext) pipelineUpdateContext.pipelineExecution.runPipeline() @@ -290,4 +296,87 @@ private[connect] object PipelinesHandler extends Logging { throw event.error.get } } + + /** + * Creates the table filters for the full refresh and refresh operations based on the StartRun + * command user provided. Also validates the command parameters to ensure that they are + * consistent and do not conflict with each other. + * + * If `fullRefreshAll` is true, create `AllTables` filter for full refresh. + * + * If `fullRefreshTables` and `refreshTables` are both empty, create `AllTables` filter for + * refresh as a default behavior. + * + * If both non-empty, verifies that there is no overlap and creates SomeTables filters for both. + * + * If one non-empty and the other empty, create `SomeTables` filter for the non-empty one, and + * `NoTables` filter for the empty one. + */ + private def createTableFilters( + startRunCommand: proto.PipelineCommand.StartRun, + graphElementRegistry: GraphRegistrationContext, + sessionHolder: SessionHolder): TableFilters = { + // Convert table names to fully qualified TableIdentifier objects + def parseTableNames(tableNames: Seq[String]): Set[TableIdentifier] = { + tableNames.map { name => + GraphIdentifierManager + .parseAndQualifyTableIdentifier( + rawTableIdentifier = + GraphIdentifierManager.parseTableIdentifier(name, sessionHolder.session), + currentCatalog = Some(graphElementRegistry.defaultCatalog), + currentDatabase = Some(graphElementRegistry.defaultDatabase)) + .identifier + }.toSet + } + + val fullRefreshTables = startRunCommand.getFullRefreshSelectionList.asScala.toSeq + val fullRefreshAll = startRunCommand.getFullRefreshAll + val refreshTables = startRunCommand.getRefreshSelectionList.asScala.toSeq + + if (refreshTables.nonEmpty && fullRefreshAll) { + throw new IllegalArgumentException( + "Cannot specify a subset to refresh when full refresh all is set to true.") + } + + if (fullRefreshTables.nonEmpty && fullRefreshAll) { + throw new IllegalArgumentException( + "Cannot specify a subset to full refresh when full refresh all is set to true.") + } + val refreshTableNames = parseTableNames(refreshTables) + val fullRefreshTableNames = parseTableNames(fullRefreshTables) + + if (refreshTables.nonEmpty && fullRefreshTables.nonEmpty) { + // check if there is an intersection between the subset + val intersection = refreshTableNames.intersect(fullRefreshTableNames) + if (intersection.nonEmpty) { + throw new IllegalArgumentException( + "Datasets specified for refresh and full refresh cannot overlap: " + + s"${intersection.mkString(", ")}") + } + } + + if (fullRefreshAll) { + return TableFilters(fullRefresh = AllTables, refresh = NoTables) + } + + (fullRefreshTables, refreshTables) match { + case (Nil, Nil) => + // If both are empty, we default to refreshing all tables + TableFilters(fullRefresh = NoTables, refresh = AllTables) + case (_, Nil) => + TableFilters(fullRefresh = SomeTables(fullRefreshTableNames), refresh = NoTables) + case (Nil, _) => + TableFilters(fullRefresh = NoTables, refresh = SomeTables(refreshTableNames)) + case (_, _) => + // If both are specified, we create filters for both after validation + TableFilters( + fullRefresh = SomeTables(fullRefreshTableNames), + refresh = SomeTables(refreshTableNames)) + } + } + + /** + * A case class to hold the table filters for full refresh and refresh operations. + */ + private case class TableFilters(fullRefresh: TableFilter, refresh: TableFilter) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala new file mode 100644 index 000000000000..794932544d5f --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -0,0 +1,353 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.pipelines + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.spark.connect.proto.{DatasetType, PipelineCommand, PipelineEvent} +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} +import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} + +/** + * Comprehensive test suite that validates pipeline refresh functionality by running actual + * pipelines with different refresh parameters and validating the results. + */ +class PipelineRefreshFunctionalSuite + extends SparkDeclarativePipelinesServerTest + with TestPipelineUpdateContextMixin + with EventVerificationTestHelpers { + + private val externalSourceTable = TableIdentifier( + catalog = Some("spark_catalog"), + database = Some("default"), + table = "source_data") + + override def beforeEach(): Unit = { + super.beforeEach() + // Create source table to simulate streaming updates + spark.sql(s"CREATE TABLE $externalSourceTable AS SELECT * FROM RANGE(1, 2)") + } + + override def afterEach(): Unit = { + super.afterEach() + // Clean up the source table after each test + spark.sql(s"DROP TABLE IF EXISTS $externalSourceTable") + } + + private def createTestPipeline(graphId: String): TestPipelineDefinition = { + new TestPipelineDefinition(graphId) { + // Create tables that depend on the mv + createTable( + name = "a", + datasetType = DatasetType.TABLE, + sql = Some(s"SELECT id FROM STREAM $externalSourceTable")) + createTable( + name = "b", + datasetType = DatasetType.TABLE, + sql = Some(s"SELECT id FROM STREAM $externalSourceTable")) + createTable( + name = "mv", + datasetType = DatasetType.MATERIALIZED_VIEW, + sql = Some(s"SELECT id FROM a")) + } + } + + /** + * Helper method to run refresh tests with common setup and verification logic. This reduces + * code duplication across the refresh test cases. + */ + private def runRefreshTest( + refreshConfigBuilder: String => Option[PipelineCommand.StartRun] = _ => None, + expectedContentAfterRefresh: Map[String, Set[Map[String, Any]]], + eventValidation: Option[ArrayBuffer[PipelineEvent] => Unit] = None): Unit = { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTestPipeline(graphId) + registerPipelineDatasets(pipeline) + + // First run to populate tables + startPipelineAndWaitForCompletion(graphId) + + // combine above into a map for verification + val initialContent = Map( + "spark_catalog.default.a" -> Set(Map("id" -> 1)), + "spark_catalog.default.b" -> Set(Map("id" -> 1)), + "spark_catalog.default.mv" -> Set(Map("id" -> 1))) + // Verify initial content + initialContent.foreach { case (tableName, expectedRows) => + checkTableContent(tableName, expectedRows) + } + // Clear cached pipeline execution before starting new run + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) + .foreach(_.removeAllPipelineExecutions()) + + // Replace source data to simulate a streaming update + spark.sql( + "INSERT OVERWRITE TABLE spark_catalog.default.source_data " + + "SELECT * FROM VALUES (2), (3) AS t(id)") + + // Run with specified refresh configuration + val capturedEvents = refreshConfigBuilder(graphId) match { + case Some(startRun) => startPipelineAndWaitForCompletion(startRun) + case None => startPipelineAndWaitForCompletion(graphId) + } + + // Additional validation if provided + eventValidation.foreach(_(capturedEvents)) + + // Verify final content with checkTableContent + expectedContentAfterRefresh.foreach { case (tableName, expectedRows) => + checkTableContent(tableName, expectedRows) + } + } + } + + test("pipeline runs selective full_refresh") { + runRefreshTest( + refreshConfigBuilder = { graphId => + Some( + PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .addAllFullRefreshSelection(List("a").asJava) + .build()) + }, + expectedContentAfterRefresh = Map( + "spark_catalog.default.a" -> Set( + Map("id" -> 2), // a is fully refreshed and only contains the new values + Map("id" -> 3)), + "spark_catalog.default.b" -> Set( + Map("id" -> 1) // b is not refreshed, so it retains the old value + ), + "spark_catalog.default.mv" -> Set( + Map("id" -> 1) // mv is not refreshed, so it retains the old value + )), + eventValidation = Some { capturedEvents => + // assert that table_b is excluded + assert( + capturedEvents.exists( + _.getMessage.contains(s"Flow \'spark_catalog.default.b\' is EXCLUDED."))) + // assert that table_a ran to completion + assert( + capturedEvents.exists( + _.getMessage.contains(s"Flow spark_catalog.default.a has COMPLETED."))) + // assert that mv is excluded + assert( + capturedEvents.exists( + _.getMessage.contains(s"Flow \'spark_catalog.default.mv\' is EXCLUDED."))) + // Verify completion event + assert(capturedEvents.exists(_.getMessage.contains("Run is COMPLETED"))) + }) + } + + test("pipeline runs selective full_refresh and selective refresh") { + runRefreshTest( + refreshConfigBuilder = { graphId => + Some( + PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .addAllFullRefreshSelection(Seq("a", "mv").asJava) + .addRefreshSelection("b") + .build()) + }, + expectedContentAfterRefresh = Map( + "spark_catalog.default.a" -> Set( + Map("id" -> 2), // a is fully refreshed and only contains the new values + Map("id" -> 3)), + "spark_catalog.default.b" -> Set( + Map("id" -> 1), // b is refreshed, so it retains the old value and adds the new ones + Map("id" -> 2), + Map("id" -> 3)), + "spark_catalog.default.mv" -> Set( + Map("id" -> 2), // mv is fully refreshed and only contains the new values + Map("id" -> 3)))) + } + + test("pipeline runs refresh by default") { + runRefreshTest(expectedContentAfterRefresh = + Map( + "spark_catalog.default.a" -> Set( + Map( + "id" -> 1 + ), // a is refreshed by default, retains the old value and adds the new ones + Map("id" -> 2), + Map("id" -> 3)), + "spark_catalog.default.b" -> Set( + Map( + "id" -> 1 + ), // b is refreshed by default, retains the old value and adds the new ones + Map("id" -> 2), + Map("id" -> 3)), + "spark_catalog.default.mv" -> Set( + Map("id" -> 1), + Map("id" -> 2), // mv is refreshed from table a, retains all values + Map("id" -> 3)))) + } + + test("pipeline runs full refresh all") { + runRefreshTest( + refreshConfigBuilder = { graphId => + Some( + PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .setFullRefreshAll(true) + .build()) + }, + // full refresh all causes all tables to lose the initial value + // and only contain the new values after the source data is updated + expectedContentAfterRefresh = Map( + "spark_catalog.default.a" -> Set(Map("id" -> 2), Map("id" -> 3)), + "spark_catalog.default.b" -> Set(Map("id" -> 2), Map("id" -> 3)), + "spark_catalog.default.mv" -> Set(Map("id" -> 2), Map("id" -> 3)))) + } + + test("validation: cannot specify subset refresh when full_refresh_all is true") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTestPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .setFullRefreshAll(true) + .addRefreshSelection("a") + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(startRun) + } + assert( + exception.getMessage.contains( + "Cannot specify a subset to refresh when full refresh all is set to true")) + } + } + + test("validation: cannot specify subset full_refresh when full_refresh_all is true") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTestPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .setFullRefreshAll(true) + .addFullRefreshSelection("a") + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(startRun) + } + assert( + exception.getMessage.contains( + "Cannot specify a subset to full refresh when full refresh all is set to true")) + } + } + + test("validation: refresh and full_refresh cannot overlap") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTestPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .addRefreshSelection("a") + .addFullRefreshSelection("a") + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(startRun) + } + assert( + exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) + assert(exception.getMessage.contains("a")) + } + } + + test("validation: multiple overlapping tables in refresh and full_refresh not allowed") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTestPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .addRefreshSelection("a") + .addRefreshSelection("b") + .addFullRefreshSelection("a") + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(startRun) + } + assert( + exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) + assert(exception.getMessage.contains("a")) + } + } + + test("validation: fully qualified table names in validation") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTestPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .addRefreshSelection("spark_catalog.default.a") + .addFullRefreshSelection("a") // This should be treated as the same table + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(startRun) + } + assert( + exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) + } + } + + private def checkTableContent[A <: Map[String, Any]]( + name: String, + expectedContent: Set[A]): Unit = { + spark.catalog.refreshTable(name) // clear cache for the table + val df = spark.table(name) + QueryTest.checkAnswer( + df, + expectedContent + .map(row => { + // Convert each row to a Row object + org.apache.spark.sql.Row.fromSeq(row.values.toSeq) + }) + .toSeq + .asJava) + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index ea4cc5f3aba5..003fd30b6075 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.connect.pipelines +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.connect.{proto => sc} +import org.apache.spark.connect.proto.{PipelineCommand, PipelineEvent} import org.apache.spark.sql.connect.{SparkConnectServerTest, SparkConnectTestUtils} import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} @@ -125,15 +128,27 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { def createPlanner(): SparkConnectPlanner = new SparkConnectPlanner(SparkConnectTestUtils.createDummySessionHolder(spark)) - def startPipelineAndWaitForCompletion(graphId: String): Unit = { + def startPipelineAndWaitForCompletion(graphId: String): ArrayBuffer[PipelineEvent] = { + val defaultStartRunCommand = + PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build() + startPipelineAndWaitForCompletion(defaultStartRunCommand) + } + + def startPipelineAndWaitForCompletion( + startRunCommand: PipelineCommand.StartRun): ArrayBuffer[PipelineEvent] = { withClient { client => - val startRunRequest = buildStartRunPlan( - sc.PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build()) + val capturedEvents = new ArrayBuffer[PipelineEvent]() + val startRunRequest = buildStartRunPlan(startRunCommand) val responseIterator = client.execute(startRunRequest) // The response iterator will be closed when the pipeline is completed. while (responseIterator.hasNext) { - responseIterator.next() + val response = responseIterator.next() + if (response.hasPipelineEventResult) { + capturedEvents.append(response.getPipelineEventResult.getEvent) + } } + return capturedEvents } + ArrayBuffer.empty[PipelineEvent] } } diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala index c68882df79ce..5a298c2f17d9 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala @@ -24,10 +24,15 @@ import org.apache.spark.sql.pipelines.logging.{FlowProgressEventLogger, Pipeline * An implementation of the PipelineUpdateContext trait used in production. * @param unresolvedGraph The graph (unresolved) to be executed in this update. * @param eventCallback A callback function to be called when an event is added to the event buffer. + * @param refreshTables Filter for which tables should be refreshed when performing this update. + * @param fullRefreshTables Filter for which tables should be full refreshed + * when performing this update. */ class PipelineUpdateContextImpl( override val unresolvedGraph: DataflowGraph, - override val eventCallback: PipelineEvent => Unit + override val eventCallback: PipelineEvent => Unit, + override val refreshTables: TableFilter = AllTables, + override val fullRefreshTables: TableFilter = NoTables ) extends PipelineUpdateContext { override val spark: SparkSession = SparkSession.getActiveSession.getOrElse( @@ -37,7 +42,5 @@ class PipelineUpdateContextImpl( override val flowProgressEventLogger: FlowProgressEventLogger = new FlowProgressEventLogger(eventCallback = eventCallback) - override val refreshTables: TableFilter = AllTables - override val fullRefreshTables: TableFilter = NoTables override val resetCheckpointFlows: FlowFilter = NoFlows }