Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support to set attributes for DataFrame. #1989

Merged
merged 1 commit into from
Dec 29, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion databricks/koalas/extensions.py
Original file line number Diff line number Diff line change
@@ -49,7 +49,7 @@ def __get__(self, obj, cls):
if obj is None:
return self._accessor
accessor_obj = self._accessor(obj)
setattr(obj, self._name, accessor_obj)
object.__setattr__(obj, self._name, accessor_obj)
return accessor_obj


72 changes: 46 additions & 26 deletions databricks/koalas/frame.py
Original file line number Diff line number Diff line change
@@ -488,17 +488,19 @@ def __init__(self, data=None, index=None, columns=None, dtype=None, copy=False):
pdf = pd.DataFrame(data=data, index=index, columns=columns, dtype=dtype, copy=copy)
internal = InternalFrame.from_pandas(pdf)

self._internal_frame = internal
object.__setattr__(self, "_internal_frame", internal)

@property
def _ksers(self):
""" Return a dict of column label -> Series which anchors `self`. """
from databricks.koalas.series import Series

if not hasattr(self, "_kseries"):
self._kseries = {
label: Series(data=self, index=label) for label in self._internal.column_labels
}
object.__setattr__(
self,
"_kseries",
{label: Series(data=self, index=label) for label in self._internal.column_labels},
)
else:
kseries = self._kseries
assert len(self._internal.column_labels) == len(kseries), (
@@ -535,30 +537,32 @@ def _update_internal_frame(self, internal: InternalFrame, requires_same_anchor:
"""
from databricks.koalas.series import Series

kseries = {}
if hasattr(self, "_kseries"):
kseries = {}

for old_label, new_label in zip_longest(
self._internal.column_labels, internal.column_labels
):
if old_label is not None:
kser = self._ksers[old_label]
for old_label, new_label in zip_longest(
self._internal.column_labels, internal.column_labels
):
if old_label is not None:
kser = self._ksers[old_label]

renamed = old_label != new_label
not_same_anchor = requires_same_anchor and not same_anchor(internal, kser)
renamed = old_label != new_label
not_same_anchor = requires_same_anchor and not same_anchor(internal, kser)

if renamed or not_same_anchor:
kdf = DataFrame(self._internal.select_column(old_label)) # type: DataFrame
kser._update_anchor(kdf)
if renamed or not_same_anchor:
kdf = DataFrame(self._internal.select_column(old_label)) # type: DataFrame
kser._update_anchor(kdf)
kser = None
else:
kser = None
else:
kser = None
if new_label is not None:
if kser is None:
kser = Series(data=self, index=new_label)
kseries[new_label] = kser
if new_label is not None:
if kser is None:
kser = Series(data=self, index=new_label)
kseries[new_label] = kser

self._kseries = kseries

self._internal_frame = internal
self._kseries = kseries

if hasattr(self, "_repr_pandas_cache"):
del self._repr_pandas_cache
@@ -10101,7 +10105,7 @@ def info(self, verbose=None, buf=None, max_cols=None, null_counts=None) -> None:
):
try:
# hack to use pandas' info as is.
self._data = self
object.__setattr__(self, "_data", self)
count_func = self.count
self.count = lambda: count_func().to_pandas() # type: ignore
return pd.DataFrame.info(
@@ -10903,7 +10907,9 @@ def _to_internal_pandas(self):

def _get_or_create_repr_pandas_cache(self, n):
if not hasattr(self, "_repr_pandas_cache") or n not in self._repr_pandas_cache:
self._repr_pandas_cache = {n: self.head(n + 1)._to_internal_pandas()}
object.__setattr__(
self, "_repr_pandas_cache", {n: self.head(n + 1)._to_internal_pandas()}
)
return self._repr_pandas_cache[n]

def __repr__(self):
@@ -11081,6 +11087,20 @@ def __getattr__(self, key: str) -> Any:
"'%s' object has no attribute '%s'" % (self.__class__.__name__, key)
)

def __setattr__(self, key: str, value) -> None:
try:
object.__getattribute__(self, key)
return object.__setattr__(self, key, value)
except AttributeError:
pass

if (key,) in self._internal.column_labels:
self[key] = value
else:
warnings.warn(
"Koalas doesn't allow columns to be created via a new attribute name", UserWarning
)

def __len__(self):
return self._internal.resolved_copy.spark_frame.count()

@@ -11170,9 +11190,9 @@ class CachedDataFrame(DataFrame):

def __init__(self, internal, storage_level=None):
if storage_level is None:
self._cached = internal.spark_frame.cache()
object.__setattr__(self, "_cached", internal.spark_frame.cache())
elif isinstance(storage_level, StorageLevel):
self._cached = internal.spark_frame.persist(storage_level)
object.__setattr__(self, "_cached", internal.spark_frame.persist(storage_level))
else:
raise TypeError(
"Only a valid pyspark.StorageLevel type is acceptable for the `storage_level`"
5 changes: 5 additions & 0 deletions databricks/koalas/tests/test_dataframe.py
Original file line number Diff line number Diff line change
@@ -452,6 +452,11 @@ def test_assign(self):

self.assert_eq(kdf, pdf)

kdf.w = 10.0
pdf.w = 10.0

self.assert_eq(kdf, pdf)

kdf[1] = 1.0
pdf[1] = 1.0