From b24698ca64dd3d843afab8faa9ee776e9b03b0e8 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Mon, 14 Jul 2025 15:37:55 -0700 Subject: [PATCH 01/17] cli args --- python/pyspark/pipelines/cli.py | 81 ++++++++++- .../pipelines/spark_connect_pipeline.py | 22 ++- python/pyspark/pipelines/tests/test_cli.py | 128 +++++++++++++++++- 3 files changed, 223 insertions(+), 8 deletions(-) diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 395f7e9b8374..95a3e28113dc 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -28,7 +28,7 @@ import yaml from dataclasses import dataclass from pathlib import Path -from typing import Any, Generator, Mapping, Optional, Sequence +from typing import Any, Generator, Mapping, Optional, Sequence, List from pyspark.errors import PySparkException, PySparkTypeError from pyspark.sql import SparkSession @@ -217,8 +217,40 @@ def change_dir(path: Path) -> Generator[None, None, None]: os.chdir(prev) -def run(spec_path: Path) -> None: - """Run the pipeline defined with the given spec.""" +def run( + spec_path: Path, + full_refresh: Optional[Sequence[str]] = None, + full_refresh_all: bool = False, + refresh: Optional[Sequence[str]] = None +) -> None: + """Run the pipeline defined with the given spec. + + :param spec_path: Path to the pipeline specification file. + :param full_refresh: List of tables to reset and recompute. + :param full_refresh_all: Perform a full graph reset and recompute. + :param refresh: List of tables to update. + """ + # Validate conflicting arguments + if full_refresh_all: + if full_refresh: + raise PySparkException( + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", + messageParameters={ + "message": "--full-refresh-all option conflicts with --full-refresh. " + "The --full-refresh-all option performs a full refresh of all tables, " + "so specifying individual tables with --full-refresh is not allowed." + } + ) + if refresh: + raise PySparkException( + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", + messageParameters={ + "message": "--full-refresh-all option conflicts with --refresh. " + "The --full-refresh-all option performs a full refresh of all tables, " + "so specifying individual tables with --refresh is not allowed." + } + ) + log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...") spec = load_pipeline_spec(spec_path) @@ -242,13 +274,34 @@ def run(spec_path: Path) -> None: register_definitions(spec_path, registry, spec) log_with_curr_timestamp("Starting run...") - result_iter = start_run(spark, dataflow_graph_id) + result_iter = start_run( + spark, + dataflow_graph_id, + full_refresh=full_refresh, + full_refresh_all=full_refresh_all, + refresh=refresh + ) try: handle_pipeline_events(result_iter) finally: spark.stop() +def parse_table_list(value: str) -> List[str]: + """Parse a comma-separated list of table names, handling whitespace.""" + return [table.strip() for table in value.split(",") if table.strip()] + + +def flatten_table_lists(table_lists: Optional[List[List[str]]]) -> Optional[List[str]]: + """Flatten a list of lists of table names into a single list.""" + if not table_lists: + return None + result = [] + for table_list in table_lists: + result.extend(table_list) + return result if result else None + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Pipeline CLI") subparsers = parser.add_subparsers(dest="command", required=True) @@ -256,6 +309,19 @@ def run(spec_path: Path) -> None: # "run" subcommand run_parser = subparsers.add_parser("run", help="Run a pipeline.") run_parser.add_argument("--spec", help="Path to the pipeline spec.") + run_parser.add_argument( + "--full-refresh", + type=parse_table_list, + action="append", + help="List of tables to reset and recompute (comma-separated). Can be specified multiple times." + ) + run_parser.add_argument("--full-refresh-all", action="store_true", help="Perform a full graph reset and recompute.") + run_parser.add_argument( + "--refresh", + type=parse_table_list, + action="append", + help="List of tables to update (comma-separated). Can be specified multiple times." + ) # "init" subcommand init_parser = subparsers.add_parser( @@ -283,6 +349,11 @@ def run(spec_path: Path) -> None: else: spec_path = find_pipeline_spec(Path.cwd()) - run(spec_path=spec_path) + run( + spec_path=spec_path, + full_refresh=flatten_table_lists(args.full_refresh), + full_refresh_all=args.full_refresh_all, + refresh=flatten_table_lists(args.refresh) + ) elif args.command == "init": init(args.name) diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py b/python/pyspark/pipelines/spark_connect_pipeline.py index 12f43a236c28..2c68b52409d3 100644 --- a/python/pyspark/pipelines/spark_connect_pipeline.py +++ b/python/pyspark/pipelines/spark_connect_pipeline.py @@ -15,7 +15,7 @@ # limitations under the License. # from datetime import timezone -from typing import Any, Dict, Mapping, Iterator, Optional, cast +from typing import Any, Dict, Mapping, Iterator, Optional, cast, Sequence import pyspark.sql.connect.proto as pb2 from pyspark.sql import SparkSession @@ -65,12 +65,30 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None: log_with_provided_timestamp(event.message, dt) -def start_run(spark: SparkSession, dataflow_graph_id: str) -> Iterator[Dict[str, Any]]: +def start_run( + spark: SparkSession, + dataflow_graph_id: str, + full_refresh: Optional[Sequence[str]] = None, + full_refresh_all: bool = False, + refresh: Optional[Sequence[str]] = None +) -> Iterator[Dict[str, Any]]: """Start a run of the dataflow graph in the Spark Connect server. :param dataflow_graph_id: The ID of the dataflow graph to start. + :param full_refresh: List of tables to reset and recompute. + :param full_refresh_all: Perform a full graph reset and recompute. + :param refresh: List of tables to update. """ + # TODO: Update protobuf schema to include these parameters + # For now, we accept the parameters but don't pass them to the protobuf command inner_command = pb2.PipelineCommand.StartRun(dataflow_graph_id=dataflow_graph_id) + # TODO: Once protobuf schema is updated, uncomment the following: + # inner_command = pb2.PipelineCommand.StartRun( + # dataflow_graph_id=dataflow_graph_id, + # full_refresh=full_refresh or [], + # full_refresh_all=full_refresh_all, + # refresh=refresh or [] + # ) command = pb2.Command() command.pipeline_command.start_run.CopyFrom(inner_command) # Cast because mypy seems to think `spark`` is a function, not an object. Likely related to diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index 66303e567c52..e8ec9590023a 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -19,6 +19,7 @@ import tempfile import textwrap from pathlib import Path +from typing import cast from pyspark.errors import PySparkException from pyspark.testing.connectutils import ( @@ -36,13 +37,14 @@ unpack_pipeline_spec, DefinitionsGlob, PipelineSpec, + run, ) from pyspark.pipelines.tests.local_graph_element_registry import LocalGraphElementRegistry @unittest.skipIf( not should_test_connect or not have_yaml, - connect_requirement_message or yaml_requirement_message, + (connect_requirement_message or yaml_requirement_message) or "Connect or YAML not available", ) class CLIUtilityTests(unittest.TestCase): def test_load_pipeline_spec(self): @@ -359,6 +361,130 @@ def test_python_import_current_directory(self): ) +@unittest.skipIf( + not should_test_connect or not have_yaml, + (connect_requirement_message or yaml_requirement_message) or "Connect or YAML not available", +) +class CLIValidationTests(unittest.TestCase): + def test_full_refresh_all_conflicts_with_full_refresh(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a minimal pipeline spec + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing both --full-refresh-all and --full-refresh raises an exception + with self.assertRaises(PySparkException) as context: + run( + spec_path=spec_path, + full_refresh=["table1", "table2"], + full_refresh_all=True, + refresh=None + ) + + self.assertEqual( + context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" + ) + message_params = context.exception.getMessageParameters() + self.assertIsNotNone(message_params) + message = cast(dict, message_params)["message"] + self.assertIn("--full-refresh-all option conflicts with --full-refresh", message) + self.assertIn("performs a full refresh of all tables", message) + + def test_full_refresh_all_conflicts_with_refresh(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a minimal pipeline spec + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing both --full-refresh-all and --refresh raises an exception + with self.assertRaises(PySparkException) as context: + run( + spec_path=spec_path, + full_refresh=None, + full_refresh_all=True, + refresh=["table1", "table2"] + ) + + self.assertEqual( + context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" + ) + message_params = context.exception.getMessageParameters() + self.assertIsNotNone(message_params) + message = cast(dict, message_params)["message"] + self.assertIn("--full-refresh-all option conflicts with --refresh", message) + self.assertIn("performs a full refresh of all tables", message) + + def test_full_refresh_all_conflicts_with_both(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a minimal pipeline spec + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing --full-refresh-all with both other options raises an exception + # (it should catch the first conflict - full_refresh) + with self.assertRaises(PySparkException) as context: + run( + spec_path=spec_path, + full_refresh=["table1"], + full_refresh_all=True, + refresh=["table2"] + ) + + self.assertEqual( + context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" + ) + message_params = context.exception.getMessageParameters() + self.assertIsNotNone(message_params) + message = cast(dict, message_params)["message"] + self.assertIn("--full-refresh-all option conflicts with --full-refresh", message) + + def test_no_conflict_when_full_refresh_all_alone(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a minimal pipeline spec + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing only --full-refresh-all doesn't raise an exception + # (it should fail later when trying to actually run, but not in our validation) + try: + run( + spec_path=spec_path, + full_refresh=None, + full_refresh_all=True, + refresh=None + ) + # If we get here, the validation passed (it will fail later in pipeline execution) + self.fail("Expected the run to fail later, but validation should have passed") + except PySparkException as e: + # Make sure it's NOT our validation error + self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + + def test_no_conflict_when_refresh_options_without_full_refresh_all(self): + with tempfile.TemporaryDirectory() as temp_dir: + # Create a minimal pipeline spec + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing --refresh and --full-refresh without --full-refresh-all doesn't raise our validation error + try: + run( + spec_path=spec_path, + full_refresh=["table1"], + full_refresh_all=False, + refresh=["table2"] + ) + # If we get here, the validation passed (it will fail later in pipeline execution) + self.fail("Expected the run to fail later, but validation should have passed") + except PySparkException as e: + # Make sure it's NOT our validation error + self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + + if __name__ == "__main__": try: import xmlrunner # type: ignore From 1d6ec2cf48c86432a5d471a3810f1d658f32c678 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Mon, 14 Jul 2025 22:19:56 -0700 Subject: [PATCH 02/17] working 1 test --- .../pipelines/spark_connect_pipeline.py | 16 +- python/pyspark/pipelines/tests/test_cli.py | 230 ++++++++++++++++++ .../sql/connect/proto/pipelines_pb2.py | 32 +-- .../sql/connect/proto/pipelines_pb2.pyi | 35 +++ .../protobuf/spark/connect/pipelines.proto | 9 + .../connect/pipelines/PipelinesHandler.scala | 53 +++- .../graph/PipelineUpdateContextImpl.scala | 13 +- 7 files changed, 354 insertions(+), 34 deletions(-) diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py b/python/pyspark/pipelines/spark_connect_pipeline.py index 2c68b52409d3..e35ad055b270 100644 --- a/python/pyspark/pipelines/spark_connect_pipeline.py +++ b/python/pyspark/pipelines/spark_connect_pipeline.py @@ -79,16 +79,12 @@ def start_run( :param full_refresh_all: Perform a full graph reset and recompute. :param refresh: List of tables to update. """ - # TODO: Update protobuf schema to include these parameters - # For now, we accept the parameters but don't pass them to the protobuf command - inner_command = pb2.PipelineCommand.StartRun(dataflow_graph_id=dataflow_graph_id) - # TODO: Once protobuf schema is updated, uncomment the following: - # inner_command = pb2.PipelineCommand.StartRun( - # dataflow_graph_id=dataflow_graph_id, - # full_refresh=full_refresh or [], - # full_refresh_all=full_refresh_all, - # refresh=refresh or [] - # ) + inner_command = pb2.PipelineCommand.StartRun( + dataflow_graph_id=dataflow_graph_id, + full_refresh=full_refresh or [], + full_refresh_all=full_refresh_all, + refresh=refresh or [] + ) command = pb2.Command() command.pipeline_command.start_run.CopyFrom(inner_command) # Cast because mypy seems to think `spark`` is a function, not an object. Likely related to diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index e8ec9590023a..242d9ab244a5 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -484,6 +484,236 @@ def test_no_conflict_when_refresh_options_without_full_refresh_all(self): # Make sure it's NOT our validation error self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + def test_parse_table_list_single_table(self): + """Test parsing a single table name.""" + from pyspark.pipelines.cli import parse_table_list + result = parse_table_list("table1") + self.assertEqual(result, ["table1"]) + + def test_parse_table_list_multiple_tables(self): + """Test parsing multiple table names.""" + from pyspark.pipelines.cli import parse_table_list + result = parse_table_list("table1,table2,table3") + self.assertEqual(result, ["table1", "table2", "table3"]) + + def test_parse_table_list_with_spaces(self): + """Test parsing table names with spaces.""" + from pyspark.pipelines.cli import parse_table_list + result = parse_table_list("table1, table2 , table3") + self.assertEqual(result, ["table1", "table2", "table3"]) + + def test_parse_table_list_empty_string(self): + """Test parsing empty string.""" + from pyspark.pipelines.cli import parse_table_list + result = parse_table_list("") + self.assertEqual(result, []) + + def test_parse_table_list_with_qualified_names(self): + """Test parsing qualified table names.""" + from pyspark.pipelines.cli import parse_table_list + result = parse_table_list("schema1.table1,schema2.table2") + self.assertEqual(result, ["schema1.table1", "schema2.table2"]) + + def test_flatten_table_lists_none(self): + """Test flattening None input.""" + from pyspark.pipelines.cli import flatten_table_lists + result = flatten_table_lists(None) + self.assertEqual(result, []) + + def test_flatten_table_lists_empty(self): + """Test flattening empty list.""" + from pyspark.pipelines.cli import flatten_table_lists + result = flatten_table_lists([]) + self.assertEqual(result, []) + + def test_flatten_table_lists_single_list(self): + """Test flattening single list.""" + from pyspark.pipelines.cli import flatten_table_lists + result = flatten_table_lists([["table1", "table2"]]) + self.assertEqual(result, ["table1", "table2"]) + + def test_flatten_table_lists_multiple_lists(self): + """Test flattening multiple lists.""" + from pyspark.pipelines.cli import flatten_table_lists + result = flatten_table_lists([["table1", "table2"], ["table3"], ["table4", "table5"]]) + self.assertEqual(result, ["table1", "table2", "table3", "table4", "table5"]) + + def test_valid_refresh_combinations(self): + """Test valid combinations of refresh parameters.""" + with tempfile.TemporaryDirectory() as temp_dir: + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test individual options don't raise validation errors + test_cases = [ + {"full_refresh": ["table1"]}, + {"refresh": ["table1"]}, + {"full_refresh_all": True}, + {"full_refresh": ["table1"], "refresh": ["table2"]}, + {"full_refresh": ["table1", "table2"], "refresh": ["table3", "table4"]}, + ] + + for case in test_cases: + try: + run(spec_path=spec_path, **case) + self.fail(f"Expected run to fail due to missing pipeline spec content: {case}") + except PySparkException as e: + # Should NOT be our validation error + self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + + def test_empty_refresh_parameters(self): + """Test behavior with empty refresh parameters.""" + with tempfile.TemporaryDirectory() as temp_dir: + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test empty lists don't cause validation errors + try: + run( + spec_path=spec_path, + full_refresh=[], + refresh=[], + full_refresh_all=False + ) + self.fail("Expected run to fail due to missing pipeline spec content") + except PySparkException as e: + # Should NOT be our validation error + self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + + def test_cli_argument_parsing_patterns(self): + """Test CLI argument parsing patterns for refresh options.""" + import argparse + from pyspark.pipelines.cli import parse_table_list + + # Simulate the argument parser + parser = argparse.ArgumentParser() + parser.add_argument("--full-refresh", type=parse_table_list, action="append") + parser.add_argument("--full-refresh-all", action="store_true") + parser.add_argument("--refresh", type=parse_table_list, action="append") + + # Test parsing various argument combinations + test_cases = [ + (["--full-refresh", "table1,table2"], {"full_refresh": [["table1", "table2"]]}), + (["--refresh", "table1", "--refresh", "table2"], {"refresh": [["table1"], ["table2"]]}), + (["--full-refresh-all"], {"full_refresh_all": True}), + (["--full-refresh", "table1", "--refresh", "table2"], {"full_refresh": [["table1"]], "refresh": [["table2"]]}), + (["--full-refresh", "schema.table1,schema.table2"], {"full_refresh": [["schema.table1", "schema.table2"]]}), + ] + + for args, expected in test_cases: + parsed = parser.parse_args(args) + for key, value in expected.items(): + self.assertEqual(getattr(parsed, key), value) + + def test_refresh_parameter_validation_edge_cases(self): + """Test edge cases for refresh parameter validation.""" + with tempfile.TemporaryDirectory() as temp_dir: + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that providing None values works correctly + try: + run( + spec_path=spec_path, + full_refresh=None, + refresh=None, + full_refresh_all=False + ) + self.fail("Expected run to fail due to missing pipeline spec content") + except PySparkException as e: + # Should NOT be our validation error + self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + + def test_refresh_parameters_with_qualified_table_names(self): + """Test refresh parameters with qualified table names.""" + with tempfile.TemporaryDirectory() as temp_dir: + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test qualified table names + try: + run( + spec_path=spec_path, + full_refresh=["schema1.table1", "schema2.table2"], + refresh=["schema3.table3"], + full_refresh_all=False + ) + self.fail("Expected run to fail due to missing pipeline spec content") + except PySparkException as e: + # Should NOT be our validation error + self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + + def test_large_table_lists_handling(self): + """Test handling of large table lists.""" + with tempfile.TemporaryDirectory() as temp_dir: + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test with large table lists + large_table_list = [f"table_{i}" for i in range(100)] + try: + run( + spec_path=spec_path, + full_refresh=large_table_list, + refresh=large_table_list, + full_refresh_all=False + ) + self.fail("Expected run to fail due to missing pipeline spec content") + except PySparkException as e: + # Should NOT be our validation error + self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + + def test_refresh_parameter_precedence(self): + """Test that full_refresh_all takes precedence over other parameters.""" + with tempfile.TemporaryDirectory() as temp_dir: + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test that full_refresh_all conflicts with other parameters + conflict_cases = [ + {"full_refresh_all": True, "full_refresh": ["table1"]}, + {"full_refresh_all": True, "refresh": ["table1"]}, + {"full_refresh_all": True, "full_refresh": ["table1"], "refresh": ["table2"]}, + ] + + for case in conflict_cases: + with self.assertRaises(PySparkException) as context: + run(spec_path=spec_path, **case) + self.assertEqual(context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + + def test_detailed_validation_error_messages(self): + """Test that validation error messages are detailed and helpful.""" + with tempfile.TemporaryDirectory() as temp_dir: + spec_path = Path(temp_dir) / "pipeline.yaml" + with spec_path.open("w") as f: + f.write('{"name": "test_pipeline"}') + + # Test full_refresh_all with full_refresh error message + with self.assertRaises(PySparkException) as context: + run( + spec_path=spec_path, + full_refresh_all=True, + full_refresh=["table1"] + ) + self.assertEqual(context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + self.assertIn("--full-refresh-all option conflicts with --full-refresh", str(context.exception)) + + # Test full_refresh_all with refresh error message + with self.assertRaises(PySparkException) as context: + run( + spec_path=spec_path, + full_refresh_all=True, + refresh=["table1"] + ) + self.assertEqual(context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") + self.assertIn("--full-refresh-all option conflicts with --refresh", str(context.exception)) + if __name__ == "__main__": try: diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py b/python/pyspark/sql/connect/proto/pipelines_pb2.py index 413e1fbe12b5..0e52646e7c3a 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.py +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xf2\x12\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x1a\x87\x03\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aQ\n\x08Response\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xd1\x04\n\rDefineDataset\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12&\n\x0c\x64\x61taset_name\x18\x02 \x01(\tH\x01R\x0b\x64\x61tasetName\x88\x01\x01\x12\x42\n\x0c\x64\x61taset_type\x18\x03 \x01(\x0e\x32\x1a.spark.connect.DatasetTypeH\x02R\x0b\x64\x61tasetType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x03R\x07\x63omment\x88\x01\x01\x12l\n\x10table_properties\x18\x05 \x03(\x0b\x32\x41.spark.connect.PipelineCommand.DefineDataset.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x06 \x03(\tR\rpartitionCols\x12\x34\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x04R\x06schema\x88\x01\x01\x12\x1b\n\x06\x66ormat\x18\x08 \x01(\tH\x05R\x06\x66ormat\x88\x01\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0f\n\r_dataset_nameB\x0f\n\r_dataset_typeB\n\n\x08_commentB\t\n\x07_schemaB\t\n\x07_format\x1a\xc8\x03\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x01R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x02R\x11targetDatasetName\x88\x01\x01\x12\x38\n\x08relation\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x03R\x08relation\x88\x01\x01\x12Q\n\x08sql_conf\x18\x05 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12\x17\n\x04once\x18\x06 \x01(\x08H\x04R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0b\n\t_relationB\x07\n\x05_once\x1aQ\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_textB\x0e\n\x0c\x63ommand_type"\x8e\x02\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message*a\n\x0b\x44\x61tasetType\x12\x1c\n\x18\x44\x41TASET_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xf4\x13\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x1a\x87\x03\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aQ\n\x08Response\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xd1\x04\n\rDefineDataset\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12&\n\x0c\x64\x61taset_name\x18\x02 \x01(\tH\x01R\x0b\x64\x61tasetName\x88\x01\x01\x12\x42\n\x0c\x64\x61taset_type\x18\x03 \x01(\x0e\x32\x1a.spark.connect.DatasetTypeH\x02R\x0b\x64\x61tasetType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x03R\x07\x63omment\x88\x01\x01\x12l\n\x10table_properties\x18\x05 \x03(\x0b\x32\x41.spark.connect.PipelineCommand.DefineDataset.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x06 \x03(\tR\rpartitionCols\x12\x34\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x04R\x06schema\x88\x01\x01\x12\x1b\n\x06\x66ormat\x18\x08 \x01(\tH\x05R\x06\x66ormat\x88\x01\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0f\n\r_dataset_nameB\x0f\n\r_dataset_typeB\n\n\x08_commentB\t\n\x07_schemaB\t\n\x07_format\x1a\xc8\x03\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x01R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x02R\x11targetDatasetName\x88\x01\x01\x12\x38\n\x08relation\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x03R\x08relation\x88\x01\x01\x12Q\n\x08sql_conf\x18\x05 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12\x17\n\x04once\x18\x06 \x01(\x08H\x04R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0b\n\t_relationB\x07\n\x05_once\x1a\xd2\x01\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12!\n\x0c\x66ull_refresh\x18\x02 \x03(\tR\x0b\x66ullRefresh\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12\x18\n\x07refresh\x18\x04 \x03(\tR\x07refreshB\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_all\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_textB\x0e\n\x0c\x63ommand_type"\x8e\x02\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message*a\n\x0b\x44\x61tasetType\x12\x1c\n\x18\x44\x41TASET_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -59,10 +59,10 @@ _globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options = b"8\001" _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = b"8\001" - _globals["_DATASETTYPE"]._serialized_start = 3026 - _globals["_DATASETTYPE"]._serialized_end = 3123 + _globals["_DATASETTYPE"]._serialized_start = 3156 + _globals["_DATASETTYPE"]._serialized_end = 3253 _globals["_PIPELINECOMMAND"]._serialized_start = 140 - _globals["_PIPELINECOMMAND"]._serialized_end = 2558 + _globals["_PIPELINECOMMAND"]._serialized_end = 2688 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 719 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1110 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start = 928 @@ -79,16 +79,16 @@ _globals["_PIPELINECOMMAND_DEFINEFLOW"]._serialized_end = 2257 _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start = 928 _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 986 - _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2259 - _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2340 - _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 2343 - _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2542 - _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2561 - _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2831 - _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 2718 - _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 2816 - _globals["_PIPELINEEVENTRESULT"]._serialized_start = 2833 - _globals["_PIPELINEEVENTRESULT"]._serialized_end = 2906 - _globals["_PIPELINEEVENT"]._serialized_start = 2908 - _globals["_PIPELINEEVENT"]._serialized_end = 3024 + _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2260 + _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2470 + _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 2473 + _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2672 + _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2691 + _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2961 + _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 2848 + _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 2946 + _globals["_PIPELINEEVENTRESULT"]._serialized_start = 2963 + _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3036 + _globals["_PIPELINEEVENT"]._serialized_start = 3038 + _globals["_PIPELINEEVENT"]._serialized_end = 3154 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi index 36fb73f06906..d52e4addf571 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi @@ -530,20 +530,42 @@ class PipelineCommand(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor DATAFLOW_GRAPH_ID_FIELD_NUMBER: builtins.int + FULL_REFRESH_FIELD_NUMBER: builtins.int + FULL_REFRESH_ALL_FIELD_NUMBER: builtins.int + REFRESH_FIELD_NUMBER: builtins.int dataflow_graph_id: builtins.str """The graph to start.""" + @property + def full_refresh( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """List of tables to reset and recompute.""" + full_refresh_all: builtins.bool + """Perform a full graph reset and recompute.""" + @property + def refresh( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """List of tables to update.""" def __init__( self, *, dataflow_graph_id: builtins.str | None = ..., + full_refresh: collections.abc.Iterable[builtins.str] | None = ..., + full_refresh_all: builtins.bool | None = ..., + refresh: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ "_dataflow_graph_id", b"_dataflow_graph_id", + "_full_refresh_all", + b"_full_refresh_all", "dataflow_graph_id", b"dataflow_graph_id", + "full_refresh_all", + b"full_refresh_all", ], ) -> builtins.bool: ... def ClearField( @@ -551,14 +573,27 @@ class PipelineCommand(google.protobuf.message.Message): field_name: typing_extensions.Literal[ "_dataflow_graph_id", b"_dataflow_graph_id", + "_full_refresh_all", + b"_full_refresh_all", "dataflow_graph_id", b"dataflow_graph_id", + "full_refresh", + b"full_refresh", + "full_refresh_all", + b"full_refresh_all", + "refresh", + b"refresh", ], ) -> None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["_dataflow_graph_id", b"_dataflow_graph_id"], ) -> typing_extensions.Literal["dataflow_graph_id"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_full_refresh_all", b"_full_refresh_all"] + ) -> typing_extensions.Literal["full_refresh_all"] | None: ... class DefineSqlGraphElements(google.protobuf.message.Message): """Parses the SQL file and registers all datasets and flows.""" diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto index c5a631264590..7f4dbb3a1f78 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto @@ -116,6 +116,15 @@ message PipelineCommand { message StartRun { // The graph to start. optional string dataflow_graph_id = 1; + + // List of tables to reset and recompute. + repeated string full_refresh = 2; + + // Perform a full graph reset and recompute. + optional bool full_refresh_all = 3; + + // List of tables to update. + repeated string refresh = 4; } // Parses the SQL file and registers all datasets and flows. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 7bb1d7358557..de602df2af40 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -26,6 +26,7 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineCommandResult, Relation} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connect.common.DataTypeProtoConverter @@ -33,7 +34,7 @@ import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.QueryOriginType import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED} -import org.apache.spark.sql.pipelines.graph.{FlowAnalysis, GraphIdentifierManager, IdentifierHelper, PipelineUpdateContextImpl, QueryContext, QueryOrigin, SqlGraphRegistrationContext, Table, TemporaryView, UnresolvedFlow} +import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress} import org.apache.spark.sql.types.StructType @@ -224,6 +225,48 @@ private[connect] object PipelinesHandler extends Logging { sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) + + // Extract refresh parameters from protobuf command + val fullRefreshTables = cmd.getFullRefreshList.asScala.toSeq + val fullRefreshAll = cmd.getFullRefreshAll + val refreshTables = cmd.getRefreshList.asScala.toSeq + + // Convert table names to TableIdentifier objects + def parseTableNames(tableNames: Seq[String]): Set[TableIdentifier] = { + tableNames.map { name => + GraphIdentifierManager.parseAndQualifyTableIdentifier( + rawTableIdentifier = + GraphIdentifierManager.parseTableIdentifier(name, sessionHolder.session), + currentCatalog = Some(graphElementRegistry.defaultCatalog), + currentDatabase = Some(graphElementRegistry.defaultDatabase) + ).identifier + }.toSet + } + + val fullRefreshTablesFilter: TableFilter = if (fullRefreshAll) { + AllTables + } else if (fullRefreshTables.nonEmpty) { + SomeTables(parseTableNames(fullRefreshTables)) + } else { + NoTables + } + + // Create table filters based on refresh parameters + val refreshTablesFilter: TableFilter = if (fullRefreshAll || fullRefreshTables.nonEmpty) { + NoTables + } else if (refreshTables.nonEmpty) { + SomeTables(parseTableNames(refreshTables)) + } else { + AllTables + } + + // print full refresh tables filter for debugging purposes + // scalastyle:off println + println( + s"Full refresh tables filter: $fullRefreshTablesFilter, " + + s"Refresh tables filter: $refreshTablesFilter") + // scalastyle:on println + // We will use this variable to store the run failure event if it occurs. This will be set // by the event callback. @volatile var runFailureEvent = Option.empty[PipelineEvent] @@ -279,8 +322,12 @@ private[connect] object PipelinesHandler extends Logging { .build()) } } - val pipelineUpdateContext = - new PipelineUpdateContextImpl(graphElementRegistry.toDataflowGraph, eventCallback) + val pipelineUpdateContext = new PipelineUpdateContextImpl( + graphElementRegistry.toDataflowGraph, + eventCallback, + refreshTablesFilter, + fullRefreshTablesFilter + ) sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext) pipelineUpdateContext.pipelineExecution.runPipeline() diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala index c68882df79ce..e03b6c299797 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala @@ -24,10 +24,17 @@ import org.apache.spark.sql.pipelines.logging.{FlowProgressEventLogger, Pipeline * An implementation of the PipelineUpdateContext trait used in production. * @param unresolvedGraph The graph (unresolved) to be executed in this update. * @param eventCallback A callback function to be called when an event is added to the event buffer. + * @param refreshTables Filter for which tables should be refreshed when performing this update. + * @param fullRefreshTables Filter for which tables should be full refreshed + * when performing this update. + * @param resetCheckpointFlows Filter for which flows should be reset. */ class PipelineUpdateContextImpl( override val unresolvedGraph: DataflowGraph, - override val eventCallback: PipelineEvent => Unit + override val eventCallback: PipelineEvent => Unit, + override val refreshTables: TableFilter = AllTables, + override val fullRefreshTables: TableFilter = NoTables, + override val resetCheckpointFlows: FlowFilter = NoFlows ) extends PipelineUpdateContext { override val spark: SparkSession = SparkSession.getActiveSession.getOrElse( @@ -36,8 +43,4 @@ class PipelineUpdateContextImpl( override val flowProgressEventLogger: FlowProgressEventLogger = new FlowProgressEventLogger(eventCallback = eventCallback) - - override val refreshTables: TableFilter = AllTables - override val fullRefreshTables: TableFilter = NoTables - override val resetCheckpointFlows: FlowFilter = NoFlows } From b8ef4c3312e21e7ae7d9c45c6602290264dc820f Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 00:09:09 -0700 Subject: [PATCH 03/17] 2 test pass --- .../connect/pipelines/PipelinesHandler.scala | 12 +- .../PipelineRefreshFunctionalSuite.scala | 289 ++++++++++++++++++ .../SparkDeclarativePipelinesServerTest.scala | 20 +- 3 files changed, 314 insertions(+), 7 deletions(-) create mode 100644 sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index de602df2af40..fa5b5d5a2b4d 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -252,12 +252,18 @@ private[connect] object PipelinesHandler extends Logging { } // Create table filters based on refresh parameters - val refreshTablesFilter: TableFilter = if (fullRefreshAll || fullRefreshTables.nonEmpty) { + val refreshTablesFilter: TableFilter = if (fullRefreshAll) { NoTables } else if (refreshTables.nonEmpty) { SomeTables(parseTableNames(refreshTables)) - } else { - AllTables + } else { // no tables specified for refresh + if (fullRefreshTablesFilter == NoTables) { + // If no tables are specified for full refresh, we default to refreshing all tables + AllTables + } else { + // If full refresh is specified, we do not need to refresh any additional tables + NoTables + } } // print full refresh tables filter for debugging purposes diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala new file mode 100644 index 000000000000..696661a88067 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect.pipelines + +import java.io.{File, FileWriter} + +import scala.jdk.CollectionConverters._ + +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.DatasetType +import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} +import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} +import org.apache.spark.util.Utils +// scalastyle:off println + +/** + * Comprehensive test suite that validates pipeline refresh functionality by running actual + * pipelines with different refresh parameters and validating the results. + */ +class PipelineRefreshFunctionalSuite + extends SparkDeclarativePipelinesServerTest + with TestPipelineUpdateContextMixin + with EventVerificationTestHelpers { + + private var testDataDir: File = _ + private var streamInputDir: File = _ + + override def beforeAll(): Unit = { + super.beforeAll() + // Create temporary directories for test data + testDataDir = Utils.createTempDir("pipeline-refresh-test") + streamInputDir = new File(testDataDir, "stream-input") + streamInputDir.mkdirs() + } + + override def afterAll(): Unit = { + try { + // Clean up test directories + Utils.deleteRecursively(testDataDir) + } finally { + super.afterAll() + } + } + + private def uploadInputFile(filename: String, contents: String): Unit = { + val file = new File(streamInputDir, filename) + val writer = new FileWriter(file) + try { + writer.write(contents) + } finally { + writer.close() + } + } + + + private def createPipelineWithTransformations(graphId: String): TestPipelineDefinition = { + new TestPipelineDefinition(graphId) { + // Create a mv that reads from files + createTable( + name = "file_data", + datasetType = DatasetType.MATERIALIZED_VIEW, + sql = Some(s""" + SELECT id, value FROM JSON.`${streamInputDir.getAbsolutePath}/*.json` + """)) + + // Create tables that depend on the mv + createTable( + name = "table_a", + datasetType = DatasetType.TABLE, + sql = Some("SELECT id, value as odd_value FROM STREAM file_data WHERE id % 2 = 1")) + + createTable( + name = "table_b", + datasetType = DatasetType.TABLE, + sql = Some("SELECT id, value as even_value FROM STREAM file_data WHERE id % 2 = 0")) + } + } + + private def createTwoSTPipeline(graphId: String): TestPipelineDefinition = { + new TestPipelineDefinition(graphId) { + // Create a mv that reads from files + createTable( + name = "file_data", + datasetType = DatasetType.MATERIALIZED_VIEW, + sql = Some(s""" + SELECT id, value FROM JSON.`${streamInputDir.getAbsolutePath}/*.json` + """)) + + // Create tables that depend on the mv + createTable( + name = "a", + datasetType = DatasetType.TABLE, + sql = Some("SELECT id, value FROM STREAM file_data")) + + createTable( + name = "b", + datasetType = DatasetType.TABLE, + sql = Some("SELECT id, value FROM STREAM file_data")) + } + } + + test("pipeline runs selective full_refresh") { + withRawBlockingStub { implicit stub => + uploadInputFile("data.json", """ + |{"id": 1, "value": 1} + |{"id": 2, "value": 2} + """.stripMargin) + val graphId = createDataflowGraph + val pipeline = createPipelineWithTransformations(graphId) + registerPipelineDatasets(pipeline) + + // First run to populate tables + startPipelineAndWaitForCompletion(graphId) + + // Verify initial data from file stream + verifyMultipleTableContent( + tableNames = Set( + "spark_catalog.default.file_data", + "spark_catalog.default.table_a", + "spark_catalog.default.table_b"), + columnsToVerify = Map( + "spark_catalog.default.file_data" -> Seq("id", "value"), + "spark_catalog.default.table_a" -> Seq("id", "odd_value"), + "spark_catalog.default.table_b" -> Seq("id", "even_value") + ), + expectedContent = Map( + "spark_catalog.default.file_data" -> Set( + Map("id" -> 1, "value" -> 1), + Map("id" -> 2, "value" -> 2) + ), + "spark_catalog.default.table_a" -> Set( + Map("id" -> 1, "odd_value" -> 1) + ), + "spark_catalog.default.table_b" -> Set( + Map("id" -> 2, "even_value" -> 2) + ) + ) + ) + + // Clear cached pipeline execution before starting new run + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) + .foreach(_.removeAllPipelineExecutions()) + + // simulate a full refresh by uploading new data + uploadInputFile("data.json", """ + |{"id": 1, "value": 1} + |{"id": 2, "value": 2} + |{"id": 3, "value": 3} + |{"id": 4, "value": 4} + """.stripMargin) + + // Run with full refresh on specific tables + val fullRefreshTables = List("file_data", "table_a") + val startRun = proto.PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .addAllFullRefresh(fullRefreshTables.asJava) + .build() + + val capturedEvents = startPipelineAndWaitForCompletion(graphId, Some(startRun)) + // assert that table_b is excluded + assert(capturedEvents.exists( + _.getMessage.contains(s"Flow \'spark_catalog.default.table_b\' is EXCLUDED."))) + // assert that table_a and file_data ran to completion + assert(capturedEvents.exists( + _.getMessage.contains(s"Flow spark_catalog.default.table_a has COMPLETED."))) + assert(capturedEvents.exists( + _.getMessage.contains(s"Flow spark_catalog.default.file_data has COMPLETED."))) + // Verify completion event + assert(capturedEvents.exists(_.getMessage.contains("Run is COMPLETED"))) + + verifyMultipleTableContent( + tableNames = Set("spark_catalog.default.file_data", "spark_catalog.default.table_a"), + columnsToVerify = Map( + "spark_catalog.default.file_data" -> Seq("id", "value"), + "spark_catalog.default.table_a" -> Seq("id", "odd_value"), + "spark_catalog.default.table_b" -> Seq("id", "even_value") + ), + expectedContent = Map( + "spark_catalog.default.file_data" -> Set( + Map("id" -> 1, "value" -> 1), + Map("id" -> 2, "value" -> 2), + Map("id" -> 3, "value" -> 3), + Map("id" -> 4, "value" -> 4) + ), + "spark_catalog.default.table_a" -> Set( + Map("id" -> 1, "odd_value" -> 1), + Map("id" -> 3, "odd_value" -> 3) + ), + "spark_catalog.default.table_b" -> Set( + Map("id" -> 2, "even_value" -> 4) // table_b should not have changed + ) + ) + ) + } + } + + test("pipeline runs selective full_refresh and selective refresh") { + withRawBlockingStub { implicit stub => + uploadInputFile("data.json", """ + |{"id": "x", "value": 1} + """.stripMargin) + val graphId = createDataflowGraph + val pipeline = createTwoSTPipeline(graphId) + registerPipelineDatasets(pipeline) + + // First run to populate tables + startPipelineAndWaitForCompletion(graphId) + + // Verify initial data from file stream + verifyMultipleTableContent( + tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + columnsToVerify = Map( + "spark_catalog.default.a" -> Seq("id", "value"), + "spark_catalog.default.b" -> Seq("id", "value") + ), + expectedContent = Map( + "spark_catalog.default.a" -> Set(Map("id" -> "x", "value" -> 1)), + "spark_catalog.default.b" -> Set(Map("id" -> "x", "value" -> 1)) + ) + ) + + // Clear cached pipeline execution before starting new run + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) + .foreach(_.removeAllPipelineExecutions()) + + uploadInputFile("data.json", """ + |{"id": "x", "value": 2} + """.stripMargin) + + val startRun = proto.PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .addAllFullRefresh(List("file_data", "a").asJava) + .addRefresh("b") + .build() + + startPipelineAndWaitForCompletion(graphId, Some(startRun)) + + // assert that table_b is refreshed + verifyMultipleTableContent( + tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + columnsToVerify = Map( + "spark_catalog.default.a" -> Seq("id", "value"), + "spark_catalog.default.b" -> Seq("id", "value") + ), + expectedContent = Map( + // a should be fully refreshed and only contain the new value + "spark_catalog.default.a" -> Set(Map("id" -> "x", "value" -> 2)), + "spark_catalog.default.b" -> + // b is incrementally refreshed and contain the new value in addition to the old one + Set(Map("id" -> "x", "value" -> 1), Map("id" -> "x", "value" -> 2)) + ) + ) + } + } + + private def verifyMultipleTableContent( + tableNames: Set[String], + columnsToVerify: Map[String, Seq[String]], + expectedContent: Map[String, Set[Map[String, Any]]]): Unit = { + tableNames.foreach { tableName => + spark.catalog.refreshTable(tableName) + val df = spark.table(tableName) + assert(df.columns.toSet == columnsToVerify(tableName).toSet, + s"Columns in $tableName do not match expected: ${df.columns.mkString(", ")}") + val actualContent = df.collect().map(row => { + columnsToVerify(tableName).map(col => col -> row.getAs[Any](col)).toMap + }).toSet + assert(actualContent == expectedContent(tableName), + s"Content of $tableName does not match expected: $actualContent") + } + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index ea4cc5f3aba5..f18681834622 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -17,12 +17,16 @@ package org.apache.spark.sql.connect.pipelines +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.connect.{proto => sc} +import org.apache.spark.connect.proto.{PipelineCommand, PipelineEvent} import org.apache.spark.sql.connect.{SparkConnectServerTest, SparkConnectTestUtils} import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.PipelineTest + class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { override def afterEach(): Unit = { @@ -125,15 +129,23 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { def createPlanner(): SparkConnectPlanner = new SparkConnectPlanner(SparkConnectTestUtils.createDummySessionHolder(spark)) - def startPipelineAndWaitForCompletion(graphId: String): Unit = { + def startPipelineAndWaitForCompletion( + graphId: String, + customStartRunCommand: Option[PipelineCommand.StartRun] = None): ArrayBuffer[PipelineEvent] = { withClient { client => - val startRunRequest = buildStartRunPlan( - sc.PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build()) + val capturedEvents = new ArrayBuffer[PipelineEvent]() + val startRunRequest = buildStartRunPlan(customStartRunCommand.getOrElse( + PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())) val responseIterator = client.execute(startRunRequest) // The response iterator will be closed when the pipeline is completed. while (responseIterator.hasNext) { - responseIterator.next() + val response = responseIterator.next() + if (response.hasPipelineEventResult) { + capturedEvents.append(response.getPipelineEventResult.getEvent) + } } + return capturedEvents } + ArrayBuffer.empty[PipelineEvent] } } From 7fbe8e7157332504d75b43028f9bf03ecede66a7 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 11:21:25 -0700 Subject: [PATCH 04/17] more tests --- .../PipelineRefreshFunctionalSuite.scala | 223 ++++++++++++------ 1 file changed, 155 insertions(+), 68 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index 696661a88067..13f28b85a07f 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -67,30 +67,6 @@ class PipelineRefreshFunctionalSuite } } - - private def createPipelineWithTransformations(graphId: String): TestPipelineDefinition = { - new TestPipelineDefinition(graphId) { - // Create a mv that reads from files - createTable( - name = "file_data", - datasetType = DatasetType.MATERIALIZED_VIEW, - sql = Some(s""" - SELECT id, value FROM JSON.`${streamInputDir.getAbsolutePath}/*.json` - """)) - - // Create tables that depend on the mv - createTable( - name = "table_a", - datasetType = DatasetType.TABLE, - sql = Some("SELECT id, value as odd_value FROM STREAM file_data WHERE id % 2 = 1")) - - createTable( - name = "table_b", - datasetType = DatasetType.TABLE, - sql = Some("SELECT id, value as even_value FROM STREAM file_data WHERE id % 2 = 0")) - } - } - private def createTwoSTPipeline(graphId: String): TestPipelineDefinition = { new TestPipelineDefinition(graphId) { // Create a mv that reads from files @@ -116,12 +92,12 @@ class PipelineRefreshFunctionalSuite test("pipeline runs selective full_refresh") { withRawBlockingStub { implicit stub => - uploadInputFile("data.json", """ - |{"id": 1, "value": 1} - |{"id": 2, "value": 2} - """.stripMargin) + uploadInputFile("data.json", + """ + |{"id": "x", "value": 1} + """.stripMargin) val graphId = createDataflowGraph - val pipeline = createPipelineWithTransformations(graphId) + val pipeline = createTwoSTPipeline(graphId) registerPipelineDatasets(pipeline) // First run to populate tables @@ -131,23 +107,22 @@ class PipelineRefreshFunctionalSuite verifyMultipleTableContent( tableNames = Set( "spark_catalog.default.file_data", - "spark_catalog.default.table_a", - "spark_catalog.default.table_b"), + "spark_catalog.default.a", + "spark_catalog.default.b"), columnsToVerify = Map( "spark_catalog.default.file_data" -> Seq("id", "value"), - "spark_catalog.default.table_a" -> Seq("id", "odd_value"), - "spark_catalog.default.table_b" -> Seq("id", "even_value") + "spark_catalog.default.a" -> Seq("id", "value"), + "spark_catalog.default.b" -> Seq("id", "value") ), expectedContent = Map( "spark_catalog.default.file_data" -> Set( - Map("id" -> 1, "value" -> 1), - Map("id" -> 2, "value" -> 2) + Map("id" -> "x", "value" -> 1) ), - "spark_catalog.default.table_a" -> Set( - Map("id" -> 1, "odd_value" -> 1) + "spark_catalog.default.a" -> Set( + Map("id" -> "x", "value" -> 1) ), - "spark_catalog.default.table_b" -> Set( - Map("id" -> 2, "even_value" -> 2) + "spark_catalog.default.b" -> Set( + Map("id" -> "x", "value" -> 1) ) ) ) @@ -158,15 +133,13 @@ class PipelineRefreshFunctionalSuite .foreach(_.removeAllPipelineExecutions()) // simulate a full refresh by uploading new data - uploadInputFile("data.json", """ - |{"id": 1, "value": 1} - |{"id": 2, "value": 2} - |{"id": 3, "value": 3} - |{"id": 4, "value": 4} - """.stripMargin) + uploadInputFile("data.json", + """ + |{"id": "x", "value": 2} + """.stripMargin) // Run with full refresh on specific tables - val fullRefreshTables = List("file_data", "table_a") + val fullRefreshTables = List("file_data", "a") val startRun = proto.PipelineCommand.StartRun.newBuilder() .setDataflowGraphId(graphId) .addAllFullRefresh(fullRefreshTables.asJava) @@ -175,35 +148,32 @@ class PipelineRefreshFunctionalSuite val capturedEvents = startPipelineAndWaitForCompletion(graphId, Some(startRun)) // assert that table_b is excluded assert(capturedEvents.exists( - _.getMessage.contains(s"Flow \'spark_catalog.default.table_b\' is EXCLUDED."))) + _.getMessage.contains(s"Flow \'spark_catalog.default.b\' is EXCLUDED."))) // assert that table_a and file_data ran to completion assert(capturedEvents.exists( - _.getMessage.contains(s"Flow spark_catalog.default.table_a has COMPLETED."))) + _.getMessage.contains(s"Flow spark_catalog.default.a has COMPLETED."))) assert(capturedEvents.exists( _.getMessage.contains(s"Flow spark_catalog.default.file_data has COMPLETED."))) // Verify completion event assert(capturedEvents.exists(_.getMessage.contains("Run is COMPLETED"))) verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.file_data", "spark_catalog.default.table_a"), + tableNames = Set("spark_catalog.default.file_data", "spark_catalog.default.a", + "spark_catalog.default.b"), columnsToVerify = Map( "spark_catalog.default.file_data" -> Seq("id", "value"), - "spark_catalog.default.table_a" -> Seq("id", "odd_value"), - "spark_catalog.default.table_b" -> Seq("id", "even_value") + "spark_catalog.default.a" -> Seq("id", "value"), + "spark_catalog.default.b" -> Seq("id", "value") ), expectedContent = Map( "spark_catalog.default.file_data" -> Set( - Map("id" -> 1, "value" -> 1), - Map("id" -> 2, "value" -> 2), - Map("id" -> 3, "value" -> 3), - Map("id" -> 4, "value" -> 4) + Map("id" -> "x", "value" -> 2) ), - "spark_catalog.default.table_a" -> Set( - Map("id" -> 1, "odd_value" -> 1), - Map("id" -> 3, "odd_value" -> 3) + "spark_catalog.default.a" -> Set( + Map("id" -> "x", "value" -> 2) ), - "spark_catalog.default.table_b" -> Set( - Map("id" -> 2, "even_value" -> 4) // table_b should not have changed + "spark_catalog.default.b" -> Set( + Map("id" -> "x", "value" -> 1) // b should not be refreshed, so it retains the old value ) ) ) @@ -212,8 +182,9 @@ class PipelineRefreshFunctionalSuite test("pipeline runs selective full_refresh and selective refresh") { withRawBlockingStub { implicit stub => - uploadInputFile("data.json", """ - |{"id": "x", "value": 1} + uploadInputFile("data.json", + """ + |{"id": "x", "value": 1} """.stripMargin) val graphId = createDataflowGraph val pipeline = createTwoSTPipeline(graphId) @@ -240,8 +211,9 @@ class PipelineRefreshFunctionalSuite .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .foreach(_.removeAllPipelineExecutions()) - uploadInputFile("data.json", """ - |{"id": "x", "value": 2} + uploadInputFile("data.json", + """ + |{"id": "y", "value": 2} """.stripMargin) val startRun = proto.PipelineCommand.StartRun.newBuilder() @@ -261,10 +233,125 @@ class PipelineRefreshFunctionalSuite ), expectedContent = Map( // a should be fully refreshed and only contain the new value - "spark_catalog.default.a" -> Set(Map("id" -> "x", "value" -> 2)), + "spark_catalog.default.a" -> Set(Map("id" -> "y", "value" -> 2)), + "spark_catalog.default.b" -> + // b contain the new value in addition to the old one + Set(Map("id" -> "x", "value" -> 1), Map("id" -> "y", "value" -> 2)) + ) + ) + } + } + + test("pipeline runs refresh by default") { + withRawBlockingStub { implicit stub => + uploadInputFile("data.json", + """ + |{"id": "x", "value": 1} + """.stripMargin) + val graphId = createDataflowGraph + val pipeline = createTwoSTPipeline(graphId) + registerPipelineDatasets(pipeline) + + // First run to populate tables + startPipelineAndWaitForCompletion(graphId) + + // Verify initial data from file stream + verifyMultipleTableContent( + tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + columnsToVerify = Map( + "spark_catalog.default.a" -> Seq("id", "value"), + "spark_catalog.default.b" -> Seq("id", "value") + ), + expectedContent = Map( + "spark_catalog.default.a" -> Set(Map("id" -> "x", "value" -> 1)), + "spark_catalog.default.b" -> Set(Map("id" -> "x", "value" -> 1)) + ) + ) + + // Clear cached pipeline execution before starting new run + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) + .foreach(_.removeAllPipelineExecutions()) + + uploadInputFile("data.json", + """ + |{"id": "y", "value": 2} + """.stripMargin) + + // Create a default StartRun command that refreshes all tables + startPipelineAndWaitForCompletion(graphId) + + // assert that both tables are refreshed + verifyMultipleTableContent( + tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + columnsToVerify = Map( + "spark_catalog.default.a" -> Seq("id", "value"), + "spark_catalog.default.b" -> Seq("id", "value") + ), + expectedContent = Map( + // both tables should contain the new value in addition to the old one + "spark_catalog.default.a" -> + Set(Map("id" -> "x", "value" -> 1), Map("id" -> "y", "value" -> 2)), "spark_catalog.default.b" -> - // b is incrementally refreshed and contain the new value in addition to the old one - Set(Map("id" -> "x", "value" -> 1), Map("id" -> "x", "value" -> 2)) + Set(Map("id" -> "x", "value" -> 1), Map("id" -> "y", "value" -> 2)) + ) + ) + } + } + + test("pipeline runs full_refresh_all") { + withRawBlockingStub { implicit stub => + uploadInputFile("data.json", + """ + |{"id": "x", "value": 1} + """.stripMargin) + val graphId = createDataflowGraph + val pipeline = createTwoSTPipeline(graphId) + registerPipelineDatasets(pipeline) + + // First run to populate tables + startPipelineAndWaitForCompletion(graphId) + + // Verify initial data from file stream + verifyMultipleTableContent( + tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + columnsToVerify = Map( + "spark_catalog.default.a" -> Seq("id", "value"), + "spark_catalog.default.b" -> Seq("id", "value") + ), + expectedContent = Map( + "spark_catalog.default.a" -> Set(Map("id" -> "x", "value" -> 1)), + "spark_catalog.default.b" -> Set(Map("id" -> "x", "value" -> 1)) + ) + ) + // Clear cached pipeline execution before starting new run + SparkConnectService.sessionManager + .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) + .foreach(_.removeAllPipelineExecutions()) + + uploadInputFile("data.json", + """ + |{"id": "y", "value": 2} + """.stripMargin) + + // Create a default StartRun command that refreshes all tables + val startRun = proto.PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .setFullRefreshAll(true) + .build() + startPipelineAndWaitForCompletion(graphId, Some(startRun)) + + // assert that all tables are fully refreshed + verifyMultipleTableContent( + tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + columnsToVerify = Map( + "spark_catalog.default.a" -> Seq("id", "value"), + "spark_catalog.default.b" -> Seq("id", "value") + ), + // both tables should only contain the new value + expectedContent = Map( + "spark_catalog.default.a" -> Set(Map("id" -> "y", "value" -> 2)), + "spark_catalog.default.b" -> Set(Map("id" -> "y", "value" -> 2)) ) ) } @@ -275,7 +362,7 @@ class PipelineRefreshFunctionalSuite columnsToVerify: Map[String, Seq[String]], expectedContent: Map[String, Set[Map[String, Any]]]): Unit = { tableNames.foreach { tableName => - spark.catalog.refreshTable(tableName) + spark.catalog.refreshTable(tableName) // clear cache for the table val df = spark.table(tableName) assert(df.columns.toSet == columnsToVerify(tableName).toSet, s"Columns in $tableName do not match expected: ${df.columns.mkString(", ")}") From 109de103d2841c6f36744a9644f9873c7112ca4f Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 12:05:30 -0700 Subject: [PATCH 05/17] add server side validation --- .../connect/pipelines/PipelinesHandler.scala | 54 +++++++++++-------- .../PipelineRefreshFunctionalSuite.scala | 2 - 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index fa5b5d5a2b4d..7c6dffde20be 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -231,7 +231,7 @@ private[connect] object PipelinesHandler extends Logging { val fullRefreshAll = cmd.getFullRefreshAll val refreshTables = cmd.getRefreshList.asScala.toSeq - // Convert table names to TableIdentifier objects + // Convert table names to fully qualified TableIdentifier objects def parseTableNames(tableNames: Seq[String]): Set[TableIdentifier] = { tableNames.map { name => GraphIdentifierManager.parseAndQualifyTableIdentifier( @@ -243,35 +243,45 @@ private[connect] object PipelinesHandler extends Logging { }.toSet } + if (fullRefreshTables.nonEmpty && fullRefreshAll) { + throw new IllegalArgumentException( + "Cannot specify a subset to refresh when full refresh all is set to true.") + } + + if (refreshTables.nonEmpty && fullRefreshAll) { + throw new IllegalArgumentException( + "Cannot specify a subset to full refresh when full refresh all is set to true.") + } + val refreshTableNames = parseTableNames(refreshTables) + val fullRefreshTableNames = parseTableNames(fullRefreshTables) + + if (refreshTables.nonEmpty && fullRefreshTables.nonEmpty) { + // check if there is an intersection between the subset + val intersection = refreshTableNames.intersect(fullRefreshTableNames) + if (intersection.nonEmpty) { + throw new IllegalArgumentException( + "Datasets specified for refresh and full refresh cannot overlap: " + + s"${intersection.mkString(", ")}" + ) + } + } + val fullRefreshTablesFilter: TableFilter = if (fullRefreshAll) { AllTables } else if (fullRefreshTables.nonEmpty) { - SomeTables(parseTableNames(fullRefreshTables)) + SomeTables(fullRefreshTableNames) } else { NoTables } - // Create table filters based on refresh parameters - val refreshTablesFilter: TableFilter = if (fullRefreshAll) { - NoTables - } else if (refreshTables.nonEmpty) { - SomeTables(parseTableNames(refreshTables)) - } else { // no tables specified for refresh - if (fullRefreshTablesFilter == NoTables) { - // If no tables are specified for full refresh, we default to refreshing all tables - AllTables - } else { - // If full refresh is specified, we do not need to refresh any additional tables + val refreshTablesFilter: TableFilter = + if (refreshTables.nonEmpty) { + SomeTables(refreshTableNames) + } else if (fullRefreshTablesFilter != NoTables) { NoTables - } - } - - // print full refresh tables filter for debugging purposes - // scalastyle:off println - println( - s"Full refresh tables filter: $fullRefreshTablesFilter, " + - s"Refresh tables filter: $refreshTablesFilter") - // scalastyle:on println + } else { + AllTables + } // We will use this variable to store the run failure event if it occurs. This will be set // by the event callback. diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index 13f28b85a07f..8ad3cde25027 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.connect.proto.DatasetType import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} import org.apache.spark.util.Utils -// scalastyle:off println /** * Comprehensive test suite that validates pipeline refresh functionality by running actual @@ -50,7 +49,6 @@ class PipelineRefreshFunctionalSuite override def afterAll(): Unit = { try { - // Clean up test directories Utils.deleteRecursively(testDataDir) } finally { super.afterAll() From 7e990c50ff416f22ff79adc4f7d4f357b7e75087 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 12:10:54 -0700 Subject: [PATCH 06/17] add validation tests --- .../PipelineRefreshFunctionalSuite.scala | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index 8ad3cde25027..87c2c99620bf 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -355,6 +355,110 @@ class PipelineRefreshFunctionalSuite } } + test("validation: cannot specify subset refresh when full_refresh_all is true") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTwoSTPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = proto.PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .setFullRefreshAll(true) + .addRefresh("a") + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(graphId, Some(startRun)) + } + assert(exception.getMessage.contains( + "Cannot specify a subset to full refresh when full refresh all is set to true")) + } + } + + test("validation: cannot specify subset full_refresh when full_refresh_all is true") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTwoSTPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = proto.PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .setFullRefreshAll(true) + .addFullRefresh("a") + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(graphId, Some(startRun)) + } + assert(exception.getMessage.contains( + "Cannot specify a subset to refresh when full refresh all is set to true")) + } + } + + test("validation: refresh and full_refresh cannot overlap") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTwoSTPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = proto.PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .addRefresh("a") + .addFullRefresh("a") + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(graphId, Some(startRun)) + } + assert(exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) + assert(exception.getMessage.contains("a")) + } + } + + test("validation: multiple overlapping tables in refresh and full_refresh") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTwoSTPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = proto.PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .addRefresh("a") + .addRefresh("b") + .addFullRefresh("a") + .addFullRefresh("file_data") + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(graphId, Some(startRun)) + } + assert(exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) + assert(exception.getMessage.contains("a")) + } + } + + test("validation: fully qualified table names in validation") { + withRawBlockingStub { implicit stub => + val graphId = createDataflowGraph + val pipeline = createTwoSTPipeline(graphId) + registerPipelineDatasets(pipeline) + + val startRun = proto.PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .addRefresh("spark_catalog.default.a") + .addFullRefresh("a") // This should be treated as the same table + .build() + + val exception = intercept[IllegalArgumentException] { + startPipelineAndWaitForCompletion(graphId, Some(startRun)) + } + assert(exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) + } + } + private def verifyMultipleTableContent( tableNames: Set[String], columnsToVerify: Map[String, Seq[String]], From e494a57617380f5d84b337a4d4a828036448eb01 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 13:23:06 -0700 Subject: [PATCH 07/17] python cli tests --- python/pyspark/errors/error-conditions.json | 7 + python/pyspark/pipelines/cli.py | 22 +-- .../pipelines/spark_connect_pipeline.py | 4 +- python/pyspark/pipelines/tests/test_cli.py | 126 +----------------- 4 files changed, 19 insertions(+), 140 deletions(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 0e54e7628fcb..f4adea5fba83 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -174,6 +174,13 @@ "`` does not allow a Column in a list." ] }, + "CONFLICTING_PIPELINE_REFRESH_OPTIONS" : { + "message" : [ + "--full-refresh-all option conflicts with --refresh and --full-refresh. ", + "The --full-refresh-all option performs a full refresh of all datasets, ", + "so specifying individual datasets with --refresh or --full-refresh is not allowed." + ] + }, "CONNECT_URL_ALREADY_DEFINED": { "message": [ "Only one Spark Connect client URL can be set; however, got a different URL [] from the existing []." diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 95a3e28113dc..8efce47e4cdb 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -226,29 +226,21 @@ def run( """Run the pipeline defined with the given spec. :param spec_path: Path to the pipeline specification file. - :param full_refresh: List of tables to reset and recompute. + :param full_refresh: List of datasets to reset and recompute. :param full_refresh_all: Perform a full graph reset and recompute. - :param refresh: List of tables to update. + :param refresh: List of datasets to update. """ # Validate conflicting arguments if full_refresh_all: if full_refresh: raise PySparkException( errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", - messageParameters={ - "message": "--full-refresh-all option conflicts with --full-refresh. " - "The --full-refresh-all option performs a full refresh of all tables, " - "so specifying individual tables with --full-refresh is not allowed." - } + messageParameters={} ) if refresh: raise PySparkException( - errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", - messageParameters={ - "message": "--full-refresh-all option conflicts with --refresh. " - "The --full-refresh-all option performs a full refresh of all tables, " - "so specifying individual tables with --refresh is not allowed." - } + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", + messageParameters={} ) log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...") @@ -313,14 +305,14 @@ def flatten_table_lists(table_lists: Optional[List[List[str]]]) -> Optional[List "--full-refresh", type=parse_table_list, action="append", - help="List of tables to reset and recompute (comma-separated). Can be specified multiple times." + help="List of datasets to reset and recompute (comma-separated)." ) run_parser.add_argument("--full-refresh-all", action="store_true", help="Perform a full graph reset and recompute.") run_parser.add_argument( "--refresh", type=parse_table_list, action="append", - help="List of tables to update (comma-separated). Can be specified multiple times." + help="List of datasets to update (comma-separated)." ) # "init" subcommand diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py b/python/pyspark/pipelines/spark_connect_pipeline.py index e35ad055b270..f4f46c3ee3ab 100644 --- a/python/pyspark/pipelines/spark_connect_pipeline.py +++ b/python/pyspark/pipelines/spark_connect_pipeline.py @@ -75,9 +75,9 @@ def start_run( """Start a run of the dataflow graph in the Spark Connect server. :param dataflow_graph_id: The ID of the dataflow graph to start. - :param full_refresh: List of tables to reset and recompute. + :param full_refresh: List of datasets to reset and recompute. :param full_refresh_all: Perform a full graph reset and recompute. - :param refresh: List of tables to update. + :param refresh: List of datasets to update. """ inner_command = pb2.PipelineCommand.StartRun( dataflow_graph_id=dataflow_graph_id, diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index 242d9ab244a5..7471737118c3 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -385,11 +385,7 @@ def test_full_refresh_all_conflicts_with_full_refresh(self): self.assertEqual( context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) - message_params = context.exception.getMessageParameters() - self.assertIsNotNone(message_params) - message = cast(dict, message_params)["message"] - self.assertIn("--full-refresh-all option conflicts with --full-refresh", message) - self.assertIn("performs a full refresh of all tables", message) + def test_full_refresh_all_conflicts_with_refresh(self): with tempfile.TemporaryDirectory() as temp_dir: @@ -410,11 +406,6 @@ def test_full_refresh_all_conflicts_with_refresh(self): self.assertEqual( context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) - message_params = context.exception.getMessageParameters() - self.assertIsNotNone(message_params) - message = cast(dict, message_params)["message"] - self.assertIn("--full-refresh-all option conflicts with --refresh", message) - self.assertIn("performs a full refresh of all tables", message) def test_full_refresh_all_conflicts_with_both(self): with tempfile.TemporaryDirectory() as temp_dir: @@ -436,10 +427,6 @@ def test_full_refresh_all_conflicts_with_both(self): self.assertEqual( context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) - message_params = context.exception.getMessageParameters() - self.assertIsNotNone(message_params) - message = cast(dict, message_params)["message"] - self.assertIn("--full-refresh-all option conflicts with --full-refresh", message) def test_no_conflict_when_full_refresh_all_alone(self): with tempfile.TemporaryDirectory() as temp_dir: @@ -518,13 +505,13 @@ def test_flatten_table_lists_none(self): """Test flattening None input.""" from pyspark.pipelines.cli import flatten_table_lists result = flatten_table_lists(None) - self.assertEqual(result, []) + self.assertEqual(result, None) def test_flatten_table_lists_empty(self): """Test flattening empty list.""" from pyspark.pipelines.cli import flatten_table_lists result = flatten_table_lists([]) - self.assertEqual(result, []) + self.assertEqual(result, None) def test_flatten_table_lists_single_list(self): """Test flattening single list.""" @@ -562,26 +549,6 @@ def test_valid_refresh_combinations(self): # Should NOT be our validation error self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - def test_empty_refresh_parameters(self): - """Test behavior with empty refresh parameters.""" - with tempfile.TemporaryDirectory() as temp_dir: - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test empty lists don't cause validation errors - try: - run( - spec_path=spec_path, - full_refresh=[], - refresh=[], - full_refresh_all=False - ) - self.fail("Expected run to fail due to missing pipeline spec content") - except PySparkException as e: - # Should NOT be our validation error - self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - def test_cli_argument_parsing_patterns(self): """Test CLI argument parsing patterns for refresh options.""" import argparse @@ -627,93 +594,6 @@ def test_refresh_parameter_validation_edge_cases(self): # Should NOT be our validation error self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - def test_refresh_parameters_with_qualified_table_names(self): - """Test refresh parameters with qualified table names.""" - with tempfile.TemporaryDirectory() as temp_dir: - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test qualified table names - try: - run( - spec_path=spec_path, - full_refresh=["schema1.table1", "schema2.table2"], - refresh=["schema3.table3"], - full_refresh_all=False - ) - self.fail("Expected run to fail due to missing pipeline spec content") - except PySparkException as e: - # Should NOT be our validation error - self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - - def test_large_table_lists_handling(self): - """Test handling of large table lists.""" - with tempfile.TemporaryDirectory() as temp_dir: - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test with large table lists - large_table_list = [f"table_{i}" for i in range(100)] - try: - run( - spec_path=spec_path, - full_refresh=large_table_list, - refresh=large_table_list, - full_refresh_all=False - ) - self.fail("Expected run to fail due to missing pipeline spec content") - except PySparkException as e: - # Should NOT be our validation error - self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - - def test_refresh_parameter_precedence(self): - """Test that full_refresh_all takes precedence over other parameters.""" - with tempfile.TemporaryDirectory() as temp_dir: - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test that full_refresh_all conflicts with other parameters - conflict_cases = [ - {"full_refresh_all": True, "full_refresh": ["table1"]}, - {"full_refresh_all": True, "refresh": ["table1"]}, - {"full_refresh_all": True, "full_refresh": ["table1"], "refresh": ["table2"]}, - ] - - for case in conflict_cases: - with self.assertRaises(PySparkException) as context: - run(spec_path=spec_path, **case) - self.assertEqual(context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - - def test_detailed_validation_error_messages(self): - """Test that validation error messages are detailed and helpful.""" - with tempfile.TemporaryDirectory() as temp_dir: - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test full_refresh_all with full_refresh error message - with self.assertRaises(PySparkException) as context: - run( - spec_path=spec_path, - full_refresh_all=True, - full_refresh=["table1"] - ) - self.assertEqual(context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - self.assertIn("--full-refresh-all option conflicts with --full-refresh", str(context.exception)) - - # Test full_refresh_all with refresh error message - with self.assertRaises(PySparkException) as context: - run( - spec_path=spec_path, - full_refresh_all=True, - refresh=["table1"] - ) - self.assertEqual(context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - self.assertIn("--full-refresh-all option conflicts with --refresh", str(context.exception)) - if __name__ == "__main__": try: From 4105c97baeca65298a8d23331cdb64c599572f3a Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 16:36:14 -0700 Subject: [PATCH 08/17] modify backend tests --- .../PipelineRefreshFunctionalSuite.scala | 252 +++++++++--------- 1 file changed, 120 insertions(+), 132 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index 87c2c99620bf..bc644873cc57 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.sql.connect.pipelines -import java.io.{File, FileWriter} - import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.DatasetType +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} -import org.apache.spark.util.Utils /** * Comprehensive test suite that validates pipeline refresh functionality by running actual @@ -36,64 +34,41 @@ class PipelineRefreshFunctionalSuite with TestPipelineUpdateContextMixin with EventVerificationTestHelpers { - private var testDataDir: File = _ - private var streamInputDir: File = _ - - override def beforeAll(): Unit = { - super.beforeAll() - // Create temporary directories for test data - testDataDir = Utils.createTempDir("pipeline-refresh-test") - streamInputDir = new File(testDataDir, "stream-input") - streamInputDir.mkdirs() - } + private val externalSourceTable = TableIdentifier( + catalog = Some("spark_catalog"), + database = Some("default"), + table = "source_data" + ) - override def afterAll(): Unit = { - try { - Utils.deleteRecursively(testDataDir) - } finally { - super.afterAll() - } + override def beforeEach(): Unit = { + super.beforeEach() + // Create source directory for streaming input + spark.sql(s"CREATE TABLE $externalSourceTable AS SELECT * FROM RANGE(1, 2)") } - private def uploadInputFile(filename: String, contents: String): Unit = { - val file = new File(streamInputDir, filename) - val writer = new FileWriter(file) - try { - writer.write(contents) - } finally { - writer.close() - } + override def afterEach(): Unit = { + super.afterEach() + // Clean up the source table after each test + spark.sql(s"DROP TABLE IF EXISTS $externalSourceTable") } private def createTwoSTPipeline(graphId: String): TestPipelineDefinition = { new TestPipelineDefinition(graphId) { - // Create a mv that reads from files - createTable( - name = "file_data", - datasetType = DatasetType.MATERIALIZED_VIEW, - sql = Some(s""" - SELECT id, value FROM JSON.`${streamInputDir.getAbsolutePath}/*.json` - """)) - // Create tables that depend on the mv createTable( name = "a", datasetType = DatasetType.TABLE, - sql = Some("SELECT id, value FROM STREAM file_data")) + sql = Some(s"SELECT id FROM STREAM $externalSourceTable")) createTable( name = "b", datasetType = DatasetType.TABLE, - sql = Some("SELECT id, value FROM STREAM file_data")) + sql = Some(s"SELECT id FROM STREAM $externalSourceTable")) } } test("pipeline runs selective full_refresh") { withRawBlockingStub { implicit stub => - uploadInputFile("data.json", - """ - |{"id": "x", "value": 1} - """.stripMargin) val graphId = createDataflowGraph val pipeline = createTwoSTPipeline(graphId) registerPipelineDatasets(pipeline) @@ -104,23 +79,18 @@ class PipelineRefreshFunctionalSuite // Verify initial data from file stream verifyMultipleTableContent( tableNames = Set( - "spark_catalog.default.file_data", "spark_catalog.default.a", "spark_catalog.default.b"), columnsToVerify = Map( - "spark_catalog.default.file_data" -> Seq("id", "value"), - "spark_catalog.default.a" -> Seq("id", "value"), - "spark_catalog.default.b" -> Seq("id", "value") + "spark_catalog.default.a" -> Seq("id"), + "spark_catalog.default.b" -> Seq("id") ), expectedContent = Map( - "spark_catalog.default.file_data" -> Set( - Map("id" -> "x", "value" -> 1) - ), "spark_catalog.default.a" -> Set( - Map("id" -> "x", "value" -> 1) + Map("id" -> 1) ), "spark_catalog.default.b" -> Set( - Map("id" -> "x", "value" -> 1) + Map("id" -> 1) ) ) ) @@ -130,14 +100,12 @@ class PipelineRefreshFunctionalSuite .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .foreach(_.removeAllPipelineExecutions()) - // simulate a full refresh by uploading new data - uploadInputFile("data.json", - """ - |{"id": "x", "value": 2} - """.stripMargin) + // spark overwrite the table source_data with new data + spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + + "SELECT * FROM VALUES (2), (3) AS t(id)") // Run with full refresh on specific tables - val fullRefreshTables = List("file_data", "a") + val fullRefreshTables = List("a") val startRun = proto.PipelineCommand.StartRun.newBuilder() .setDataflowGraphId(graphId) .addAllFullRefresh(fullRefreshTables.asJava) @@ -150,28 +118,24 @@ class PipelineRefreshFunctionalSuite // assert that table_a and file_data ran to completion assert(capturedEvents.exists( _.getMessage.contains(s"Flow spark_catalog.default.a has COMPLETED."))) - assert(capturedEvents.exists( - _.getMessage.contains(s"Flow spark_catalog.default.file_data has COMPLETED."))) // Verify completion event assert(capturedEvents.exists(_.getMessage.contains("Run is COMPLETED"))) verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.file_data", "spark_catalog.default.a", + tableNames = Set( + "spark_catalog.default.a", "spark_catalog.default.b"), columnsToVerify = Map( - "spark_catalog.default.file_data" -> Seq("id", "value"), - "spark_catalog.default.a" -> Seq("id", "value"), - "spark_catalog.default.b" -> Seq("id", "value") + "spark_catalog.default.a" -> Seq("id"), + "spark_catalog.default.b" -> Seq("id") ), expectedContent = Map( - "spark_catalog.default.file_data" -> Set( - Map("id" -> "x", "value" -> 2) - ), "spark_catalog.default.a" -> Set( - Map("id" -> "x", "value" -> 2) + Map("id" -> 2), // a should be fully refreshed and only contain the new value + Map("id" -> 3) ), "spark_catalog.default.b" -> Set( - Map("id" -> "x", "value" -> 1) // b should not be refreshed, so it retains the old value + Map("id" -> 1) // b is refreshed, so it retains the old value ) ) ) @@ -180,10 +144,6 @@ class PipelineRefreshFunctionalSuite test("pipeline runs selective full_refresh and selective refresh") { withRawBlockingStub { implicit stub => - uploadInputFile("data.json", - """ - |{"id": "x", "value": 1} - """.stripMargin) val graphId = createDataflowGraph val pipeline = createTwoSTPipeline(graphId) registerPipelineDatasets(pipeline) @@ -193,14 +153,20 @@ class PipelineRefreshFunctionalSuite // Verify initial data from file stream verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + tableNames = Set( + "spark_catalog.default.a", + "spark_catalog.default.b"), columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id", "value"), - "spark_catalog.default.b" -> Seq("id", "value") + "spark_catalog.default.a" -> Seq("id"), + "spark_catalog.default.b" -> Seq("id") ), expectedContent = Map( - "spark_catalog.default.a" -> Set(Map("id" -> "x", "value" -> 1)), - "spark_catalog.default.b" -> Set(Map("id" -> "x", "value" -> 1)) + "spark_catalog.default.a" -> Set( + Map("id" -> 1) + ), + "spark_catalog.default.b" -> Set( + Map("id" -> 1) + ) ) ) @@ -209,14 +175,13 @@ class PipelineRefreshFunctionalSuite .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .foreach(_.removeAllPipelineExecutions()) - uploadInputFile("data.json", - """ - |{"id": "y", "value": 2} - """.stripMargin) + // spark overwrite the table source_data with new data + spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + + "SELECT * FROM VALUES (2), (3) AS t(id)") val startRun = proto.PipelineCommand.StartRun.newBuilder() .setDataflowGraphId(graphId) - .addAllFullRefresh(List("file_data", "a").asJava) + .addFullRefresh("a") .addRefresh("b") .build() @@ -224,17 +189,23 @@ class PipelineRefreshFunctionalSuite // assert that table_b is refreshed verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + tableNames = Set( + "spark_catalog.default.a", + "spark_catalog.default.b"), columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id", "value"), - "spark_catalog.default.b" -> Seq("id", "value") + "spark_catalog.default.a" -> Seq("id"), + "spark_catalog.default.b" -> Seq("id") ), expectedContent = Map( - // a should be fully refreshed and only contain the new value - "spark_catalog.default.a" -> Set(Map("id" -> "y", "value" -> 2)), - "spark_catalog.default.b" -> - // b contain the new value in addition to the old one - Set(Map("id" -> "x", "value" -> 1), Map("id" -> "y", "value" -> 2)) + "spark_catalog.default.a" -> Set( + Map("id" -> 2), // a is fully refreshed and only contain the new value + Map("id" -> 3) + ), + "spark_catalog.default.b" -> Set( + Map("id" -> 1), // b is refreshed, so it retains the old value and adds the new one + Map("id" -> 2), + Map("id" -> 3) + ) ) ) } @@ -242,10 +213,6 @@ class PipelineRefreshFunctionalSuite test("pipeline runs refresh by default") { withRawBlockingStub { implicit stub => - uploadInputFile("data.json", - """ - |{"id": "x", "value": 1} - """.stripMargin) val graphId = createDataflowGraph val pipeline = createTwoSTPipeline(graphId) registerPipelineDatasets(pipeline) @@ -255,14 +222,20 @@ class PipelineRefreshFunctionalSuite // Verify initial data from file stream verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + tableNames = Set( + "spark_catalog.default.a", + "spark_catalog.default.b"), columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id", "value"), - "spark_catalog.default.b" -> Seq("id", "value") + "spark_catalog.default.a" -> Seq("id"), + "spark_catalog.default.b" -> Seq("id") ), expectedContent = Map( - "spark_catalog.default.a" -> Set(Map("id" -> "x", "value" -> 1)), - "spark_catalog.default.b" -> Set(Map("id" -> "x", "value" -> 1)) + "spark_catalog.default.a" -> Set( + Map("id" -> 1) + ), + "spark_catalog.default.b" -> Set( + Map("id" -> 1) + ) ) ) @@ -271,38 +244,40 @@ class PipelineRefreshFunctionalSuite .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .foreach(_.removeAllPipelineExecutions()) - uploadInputFile("data.json", - """ - |{"id": "y", "value": 2} - """.stripMargin) + // spark overwrite the table source_data with new data + spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + + "SELECT * FROM VALUES (2), (3) AS t(id)") // Create a default StartRun command that refreshes all tables startPipelineAndWaitForCompletion(graphId) // assert that both tables are refreshed verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + tableNames = Set( + "spark_catalog.default.a", + "spark_catalog.default.b"), columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id", "value"), - "spark_catalog.default.b" -> Seq("id", "value") + "spark_catalog.default.a" -> Seq("id"), + "spark_catalog.default.b" -> Seq("id") ), expectedContent = Map( - // both tables should contain the new value in addition to the old one - "spark_catalog.default.a" -> - Set(Map("id" -> "x", "value" -> 1), Map("id" -> "y", "value" -> 2)), - "spark_catalog.default.b" -> - Set(Map("id" -> "x", "value" -> 1), Map("id" -> "y", "value" -> 2)) + "spark_catalog.default.a" -> Set( + Map("id" -> 1), // a is refreshed by default, retains the old value and adds the new one + Map("id" -> 2), + Map("id" -> 3) + ), + "spark_catalog.default.b" -> Set( + Map("id" -> 1), // b is refreshed by default, retains the old value and adds the new one + Map("id" -> 2), + Map("id" -> 3) + ) ) ) } } - test("pipeline runs full_refresh_all") { + test("pipeline runs full refresh all") { withRawBlockingStub { implicit stub => - uploadInputFile("data.json", - """ - |{"id": "x", "value": 1} - """.stripMargin) val graphId = createDataflowGraph val pipeline = createTwoSTPipeline(graphId) registerPipelineDatasets(pipeline) @@ -312,25 +287,31 @@ class PipelineRefreshFunctionalSuite // Verify initial data from file stream verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + tableNames = Set( + "spark_catalog.default.a", + "spark_catalog.default.b"), columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id", "value"), - "spark_catalog.default.b" -> Seq("id", "value") + "spark_catalog.default.a" -> Seq("id"), + "spark_catalog.default.b" -> Seq("id") ), expectedContent = Map( - "spark_catalog.default.a" -> Set(Map("id" -> "x", "value" -> 1)), - "spark_catalog.default.b" -> Set(Map("id" -> "x", "value" -> 1)) + "spark_catalog.default.a" -> Set( + Map("id" -> 1) + ), + "spark_catalog.default.b" -> Set( + Map("id" -> 1) + ) ) ) + // Clear cached pipeline execution before starting new run SparkConnectService.sessionManager .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .foreach(_.removeAllPipelineExecutions()) - uploadInputFile("data.json", - """ - |{"id": "y", "value": 2} - """.stripMargin) + // spark overwrite the table source_data with new data + spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + + "SELECT * FROM VALUES (2), (3) AS t(id)") // Create a default StartRun command that refreshes all tables val startRun = proto.PipelineCommand.StartRun.newBuilder() @@ -339,17 +320,24 @@ class PipelineRefreshFunctionalSuite .build() startPipelineAndWaitForCompletion(graphId, Some(startRun)) - // assert that all tables are fully refreshed + // assert that both tables are refreshed verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.a", "spark_catalog.default.b"), + tableNames = Set( + "spark_catalog.default.a", + "spark_catalog.default.b"), columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id", "value"), - "spark_catalog.default.b" -> Seq("id", "value") + "spark_catalog.default.a" -> Seq("id"), + "spark_catalog.default.b" -> Seq("id") ), - // both tables should only contain the new value expectedContent = Map( - "spark_catalog.default.a" -> Set(Map("id" -> "y", "value" -> 2)), - "spark_catalog.default.b" -> Set(Map("id" -> "y", "value" -> 2)) + "spark_catalog.default.a" -> Set( + Map("id" -> 2), + Map("id" -> 3) + ), + "spark_catalog.default.b" -> Set( + Map("id" -> 2), + Map("id" -> 3) + ) ) ) } @@ -416,7 +404,7 @@ class PipelineRefreshFunctionalSuite } } - test("validation: multiple overlapping tables in refresh and full_refresh") { + test("validation: multiple overlapping tables in refresh and full_refresh not allowed") { withRawBlockingStub { implicit stub => val graphId = createDataflowGraph val pipeline = createTwoSTPipeline(graphId) From e580bb41b675dcf9be87722717f778738d355d75 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 16:48:10 -0700 Subject: [PATCH 09/17] test overhaul --- .../PipelineRefreshFunctionalSuite.scala | 385 +++++++----------- 1 file changed, 146 insertions(+), 239 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index bc644873cc57..a8253fce7417 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -17,14 +17,15 @@ package org.apache.spark.sql.connect.pipelines +import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ -import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.DatasetType +import org.apache.spark.connect.proto.{DatasetType, PipelineCommand, PipelineEvent} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} + /** * Comprehensive test suite that validates pipeline refresh functionality by running actual * pipelines with different refresh parameters and validating the results. @@ -42,7 +43,7 @@ class PipelineRefreshFunctionalSuite override def beforeEach(): Unit = { super.beforeEach() - // Create source directory for streaming input + // Create source table to simulate streaming updates spark.sql(s"CREATE TABLE $externalSourceTable AS SELECT * FROM RANGE(1, 2)") } @@ -52,46 +53,57 @@ class PipelineRefreshFunctionalSuite spark.sql(s"DROP TABLE IF EXISTS $externalSourceTable") } - private def createTwoSTPipeline(graphId: String): TestPipelineDefinition = { + private def createTestPipeline(graphId: String): TestPipelineDefinition = { new TestPipelineDefinition(graphId) { // Create tables that depend on the mv createTable( name = "a", datasetType = DatasetType.TABLE, - sql = Some(s"SELECT id FROM STREAM $externalSourceTable")) - + sql = Some(s"SELECT id FROM STREAM $externalSourceTable") + ) createTable( name = "b", datasetType = DatasetType.TABLE, - sql = Some(s"SELECT id FROM STREAM $externalSourceTable")) + sql = Some(s"SELECT id FROM STREAM $externalSourceTable") + ) + createTable( + name = "mv", + datasetType = DatasetType.MATERIALIZED_VIEW, + sql = Some(s"SELECT id FROM a") + ) } } - test("pipeline runs selective full_refresh") { + /** + * Helper method to run refresh tests with common setup and verification logic. + * This reduces code duplication across the refresh test cases. + */ + private def runRefreshTest( + refreshConfigBuilder: String => Option[PipelineCommand.StartRun] = _ => None, + expectedContentAfterRefresh: Map[String, Set[Map[String, Any]]], + eventValidation: Option[ArrayBuffer[PipelineEvent] => Unit] = None + ): Unit = { withRawBlockingStub { implicit stub => val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) + val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) // First run to populate tables startPipelineAndWaitForCompletion(graphId) - // Verify initial data from file stream + // Verify initial data - all tests expect the same initial state verifyMultipleTableContent( - tableNames = Set( - "spark_catalog.default.a", - "spark_catalog.default.b"), + tableNames = Set("spark_catalog.default.a", + "spark_catalog.default.b", "spark_catalog.default.mv"), columnsToVerify = Map( "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id") + "spark_catalog.default.b" -> Seq("id"), + "spark_catalog.default.mv" -> Seq("id") ), expectedContent = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 1) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 1) - ) + "spark_catalog.default.a" -> Set(Map("id" -> 1)), + "spark_catalog.default.b" -> Set(Map("id" -> 1)), + "spark_catalog.default.mv" -> Set(Map("id" -> 1)) ) ) @@ -100,256 +112,152 @@ class PipelineRefreshFunctionalSuite .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) .foreach(_.removeAllPipelineExecutions()) - // spark overwrite the table source_data with new data + // Replace source data to simulate a streaming update spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + "SELECT * FROM VALUES (2), (3) AS t(id)") - // Run with full refresh on specific tables - val fullRefreshTables = List("a") - val startRun = proto.PipelineCommand.StartRun.newBuilder() - .setDataflowGraphId(graphId) - .addAllFullRefresh(fullRefreshTables.asJava) - .build() + // Run with specified refresh configuration + val capturedEvents = refreshConfigBuilder(graphId) match { + case Some(startRun) => startPipelineAndWaitForCompletion(graphId, Some(startRun)) + case None => startPipelineAndWaitForCompletion(graphId) + } - val capturedEvents = startPipelineAndWaitForCompletion(graphId, Some(startRun)) - // assert that table_b is excluded - assert(capturedEvents.exists( - _.getMessage.contains(s"Flow \'spark_catalog.default.b\' is EXCLUDED."))) - // assert that table_a and file_data ran to completion - assert(capturedEvents.exists( - _.getMessage.contains(s"Flow spark_catalog.default.a has COMPLETED."))) - // Verify completion event - assert(capturedEvents.exists(_.getMessage.contains("Run is COMPLETED"))) + // Additional validation if provided + eventValidation.foreach(_(capturedEvents)) + // Verify final content verifyMultipleTableContent( - tableNames = Set( - "spark_catalog.default.a", - "spark_catalog.default.b"), + tableNames = Set("spark_catalog.default.a", + "spark_catalog.default.b", "spark_catalog.default.mv"), columnsToVerify = Map( "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id") + "spark_catalog.default.b" -> Seq("id"), + "spark_catalog.default.mv" -> Seq("id") ), - expectedContent = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 2), // a should be fully refreshed and only contain the new value - Map("id" -> 3) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 1) // b is refreshed, so it retains the old value - ) - ) + expectedContent = expectedContentAfterRefresh ) } } - test("pipeline runs selective full_refresh and selective refresh") { - withRawBlockingStub { implicit stub => - val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) - registerPipelineDatasets(pipeline) - - // First run to populate tables - startPipelineAndWaitForCompletion(graphId) - - // Verify initial data from file stream - verifyMultipleTableContent( - tableNames = Set( - "spark_catalog.default.a", - "spark_catalog.default.b"), - columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id") + test("pipeline runs selective full_refresh") { + runRefreshTest( + refreshConfigBuilder = { graphId => + Some(PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .addAllFullRefresh(List("a").asJava) + .build()) + }, + expectedContentAfterRefresh = Map( + "spark_catalog.default.a" -> Set( + Map("id" -> 2), // a is fully refreshed and only contains the new values + Map("id" -> 3) ), - expectedContent = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 1) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 1) - ) + "spark_catalog.default.b" -> Set( + Map("id" -> 1) // b is not refreshed, so it retains the old value + ), + "spark_catalog.default.mv" -> Set( + Map("id" -> 1) // mv is not refreshed, so it retains the old value ) - ) - - // Clear cached pipeline execution before starting new run - SparkConnectService.sessionManager - .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) - .foreach(_.removeAllPipelineExecutions()) - - // spark overwrite the table source_data with new data - spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + - "SELECT * FROM VALUES (2), (3) AS t(id)") - - val startRun = proto.PipelineCommand.StartRun.newBuilder() - .setDataflowGraphId(graphId) - .addFullRefresh("a") - .addRefresh("b") - .build() - - startPipelineAndWaitForCompletion(graphId, Some(startRun)) + ), + eventValidation = Some { capturedEvents => + // assert that table_b is excluded + assert(capturedEvents.exists( + _.getMessage.contains(s"Flow \'spark_catalog.default.b\' is EXCLUDED."))) + // assert that table_a ran to completion + assert(capturedEvents.exists( + _.getMessage.contains(s"Flow spark_catalog.default.a has COMPLETED."))) + // assert that mv is excluded + assert(capturedEvents.exists( + _.getMessage.contains(s"Flow \'spark_catalog.default.mv\' is EXCLUDED."))) + // Verify completion event + assert(capturedEvents.exists(_.getMessage.contains("Run is COMPLETED"))) + } + ) + } - // assert that table_b is refreshed - verifyMultipleTableContent( - tableNames = Set( - "spark_catalog.default.a", - "spark_catalog.default.b"), - columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id") + test("pipeline runs selective full_refresh and selective refresh") { + runRefreshTest( + refreshConfigBuilder = { graphId => + Some(PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .addAllFullRefresh(Seq("a", "mv").asJava) + .addRefresh("b") + .build()) + }, + expectedContentAfterRefresh = Map( + "spark_catalog.default.a" -> Set( + Map("id" -> 2), // a is fully refreshed and only contains the new values + Map("id" -> 3) ), - expectedContent = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 2), // a is fully refreshed and only contain the new value - Map("id" -> 3) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 1), // b is refreshed, so it retains the old value and adds the new one - Map("id" -> 2), - Map("id" -> 3) - ) + "spark_catalog.default.b" -> Set( + Map("id" -> 1), // b is refreshed, so it retains the old value and adds the new ones + Map("id" -> 2), + Map("id" -> 3) + ), + "spark_catalog.default.mv" -> Set( + Map("id" -> 2), // mv is fully refreshed and only contains the new values + Map("id" -> 3) ) ) - } + ) } test("pipeline runs refresh by default") { - withRawBlockingStub { implicit stub => - val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) - registerPipelineDatasets(pipeline) - - // First run to populate tables - startPipelineAndWaitForCompletion(graphId) - - // Verify initial data from file stream - verifyMultipleTableContent( - tableNames = Set( - "spark_catalog.default.a", - "spark_catalog.default.b"), - columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id") + runRefreshTest( + expectedContentAfterRefresh = Map( + "spark_catalog.default.a" -> Set( + Map("id" -> 1), // a is refreshed by default, retains the old value and adds the new ones + Map("id" -> 2), + Map("id" -> 3) ), - expectedContent = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 1) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 1) - ) - ) - ) - - // Clear cached pipeline execution before starting new run - SparkConnectService.sessionManager - .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) - .foreach(_.removeAllPipelineExecutions()) - - // spark overwrite the table source_data with new data - spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + - "SELECT * FROM VALUES (2), (3) AS t(id)") - - // Create a default StartRun command that refreshes all tables - startPipelineAndWaitForCompletion(graphId) - - // assert that both tables are refreshed - verifyMultipleTableContent( - tableNames = Set( - "spark_catalog.default.a", - "spark_catalog.default.b"), - columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id") + "spark_catalog.default.b" -> Set( + Map("id" -> 1), // b is refreshed by default, retains the old value and adds the new ones + Map("id" -> 2), + Map("id" -> 3) ), - expectedContent = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 1), // a is refreshed by default, retains the old value and adds the new one - Map("id" -> 2), - Map("id" -> 3) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 1), // b is refreshed by default, retains the old value and adds the new one - Map("id" -> 2), - Map("id" -> 3) - ) + "spark_catalog.default.mv" -> Set( + Map("id" -> 1), + Map("id" -> 2), // mv is refreshed from table a, retains all values + Map("id" -> 3) ) ) - } + ) } test("pipeline runs full refresh all") { - withRawBlockingStub { implicit stub => - val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) - registerPipelineDatasets(pipeline) - - // First run to populate tables - startPipelineAndWaitForCompletion(graphId) - - // Verify initial data from file stream - verifyMultipleTableContent( - tableNames = Set( - "spark_catalog.default.a", - "spark_catalog.default.b"), - columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id") + runRefreshTest( + refreshConfigBuilder = { graphId => + Some(PipelineCommand.StartRun.newBuilder() + .setDataflowGraphId(graphId) + .setFullRefreshAll(true) + .build()) + }, + // full refresh all causes all tables to lose the initial value + // and only contain the new values after the source data is updated + expectedContentAfterRefresh = Map( + "spark_catalog.default.a" -> Set( + Map("id" -> 2), + Map("id" -> 3) ), - expectedContent = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 1) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 1) - ) - ) - ) - - // Clear cached pipeline execution before starting new run - SparkConnectService.sessionManager - .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) - .foreach(_.removeAllPipelineExecutions()) - - // spark overwrite the table source_data with new data - spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + - "SELECT * FROM VALUES (2), (3) AS t(id)") - - // Create a default StartRun command that refreshes all tables - val startRun = proto.PipelineCommand.StartRun.newBuilder() - .setDataflowGraphId(graphId) - .setFullRefreshAll(true) - .build() - startPipelineAndWaitForCompletion(graphId, Some(startRun)) - - // assert that both tables are refreshed - verifyMultipleTableContent( - tableNames = Set( - "spark_catalog.default.a", - "spark_catalog.default.b"), - columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id") + "spark_catalog.default.b" -> Set( + Map("id" -> 2), + Map("id" -> 3) ), - expectedContent = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 2), - Map("id" -> 3) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 2), - Map("id" -> 3) - ) + "spark_catalog.default.mv" -> Set( + Map("id" -> 2), + Map("id" -> 3) ) ) - } + ) } test("validation: cannot specify subset refresh when full_refresh_all is true") { withRawBlockingStub { implicit stub => val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) + val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = proto.PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun.newBuilder() .setDataflowGraphId(graphId) .setFullRefreshAll(true) .addRefresh("a") @@ -366,10 +274,10 @@ class PipelineRefreshFunctionalSuite test("validation: cannot specify subset full_refresh when full_refresh_all is true") { withRawBlockingStub { implicit stub => val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) + val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = proto.PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun.newBuilder() .setDataflowGraphId(graphId) .setFullRefreshAll(true) .addFullRefresh("a") @@ -386,10 +294,10 @@ class PipelineRefreshFunctionalSuite test("validation: refresh and full_refresh cannot overlap") { withRawBlockingStub { implicit stub => val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) + val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = proto.PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun.newBuilder() .setDataflowGraphId(graphId) .addRefresh("a") .addFullRefresh("a") @@ -407,15 +315,14 @@ class PipelineRefreshFunctionalSuite test("validation: multiple overlapping tables in refresh and full_refresh not allowed") { withRawBlockingStub { implicit stub => val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) + val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = proto.PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun.newBuilder() .setDataflowGraphId(graphId) .addRefresh("a") .addRefresh("b") .addFullRefresh("a") - .addFullRefresh("file_data") .build() val exception = intercept[IllegalArgumentException] { @@ -430,10 +337,10 @@ class PipelineRefreshFunctionalSuite test("validation: fully qualified table names in validation") { withRawBlockingStub { implicit stub => val graphId = createDataflowGraph - val pipeline = createTwoSTPipeline(graphId) + val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = proto.PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun.newBuilder() .setDataflowGraphId(graphId) .addRefresh("spark_catalog.default.a") .addFullRefresh("a") // This should be treated as the same table From 4d26d772757cc7b88334f0cf5bdf86aa531a0219 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 20:42:24 -0700 Subject: [PATCH 10/17] fmt --- python/pyspark/pipelines/tests/test_cli.py | 68 ------ .../connect/pipelines/PipelinesHandler.scala | 21 +- .../PipelineRefreshFunctionalSuite.scala | 222 +++++++++--------- .../SparkDeclarativePipelinesServerTest.scala | 11 +- 4 files changed, 124 insertions(+), 198 deletions(-) diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index 7471737118c3..c8554c72462b 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -489,30 +489,6 @@ def test_parse_table_list_with_spaces(self): result = parse_table_list("table1, table2 , table3") self.assertEqual(result, ["table1", "table2", "table3"]) - def test_parse_table_list_empty_string(self): - """Test parsing empty string.""" - from pyspark.pipelines.cli import parse_table_list - result = parse_table_list("") - self.assertEqual(result, []) - - def test_parse_table_list_with_qualified_names(self): - """Test parsing qualified table names.""" - from pyspark.pipelines.cli import parse_table_list - result = parse_table_list("schema1.table1,schema2.table2") - self.assertEqual(result, ["schema1.table1", "schema2.table2"]) - - def test_flatten_table_lists_none(self): - """Test flattening None input.""" - from pyspark.pipelines.cli import flatten_table_lists - result = flatten_table_lists(None) - self.assertEqual(result, None) - - def test_flatten_table_lists_empty(self): - """Test flattening empty list.""" - from pyspark.pipelines.cli import flatten_table_lists - result = flatten_table_lists([]) - self.assertEqual(result, None) - def test_flatten_table_lists_single_list(self): """Test flattening single list.""" from pyspark.pipelines.cli import flatten_table_lists @@ -525,30 +501,6 @@ def test_flatten_table_lists_multiple_lists(self): result = flatten_table_lists([["table1", "table2"], ["table3"], ["table4", "table5"]]) self.assertEqual(result, ["table1", "table2", "table3", "table4", "table5"]) - def test_valid_refresh_combinations(self): - """Test valid combinations of refresh parameters.""" - with tempfile.TemporaryDirectory() as temp_dir: - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test individual options don't raise validation errors - test_cases = [ - {"full_refresh": ["table1"]}, - {"refresh": ["table1"]}, - {"full_refresh_all": True}, - {"full_refresh": ["table1"], "refresh": ["table2"]}, - {"full_refresh": ["table1", "table2"], "refresh": ["table3", "table4"]}, - ] - - for case in test_cases: - try: - run(spec_path=spec_path, **case) - self.fail(f"Expected run to fail due to missing pipeline spec content: {case}") - except PySparkException as e: - # Should NOT be our validation error - self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - def test_cli_argument_parsing_patterns(self): """Test CLI argument parsing patterns for refresh options.""" import argparse @@ -574,26 +526,6 @@ def test_cli_argument_parsing_patterns(self): for key, value in expected.items(): self.assertEqual(getattr(parsed, key), value) - def test_refresh_parameter_validation_edge_cases(self): - """Test edge cases for refresh parameter validation.""" - with tempfile.TemporaryDirectory() as temp_dir: - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test that providing None values works correctly - try: - run( - spec_path=spec_path, - full_refresh=None, - refresh=None, - full_refresh_all=False - ) - self.fail("Expected run to fail due to missing pipeline spec content") - except PySparkException as e: - # Should NOT be our validation error - self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - if __name__ == "__main__": try: diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 7c6dffde20be..a2d7b6fb4c0e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -234,12 +234,13 @@ private[connect] object PipelinesHandler extends Logging { // Convert table names to fully qualified TableIdentifier objects def parseTableNames(tableNames: Seq[String]): Set[TableIdentifier] = { tableNames.map { name => - GraphIdentifierManager.parseAndQualifyTableIdentifier( - rawTableIdentifier = - GraphIdentifierManager.parseTableIdentifier(name, sessionHolder.session), - currentCatalog = Some(graphElementRegistry.defaultCatalog), - currentDatabase = Some(graphElementRegistry.defaultDatabase) - ).identifier + GraphIdentifierManager + .parseAndQualifyTableIdentifier( + rawTableIdentifier = + GraphIdentifierManager.parseTableIdentifier(name, sessionHolder.session), + currentCatalog = Some(graphElementRegistry.defaultCatalog), + currentDatabase = Some(graphElementRegistry.defaultDatabase)) + .identifier }.toSet } @@ -261,8 +262,7 @@ private[connect] object PipelinesHandler extends Logging { if (intersection.nonEmpty) { throw new IllegalArgumentException( "Datasets specified for refresh and full refresh cannot overlap: " + - s"${intersection.mkString(", ")}" - ) + s"${intersection.mkString(", ")}") } } @@ -281,7 +281,7 @@ private[connect] object PipelinesHandler extends Logging { NoTables } else { AllTables - } + } // We will use this variable to store the run failure event if it occurs. This will be set // by the event callback. @@ -342,8 +342,7 @@ private[connect] object PipelinesHandler extends Logging { graphElementRegistry.toDataflowGraph, eventCallback, refreshTablesFilter, - fullRefreshTablesFilter - ) + fullRefreshTablesFilter) sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext) pipelineUpdateContext.pipelineExecution.runPipeline() diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index a8253fce7417..477457778f06 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -25,21 +25,19 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} - /** * Comprehensive test suite that validates pipeline refresh functionality by running actual * pipelines with different refresh parameters and validating the results. */ class PipelineRefreshFunctionalSuite - extends SparkDeclarativePipelinesServerTest + extends SparkDeclarativePipelinesServerTest with TestPipelineUpdateContextMixin with EventVerificationTestHelpers { private val externalSourceTable = TableIdentifier( catalog = Some("spark_catalog"), database = Some("default"), - table = "source_data" - ) + table = "source_data") override def beforeEach(): Unit = { super.beforeEach() @@ -59,30 +57,26 @@ class PipelineRefreshFunctionalSuite createTable( name = "a", datasetType = DatasetType.TABLE, - sql = Some(s"SELECT id FROM STREAM $externalSourceTable") - ) + sql = Some(s"SELECT id FROM STREAM $externalSourceTable")) createTable( name = "b", datasetType = DatasetType.TABLE, - sql = Some(s"SELECT id FROM STREAM $externalSourceTable") - ) + sql = Some(s"SELECT id FROM STREAM $externalSourceTable")) createTable( name = "mv", datasetType = DatasetType.MATERIALIZED_VIEW, - sql = Some(s"SELECT id FROM a") - ) + sql = Some(s"SELECT id FROM a")) } } /** - * Helper method to run refresh tests with common setup and verification logic. - * This reduces code duplication across the refresh test cases. + * Helper method to run refresh tests with common setup and verification logic. This reduces + * code duplication across the refresh test cases. */ private def runRefreshTest( - refreshConfigBuilder: String => Option[PipelineCommand.StartRun] = _ => None, - expectedContentAfterRefresh: Map[String, Set[Map[String, Any]]], - eventValidation: Option[ArrayBuffer[PipelineEvent] => Unit] = None - ): Unit = { + refreshConfigBuilder: String => Option[PipelineCommand.StartRun] = _ => None, + expectedContentAfterRefresh: Map[String, Set[Map[String, Any]]], + eventValidation: Option[ArrayBuffer[PipelineEvent] => Unit] = None): Unit = { withRawBlockingStub { implicit stub => val graphId = createDataflowGraph val pipeline = createTestPipeline(graphId) @@ -93,19 +87,16 @@ class PipelineRefreshFunctionalSuite // Verify initial data - all tests expect the same initial state verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.a", - "spark_catalog.default.b", "spark_catalog.default.mv"), + tableNames = + Set("spark_catalog.default.a", "spark_catalog.default.b", "spark_catalog.default.mv"), columnsToVerify = Map( "spark_catalog.default.a" -> Seq("id"), "spark_catalog.default.b" -> Seq("id"), - "spark_catalog.default.mv" -> Seq("id") - ), + "spark_catalog.default.mv" -> Seq("id")), expectedContent = Map( "spark_catalog.default.a" -> Set(Map("id" -> 1)), "spark_catalog.default.b" -> Set(Map("id" -> 1)), - "spark_catalog.default.mv" -> Set(Map("id" -> 1)) - ) - ) + "spark_catalog.default.mv" -> Set(Map("id" -> 1)))) // Clear cached pipeline execution before starting new run SparkConnectService.sessionManager @@ -113,8 +104,9 @@ class PipelineRefreshFunctionalSuite .foreach(_.removeAllPipelineExecutions()) // Replace source data to simulate a streaming update - spark.sql("INSERT OVERWRITE TABLE spark_catalog.default.source_data " + - "SELECT * FROM VALUES (2), (3) AS t(id)") + spark.sql( + "INSERT OVERWRITE TABLE spark_catalog.default.source_data " + + "SELECT * FROM VALUES (2), (3) AS t(id)") // Run with specified refresh configuration val capturedEvents = refreshConfigBuilder(graphId) match { @@ -127,128 +119,115 @@ class PipelineRefreshFunctionalSuite // Verify final content verifyMultipleTableContent( - tableNames = Set("spark_catalog.default.a", - "spark_catalog.default.b", "spark_catalog.default.mv"), + tableNames = + Set("spark_catalog.default.a", "spark_catalog.default.b", "spark_catalog.default.mv"), columnsToVerify = Map( "spark_catalog.default.a" -> Seq("id"), "spark_catalog.default.b" -> Seq("id"), - "spark_catalog.default.mv" -> Seq("id") - ), - expectedContent = expectedContentAfterRefresh - ) + "spark_catalog.default.mv" -> Seq("id")), + expectedContent = expectedContentAfterRefresh) } } test("pipeline runs selective full_refresh") { runRefreshTest( refreshConfigBuilder = { graphId => - Some(PipelineCommand.StartRun.newBuilder() - .setDataflowGraphId(graphId) - .addAllFullRefresh(List("a").asJava) - .build()) + Some( + PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .addAllFullRefresh(List("a").asJava) + .build()) }, expectedContentAfterRefresh = Map( "spark_catalog.default.a" -> Set( Map("id" -> 2), // a is fully refreshed and only contains the new values - Map("id" -> 3) - ), + Map("id" -> 3)), "spark_catalog.default.b" -> Set( Map("id" -> 1) // b is not refreshed, so it retains the old value ), "spark_catalog.default.mv" -> Set( Map("id" -> 1) // mv is not refreshed, so it retains the old value - ) - ), + )), eventValidation = Some { capturedEvents => // assert that table_b is excluded - assert(capturedEvents.exists( - _.getMessage.contains(s"Flow \'spark_catalog.default.b\' is EXCLUDED."))) + assert( + capturedEvents.exists( + _.getMessage.contains(s"Flow \'spark_catalog.default.b\' is EXCLUDED."))) // assert that table_a ran to completion - assert(capturedEvents.exists( - _.getMessage.contains(s"Flow spark_catalog.default.a has COMPLETED."))) + assert( + capturedEvents.exists( + _.getMessage.contains(s"Flow spark_catalog.default.a has COMPLETED."))) // assert that mv is excluded - assert(capturedEvents.exists( - _.getMessage.contains(s"Flow \'spark_catalog.default.mv\' is EXCLUDED."))) + assert( + capturedEvents.exists( + _.getMessage.contains(s"Flow \'spark_catalog.default.mv\' is EXCLUDED."))) // Verify completion event assert(capturedEvents.exists(_.getMessage.contains("Run is COMPLETED"))) - } - ) + }) } test("pipeline runs selective full_refresh and selective refresh") { runRefreshTest( refreshConfigBuilder = { graphId => - Some(PipelineCommand.StartRun.newBuilder() - .setDataflowGraphId(graphId) - .addAllFullRefresh(Seq("a", "mv").asJava) - .addRefresh("b") - .build()) + Some( + PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .addAllFullRefresh(Seq("a", "mv").asJava) + .addRefresh("b") + .build()) }, expectedContentAfterRefresh = Map( "spark_catalog.default.a" -> Set( Map("id" -> 2), // a is fully refreshed and only contains the new values - Map("id" -> 3) - ), + Map("id" -> 3)), "spark_catalog.default.b" -> Set( Map("id" -> 1), // b is refreshed, so it retains the old value and adds the new ones Map("id" -> 2), - Map("id" -> 3) - ), + Map("id" -> 3)), "spark_catalog.default.mv" -> Set( Map("id" -> 2), // mv is fully refreshed and only contains the new values - Map("id" -> 3) - ) - ) - ) + Map("id" -> 3)))) } test("pipeline runs refresh by default") { - runRefreshTest( - expectedContentAfterRefresh = Map( + runRefreshTest(expectedContentAfterRefresh = + Map( "spark_catalog.default.a" -> Set( - Map("id" -> 1), // a is refreshed by default, retains the old value and adds the new ones + Map( + "id" -> 1 + ), // a is refreshed by default, retains the old value and adds the new ones Map("id" -> 2), - Map("id" -> 3) - ), + Map("id" -> 3)), "spark_catalog.default.b" -> Set( - Map("id" -> 1), // b is refreshed by default, retains the old value and adds the new ones + Map( + "id" -> 1 + ), // b is refreshed by default, retains the old value and adds the new ones Map("id" -> 2), - Map("id" -> 3) - ), + Map("id" -> 3)), "spark_catalog.default.mv" -> Set( Map("id" -> 1), Map("id" -> 2), // mv is refreshed from table a, retains all values - Map("id" -> 3) - ) - ) - ) + Map("id" -> 3)))) } test("pipeline runs full refresh all") { runRefreshTest( refreshConfigBuilder = { graphId => - Some(PipelineCommand.StartRun.newBuilder() - .setDataflowGraphId(graphId) - .setFullRefreshAll(true) - .build()) + Some( + PipelineCommand.StartRun + .newBuilder() + .setDataflowGraphId(graphId) + .setFullRefreshAll(true) + .build()) }, // full refresh all causes all tables to lose the initial value // and only contain the new values after the source data is updated expectedContentAfterRefresh = Map( - "spark_catalog.default.a" -> Set( - Map("id" -> 2), - Map("id" -> 3) - ), - "spark_catalog.default.b" -> Set( - Map("id" -> 2), - Map("id" -> 3) - ), - "spark_catalog.default.mv" -> Set( - Map("id" -> 2), - Map("id" -> 3) - ) - ) - ) + "spark_catalog.default.a" -> Set(Map("id" -> 2), Map("id" -> 3)), + "spark_catalog.default.b" -> Set(Map("id" -> 2), Map("id" -> 3)), + "spark_catalog.default.mv" -> Set(Map("id" -> 2), Map("id" -> 3)))) } test("validation: cannot specify subset refresh when full_refresh_all is true") { @@ -257,7 +236,8 @@ class PipelineRefreshFunctionalSuite val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun + .newBuilder() .setDataflowGraphId(graphId) .setFullRefreshAll(true) .addRefresh("a") @@ -266,8 +246,9 @@ class PipelineRefreshFunctionalSuite val exception = intercept[IllegalArgumentException] { startPipelineAndWaitForCompletion(graphId, Some(startRun)) } - assert(exception.getMessage.contains( - "Cannot specify a subset to full refresh when full refresh all is set to true")) + assert( + exception.getMessage.contains( + "Cannot specify a subset to full refresh when full refresh all is set to true")) } } @@ -277,7 +258,8 @@ class PipelineRefreshFunctionalSuite val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun + .newBuilder() .setDataflowGraphId(graphId) .setFullRefreshAll(true) .addFullRefresh("a") @@ -286,8 +268,9 @@ class PipelineRefreshFunctionalSuite val exception = intercept[IllegalArgumentException] { startPipelineAndWaitForCompletion(graphId, Some(startRun)) } - assert(exception.getMessage.contains( - "Cannot specify a subset to refresh when full refresh all is set to true")) + assert( + exception.getMessage.contains( + "Cannot specify a subset to refresh when full refresh all is set to true")) } } @@ -297,7 +280,8 @@ class PipelineRefreshFunctionalSuite val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun + .newBuilder() .setDataflowGraphId(graphId) .addRefresh("a") .addFullRefresh("a") @@ -306,8 +290,9 @@ class PipelineRefreshFunctionalSuite val exception = intercept[IllegalArgumentException] { startPipelineAndWaitForCompletion(graphId, Some(startRun)) } - assert(exception.getMessage.contains( - "Datasets specified for refresh and full refresh cannot overlap")) + assert( + exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) assert(exception.getMessage.contains("a")) } } @@ -318,7 +303,8 @@ class PipelineRefreshFunctionalSuite val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun + .newBuilder() .setDataflowGraphId(graphId) .addRefresh("a") .addRefresh("b") @@ -328,8 +314,9 @@ class PipelineRefreshFunctionalSuite val exception = intercept[IllegalArgumentException] { startPipelineAndWaitForCompletion(graphId, Some(startRun)) } - assert(exception.getMessage.contains( - "Datasets specified for refresh and full refresh cannot overlap")) + assert( + exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) assert(exception.getMessage.contains("a")) } } @@ -340,33 +327,40 @@ class PipelineRefreshFunctionalSuite val pipeline = createTestPipeline(graphId) registerPipelineDatasets(pipeline) - val startRun = PipelineCommand.StartRun.newBuilder() + val startRun = PipelineCommand.StartRun + .newBuilder() .setDataflowGraphId(graphId) .addRefresh("spark_catalog.default.a") - .addFullRefresh("a") // This should be treated as the same table + .addFullRefresh("a") // This should be treated as the same table .build() val exception = intercept[IllegalArgumentException] { startPipelineAndWaitForCompletion(graphId, Some(startRun)) } - assert(exception.getMessage.contains( - "Datasets specified for refresh and full refresh cannot overlap")) + assert( + exception.getMessage.contains( + "Datasets specified for refresh and full refresh cannot overlap")) } } private def verifyMultipleTableContent( - tableNames: Set[String], - columnsToVerify: Map[String, Seq[String]], - expectedContent: Map[String, Set[Map[String, Any]]]): Unit = { + tableNames: Set[String], + columnsToVerify: Map[String, Seq[String]], + expectedContent: Map[String, Set[Map[String, Any]]]): Unit = { tableNames.foreach { tableName => spark.catalog.refreshTable(tableName) // clear cache for the table val df = spark.table(tableName) - assert(df.columns.toSet == columnsToVerify(tableName).toSet, + assert( + df.columns.toSet == columnsToVerify(tableName).toSet, s"Columns in $tableName do not match expected: ${df.columns.mkString(", ")}") - val actualContent = df.collect().map(row => { - columnsToVerify(tableName).map(col => col -> row.getAs[Any](col)).toMap - }).toSet - assert(actualContent == expectedContent(tableName), + val actualContent = df + .collect() + .map(row => { + columnsToVerify(tableName).map(col => col -> row.getAs[Any](col)).toMap + }) + .toSet + assert( + actualContent == expectedContent(tableName), s"Content of $tableName does not match expected: $actualContent") } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index f18681834622..eb234b12929d 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.PipelineTest - class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { override def afterEach(): Unit = { @@ -130,12 +129,14 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { new SparkConnectPlanner(SparkConnectTestUtils.createDummySessionHolder(spark)) def startPipelineAndWaitForCompletion( - graphId: String, - customStartRunCommand: Option[PipelineCommand.StartRun] = None): ArrayBuffer[PipelineEvent] = { + graphId: String, + customStartRunCommand: Option[PipelineCommand.StartRun] = None) + : ArrayBuffer[PipelineEvent] = { withClient { client => val capturedEvents = new ArrayBuffer[PipelineEvent]() - val startRunRequest = buildStartRunPlan(customStartRunCommand.getOrElse( - PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())) + val startRunRequest = buildStartRunPlan( + customStartRunCommand.getOrElse( + PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())) val responseIterator = client.execute(startRunRequest) // The response iterator will be closed when the pipeline is completed. while (responseIterator.hasNext) { From e94c85f1fe73cc8b1809b5b2b479249928aa1104 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 22:02:45 -0700 Subject: [PATCH 11/17] fmt --- python/pyspark/pipelines/tests/test_cli.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index c8554c72462b..d8f8675842dd 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -44,7 +44,7 @@ @unittest.skipIf( not should_test_connect or not have_yaml, - (connect_requirement_message or yaml_requirement_message) or "Connect or YAML not available", + connect_requirement_message or yaml_requirement_message, ) class CLIUtilityTests(unittest.TestCase): def test_load_pipeline_spec(self): @@ -363,7 +363,7 @@ def test_python_import_current_directory(self): @unittest.skipIf( not should_test_connect or not have_yaml, - (connect_requirement_message or yaml_requirement_message) or "Connect or YAML not available", + connect_requirement_message or yaml_requirement_message, ) class CLIValidationTests(unittest.TestCase): def test_full_refresh_all_conflicts_with_full_refresh(self): From 695054f1439a89785a7f62424b177a598cd91a94 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 22:34:19 -0700 Subject: [PATCH 12/17] fmt --- python/pyspark/pipelines/cli.py | 36 ++++++------- .../pipelines/spark_connect_pipeline.py | 6 +-- python/pyspark/pipelines/tests/test_cli.py | 53 ++++++++++--------- 3 files changed, 50 insertions(+), 45 deletions(-) diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 8efce47e4cdb..691e8b2b8695 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -218,13 +218,13 @@ def change_dir(path: Path) -> Generator[None, None, None]: def run( - spec_path: Path, + spec_path: Path, full_refresh: Optional[Sequence[str]] = None, full_refresh_all: bool = False, - refresh: Optional[Sequence[str]] = None + refresh: Optional[Sequence[str]] = None, ) -> None: """Run the pipeline defined with the given spec. - + :param spec_path: Path to the pipeline specification file. :param full_refresh: List of datasets to reset and recompute. :param full_refresh_all: Perform a full graph reset and recompute. @@ -234,15 +234,13 @@ def run( if full_refresh_all: if full_refresh: raise PySparkException( - errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", - messageParameters={} + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", messageParameters={} ) if refresh: raise PySparkException( - errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", - messageParameters={} + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", messageParameters={} ) - + log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...") spec = load_pipeline_spec(spec_path) @@ -267,11 +265,11 @@ def run( log_with_curr_timestamp("Starting run...") result_iter = start_run( - spark, + spark, dataflow_graph_id, full_refresh=full_refresh, full_refresh_all=full_refresh_all, - refresh=refresh + refresh=refresh, ) try: handle_pipeline_events(result_iter) @@ -302,17 +300,19 @@ def flatten_table_lists(table_lists: Optional[List[List[str]]]) -> Optional[List run_parser = subparsers.add_parser("run", help="Run a pipeline.") run_parser.add_argument("--spec", help="Path to the pipeline spec.") run_parser.add_argument( - "--full-refresh", - type=parse_table_list, + "--full-refresh", + type=parse_table_list, action="append", - help="List of datasets to reset and recompute (comma-separated)." + help="List of datasets to reset and recompute (comma-separated).", ) - run_parser.add_argument("--full-refresh-all", action="store_true", help="Perform a full graph reset and recompute.") run_parser.add_argument( - "--refresh", + "--full-refresh-all", action="store_true", help="Perform a full graph reset and recompute." + ) + run_parser.add_argument( + "--refresh", type=parse_table_list, - action="append", - help="List of datasets to update (comma-separated)." + action="append", + help="List of datasets to update (comma-separated).", ) # "init" subcommand @@ -345,7 +345,7 @@ def flatten_table_lists(table_lists: Optional[List[List[str]]]) -> Optional[List spec_path=spec_path, full_refresh=flatten_table_lists(args.full_refresh), full_refresh_all=args.full_refresh_all, - refresh=flatten_table_lists(args.refresh) + refresh=flatten_table_lists(args.refresh), ) elif args.command == "init": init(args.name) diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py b/python/pyspark/pipelines/spark_connect_pipeline.py index f4f46c3ee3ab..59f0a4df586d 100644 --- a/python/pyspark/pipelines/spark_connect_pipeline.py +++ b/python/pyspark/pipelines/spark_connect_pipeline.py @@ -66,11 +66,11 @@ def handle_pipeline_events(iter: Iterator[Dict[str, Any]]) -> None: def start_run( - spark: SparkSession, + spark: SparkSession, dataflow_graph_id: str, full_refresh: Optional[Sequence[str]] = None, full_refresh_all: bool = False, - refresh: Optional[Sequence[str]] = None + refresh: Optional[Sequence[str]] = None, ) -> Iterator[Dict[str, Any]]: """Start a run of the dataflow graph in the Spark Connect server. @@ -83,7 +83,7 @@ def start_run( dataflow_graph_id=dataflow_graph_id, full_refresh=full_refresh or [], full_refresh_all=full_refresh_all, - refresh=refresh or [] + refresh=refresh or [], ) command = pb2.Command() command.pipeline_command.start_run.CopyFrom(inner_command) diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index d8f8675842dd..401b7127abf7 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -372,37 +372,36 @@ def test_full_refresh_all_conflicts_with_full_refresh(self): spec_path = Path(temp_dir) / "pipeline.yaml" with spec_path.open("w") as f: f.write('{"name": "test_pipeline"}') - + # Test that providing both --full-refresh-all and --full-refresh raises an exception with self.assertRaises(PySparkException) as context: run( spec_path=spec_path, full_refresh=["table1", "table2"], full_refresh_all=True, - refresh=None + refresh=None, ) - + self.assertEqual( context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) - def test_full_refresh_all_conflicts_with_refresh(self): with tempfile.TemporaryDirectory() as temp_dir: # Create a minimal pipeline spec spec_path = Path(temp_dir) / "pipeline.yaml" with spec_path.open("w") as f: f.write('{"name": "test_pipeline"}') - + # Test that providing both --full-refresh-all and --refresh raises an exception with self.assertRaises(PySparkException) as context: run( spec_path=spec_path, full_refresh=None, full_refresh_all=True, - refresh=["table1", "table2"] + refresh=["table1", "table2"], ) - + self.assertEqual( context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) @@ -413,7 +412,7 @@ def test_full_refresh_all_conflicts_with_both(self): spec_path = Path(temp_dir) / "pipeline.yaml" with spec_path.open("w") as f: f.write('{"name": "test_pipeline"}') - + # Test that providing --full-refresh-all with both other options raises an exception # (it should catch the first conflict - full_refresh) with self.assertRaises(PySparkException) as context: @@ -421,9 +420,9 @@ def test_full_refresh_all_conflicts_with_both(self): spec_path=spec_path, full_refresh=["table1"], full_refresh_all=True, - refresh=["table2"] + refresh=["table2"], ) - + self.assertEqual( context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) @@ -434,16 +433,11 @@ def test_no_conflict_when_full_refresh_all_alone(self): spec_path = Path(temp_dir) / "pipeline.yaml" with spec_path.open("w") as f: f.write('{"name": "test_pipeline"}') - + # Test that providing only --full-refresh-all doesn't raise an exception # (it should fail later when trying to actually run, but not in our validation) try: - run( - spec_path=spec_path, - full_refresh=None, - full_refresh_all=True, - refresh=None - ) + run(spec_path=spec_path, full_refresh=None, full_refresh_all=True, refresh=None) # If we get here, the validation passed (it will fail later in pipeline execution) self.fail("Expected the run to fail later, but validation should have passed") except PySparkException as e: @@ -456,14 +450,14 @@ def test_no_conflict_when_refresh_options_without_full_refresh_all(self): spec_path = Path(temp_dir) / "pipeline.yaml" with spec_path.open("w") as f: f.write('{"name": "test_pipeline"}') - + # Test that providing --refresh and --full-refresh without --full-refresh-all doesn't raise our validation error try: run( spec_path=spec_path, full_refresh=["table1"], full_refresh_all=False, - refresh=["table2"] + refresh=["table2"], ) # If we get here, the validation passed (it will fail later in pipeline execution) self.fail("Expected the run to fail later, but validation should have passed") @@ -474,30 +468,35 @@ def test_no_conflict_when_refresh_options_without_full_refresh_all(self): def test_parse_table_list_single_table(self): """Test parsing a single table name.""" from pyspark.pipelines.cli import parse_table_list + result = parse_table_list("table1") self.assertEqual(result, ["table1"]) def test_parse_table_list_multiple_tables(self): """Test parsing multiple table names.""" from pyspark.pipelines.cli import parse_table_list + result = parse_table_list("table1,table2,table3") self.assertEqual(result, ["table1", "table2", "table3"]) def test_parse_table_list_with_spaces(self): """Test parsing table names with spaces.""" from pyspark.pipelines.cli import parse_table_list + result = parse_table_list("table1, table2 , table3") self.assertEqual(result, ["table1", "table2", "table3"]) def test_flatten_table_lists_single_list(self): """Test flattening single list.""" from pyspark.pipelines.cli import flatten_table_lists + result = flatten_table_lists([["table1", "table2"]]) self.assertEqual(result, ["table1", "table2"]) def test_flatten_table_lists_multiple_lists(self): """Test flattening multiple lists.""" from pyspark.pipelines.cli import flatten_table_lists + result = flatten_table_lists([["table1", "table2"], ["table3"], ["table4", "table5"]]) self.assertEqual(result, ["table1", "table2", "table3", "table4", "table5"]) @@ -505,22 +504,28 @@ def test_cli_argument_parsing_patterns(self): """Test CLI argument parsing patterns for refresh options.""" import argparse from pyspark.pipelines.cli import parse_table_list - + # Simulate the argument parser parser = argparse.ArgumentParser() parser.add_argument("--full-refresh", type=parse_table_list, action="append") parser.add_argument("--full-refresh-all", action="store_true") parser.add_argument("--refresh", type=parse_table_list, action="append") - + # Test parsing various argument combinations test_cases = [ (["--full-refresh", "table1,table2"], {"full_refresh": [["table1", "table2"]]}), (["--refresh", "table1", "--refresh", "table2"], {"refresh": [["table1"], ["table2"]]}), (["--full-refresh-all"], {"full_refresh_all": True}), - (["--full-refresh", "table1", "--refresh", "table2"], {"full_refresh": [["table1"]], "refresh": [["table2"]]}), - (["--full-refresh", "schema.table1,schema.table2"], {"full_refresh": [["schema.table1", "schema.table2"]]}), + ( + ["--full-refresh", "table1", "--refresh", "table2"], + {"full_refresh": [["table1"]], "refresh": [["table2"]]}, + ), + ( + ["--full-refresh", "schema.table1,schema.table2"], + {"full_refresh": [["schema.table1", "schema.table2"]]}, + ), ] - + for args, expected in test_cases: parsed = parser.parse_args(args) for key, value in expected.items(): From 1abe5c3a76255305f4d761a36202f883cf40149d Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Tue, 15 Jul 2025 22:40:13 -0700 Subject: [PATCH 13/17] fmt --- python/pyspark/pipelines/tests/test_cli.py | 39 ---------------------- 1 file changed, 39 deletions(-) diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index 401b7127abf7..e801a0deaed8 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -19,7 +19,6 @@ import tempfile import textwrap from pathlib import Path -from typing import cast from pyspark.errors import PySparkException from pyspark.testing.connectutils import ( @@ -427,44 +426,6 @@ def test_full_refresh_all_conflicts_with_both(self): context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) - def test_no_conflict_when_full_refresh_all_alone(self): - with tempfile.TemporaryDirectory() as temp_dir: - # Create a minimal pipeline spec - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test that providing only --full-refresh-all doesn't raise an exception - # (it should fail later when trying to actually run, but not in our validation) - try: - run(spec_path=spec_path, full_refresh=None, full_refresh_all=True, refresh=None) - # If we get here, the validation passed (it will fail later in pipeline execution) - self.fail("Expected the run to fail later, but validation should have passed") - except PySparkException as e: - # Make sure it's NOT our validation error - self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - - def test_no_conflict_when_refresh_options_without_full_refresh_all(self): - with tempfile.TemporaryDirectory() as temp_dir: - # Create a minimal pipeline spec - spec_path = Path(temp_dir) / "pipeline.yaml" - with spec_path.open("w") as f: - f.write('{"name": "test_pipeline"}') - - # Test that providing --refresh and --full-refresh without --full-refresh-all doesn't raise our validation error - try: - run( - spec_path=spec_path, - full_refresh=["table1"], - full_refresh_all=False, - refresh=["table2"], - ) - # If we get here, the validation passed (it will fail later in pipeline execution) - self.fail("Expected the run to fail later, but validation should have passed") - except PySparkException as e: - # Make sure it's NOT our validation error - self.assertNotEqual(e.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS") - def test_parse_table_list_single_table(self): """Test parsing a single table name.""" from pyspark.pipelines.cli import parse_table_list From 1693ac546225c8a6be1d96eb5e64fcf03f77a344 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Wed, 16 Jul 2025 16:48:40 -0700 Subject: [PATCH 14/17] address feedback --- python/pyspark/errors/error-conditions.json | 4 +- python/pyspark/pipelines/cli.py | 38 ++--- python/pyspark/pipelines/tests/test_cli.py | 65 ++------ .../connect/pipelines/PipelinesHandler.scala | 151 +++++++++++------- .../PipelineRefreshFunctionalSuite.scala | 86 +++++----- .../SparkDeclarativePipelinesServerTest.scala | 15 +- .../graph/PipelineUpdateContextImpl.scala | 6 +- 7 files changed, 170 insertions(+), 195 deletions(-) diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index f4adea5fba83..2a638bc7ec36 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -176,9 +176,9 @@ }, "CONFLICTING_PIPELINE_REFRESH_OPTIONS" : { "message" : [ - "--full-refresh-all option conflicts with --refresh and --full-refresh. ", + "--full-refresh-all option conflicts with ", "The --full-refresh-all option performs a full refresh of all datasets, ", - "so specifying individual datasets with --refresh or --full-refresh is not allowed." + "so specifying individual datasets with is not allowed." ] }, "CONNECT_URL_ALREADY_DEFINED": { diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 691e8b2b8695..5c58c134b9bf 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -219,9 +219,9 @@ def change_dir(path: Path) -> Generator[None, None, None]: def run( spec_path: Path, - full_refresh: Optional[Sequence[str]] = None, - full_refresh_all: bool = False, - refresh: Optional[Sequence[str]] = None, + full_refresh: Sequence[str], + full_refresh_all: bool, + refresh: Sequence[str], ) -> None: """Run the pipeline defined with the given spec. @@ -234,11 +234,15 @@ def run( if full_refresh_all: if full_refresh: raise PySparkException( - errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", messageParameters={} + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", messageParameters={ + "conflicting_option": "--full_refresh", + } ) if refresh: raise PySparkException( - errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", messageParameters={} + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", messageParameters={ + "conflicting_option": "--refresh", + } ) log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...") @@ -281,28 +285,20 @@ def parse_table_list(value: str) -> List[str]: """Parse a comma-separated list of table names, handling whitespace.""" return [table.strip() for table in value.split(",") if table.strip()] - -def flatten_table_lists(table_lists: Optional[List[List[str]]]) -> Optional[List[str]]: - """Flatten a list of lists of table names into a single list.""" - if not table_lists: - return None - result = [] - for table_list in table_lists: - result.extend(table_list) - return result if result else None - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Pipeline CLI") subparsers = parser.add_subparsers(dest="command", required=True) # "run" subcommand - run_parser = subparsers.add_parser("run", help="Run a pipeline.") + run_parser = subparsers.add_parser( + "run", + help="Run a pipeline. If no refresh options are specified, a default incremental update is performed.", + ) run_parser.add_argument("--spec", help="Path to the pipeline spec.") run_parser.add_argument( "--full-refresh", type=parse_table_list, - action="append", + action="extend", help="List of datasets to reset and recompute (comma-separated).", ) run_parser.add_argument( @@ -311,7 +307,7 @@ def flatten_table_lists(table_lists: Optional[List[List[str]]]) -> Optional[List run_parser.add_argument( "--refresh", type=parse_table_list, - action="append", + action="extend", help="List of datasets to update (comma-separated).", ) @@ -343,9 +339,9 @@ def flatten_table_lists(table_lists: Optional[List[List[str]]]) -> Optional[List run( spec_path=spec_path, - full_refresh=flatten_table_lists(args.full_refresh), + full_refresh=args.full_refresh, full_refresh_all=args.full_refresh_all, - refresh=flatten_table_lists(args.refresh), + refresh=args.refresh, ) elif args.command == "init": init(args.name) diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index e801a0deaed8..319f637c8744 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -359,12 +359,6 @@ def test_python_import_current_directory(self): ), ) - -@unittest.skipIf( - not should_test_connect or not have_yaml, - connect_requirement_message or yaml_requirement_message, -) -class CLIValidationTests(unittest.TestCase): def test_full_refresh_all_conflicts_with_full_refresh(self): with tempfile.TemporaryDirectory() as temp_dir: # Create a minimal pipeline spec @@ -378,12 +372,17 @@ def test_full_refresh_all_conflicts_with_full_refresh(self): spec_path=spec_path, full_refresh=["table1", "table2"], full_refresh_all=True, - refresh=None, + refresh=[], ) self.assertEqual( context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) + self.assertEqual( + context.exception.getMessageParameters(), { + "conflicting_option": "--full_refresh" + } + ) def test_full_refresh_all_conflicts_with_refresh(self): with tempfile.TemporaryDirectory() as temp_dir: @@ -396,7 +395,7 @@ def test_full_refresh_all_conflicts_with_refresh(self): with self.assertRaises(PySparkException) as context: run( spec_path=spec_path, - full_refresh=None, + full_refresh=[], full_refresh_all=True, refresh=["table1", "table2"], ) @@ -404,6 +403,11 @@ def test_full_refresh_all_conflicts_with_refresh(self): self.assertEqual( context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) + self.assertEqual( + context.exception.getMessageParameters(), { + "conflicting_option": "--refresh" + }, + ) def test_full_refresh_all_conflicts_with_both(self): with tempfile.TemporaryDirectory() as temp_dir: @@ -447,51 +451,6 @@ def test_parse_table_list_with_spaces(self): result = parse_table_list("table1, table2 , table3") self.assertEqual(result, ["table1", "table2", "table3"]) - def test_flatten_table_lists_single_list(self): - """Test flattening single list.""" - from pyspark.pipelines.cli import flatten_table_lists - - result = flatten_table_lists([["table1", "table2"]]) - self.assertEqual(result, ["table1", "table2"]) - - def test_flatten_table_lists_multiple_lists(self): - """Test flattening multiple lists.""" - from pyspark.pipelines.cli import flatten_table_lists - - result = flatten_table_lists([["table1", "table2"], ["table3"], ["table4", "table5"]]) - self.assertEqual(result, ["table1", "table2", "table3", "table4", "table5"]) - - def test_cli_argument_parsing_patterns(self): - """Test CLI argument parsing patterns for refresh options.""" - import argparse - from pyspark.pipelines.cli import parse_table_list - - # Simulate the argument parser - parser = argparse.ArgumentParser() - parser.add_argument("--full-refresh", type=parse_table_list, action="append") - parser.add_argument("--full-refresh-all", action="store_true") - parser.add_argument("--refresh", type=parse_table_list, action="append") - - # Test parsing various argument combinations - test_cases = [ - (["--full-refresh", "table1,table2"], {"full_refresh": [["table1", "table2"]]}), - (["--refresh", "table1", "--refresh", "table2"], {"refresh": [["table1"], ["table2"]]}), - (["--full-refresh-all"], {"full_refresh_all": True}), - ( - ["--full-refresh", "table1", "--refresh", "table2"], - {"full_refresh": [["table1"]], "refresh": [["table2"]]}, - ), - ( - ["--full-refresh", "schema.table1,schema.table2"], - {"full_refresh": [["schema.table1", "schema.table2"]]}, - ), - ] - - for args, expected in test_cases: - parsed = parser.parse_args(args) - for key, value in expected.items(): - self.assertEqual(getattr(parsed, key), value) - if __name__ == "__main__": try: diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index a2d7b6fb4c0e..e5d9835165cf 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.connect.service.SessionHolder import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.QueryOriginType import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED} -import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} +import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} import org.apache.spark.sql.pipelines.logging.{PipelineEvent, RunProgress} import org.apache.spark.sql.types.StructType @@ -225,63 +225,7 @@ private[connect] object PipelinesHandler extends Logging { sessionHolder: SessionHolder): Unit = { val dataflowGraphId = cmd.getDataflowGraphId val graphElementRegistry = DataflowGraphRegistry.getDataflowGraphOrThrow(dataflowGraphId) - - // Extract refresh parameters from protobuf command - val fullRefreshTables = cmd.getFullRefreshList.asScala.toSeq - val fullRefreshAll = cmd.getFullRefreshAll - val refreshTables = cmd.getRefreshList.asScala.toSeq - - // Convert table names to fully qualified TableIdentifier objects - def parseTableNames(tableNames: Seq[String]): Set[TableIdentifier] = { - tableNames.map { name => - GraphIdentifierManager - .parseAndQualifyTableIdentifier( - rawTableIdentifier = - GraphIdentifierManager.parseTableIdentifier(name, sessionHolder.session), - currentCatalog = Some(graphElementRegistry.defaultCatalog), - currentDatabase = Some(graphElementRegistry.defaultDatabase)) - .identifier - }.toSet - } - - if (fullRefreshTables.nonEmpty && fullRefreshAll) { - throw new IllegalArgumentException( - "Cannot specify a subset to refresh when full refresh all is set to true.") - } - - if (refreshTables.nonEmpty && fullRefreshAll) { - throw new IllegalArgumentException( - "Cannot specify a subset to full refresh when full refresh all is set to true.") - } - val refreshTableNames = parseTableNames(refreshTables) - val fullRefreshTableNames = parseTableNames(fullRefreshTables) - - if (refreshTables.nonEmpty && fullRefreshTables.nonEmpty) { - // check if there is an intersection between the subset - val intersection = refreshTableNames.intersect(fullRefreshTableNames) - if (intersection.nonEmpty) { - throw new IllegalArgumentException( - "Datasets specified for refresh and full refresh cannot overlap: " + - s"${intersection.mkString(", ")}") - } - } - - val fullRefreshTablesFilter: TableFilter = if (fullRefreshAll) { - AllTables - } else if (fullRefreshTables.nonEmpty) { - SomeTables(fullRefreshTableNames) - } else { - NoTables - } - - val refreshTablesFilter: TableFilter = - if (refreshTables.nonEmpty) { - SomeTables(refreshTableNames) - } else if (fullRefreshTablesFilter != NoTables) { - NoTables - } else { - AllTables - } + val tableFiltersResult = createTableFilters(cmd, graphElementRegistry, sessionHolder) // We will use this variable to store the run failure event if it occurs. This will be set // by the event callback. @@ -341,8 +285,8 @@ private[connect] object PipelinesHandler extends Logging { val pipelineUpdateContext = new PipelineUpdateContextImpl( graphElementRegistry.toDataflowGraph, eventCallback, - refreshTablesFilter, - fullRefreshTablesFilter) + tableFiltersResult.refresh, + tableFiltersResult.fullRefresh) sessionHolder.cachePipelineExecution(dataflowGraphId, pipelineUpdateContext) pipelineUpdateContext.pipelineExecution.runPipeline() @@ -352,4 +296,91 @@ private[connect] object PipelinesHandler extends Logging { throw event.error.get } } + + /** + * Creates the table filters for the full refresh and refresh operations based on the + * StartRun command user provided. Also validates the command parameters to ensure that they are + * consistent and do not conflict with each other. + * + * If `fullRefreshAll` is true, create `AllTables` filter for full refresh. + * + * If `fullRefreshTables` and `refreshTables` are both empty, + * create `AllTables` filter for refresh as a default behavior. + * + * If both non-empty, verifies that there is no overlap and creates SomeTables filters for both. + * + * If one non-empty and the other empty, create `SomeTables` filter for the non-empty one, and + * `NoTables` filter for the empty one. + */ + private def createTableFilters( + startRunCommand: proto.PipelineCommand.StartRun, + graphElementRegistry: GraphRegistrationContext, + sessionHolder: SessionHolder): TableFilters = { + // Convert table names to fully qualified TableIdentifier objects + def parseTableNames(tableNames: Seq[String]): Set[TableIdentifier] = { + tableNames.map { name => + GraphIdentifierManager + .parseAndQualifyTableIdentifier( + rawTableIdentifier = + GraphIdentifierManager.parseTableIdentifier(name, sessionHolder.session), + currentCatalog = Some(graphElementRegistry.defaultCatalog), + currentDatabase = Some(graphElementRegistry.defaultDatabase)) + .identifier + }.toSet + } + + val fullRefreshTables = startRunCommand.getFullRefreshList.asScala.toSeq + val fullRefreshAll = startRunCommand.getFullRefreshAll + val refreshTables = startRunCommand.getRefreshList.asScala.toSeq + + if (refreshTables.nonEmpty && fullRefreshAll) { + throw new IllegalArgumentException( + "Cannot specify a subset to refresh when full refresh all is set to true.") + } + + if (fullRefreshTables.nonEmpty && fullRefreshAll) { + throw new IllegalArgumentException( + "Cannot specify a subset to full refresh when full refresh all is set to true.") + } + val refreshTableNames = parseTableNames(refreshTables) + val fullRefreshTableNames = parseTableNames(fullRefreshTables) + + if (refreshTables.nonEmpty && fullRefreshTables.nonEmpty) { + // check if there is an intersection between the subset + val intersection = refreshTableNames.intersect(fullRefreshTableNames) + if (intersection.nonEmpty) { + throw new IllegalArgumentException( + "Datasets specified for refresh and full refresh cannot overlap: " + + s"${intersection.mkString(", ")}") + } + } + + if (fullRefreshAll) { + return TableFilters(fullRefresh = AllTables, refresh = NoTables) + } + + (fullRefreshTables, refreshTables) match { + case (Nil, Nil) => + // If both are empty, we default to refreshing all tables + TableFilters(fullRefresh = NoTables, refresh = AllTables) + case (_, Nil) => + TableFilters(fullRefresh = SomeTables(fullRefreshTableNames), refresh = NoTables) + case (Nil, _) => + TableFilters(fullRefresh = NoTables, refresh = SomeTables(refreshTableNames)) + case (_, _) => + // If both are specified, we create filters for both after validation + TableFilters( + fullRefresh = SomeTables(fullRefreshTableNames), + refresh = SomeTables(refreshTableNames) + ) + } + } + + /** + * A case class to hold the table filters for full refresh and refresh operations. + */ + private case class TableFilters( + fullRefresh: TableFilter, + refresh: TableFilter + ) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index 477457778f06..83846c55d72c 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import org.apache.spark.connect.proto.{DatasetType, PipelineCommand, PipelineEvent} +import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.connect.service.{SessionKey, SparkConnectService} import org.apache.spark.sql.pipelines.utils.{EventVerificationTestHelpers, TestPipelineUpdateContextMixin} @@ -85,19 +86,16 @@ class PipelineRefreshFunctionalSuite // First run to populate tables startPipelineAndWaitForCompletion(graphId) - // Verify initial data - all tests expect the same initial state - verifyMultipleTableContent( - tableNames = - Set("spark_catalog.default.a", "spark_catalog.default.b", "spark_catalog.default.mv"), - columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id"), - "spark_catalog.default.mv" -> Seq("id")), - expectedContent = Map( - "spark_catalog.default.a" -> Set(Map("id" -> 1)), - "spark_catalog.default.b" -> Set(Map("id" -> 1)), - "spark_catalog.default.mv" -> Set(Map("id" -> 1)))) - + // combine above into a map for verification + val initialContent = Map( + "spark_catalog.default.a" -> Set(Map("id" -> 1)), + "spark_catalog.default.b" -> Set(Map("id" -> 1)), + "spark_catalog.default.mv" -> Set(Map("id" -> 1)) + ) + // Verify initial content + initialContent.foreach { case (tableName, expectedRows) => + checkTableContent(tableName, expectedRows) + } // Clear cached pipeline execution before starting new run SparkConnectService.sessionManager .getIsolatedSessionIfPresent(SessionKey(defaultUserId, defaultSessionId)) @@ -110,22 +108,17 @@ class PipelineRefreshFunctionalSuite // Run with specified refresh configuration val capturedEvents = refreshConfigBuilder(graphId) match { - case Some(startRun) => startPipelineAndWaitForCompletion(graphId, Some(startRun)) + case Some(startRun) => startPipelineAndWaitForCompletion(startRun) case None => startPipelineAndWaitForCompletion(graphId) } // Additional validation if provided eventValidation.foreach(_(capturedEvents)) - // Verify final content - verifyMultipleTableContent( - tableNames = - Set("spark_catalog.default.a", "spark_catalog.default.b", "spark_catalog.default.mv"), - columnsToVerify = Map( - "spark_catalog.default.a" -> Seq("id"), - "spark_catalog.default.b" -> Seq("id"), - "spark_catalog.default.mv" -> Seq("id")), - expectedContent = expectedContentAfterRefresh) + // Verify final content with checkTableContent + expectedContentAfterRefresh.foreach { case (tableName, expectedRows) => + checkTableContent(tableName, expectedRows) + } } } @@ -244,11 +237,11 @@ class PipelineRefreshFunctionalSuite .build() val exception = intercept[IllegalArgumentException] { - startPipelineAndWaitForCompletion(graphId, Some(startRun)) + startPipelineAndWaitForCompletion(startRun) } assert( exception.getMessage.contains( - "Cannot specify a subset to full refresh when full refresh all is set to true")) + "Cannot specify a subset to refresh when full refresh all is set to true")) } } @@ -266,11 +259,11 @@ class PipelineRefreshFunctionalSuite .build() val exception = intercept[IllegalArgumentException] { - startPipelineAndWaitForCompletion(graphId, Some(startRun)) + startPipelineAndWaitForCompletion(startRun) } assert( exception.getMessage.contains( - "Cannot specify a subset to refresh when full refresh all is set to true")) + "Cannot specify a subset to full refresh when full refresh all is set to true")) } } @@ -288,7 +281,7 @@ class PipelineRefreshFunctionalSuite .build() val exception = intercept[IllegalArgumentException] { - startPipelineAndWaitForCompletion(graphId, Some(startRun)) + startPipelineAndWaitForCompletion(startRun) } assert( exception.getMessage.contains( @@ -312,7 +305,7 @@ class PipelineRefreshFunctionalSuite .build() val exception = intercept[IllegalArgumentException] { - startPipelineAndWaitForCompletion(graphId, Some(startRun)) + startPipelineAndWaitForCompletion(startRun) } assert( exception.getMessage.contains( @@ -335,7 +328,7 @@ class PipelineRefreshFunctionalSuite .build() val exception = intercept[IllegalArgumentException] { - startPipelineAndWaitForCompletion(graphId, Some(startRun)) + startPipelineAndWaitForCompletion(startRun) } assert( exception.getMessage.contains( @@ -343,25 +336,18 @@ class PipelineRefreshFunctionalSuite } } - private def verifyMultipleTableContent( - tableNames: Set[String], - columnsToVerify: Map[String, Seq[String]], - expectedContent: Map[String, Set[Map[String, Any]]]): Unit = { - tableNames.foreach { tableName => - spark.catalog.refreshTable(tableName) // clear cache for the table - val df = spark.table(tableName) - assert( - df.columns.toSet == columnsToVerify(tableName).toSet, - s"Columns in $tableName do not match expected: ${df.columns.mkString(", ")}") - val actualContent = df - .collect() - .map(row => { - columnsToVerify(tableName).map(col => col -> row.getAs[Any](col)).toMap - }) - .toSet - assert( - actualContent == expectedContent(tableName), - s"Content of $tableName does not match expected: $actualContent") - } + private def checkTableContent[A <: Map[String, Any]]( + name: String, + expectedContent: Set[A] + ): Unit = { + spark.catalog.refreshTable(name) // clear cache for the table + val df = spark.table(name) + QueryTest.checkAnswer( + df, + expectedContent.map(row => { + // Convert each row to a Row object + org.apache.spark.sql.Row.fromSeq(row.values.toSeq) + }).toSeq.asJava + ) } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index eb234b12929d..b3124b236ba0 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -128,15 +128,18 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { def createPlanner(): SparkConnectPlanner = new SparkConnectPlanner(SparkConnectTestUtils.createDummySessionHolder(spark)) + def startPipelineAndWaitForCompletion(graphId: String) + : ArrayBuffer[PipelineEvent] = { + val defaultStartRunCommand = + PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build() + startPipelineAndWaitForCompletion(defaultStartRunCommand) + } + def startPipelineAndWaitForCompletion( - graphId: String, - customStartRunCommand: Option[PipelineCommand.StartRun] = None) - : ArrayBuffer[PipelineEvent] = { + startRunCommand: PipelineCommand.StartRun): ArrayBuffer[PipelineEvent] = { withClient { client => val capturedEvents = new ArrayBuffer[PipelineEvent]() - val startRunRequest = buildStartRunPlan( - customStartRunCommand.getOrElse( - PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build())) + val startRunRequest = buildStartRunPlan(startRunCommand) val responseIterator = client.execute(startRunRequest) // The response iterator will be closed when the pipeline is completed. while (responseIterator.hasNext) { diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala index e03b6c299797..5a298c2f17d9 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala @@ -27,14 +27,12 @@ import org.apache.spark.sql.pipelines.logging.{FlowProgressEventLogger, Pipeline * @param refreshTables Filter for which tables should be refreshed when performing this update. * @param fullRefreshTables Filter for which tables should be full refreshed * when performing this update. - * @param resetCheckpointFlows Filter for which flows should be reset. */ class PipelineUpdateContextImpl( override val unresolvedGraph: DataflowGraph, override val eventCallback: PipelineEvent => Unit, override val refreshTables: TableFilter = AllTables, - override val fullRefreshTables: TableFilter = NoTables, - override val resetCheckpointFlows: FlowFilter = NoFlows + override val fullRefreshTables: TableFilter = NoTables ) extends PipelineUpdateContext { override val spark: SparkSession = SparkSession.getActiveSession.getOrElse( @@ -43,4 +41,6 @@ class PipelineUpdateContextImpl( override val flowProgressEventLogger: FlowProgressEventLogger = new FlowProgressEventLogger(eventCallback = eventCallback) + + override val resetCheckpointFlows: FlowFilter = NoFlows } From e56d39b815308e46e767408c9c9e370dd9e61480 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Wed, 16 Jul 2025 22:10:50 -0700 Subject: [PATCH 15/17] rename proto --- python/pyspark/pipelines/cli.py | 2 ++ .../pipelines/spark_connect_pipeline.py | 4 +-- .../sql/connect/proto/pipelines_pb2.py | 30 +++++++++---------- .../sql/connect/proto/pipelines_pb2.pyi | 24 +++++++-------- .../protobuf/spark/connect/pipelines.proto | 8 ++--- .../connect/pipelines/PipelinesHandler.scala | 4 +-- .../PipelineRefreshFunctionalSuite.scala | 24 +++++++-------- 7 files changed, 49 insertions(+), 47 deletions(-) diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 5c58c134b9bf..1d7fb4a12385 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -300,6 +300,7 @@ def parse_table_list(value: str) -> List[str]: type=parse_table_list, action="extend", help="List of datasets to reset and recompute (comma-separated).", + default=[], ) run_parser.add_argument( "--full-refresh-all", action="store_true", help="Perform a full graph reset and recompute." @@ -309,6 +310,7 @@ def parse_table_list(value: str) -> List[str]: type=parse_table_list, action="extend", help="List of datasets to update (comma-separated).", + default=[], ) # "init" subcommand diff --git a/python/pyspark/pipelines/spark_connect_pipeline.py b/python/pyspark/pipelines/spark_connect_pipeline.py index 59f0a4df586d..f430d33be4a1 100644 --- a/python/pyspark/pipelines/spark_connect_pipeline.py +++ b/python/pyspark/pipelines/spark_connect_pipeline.py @@ -81,9 +81,9 @@ def start_run( """ inner_command = pb2.PipelineCommand.StartRun( dataflow_graph_id=dataflow_graph_id, - full_refresh=full_refresh or [], + full_refresh_selection=full_refresh or [], full_refresh_all=full_refresh_all, - refresh=refresh or [], + refresh_selection=refresh or [], ) command = pb2.Command() command.pipeline_command.start_run.CopyFrom(inner_command) diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py b/python/pyspark/sql/connect/proto/pipelines_pb2.py index 0e52646e7c3a..e13877d05a60 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.py +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xf4\x13\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x1a\x87\x03\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aQ\n\x08Response\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xd1\x04\n\rDefineDataset\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12&\n\x0c\x64\x61taset_name\x18\x02 \x01(\tH\x01R\x0b\x64\x61tasetName\x88\x01\x01\x12\x42\n\x0c\x64\x61taset_type\x18\x03 \x01(\x0e\x32\x1a.spark.connect.DatasetTypeH\x02R\x0b\x64\x61tasetType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x03R\x07\x63omment\x88\x01\x01\x12l\n\x10table_properties\x18\x05 \x03(\x0b\x32\x41.spark.connect.PipelineCommand.DefineDataset.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x06 \x03(\tR\rpartitionCols\x12\x34\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x04R\x06schema\x88\x01\x01\x12\x1b\n\x06\x66ormat\x18\x08 \x01(\tH\x05R\x06\x66ormat\x88\x01\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0f\n\r_dataset_nameB\x0f\n\r_dataset_typeB\n\n\x08_commentB\t\n\x07_schemaB\t\n\x07_format\x1a\xc8\x03\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x01R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x02R\x11targetDatasetName\x88\x01\x01\x12\x38\n\x08relation\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x03R\x08relation\x88\x01\x01\x12Q\n\x08sql_conf\x18\x05 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12\x17\n\x04once\x18\x06 \x01(\x08H\x04R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0b\n\t_relationB\x07\n\x05_once\x1a\xd2\x01\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12!\n\x0c\x66ull_refresh\x18\x02 \x03(\tR\x0b\x66ullRefresh\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12\x18\n\x07refresh\x18\x04 \x03(\tR\x07refreshB\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_all\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_textB\x0e\n\x0c\x63ommand_type"\x8e\x02\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message*a\n\x0b\x44\x61tasetType\x12\x1c\n\x18\x44\x41TASET_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9a\x14\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12U\n\x0e\x64\x65\x66ine_dataset\x18\x02 \x01(\x0b\x32,.spark.connect.PipelineCommand.DefineDatasetH\x00R\rdefineDataset\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x1a\x87\x03\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1aQ\n\x08Response\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\xd1\x04\n\rDefineDataset\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12&\n\x0c\x64\x61taset_name\x18\x02 \x01(\tH\x01R\x0b\x64\x61tasetName\x88\x01\x01\x12\x42\n\x0c\x64\x61taset_type\x18\x03 \x01(\x0e\x32\x1a.spark.connect.DatasetTypeH\x02R\x0b\x64\x61tasetType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x03R\x07\x63omment\x88\x01\x01\x12l\n\x10table_properties\x18\x05 \x03(\x0b\x32\x41.spark.connect.PipelineCommand.DefineDataset.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x06 \x03(\tR\rpartitionCols\x12\x34\n\x06schema\x18\x07 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x04R\x06schema\x88\x01\x01\x12\x1b\n\x06\x66ormat\x18\x08 \x01(\tH\x05R\x06\x66ormat\x88\x01\x01\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0f\n\r_dataset_nameB\x0f\n\r_dataset_typeB\n\n\x08_commentB\t\n\x07_schemaB\t\n\x07_format\x1a\xc8\x03\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x01R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x02R\x11targetDatasetName\x88\x01\x01\x12\x38\n\x08relation\x18\x04 \x01(\x0b\x32\x17.spark.connect.RelationH\x03R\x08relation\x88\x01\x01\x12Q\n\x08sql_conf\x18\x05 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12\x17\n\x04once\x18\x06 \x01(\x08H\x04R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0b\n\t_relationB\x07\n\x05_once\x1a\xf8\x01\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelectionB\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_all\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_textB\x0e\n\x0c\x63ommand_type"\x8e\x02\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message*a\n\x0b\x44\x61tasetType\x12\x1c\n\x18\x44\x41TASET_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -59,10 +59,10 @@ _globals["_PIPELINECOMMAND_DEFINEDATASET_TABLEPROPERTIESENTRY"]._serialized_options = b"8\001" _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = b"8\001" - _globals["_DATASETTYPE"]._serialized_start = 3156 - _globals["_DATASETTYPE"]._serialized_end = 3253 + _globals["_DATASETTYPE"]._serialized_start = 3194 + _globals["_DATASETTYPE"]._serialized_end = 3291 _globals["_PIPELINECOMMAND"]._serialized_start = 140 - _globals["_PIPELINECOMMAND"]._serialized_end = 2688 + _globals["_PIPELINECOMMAND"]._serialized_end = 2726 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 719 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_end = 1110 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH_SQLCONFENTRY"]._serialized_start = 928 @@ -80,15 +80,15 @@ _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_start = 928 _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_end = 986 _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_start = 2260 - _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2470 - _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 2473 - _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2672 - _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2691 - _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2961 - _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 2848 - _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 2946 - _globals["_PIPELINEEVENTRESULT"]._serialized_start = 2963 - _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3036 - _globals["_PIPELINEEVENT"]._serialized_start = 3038 - _globals["_PIPELINEEVENT"]._serialized_end = 3154 + _globals["_PIPELINECOMMAND_STARTRUN"]._serialized_end = 2508 + _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_start = 2511 + _globals["_PIPELINECOMMAND_DEFINESQLGRAPHELEMENTS"]._serialized_end = 2710 + _globals["_PIPELINECOMMANDRESULT"]._serialized_start = 2729 + _globals["_PIPELINECOMMANDRESULT"]._serialized_end = 2999 + _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_start = 2886 + _globals["_PIPELINECOMMANDRESULT_CREATEDATAFLOWGRAPHRESULT"]._serialized_end = 2984 + _globals["_PIPELINEEVENTRESULT"]._serialized_start = 3001 + _globals["_PIPELINEEVENTRESULT"]._serialized_end = 3074 + _globals["_PIPELINEEVENT"]._serialized_start = 3076 + _globals["_PIPELINEEVENT"]._serialized_end = 3192 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi index d52e4addf571..fff130ac3b93 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi @@ -530,30 +530,30 @@ class PipelineCommand(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor DATAFLOW_GRAPH_ID_FIELD_NUMBER: builtins.int - FULL_REFRESH_FIELD_NUMBER: builtins.int + FULL_REFRESH_SELECTION_FIELD_NUMBER: builtins.int FULL_REFRESH_ALL_FIELD_NUMBER: builtins.int - REFRESH_FIELD_NUMBER: builtins.int + REFRESH_SELECTION_FIELD_NUMBER: builtins.int dataflow_graph_id: builtins.str """The graph to start.""" @property - def full_refresh( + def full_refresh_selection( self, ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: - """List of tables to reset and recompute.""" + """List of dataset to reset and recompute.""" full_refresh_all: builtins.bool """Perform a full graph reset and recompute.""" @property - def refresh( + def refresh_selection( self, ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: - """List of tables to update.""" + """List of dataset to update.""" def __init__( self, *, dataflow_graph_id: builtins.str | None = ..., - full_refresh: collections.abc.Iterable[builtins.str] | None = ..., + full_refresh_selection: collections.abc.Iterable[builtins.str] | None = ..., full_refresh_all: builtins.bool | None = ..., - refresh: collections.abc.Iterable[builtins.str] | None = ..., + refresh_selection: collections.abc.Iterable[builtins.str] | None = ..., ) -> None: ... def HasField( self, @@ -577,12 +577,12 @@ class PipelineCommand(google.protobuf.message.Message): b"_full_refresh_all", "dataflow_graph_id", b"dataflow_graph_id", - "full_refresh", - b"full_refresh", "full_refresh_all", b"full_refresh_all", - "refresh", - b"refresh", + "full_refresh_selection", + b"full_refresh_selection", + "refresh_selection", + b"refresh_selection", ], ) -> None: ... @typing.overload diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto index 7f4dbb3a1f78..751b14abe478 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto @@ -117,14 +117,14 @@ message PipelineCommand { // The graph to start. optional string dataflow_graph_id = 1; - // List of tables to reset and recompute. - repeated string full_refresh = 2; + // List of dataset to reset and recompute. + repeated string full_refresh_selection = 2; // Perform a full graph reset and recompute. optional bool full_refresh_all = 3; - // List of tables to update. - repeated string refresh = 4; + // List of dataset to update. + repeated string refresh_selection = 4; } // Parses the SQL file and registers all datasets and flows. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index e5d9835165cf..b41fc10ffa83 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -329,9 +329,9 @@ private[connect] object PipelinesHandler extends Logging { }.toSet } - val fullRefreshTables = startRunCommand.getFullRefreshList.asScala.toSeq + val fullRefreshTables = startRunCommand.getFullRefreshSelectionList.asScala.toSeq val fullRefreshAll = startRunCommand.getFullRefreshAll - val refreshTables = startRunCommand.getRefreshList.asScala.toSeq + val refreshTables = startRunCommand.getRefreshSelectionList.asScala.toSeq if (refreshTables.nonEmpty && fullRefreshAll) { throw new IllegalArgumentException( diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index 83846c55d72c..a444e908db81 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -129,7 +129,7 @@ class PipelineRefreshFunctionalSuite PipelineCommand.StartRun .newBuilder() .setDataflowGraphId(graphId) - .addAllFullRefresh(List("a").asJava) + .addAllFullRefreshSelection(List("a").asJava) .build()) }, expectedContentAfterRefresh = Map( @@ -167,8 +167,8 @@ class PipelineRefreshFunctionalSuite PipelineCommand.StartRun .newBuilder() .setDataflowGraphId(graphId) - .addAllFullRefresh(Seq("a", "mv").asJava) - .addRefresh("b") + .addAllFullRefreshSelection(Seq("a", "mv").asJava) + .addRefreshSelection("b") .build()) }, expectedContentAfterRefresh = Map( @@ -233,7 +233,7 @@ class PipelineRefreshFunctionalSuite .newBuilder() .setDataflowGraphId(graphId) .setFullRefreshAll(true) - .addRefresh("a") + .addRefreshSelection("a") .build() val exception = intercept[IllegalArgumentException] { @@ -255,7 +255,7 @@ class PipelineRefreshFunctionalSuite .newBuilder() .setDataflowGraphId(graphId) .setFullRefreshAll(true) - .addFullRefresh("a") + .addFullRefreshSelection("a") .build() val exception = intercept[IllegalArgumentException] { @@ -276,8 +276,8 @@ class PipelineRefreshFunctionalSuite val startRun = PipelineCommand.StartRun .newBuilder() .setDataflowGraphId(graphId) - .addRefresh("a") - .addFullRefresh("a") + .addRefreshSelection("a") + .addFullRefreshSelection("a") .build() val exception = intercept[IllegalArgumentException] { @@ -299,9 +299,9 @@ class PipelineRefreshFunctionalSuite val startRun = PipelineCommand.StartRun .newBuilder() .setDataflowGraphId(graphId) - .addRefresh("a") - .addRefresh("b") - .addFullRefresh("a") + .addRefreshSelection("a") + .addRefreshSelection("b") + .addFullRefreshSelection("a") .build() val exception = intercept[IllegalArgumentException] { @@ -323,8 +323,8 @@ class PipelineRefreshFunctionalSuite val startRun = PipelineCommand.StartRun .newBuilder() .setDataflowGraphId(graphId) - .addRefresh("spark_catalog.default.a") - .addFullRefresh("a") // This should be treated as the same table + .addRefreshSelection("spark_catalog.default.a") + .addFullRefreshSelection("a") // This should be treated as the same table .build() val exception = intercept[IllegalArgumentException] { From b04ac242c161f6f070a577a11fca5d2dcaefd753 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Wed, 16 Jul 2025 23:39:31 -0700 Subject: [PATCH 16/17] fmt --- python/pyspark/pipelines/cli.py | 14 ++++++++----- python/pyspark/pipelines/tests/test_cli.py | 9 +++------ .../connect/pipelines/PipelinesHandler.scala | 16 ++++++--------- .../PipelineRefreshFunctionalSuite.scala | 20 +++++++++---------- .../SparkDeclarativePipelinesServerTest.scala | 3 +-- 5 files changed, 29 insertions(+), 33 deletions(-) diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 1d7fb4a12385..a7a0f01b4ca1 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -234,15 +234,17 @@ def run( if full_refresh_all: if full_refresh: raise PySparkException( - errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", messageParameters={ + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", + messageParameters={ "conflicting_option": "--full_refresh", - } + }, ) if refresh: raise PySparkException( - errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", messageParameters={ + errorClass="CONFLICTING_PIPELINE_REFRESH_OPTIONS", + messageParameters={ "conflicting_option": "--refresh", - } + }, ) log_with_curr_timestamp(f"Loading pipeline spec from {spec_path}...") @@ -285,6 +287,7 @@ def parse_table_list(value: str) -> List[str]: """Parse a comma-separated list of table names, handling whitespace.""" return [table.strip() for table in value.split(",") if table.strip()] + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Pipeline CLI") subparsers = parser.add_subparsers(dest="command", required=True) @@ -292,7 +295,8 @@ def parse_table_list(value: str) -> List[str]: # "run" subcommand run_parser = subparsers.add_parser( "run", - help="Run a pipeline. If no refresh options are specified, a default incremental update is performed.", + help="Run a pipeline. If no refresh options specified, " + "a default incremental update is performed.", ) run_parser.add_argument("--spec", help="Path to the pipeline spec.") run_parser.add_argument( diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index 319f637c8744..ded00e691db4 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -379,9 +379,7 @@ def test_full_refresh_all_conflicts_with_full_refresh(self): context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) self.assertEqual( - context.exception.getMessageParameters(), { - "conflicting_option": "--full_refresh" - } + context.exception.getMessageParameters(), {"conflicting_option": "--full_refresh"} ) def test_full_refresh_all_conflicts_with_refresh(self): @@ -404,9 +402,8 @@ def test_full_refresh_all_conflicts_with_refresh(self): context.exception.getCondition(), "CONFLICTING_PIPELINE_REFRESH_OPTIONS" ) self.assertEqual( - context.exception.getMessageParameters(), { - "conflicting_option": "--refresh" - }, + context.exception.getMessageParameters(), + {"conflicting_option": "--refresh"}, ) def test_full_refresh_all_conflicts_with_both(self): diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index b41fc10ffa83..7f92aa13944c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -298,14 +298,14 @@ private[connect] object PipelinesHandler extends Logging { } /** - * Creates the table filters for the full refresh and refresh operations based on the - * StartRun command user provided. Also validates the command parameters to ensure that they are + * Creates the table filters for the full refresh and refresh operations based on the StartRun + * command user provided. Also validates the command parameters to ensure that they are * consistent and do not conflict with each other. * * If `fullRefreshAll` is true, create `AllTables` filter for full refresh. * - * If `fullRefreshTables` and `refreshTables` are both empty, - * create `AllTables` filter for refresh as a default behavior. + * If `fullRefreshTables` and `refreshTables` are both empty, create `AllTables` filter for + * refresh as a default behavior. * * If both non-empty, verifies that there is no overlap and creates SomeTables filters for both. * @@ -371,16 +371,12 @@ private[connect] object PipelinesHandler extends Logging { // If both are specified, we create filters for both after validation TableFilters( fullRefresh = SomeTables(fullRefreshTableNames), - refresh = SomeTables(refreshTableNames) - ) + refresh = SomeTables(refreshTableNames)) } } /** * A case class to hold the table filters for full refresh and refresh operations. */ - private case class TableFilters( - fullRefresh: TableFilter, - refresh: TableFilter - ) + private case class TableFilters(fullRefresh: TableFilter, refresh: TableFilter) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala index a444e908db81..794932544d5f 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PipelineRefreshFunctionalSuite.scala @@ -90,8 +90,7 @@ class PipelineRefreshFunctionalSuite val initialContent = Map( "spark_catalog.default.a" -> Set(Map("id" -> 1)), "spark_catalog.default.b" -> Set(Map("id" -> 1)), - "spark_catalog.default.mv" -> Set(Map("id" -> 1)) - ) + "spark_catalog.default.mv" -> Set(Map("id" -> 1))) // Verify initial content initialContent.foreach { case (tableName, expectedRows) => checkTableContent(tableName, expectedRows) @@ -337,17 +336,18 @@ class PipelineRefreshFunctionalSuite } private def checkTableContent[A <: Map[String, Any]]( - name: String, - expectedContent: Set[A] - ): Unit = { + name: String, + expectedContent: Set[A]): Unit = { spark.catalog.refreshTable(name) // clear cache for the table val df = spark.table(name) QueryTest.checkAnswer( df, - expectedContent.map(row => { - // Convert each row to a Row object - org.apache.spark.sql.Row.fromSeq(row.values.toSeq) - }).toSeq.asJava - ) + expectedContent + .map(row => { + // Convert each row to a Row object + org.apache.spark.sql.Row.fromSeq(row.values.toSeq) + }) + .toSeq + .asJava) } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala index b3124b236ba0..003fd30b6075 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerTest.scala @@ -128,8 +128,7 @@ class SparkDeclarativePipelinesServerTest extends SparkConnectServerTest { def createPlanner(): SparkConnectPlanner = new SparkConnectPlanner(SparkConnectTestUtils.createDummySessionHolder(spark)) - def startPipelineAndWaitForCompletion(graphId: String) - : ArrayBuffer[PipelineEvent] = { + def startPipelineAndWaitForCompletion(graphId: String): ArrayBuffer[PipelineEvent] = { val defaultStartRunCommand = PipelineCommand.StartRun.newBuilder().setDataflowGraphId(graphId).build() startPipelineAndWaitForCompletion(defaultStartRunCommand) From f21d79fbbd5bb4741c9afb50a2a8e6f92c3ccd17 Mon Sep 17 00:00:00 2001 From: Jacky Wang Date: Wed, 16 Jul 2025 23:45:18 -0700 Subject: [PATCH 17/17] nit --- python/pyspark/pipelines/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index a7a0f01b4ca1..2a0cf880d10c 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -28,7 +28,7 @@ import yaml from dataclasses import dataclass from pathlib import Path -from typing import Any, Generator, Mapping, Optional, Sequence, List +from typing import Any, Generator, List, Mapping, Optional, Sequence from pyspark.errors import PySparkException, PySparkTypeError from pyspark.sql import SparkSession