Skip to content

Commit

Permalink
[BUG] Register super extension on to_arrow (#3030)
Browse files Browse the repository at this point in the history
This is an issue where Daft Extension types were not getting converted
to PyArrow properly. @jaychia discovered this while trying to write
parquet with a tensor column, where the Extension metadata for tensor
was getting dropped.

A simple test to reproduce the error: 
```
import daft
import numpy as np
from daft import Series

# Create sample tensor data with some null values
tensor_data = [np.array([[1, 2], [3, 4]]), None, None]

# Uncomment this and it will work
# from daft.datatype import _ensure_registered_super_ext_type 
# _ensure_registered_super_ext_type()

df_original = daft.from_pydict({"tensor_col": Series.from_pylist(tensor_data)})
print(df_original.to_arrow().schema)
```


Output:
```
tensor_col: struct<data: large_list<item: int64>, shape: large_list<item: uint64>>
  child 0, data: large_list<item: int64>
      child 0, item: int64
  child 1, shape: large_list<item: uint64>
      child 0, item: uint64
```
It's not a tensor type! However if you uncomment the
`_ensure_registered_super_ext_type()`, you will now see:
```
tensor_col: extension<daft.super_extension<DaftExtension>>
```


The issue here is that the `class DaftExtension(pa.ExtensionType):` is
not imported during the FFI, as it is now a lazy import that must be
called via `_ensure_registered_super_ext_type()`.

This PR adds calls to this import in `to_arrow` for series and schema.
However, I do not know if this is exhaustive, and I will give this more
thought. @desmondcheongzx @samster25

---------

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Oct 11, 2024
1 parent f4c3afa commit ab1b772
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
3 changes: 2 additions & 1 deletion daft/logical/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from daft.daft import read_csv_schema as _read_csv_schema
from daft.daft import read_json_schema as _read_json_schema
from daft.daft import read_parquet_schema as _read_parquet_schema
from daft.datatype import DataType, TimeUnit
from daft.datatype import DataType, TimeUnit, _ensure_registered_super_ext_type

if TYPE_CHECKING:
import pyarrow as pa
Expand Down Expand Up @@ -82,6 +82,7 @@ def to_pyarrow_schema(self) -> pa.Schema:
Returns:
pa.Schema: PyArrow schema that corresponds to the provided Daft schema
"""
_ensure_registered_super_ext_type()
return self._schema.to_pyarrow_schema()

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def to_arrow(self) -> pa.Array:
"""
Convert this Series to an pyarrow array.
"""
_ensure_registered_super_ext_type()

dtype = self.datatype()
arrow_arr = self._series.to_arrow()

Expand Down

0 comments on commit ab1b772

Please sign in to comment.