|
22 | 22 | import unittest |
23 | 23 | import warnings |
24 | 24 |
|
25 | | -from pyspark.sql import Row |
| 25 | +from pyspark.sql import Row, SparkSession |
26 | 26 | from pyspark.sql.functions import udf |
27 | 27 | from pyspark.sql.types import * |
28 | 28 | from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ |
@@ -421,6 +421,35 @@ def run_test(num_records, num_parts, max_records, use_delay=False): |
421 | 421 | run_test(*case) |
422 | 422 |
|
423 | 423 |
|
| 424 | +@unittest.skipIf( |
| 425 | + not have_pandas or not have_pyarrow, |
| 426 | + pandas_requirement_message or pyarrow_requirement_message) |
| 427 | +class MaxResultArrowTests(unittest.TestCase): |
| 428 | + # These tests are separate as 'spark.driver.maxResultSize' configuration |
| 429 | + # is a static configuration to Spark context. |
| 430 | + |
| 431 | + @classmethod |
| 432 | + def setUpClass(cls): |
| 433 | + cls.spark = SparkSession.builder \ |
| 434 | + .master("local[4]") \ |
| 435 | + .appName(cls.__name__) \ |
| 436 | + .config("spark.driver.maxResultSize", "10k") \ |
| 437 | + .getOrCreate() |
| 438 | + |
| 439 | + # Explicitly enable Arrow and disable fallback. |
| 440 | + cls.spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true") |
| 441 | + cls.spark.conf.set("spark.sql.execution.arrow.pyspark.fallback.enabled", "false") |
| 442 | + |
| 443 | + @classmethod |
| 444 | + def tearDownClass(cls): |
| 445 | + if hasattr(cls, "spark"): |
| 446 | + cls.spark.stop() |
| 447 | + |
| 448 | + def test_exception_by_max_results(self): |
| 449 | + with self.assertRaisesRegexp(Exception, "is bigger than"): |
| 450 | + self.spark.range(0, 10000, 1, 100).toPandas() |
| 451 | + |
| 452 | + |
424 | 453 | class EncryptionArrowTests(ArrowTests): |
425 | 454 |
|
426 | 455 | @classmethod |
|
0 commit comments