Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Mar 12, 2021
1 parent a026365 commit f984efd
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions databricks/koalas/tests/test_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

import numpy as np
import pandas as pd
import pyspark

from databricks import koalas as ks
from databricks.koalas.testing.utils import ReusedSQLTestCase
from databricks.koalas.testing.utils import ReusedSQLTestCase, SPARK_CONF_ARROW_ENABLED
from databricks.koalas.utils import name_like_string


Expand Down Expand Up @@ -109,23 +110,41 @@ def test_get_dummies_date_datetime(self):
)
kdf = ks.from_pandas(pdf)

self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8))
if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8))
else:
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.dt), pd.get_dummies(pdf.dt, dtype=np.int8))

def test_get_dummies_boolean(self):
pdf = pd.DataFrame({"b": [True, False, True]})
kdf = ks.from_pandas(pdf)

self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8))
else:
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.b), pd.get_dummies(pdf.b, dtype=np.int8))

def test_get_dummies_decimal(self):
pdf = pd.DataFrame({"d": [Decimal(1.0), Decimal(2.0), Decimal(1)]})
kdf = ks.from_pandas(pdf)

self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True)
if LooseVersion(pyspark.__version__) >= LooseVersion("2.4"):
self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(ks.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True)
else:
with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}):
self.assert_eq(ks.get_dummies(kdf), pd.get_dummies(pdf, dtype=np.int8))
self.assert_eq(
ks.get_dummies(kdf.d), pd.get_dummies(pdf.d, dtype=np.int8), almost=True
)

def test_get_dummies_kwargs(self):
# pser = pd.Series([1, 1, 1, 2, 2, 1, 3, 4], dtype='category')
Expand Down

0 comments on commit f984efd

Please sign in to comment.