Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 50 additions & 80 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,17 +2461,13 @@ def test_join_without_on(self):
df1 = self.spark.range(1).toDF("a")
df2 = self.spark.range(1).toDF("b")

try:
self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())

self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the sql_conf context will change this back to be unset right?

Copy link
Copy Markdown
Member Author

@HyukjinKwon HyukjinKwon Mar 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, it originally unset spark.sql.crossJoin.enabled but now it set to the original value back.
If spark.sql.crossJoin.enabled is unset in this test, it will change this back to be like it's unset.

actual = df1.join(df2, how="inner").collect()
expected = [Row(a=0, b=0)]
self.assertEqual(actual, expected)
finally:
# We should unset this. Otherwise, other tests are affected.
self.spark.conf.unset("spark.sql.crossJoin.enabled")

# Regression test for invalid join methods when on is None, Spark-14761
def test_invalid_join_method(self):
Expand Down Expand Up @@ -2932,21 +2928,18 @@ def test_create_dateframe_from_pandas_with_dst(self):
self.assertPandasEqual(pdf, df.toPandas())

orig_env_tz = os.environ.get('TZ', None)
orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone')
try:
tz = 'America/Los_Angeles'
os.environ['TZ'] = tz
time.tzset()
self.spark.conf.set('spark.sql.session.timeZone', tz)

df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
with self.sql_conf({'spark.sql.session.timeZone': tz}):
df = self.spark.createDataFrame(pdf)
self.assertPandasEqual(pdf, df.toPandas())
finally:
del os.environ['TZ']
if orig_env_tz is not None:
os.environ['TZ'] = orig_env_tz
time.tzset()
self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz)


class HiveSparkSubmitTests(SparkSubmitTests):
Expand Down Expand Up @@ -3551,12 +3544,11 @@ def test_null_conversion(self):
self.assertTrue(all([c == 1 for c in null_counts]))

def _toPandas_arrow_toggle(self, df):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
pdf = df.toPandas()
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")

pdf_arrow = df.toPandas()

return pdf, pdf_arrow

def test_toPandas_arrow_toggle(self):
Expand All @@ -3568,16 +3560,17 @@ def test_toPandas_arrow_toggle(self):

def test_toPandas_respect_session_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")

timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_la, pdf_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf_arrow_ny, pdf_ny)

Expand All @@ -3590,8 +3583,6 @@ def test_toPandas_respect_session_timezone(self):
pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
pdf_la_corrected[field.name], timezone)
self.assertPandasEqual(pdf_ny, pdf_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_pandas_round_trip(self):
pdf = self.create_pandas_data_frame()
Expand All @@ -3607,12 +3598,11 @@ def test_filtered_frame(self):
self.assertTrue(pdf.empty)

def _createDataFrame_toggle(self, pdf, schema=None):
self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
try:
with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
finally:
self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")

df_arrow = self.spark.createDataFrame(pdf, schema=schema)

return df_no_arrow, df_arrow

def test_createDataFrame_toggle(self):
Expand All @@ -3623,18 +3613,18 @@ def test_createDataFrame_toggle(self):
def test_createDataFrame_respect_session_timezone(self):
from datetime import timedelta
pdf = self.create_pandas_data_frame()
orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
result_la = df_no_arrow_la.collect()
result_arrow_la = df_arrow_la.collect()
self.assertEqual(result_la, result_arrow_la)
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
result_la = df_no_arrow_la.collect()
result_arrow_la = df_arrow_la.collect()
self.assertEqual(result_la, result_arrow_la)

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
result_ny = df_no_arrow_ny.collect()
result_arrow_ny = df_arrow_ny.collect()
Expand All @@ -3647,8 +3637,6 @@ def test_createDataFrame_respect_session_timezone(self):
for k, v in row.asDict().items()})
for row in result_la]
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_createDataFrame_with_schema(self):
pdf = self.create_pandas_data_frame()
Expand Down Expand Up @@ -4325,9 +4313,7 @@ def gen_timestamps(id):
def test_vectorized_udf_check_config(self):
from pyspark.sql.functions import pandas_udf, col
import pandas as pd
orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch", None)
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
try:
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
df = self.spark.range(10, numPartitions=1)

@pandas_udf(returnType=LongType())
Expand All @@ -4337,11 +4323,6 @@ def check_records_per_batch(x):
result = df.select(check_records_per_batch(col("id"))).collect()
for (r,) in result:
self.assertTrue(r <= 3)
finally:
if orig_value is None:
self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
else:
self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)

def test_vectorized_udf_timestamps_respect_session_timezone(self):
from pyspark.sql.functions import pandas_udf, col
Expand All @@ -4360,30 +4341,27 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
internal_value = pandas_udf(
lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())

orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
try:
timezone = "America/New_York"
self.spark.conf.set("spark.sql.session.timeZone", timezone)
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
try:
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
finally:
self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true")
timezone = "America/New_York"
with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": False,
"spark.sql.session.timeZone": timezone}):
df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_la = df_la.select(col("idx"), col("internal_value")).collect()
# Correct result_la by adjusting 3 hours difference between Los Angeles and New York
diff = 3 * 60 * 60 * 1000 * 1000 * 1000
result_la_corrected = \
df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()

with self.sql_conf({
"spark.sql.execution.pandas.respectSessionTimeZone": True,
"spark.sql.session.timeZone": timezone}):
df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
.withColumn("internal_value", internal_value(col("timestamp")))
result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()

self.assertNotEqual(result_ny, result_la)
self.assertEqual(result_ny, result_la_corrected)
finally:
self.spark.conf.set("spark.sql.session.timeZone", orig_tz)

def test_nondeterministic_vectorized_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
Expand Down Expand Up @@ -5159,22 +5137,14 @@ def test_complex_expressions(self):

def test_retain_group_columns(self):
from pyspark.sql.functions import sum, lit, col
orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None)
self.spark.conf.set("spark.sql.retainGroupColumns", False)
try:
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
df = self.data
sum_udf = self.pandas_agg_sum_udf

result1 = df.groupby(df.id).agg(sum_udf(df.v))
expected1 = df.groupby(df.id).agg(sum(df.v))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())

finally:
if orig_value is None:
self.spark.conf.unset("spark.sql.retainGroupColumns")
else:
self.spark.conf.set("spark.sql.retainGroupColumns", orig_value)

def test_invalid_args(self):
from pyspark.sql.functions import mean

Expand Down