Skip to content

Commit bdd0155

Browse files
committed
Add default_cast_for
1 parent 7ed7891 commit bdd0155

File tree

2 files changed

+162
-9
lines changed

2 files changed

+162
-9
lines changed

datafusion/common/src/types/logical.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use super::NativeType;
19+
use crate::error::Result;
20+
use arrow_schema::DataType;
1821
use core::fmt;
1922
use std::{cmp::Ordering, hash::Hash, sync::Arc};
2023

21-
use super::NativeType;
22-
2324
/// Signature that uniquely identifies a type among other types.
2425
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
2526
pub enum TypeSignature<'a> {
@@ -75,8 +76,17 @@ pub type LogicalTypeRef = Arc<dyn LogicalType>;
7576
/// }
7677
/// ```
7778
pub trait LogicalType: Sync + Send {
79+
/// Get the native backing type of this logical type.
7880
fn native(&self) -> &NativeType;
81+
/// Get the unique type signature for this logical type. Logical types with identical
82+
/// signatures are considered equal.
7983
fn signature(&self) -> TypeSignature<'_>;
84+
85+
/// Get the default physical type to cast `origin` to in order to obtain a physical type
86+
/// that is logically compatible with this logical type.
87+
fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
88+
self.native().default_cast_for(origin)
89+
}
8090
}
8191

8292
impl fmt::Debug for dyn LogicalType {
@@ -90,7 +100,7 @@ impl fmt::Debug for dyn LogicalType {
90100

91101
impl PartialEq for dyn LogicalType {
92102
fn eq(&self, other: &Self) -> bool {
93-
self.native().eq(other.native()) && self.signature().eq(&other.signature())
103+
self.signature().eq(&other.signature())
94104
}
95105
}
96106

datafusion/common/src/types/native.rs

Lines changed: 149 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
19-
20-
use arrow_schema::{DataType, IntervalUnit, TimeUnit};
21-
2218
use super::{
23-
LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature,
19+
LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields,
20+
TypeSignature,
2421
};
22+
use crate::error::{Result, _internal_err};
23+
use arrow_schema::{
24+
DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields,
25+
};
26+
use std::sync::Arc;
2527

2628
/// Representation of a type that DataFusion can handle natively. It is a subset
2729
/// of the physical variants in Arrow's native [`DataType`].
@@ -188,6 +190,147 @@ impl LogicalType for NativeType {
188190
fn signature(&self) -> TypeSignature<'_> {
189191
TypeSignature::Native(self)
190192
}
193+
194+
fn default_cast_for(&self, origin: &DataType) -> Result<DataType> {
195+
use DataType::*;
196+
197+
fn default_field_cast(to: &LogicalField, from: &Field) -> Result<FieldRef> {
198+
Ok(Arc::new(Field::new(
199+
to.name.clone(),
200+
to.logical_type.default_cast_for(from.data_type())?,
201+
to.nullable,
202+
)))
203+
}
204+
205+
Ok(match (self, origin) {
206+
(Self::Null, _) => Null,
207+
(Self::Boolean, _) => Boolean,
208+
(Self::Int8, _) => Int8,
209+
(Self::Int16, _) => Int16,
210+
(Self::Int32, _) => Int32,
211+
(Self::Int64, _) => Int64,
212+
(Self::UInt8, _) => UInt8,
213+
(Self::UInt16, _) => UInt16,
214+
(Self::UInt32, _) => UInt32,
215+
(Self::UInt64, _) => UInt64,
216+
(Self::Float16, _) => Float16,
217+
(Self::Float32, _) => Float32,
218+
(Self::Float64, _) => Float64,
219+
(Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s),
220+
(Self::Decimal(p, s), _) => Decimal256(*p, *s),
221+
(Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()),
222+
(Self::Date, _) => Date32,
223+
(Self::Time(tu), _) => match tu {
224+
TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu),
225+
TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(*tu),
226+
},
227+
(Self::Duration(tu), _) => Duration(*tu),
228+
(Self::Interval(iu), _) => Interval(*iu),
229+
(Self::Binary, LargeUtf8) => LargeBinary,
230+
(Self::Binary, Utf8View) => BinaryView,
231+
(Self::Binary, _) => Binary,
232+
(Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size),
233+
(Self::Utf8, LargeBinary) => LargeUtf8,
234+
(Self::Utf8, BinaryView) => Utf8View,
235+
(Self::Utf8, _) => Utf8,
236+
(Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => {
237+
List(default_field_cast(to_field, from_field)?)
238+
}
239+
(Self::List(to_field), LargeList(from_field)) => {
240+
LargeList(default_field_cast(to_field, from_field)?)
241+
}
242+
(Self::List(to_field), ListView(from_field)) => {
243+
ListView(default_field_cast(to_field, from_field)?)
244+
}
245+
(Self::List(to_field), LargeListView(from_field)) => {
246+
LargeListView(default_field_cast(to_field, from_field)?)
247+
}
248+
// List array where each element is a len 1 list of the origin type
249+
(Self::List(field), _) => List(Arc::new(Field::new(
250+
field.name.clone(),
251+
field.logical_type.default_cast_for(origin)?,
252+
field.nullable,
253+
))),
254+
(
255+
Self::FixedSizeList(to_field, to_size),
256+
FixedSizeList(from_field, from_size),
257+
) if from_size == to_size => {
258+
FixedSizeList(default_field_cast(to_field, from_field)?, *to_size)
259+
}
260+
(
261+
Self::FixedSizeList(to_field, size),
262+
List(from_field)
263+
| LargeList(from_field)
264+
| ListView(from_field)
265+
| LargeListView(from_field),
266+
) => FixedSizeList(default_field_cast(to_field, from_field)?, *size),
267+
// FixedSizeList array where each element is a len 1 list of the origin type
268+
(Self::FixedSizeList(field, size), _) => FixedSizeList(
269+
Arc::new(Field::new(
270+
field.name.clone(),
271+
field.logical_type.default_cast_for(origin)?,
272+
field.nullable,
273+
)),
274+
*size,
275+
),
276+
// From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196
277+
(Self::Struct(to_fields), Struct(from_fields))
278+
if from_fields.len() == to_fields.len() =>
279+
{
280+
Struct(
281+
from_fields
282+
.iter()
283+
.zip(to_fields.iter())
284+
.map(|(from, to)| default_field_cast(to, from))
285+
.collect::<Result<Fields>>()?,
286+
)
287+
}
288+
(Self::Struct(to_fields), Null) => Struct(
289+
to_fields
290+
.iter()
291+
.map(|field| {
292+
Ok(Arc::new(Field::new(
293+
field.name.clone(),
294+
field.logical_type.default_cast_for(&Null)?,
295+
field.nullable,
296+
)))
297+
})
298+
.collect::<Result<Fields>>()?,
299+
),
300+
(Self::Map(to_field), Map(from_field, sorted)) => {
301+
Map(default_field_cast(to_field, from_field)?, *sorted)
302+
}
303+
(Self::Map(field), Null) => Map(
304+
Arc::new(Field::new(
305+
field.name.clone(),
306+
field.logical_type.default_cast_for(&Null)?,
307+
field.nullable,
308+
)),
309+
false,
310+
),
311+
(Self::Union(to_fields), Union(from_fields, mode))
312+
if from_fields.len() == to_fields.len() =>
313+
{
314+
Union(
315+
from_fields
316+
.iter()
317+
.zip(to_fields.iter())
318+
.map(|((_, from), (i, to))| {
319+
Ok((*i, default_field_cast(to, from)?))
320+
})
321+
.collect::<Result<UnionFields>>()?,
322+
*mode,
323+
)
324+
}
325+
_ => {
326+
return _internal_err!(
327+
"Unavailable default cast for native type {:?} from physical type {:?}",
328+
self,
329+
origin
330+
)
331+
}
332+
})
333+
}
191334
}
192335

193336
// The following From<DataType>, From<Field>, ... implementations are temporary
@@ -230,9 +373,9 @@ impl From<DataType> for NativeType {
230373
DataType::Union(union_fields, _) => {
231374
Union(LogicalUnionFields::from(&union_fields))
232375
}
233-
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
234376
DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s),
235377
DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())),
378+
DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(),
236379
DataType::RunEndEncoded(_, field) => field.data_type().clone().into(),
237380
}
238381
}

0 commit comments

Comments
 (0)