Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 72 additions & 71 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import array
import ctypes
import py4j
from contextlib import contextmanager

try:
import xmlrunner
Expand Down Expand Up @@ -201,6 +202,28 @@ def assertPandasEqual(self, expected, result):
"\n\nResult:\n%s\n%s" % (result, result.dtypes))
self.assertTrue(expected.equals(result), msg=msg)

@contextmanager

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was extracted alone from d6632d1

def sql_conf(self, pairs):
"""
A convenient context manager to test some configuration specific logic. This sets
`value` to the configuration `key` and then restores it back when it exits.
"""
assert isinstance(pairs, dict), "pairs should be a dictionary."

keys = pairs.keys()
new_values = pairs.values()
old_values = [self.spark.conf.get(key, None) for key in keys]
for key, new_value in zip(keys, new_values):
self.spark.conf.set(key, new_value)
try:
yield
finally:
for key, old_value in zip(keys, old_values):
if old_value is None:
self.spark.conf.unset(key)
else:
self.spark.conf.set(key, old_value)


class DataTypeTests(unittest.TestCase):
# regression test for SPARK-6055
Expand Down Expand Up @@ -2409,17 +2432,13 @@ def test_join_without_on(self):
df1 = self.spark.range(1).toDF("a")
df2 = self.spark.range(1).toDF("b")

try:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other diff are basically the same.

self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())

self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
actual = df1.join(df2, how="inner").collect()
expected = [Row(a=0, b=0)]
self.assertEqual(actual, expected)
finally:
# We should unset this. Otherwise, other tests are affected.
self.spark.conf.unset("spark.sql.crossJoin.enabled")

# Regression test for invalid join methods when on is None, Spark-14761
def test_invalid_join_method(self):
Expand Down Expand Up @@ -2891,21 +2910,18 @@ def test_create_dateframe_from_pandas_with_dst(self):
self.assertPandasEqual(pdf, df.toPandas())

orig_env_tz = os.environ.get('TZ', None)
orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone')
try:
tz = 'America/Los_Angeles'
os.environ['TZ'] = tz
time.tzset()
self.spark.conf.set('spark.sql.session.timeZone', tz)

df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
with self.sql_conf({'spark.sql.session.timeZone': tz}):
df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
finally:
del os.environ['TZ']
if orig_env_tz is not None:
os.environ['TZ'] = orig_env_tz
time.tzset()
self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz)


class HiveSparkSubmitTests(SparkSubmitTests):
Expand Down Expand Up @@ -3472,12 +3488,11 @@ def test_null_conversion(self):
self.assertTrue(all([c == 1 for c in null_counts]))

def _toPandas_arrow_toggle(self, df):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
pdf = df.toPandas()
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")

pdf_arrow = df.toPandas()

return pdf, pdf_arrow

def test_toPandas_arrow_toggle(self):
Expand All @@ -3489,16 +3504,17 @@ def test_toPandas_arrow_toggle(self):

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")

timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)

Expand All @@ -3511,8 +3527,6 @@ def test_toPandas_respect_session_timezone(self):
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
pdf_la_corrected[field.name], timezone)
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_pandas_round_trip(self):
pdf = self.create_pandas_data_frame()
Expand All @@ -3528,12 +3542,11 @@ def test_filtered_frame(self):
self.assertTrue(pdf.empty)

def _createDataFrame_toggle(self, pdf, schema=None):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")

df_arrow = self.spark.createDataFrame(pdf, schema=schema)

return df_no_arrow, df_arrow

def test_createDataFrame_toggle(self):
Expand All @@ -3544,18 +3557,18 @@ def test_createDataFrame_toggle(self):
def test_createDataFrame_respect_session_timezone(self):
from datetime import timedelta
pdf = self.create_pandas_data_frame()
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
result_la = df_no_arrow_la.collect()
result_arrow_la = df_arrow_la.collect()
self.assertEqual(result_la, result_arrow_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
result_la = df_no_arrow_la.collect()
result_arrow_la = df_arrow_la.collect()
self.assertEqual(result_la, result_arrow_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
result_ny = df_no_arrow_ny.collect()
result_arrow_ny = df_arrow_ny.collect()
Expand All @@ -3568,8 +3581,6 @@ def test_createDataFrame_respect_session_timezone(self):
for k, v in row.asDict().items()})
for row in result_la]
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
Expand Down Expand Up @@ -4222,9 +4233,7 @@ def gen_timestamps(id):
def test_vectorized_udf_check_config(self):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
try:
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
df = self.spark.range(10, numPartitions=1)

@pandas_udf(returnType=LongType())
Expand All @@ -4234,11 +4243,6 @@ def check_records_per_batch(x):
result = df.select(check_records_per_batch(col("id"))).collect()
for (r,) in result:
self.assertTrue(r <= 3)
finally:
if orig_value is None:
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
else:
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)

def test_vectorized_udf_timestamps_respect_session_timezone(self):
from pyspark.sql.functions import pandas_udf, col
Expand All @@ -4257,30 +4261,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
internal_value = pandas_udf(
lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())

orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()

self.assertNotEqual(result_ny, result_la)
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_nondeterministic_vectorized_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
Expand Down