Skip to content

Commit

Permalink
fix some datatype bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasmine-ge committed Nov 21, 2024
1 parent bf16bcb commit 4405501
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 36 deletions.
44 changes: 20 additions & 24 deletions ibis/backends/sql/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ibis.expr.datatypes as dt
from ibis.common.collections import FrozenDict
from ibis.formats import TypeMapper
from ibis.util import get_subclasses

typecode = sge.DataType.Type

Expand Down Expand Up @@ -172,6 +173,7 @@ def to_ibis(cls, typ: sge.DataType, nullable: bool | None = None) -> dt.DataType
nullable = typ.args.get(
"nullable", nullable if nullable is not None else cls.default_nullable
)

if method := getattr(cls, f"_from_sqlglot_{typecode.name}", None):
dtype = method(*typ.expressions, nullable=nullable)
elif (known_typ := _from_sqlglot_types.get(typecode)) is not None:
Expand Down Expand Up @@ -1242,41 +1244,26 @@ class DatabricksType(SqlglotType):
dialect = "databricks"


class DatabricksType(SqlglotType):
dialect = "databricks"


TYPE_MAPPERS = {
mapper.dialect: mapper
for mapper in set(get_subclasses(SqlglotType)) - {SqlglotType, BigQueryUDFType}
}


class TimeplusType(SqlglotType):
dialect = "timeplus"
default_decimal_precision = None
default_decimal_scale = None

@classmethod
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
typ = super().from_ibis(dtype)

if typ.this == typecode.NULLABLE:
return typ

if dtype.nullable:
# If the type is already nullable, no need to wrap it again
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
else:
typ.args["nullable"] = False
return typ

@classmethod
def _from_sqlglot_NULLABLE(cls, inner_type: sge.DataType) -> dt.DataType:
def _from_sqlglot_NULLABLE(
cls, inner_type: sge.DataType, nullable: bool | None = None
) -> dt.DataType:
return cls.to_ibis(inner_type, nullable=True)

@classmethod
def _from_sqlglot_DATETIME(
cls, timezone: sge.DataTypeParam | None = None
cls, timezone: sge.DataTypeParam | None = None, nullable: bool | None = None
) -> dt.Timestamp:
return dt.Timestamp(
scale=0,
Expand All @@ -1289,6 +1276,7 @@ def _from_sqlglot_DATETIME64(
cls,
scale: sge.DataTypeSize | None = None,
timezone: sge.Literal | None = None,
nullable: bool | None = None,
) -> dt.Timestamp:
return dt.Timestamp(
timezone=None if timezone is None else timezone.this.this,
Expand All @@ -1297,11 +1285,15 @@ def _from_sqlglot_DATETIME64(
)

@classmethod
def _from_sqlglot_LOWCARDINALITY(cls, inner_type: sge.DataType) -> dt.DataType:
def _from_sqlglot_LOWCARDINALITY(
cls, inner_type: sge.DataType, nullable: bool | None = None
) -> dt.DataType:
return cls.to_ibis(inner_type)

@classmethod
def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct:
def _from_sqlglot_NESTED(
cls, *fields: sge.DataType, nullable: bool | None = None
) -> dt.Struct:
fields = {
field.name: dt.Array(
cls.to_ibis(field.args["kind"]), nullable=cls.default_nullable
Expand All @@ -1311,7 +1303,9 @@ def _from_sqlglot_NESTED(cls, *fields: sge.DataType) -> dt.Struct:
return dt.Struct(fields, nullable=cls.default_nullable)

@classmethod
def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
def _from_ibis_Timestamp(
cls, dtype: dt.Timestamp, nullable: bool | None = None
) -> sge.DataType:
if dtype.timezone is None:
timezone = None
else:
Expand All @@ -1324,7 +1318,9 @@ def _from_ibis_Timestamp(cls, dtype: dt.Timestamp) -> sge.DataType:
return sge.DataType(this=typecode.DATETIME64, expressions=[scale, timezone])

@classmethod
def _from_ibis_Map(cls, dtype: dt.Map) -> sge.DataType:
def _from_ibis_Map(
cls, dtype: dt.Map, nullable: bool | None = None
) -> sge.DataType:
key_type = cls.from_ibis(dtype.key_type.copy(nullable=False))
value_type = cls.from_ibis(dtype.value_type)
return sge.DataType(
Expand Down
5 changes: 2 additions & 3 deletions ibis/backends/sql/dialects.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,16 +491,15 @@ class Generator(Postgres.Generator):
)
}


class Timeplus(ClickHouse):
"""Subclass of ClickHouse dialect for Timeplus.
This is here to allow referring to the Clickhouse dialect as "Timeplus"
"""

class Generator(ClickHouse.Generator):
TYPE_MAPPING = ClickHouse.Generator.TYPE_MAPPING.copy() | {
sge.DataType.Type.NULLABLE: "nullable",
}
TYPE_MAPPING = ClickHouse.Generator.TYPE_MAPPING.copy() | {}
STRING_TYPE_MAPPING = ClickHouse.Generator.STRING_TYPE_MAPPING.copy() | {
sge.DataType.Type.CHAR: "string",
sge.DataType.Type.LONGBLOB: "string",
Expand Down
12 changes: 9 additions & 3 deletions ibis/backends/timeplus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def get_schema(
query = query.sql(dialect=self.dialect, pretty=True)
results = self.con.execute(query)
names, types = zip(*[(name, type_) for name, type_, *_ in results])

return sch.Schema(
dict(zip(names, map(self.compiler.type_mapper.from_string, types)))
)
Expand Down Expand Up @@ -397,7 +398,7 @@ def create_table(
Create a temporary table. This is not yet supported, and exists for
API compatibility.
overwrite
Whether to overwrite the table
Whether to overwrite the table. This is not yet supported.
engine
The engine to use.
order_by
Expand Down Expand Up @@ -711,10 +712,12 @@ def insert(
name: str,
obj: pd.DataFrame | ir.Table,
database: str | None = None,
overwrite: bool = False,
overwrite: bool | None = None,
**kwargs: Any,
):
# ir.Table, pa.Table, dict, pd.DataFrame
if overwrite is not None:
raise com.IbisError("`overwrite` namespaces are not supported by timeplus")
# ir.Table, pa.Table, dict, pd.DataFrame
if isinstance(obj, ir.Table):
statement = InsertSelect(
name,
Expand Down Expand Up @@ -837,3 +840,6 @@ def batcher(
return pa.ipc.RecordBatchReader.from_batches(
schema, batcher(sql, schema=schema, settings=settings, **kwargs)
)

def _register_in_memory_table(self, op: ops.InMemoryTable) -> None:
"""No-op."""
7 changes: 1 addition & 6 deletions ibis/backends/timeplus/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,6 @@ def test_create_and_drop_table(con, temp_table):
),
],
)
def test_get_schema_using_query(con, query, expected_schema):
result = con._get_schema_using_query(query)
assert result == expected_schema


def test_list_tables_database(con):
tables = con.list_tables()
tables2 = con.list_tables(database=con.current_database)
Expand Down Expand Up @@ -174,7 +169,7 @@ def test_insert_no_overwrite_from_dataframe(
con, test_employee_data_2, employee_empty_temp_table
):
temporary = con.table(employee_empty_temp_table)
con.insert(employee_empty_temp_table, obj=test_employee_data_2, overwrite=False)
con.insert(employee_empty_temp_table, obj=test_employee_data_2)
result = temporary.execute(settings={"query_mode": "table"})
assert len(result) == 3
assert (
Expand Down

0 comments on commit 4405501

Please sign in to comment.