Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions datafusion/common/src/types/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use super::NativeType;
use crate::error::Result;
use arrow_schema::DataType;
use core::fmt;
use std::{cmp::Ordering, hash::Hash, sync::Arc};

use super::NativeType;

/// Signature that uniquely identifies a type among other types.
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum TypeSignature<'a> {
Expand Down Expand Up @@ -75,8 +76,17 @@ pub type LogicalTypeRef = Arc<dyn LogicalType>;
/// }
/// ```
pub trait LogicalType: Sync + Send {
/// Get the native backing type of this logical type.
fn native(&self) -> &NativeType;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I previously propose can_decode_to(DataType) -> bool, so given logical type and DataType, we can know whether they are paired.

How can we do the equivalent check by the current design?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given say arrow Int64 data, i want to know whether these is numbers, timestamp, time, date or something else (eg user-defined enum). The fact that any of these hypothetical logical types could be stored as Int64 doesn't help me know. Asking logical type "could you please decode this arrow type?" doesn't help me know.
Thus, going from arrow type to logical type is not an option. We simply need to know what logical type this should be.

Copy link
Contributor

@jayzhan211 jayzhan211 Oct 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea is that we have LogicalType already. In logical level, they are either LogicalNumber, LogicalTimestamp or LogicalDate, and we can differ them in logical level. They can also decode as i64, i32 in physical level. So asking logical type "could you please decode this arrow type?" is to tell the relationship between logical type and physical type. We don't need to know whether the arrow i64 is number or timestamp, because we already know that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I can follow. @jayzhan211 -- can you write a small practical example? I want to make sure I understand the use case. Thanks :)

Copy link
Contributor

@jayzhan211 jayzhan211 Oct 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

impl From<DataType> for NativeType is enough for native type since we can know whether the ArrayRef matches the LogicalType we have. But for LogicalType::UserDefined, I think we need to define what kind of DataType it could be decoded to.

We can figure this out if we meet any practical usage.

Copy link
Contributor Author

@notfilippo notfilippo Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For any user defined logical type you still know the backing native type (via the native() method), so you should be able to use the same logic to know if your DataType can represent that logical type.

/// Get the unique type signature for this logical type. Logical types with identical
/// signatures are considered equal.
fn signature(&self) -> TypeSignature<'_>;

/// Get the default physical type to cast `origin` to in order to obtain a physical type
/// that is logically compatible with this logical type.
fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
self.native().default_cast_for(origin)
}
}

impl fmt::Debug for dyn LogicalType {
Expand All @@ -90,7 +100,7 @@ impl fmt::Debug for dyn LogicalType {

impl PartialEq for dyn LogicalType {
fn eq(&self, other: &Self) -> bool {
self.native().eq(other.native()) && self.signature().eq(&other.signature())
self.signature().eq(&other.signature())
}
}

Expand Down
155 changes: 149 additions & 6 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;

use arrow_schema::{DataType, IntervalUnit, TimeUnit};

use super::{
LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature,
LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields,
TypeSignature,
};
use crate::error::{Result, _internal_err};
use arrow_schema::{
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
};
use std::sync::Arc;

/// Representation of a type that DataFusion can handle natively. It is a subset
/// of the physical variants in Arrow's native [`DataType`].
Expand Down Expand Up @@ -188,6 +190,147 @@ impl LogicalType for NativeType {
fn signature(&self) -> TypeSignature<'_> {
TypeSignature::Native(self)
}

fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
use DataType::*;

fn default_field_cast(to: &LogicalField, from: &Field) -> Result<FieldRef> {
Ok(Arc::new(Field::new(
to.name.clone(),
to.logical_type.default_cast_for(from.data_type())?,
to.nullable,
)))
}

Ok(match (self, origin) {
(Self::Null, _) => Null,
(Self::Boolean, _) => Boolean,
(Self::Int8, _) => Int8,
(Self::Int16, _) => Int16,
(Self::Int32, _) => Int32,
(Self::Int64, _) => Int64,
(Self::UInt8, _) => UInt8,
(Self::UInt16, _) => UInt16,
(Self::UInt32, _) => UInt32,
(Self::UInt64, _) => UInt64,
(Self::Float16, _) => Float16,
(Self::Float32, _) => Float32,
(Self::Float64, _) => Float64,
(Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s),
(Self::Decimal(p, s), _) => Decimal256(*p, *s),
(Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()),
(Self::Date, _) => Date32,
(Self::Time(tu), _) => match tu {
TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu),
TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(*tu),
},
(Self::Duration(tu), _) => Duration(*tu),
(Self::Interval(iu), _) => Interval(*iu),
(Self::Binary, LargeUtf8) => LargeBinary,
(Self::Binary, Utf8View) => BinaryView,
(Self::Binary, _) => Binary,
(Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size),
(Self::Utf8, LargeBinary) => LargeUtf8,
(Self::Utf8, BinaryView) => Utf8View,
(Self::Utf8, _) => Utf8,
(Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => {
List(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), LargeList(from_field)) => {
LargeList(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), ListView(from_field)) => {
ListView(default_field_cast(to_field, from_field)?)
}
(Self::List(to_field), LargeListView(from_field)) => {
LargeListView(default_field_cast(to_field, from_field)?)
}
// List array where each element is a len 1 list of the origin type
(Self::List(field), _) => List(Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(origin)?,
field.nullable,
))),
(
Self::FixedSizeList(to_field, to_size),
FixedSizeList(from_field, from_size),
) if from_size == to_size => {
FixedSizeList(default_field_cast(to_field, from_field)?, *to_size)
}
(
Self::FixedSizeList(to_field, size),
List(from_field)
| LargeList(from_field)
| ListView(from_field)
| LargeListView(from_field),
) => FixedSizeList(default_field_cast(to_field, from_field)?, *size),
// FixedSizeList array where each element is a len 1 list of the origin type
(Self::FixedSizeList(field, size), _) => FixedSizeList(
Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(origin)?,
field.nullable,
)),
*size,
),
// From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196
(Self::Struct(to_fields), Struct(from_fields))
if from_fields.len() == to_fields.len() =>
{
Struct(
from_fields
.iter()
.zip(to_fields.iter())
.map(|(from, to)| default_field_cast(to, from))
.collect::<Result<Fields>>()?,
)
}
(Self::Struct(to_fields), Null) => Struct(
to_fields
.iter()
.map(|field| {
Ok(Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(&Null)?,
field.nullable,
)))
})
.collect::<Result<Fields>>()?,
),
(Self::Map(to_field), Map(from_field, sorted)) => {
Map(default_field_cast(to_field, from_field)?, *sorted)
}
(Self::Map(field), Null) => Map(
Arc::new(Field::new(
field.name.clone(),
field.logical_type.default_cast_for(&Null)?,
field.nullable,
)),
false,
),
(Self::Union(to_fields), Union(from_fields, mode))
if from_fields.len() == to_fields.len() =>
{
Union(
from_fields
.iter()
.zip(to_fields.iter())
.map(|((_, from), (i, to))| {
Ok((*i, default_field_cast(to, from)?))
})
.collect::<Result<UnionFields>>()?,
*mode,
)
}
_ => {
return _internal_err!(
"Unavailable default cast for native type {:?} from physical type {:?}",
self,
origin
)
}
})
}
}

// The following From<DataType>, From<Field>, ... implementations are temporary
Expand Down Expand Up @@ -230,9 +373,9 @@ impl From<DataType> for NativeType {
DataType::Union(union_fields, _) => {
Union(LogicalUnionFields::from(&union_fields))
}
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s),
DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())),
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
DataType::RunEndEncoded(_, field) => field.data_type().clone().into(),
}
}
Expand Down