|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
18 | | -use std::sync::Arc; |
19 | | - |
20 | | -use arrow_schema::{DataType, IntervalUnit, TimeUnit}; |
21 | | - |
22 | 18 | use super::{ |
23 | | - LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, TypeSignature, |
| 19 | + LogicalField, LogicalFieldRef, LogicalFields, LogicalType, LogicalUnionFields, |
| 20 | + TypeSignature, |
24 | 21 | }; |
| 22 | +use crate::error::{Result, _internal_err}; |
| 23 | +use arrow_schema::{ |
| 24 | + DataType, Field, FieldRef, Fields, IntervalUnit, TimeUnit, UnionFields, |
| 25 | +}; |
| 26 | +use std::sync::Arc; |
25 | 27 |
|
26 | 28 | /// Representation of a type that DataFusion can handle natively. It is a subset |
27 | 29 | /// of the physical variants in Arrow's native [`DataType`]. |
@@ -188,6 +190,147 @@ impl LogicalType for NativeType { |
188 | 190 | fn signature(&self) -> TypeSignature<'_> { |
189 | 191 | TypeSignature::Native(self) |
190 | 192 | } |
| 193 | + |
| 194 | + fn default_cast_for(&self, origin: &DataType) -> Result<DataType> { |
| 195 | + use DataType::*; |
| 196 | + |
| 197 | + fn default_field_cast(to: &LogicalField, from: &Field) -> Result<FieldRef> { |
| 198 | + Ok(Arc::new(Field::new( |
| 199 | + to.name.clone(), |
| 200 | + to.logical_type.default_cast_for(from.data_type())?, |
| 201 | + to.nullable, |
| 202 | + ))) |
| 203 | + } |
| 204 | + |
| 205 | + Ok(match (self, origin) { |
| 206 | + (Self::Null, _) => Null, |
| 207 | + (Self::Boolean, _) => Boolean, |
| 208 | + (Self::Int8, _) => Int8, |
| 209 | + (Self::Int16, _) => Int16, |
| 210 | + (Self::Int32, _) => Int32, |
| 211 | + (Self::Int64, _) => Int64, |
| 212 | + (Self::UInt8, _) => UInt8, |
| 213 | + (Self::UInt16, _) => UInt16, |
| 214 | + (Self::UInt32, _) => UInt32, |
| 215 | + (Self::UInt64, _) => UInt64, |
| 216 | + (Self::Float16, _) => Float16, |
| 217 | + (Self::Float32, _) => Float32, |
| 218 | + (Self::Float64, _) => Float64, |
| 219 | + (Self::Decimal(p, s), _) if p <= &38 => Decimal128(*p, *s), |
| 220 | + (Self::Decimal(p, s), _) => Decimal256(*p, *s), |
| 221 | + (Self::Timestamp(tu, tz), _) => Timestamp(*tu, tz.clone()), |
| 222 | + (Self::Date, _) => Date32, |
| 223 | + (Self::Time(tu), _) => match tu { |
| 224 | + TimeUnit::Second | TimeUnit::Millisecond => Time32(*tu), |
| 225 | + TimeUnit::Microsecond | TimeUnit::Nanosecond => Time64(*tu), |
| 226 | + }, |
| 227 | + (Self::Duration(tu), _) => Duration(*tu), |
| 228 | + (Self::Interval(iu), _) => Interval(*iu), |
| 229 | + (Self::Binary, LargeUtf8) => LargeBinary, |
| 230 | + (Self::Binary, Utf8View) => BinaryView, |
| 231 | + (Self::Binary, _) => Binary, |
| 232 | + (Self::FixedSizeBinary(size), _) => FixedSizeBinary(*size), |
| 233 | + (Self::Utf8, LargeBinary) => LargeUtf8, |
| 234 | + (Self::Utf8, BinaryView) => Utf8View, |
| 235 | + (Self::Utf8, _) => Utf8, |
| 236 | + (Self::List(to_field), List(from_field) | FixedSizeList(from_field, _)) => { |
| 237 | + List(default_field_cast(to_field, from_field)?) |
| 238 | + } |
| 239 | + (Self::List(to_field), LargeList(from_field)) => { |
| 240 | + LargeList(default_field_cast(to_field, from_field)?) |
| 241 | + } |
| 242 | + (Self::List(to_field), ListView(from_field)) => { |
| 243 | + ListView(default_field_cast(to_field, from_field)?) |
| 244 | + } |
| 245 | + (Self::List(to_field), LargeListView(from_field)) => { |
| 246 | + LargeListView(default_field_cast(to_field, from_field)?) |
| 247 | + } |
| 248 | + // List array where each element is a len 1 list of the origin type |
| 249 | + (Self::List(field), _) => List(Arc::new(Field::new( |
| 250 | + field.name.clone(), |
| 251 | + field.logical_type.default_cast_for(origin)?, |
| 252 | + field.nullable, |
| 253 | + ))), |
| 254 | + ( |
| 255 | + Self::FixedSizeList(to_field, to_size), |
| 256 | + FixedSizeList(from_field, from_size), |
| 257 | + ) if from_size == to_size => { |
| 258 | + FixedSizeList(default_field_cast(to_field, from_field)?, *to_size) |
| 259 | + } |
| 260 | + ( |
| 261 | + Self::FixedSizeList(to_field, size), |
| 262 | + List(from_field) |
| 263 | + | LargeList(from_field) |
| 264 | + | ListView(from_field) |
| 265 | + | LargeListView(from_field), |
| 266 | + ) => FixedSizeList(default_field_cast(to_field, from_field)?, *size), |
| 267 | + // FixedSizeList array where each element is a len 1 list of the origin type |
| 268 | + (Self::FixedSizeList(field, size), _) => FixedSizeList( |
| 269 | + Arc::new(Field::new( |
| 270 | + field.name.clone(), |
| 271 | + field.logical_type.default_cast_for(origin)?, |
| 272 | + field.nullable, |
| 273 | + )), |
| 274 | + *size, |
| 275 | + ), |
| 276 | + // From https://github.com/apache/arrow-rs/blob/56525efbd5f37b89d1b56aa51709cab9f81bc89e/arrow-cast/src/cast/mod.rs#L189-L196 |
| 277 | + (Self::Struct(to_fields), Struct(from_fields)) |
| 278 | + if from_fields.len() == to_fields.len() => |
| 279 | + { |
| 280 | + Struct( |
| 281 | + from_fields |
| 282 | + .iter() |
| 283 | + .zip(to_fields.iter()) |
| 284 | + .map(|(from, to)| default_field_cast(to, from)) |
| 285 | + .collect::<Result<Fields>>()?, |
| 286 | + ) |
| 287 | + } |
| 288 | + (Self::Struct(to_fields), Null) => Struct( |
| 289 | + to_fields |
| 290 | + .iter() |
| 291 | + .map(|field| { |
| 292 | + Ok(Arc::new(Field::new( |
| 293 | + field.name.clone(), |
| 294 | + field.logical_type.default_cast_for(&Null)?, |
| 295 | + field.nullable, |
| 296 | + ))) |
| 297 | + }) |
| 298 | + .collect::<Result<Fields>>()?, |
| 299 | + ), |
| 300 | + (Self::Map(to_field), Map(from_field, sorted)) => { |
| 301 | + Map(default_field_cast(to_field, from_field)?, *sorted) |
| 302 | + } |
| 303 | + (Self::Map(field), Null) => Map( |
| 304 | + Arc::new(Field::new( |
| 305 | + field.name.clone(), |
| 306 | + field.logical_type.default_cast_for(&Null)?, |
| 307 | + field.nullable, |
| 308 | + )), |
| 309 | + false, |
| 310 | + ), |
| 311 | + (Self::Union(to_fields), Union(from_fields, mode)) |
| 312 | + if from_fields.len() == to_fields.len() => |
| 313 | + { |
| 314 | + Union( |
| 315 | + from_fields |
| 316 | + .iter() |
| 317 | + .zip(to_fields.iter()) |
| 318 | + .map(|((_, from), (i, to))| { |
| 319 | + Ok((*i, default_field_cast(to, from)?)) |
| 320 | + }) |
| 321 | + .collect::<Result<UnionFields>>()?, |
| 322 | + *mode, |
| 323 | + ) |
| 324 | + } |
| 325 | + _ => { |
| 326 | + return _internal_err!( |
| 327 | + "Unavailable default cast for native type {:?} from physical type {:?}", |
| 328 | + self, |
| 329 | + origin |
| 330 | + ) |
| 331 | + } |
| 332 | + }) |
| 333 | + } |
191 | 334 | } |
192 | 335 |
|
193 | 336 | // The following From<DataType>, From<Field>, ... implementations are temporary |
@@ -230,9 +373,9 @@ impl From<DataType> for NativeType { |
230 | 373 | DataType::Union(union_fields, _) => { |
231 | 374 | Union(LogicalUnionFields::from(&union_fields)) |
232 | 375 | } |
233 | | - DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), |
234 | 376 | DataType::Decimal128(p, s) | DataType::Decimal256(p, s) => Decimal(p, s), |
235 | 377 | DataType::Map(field, _) => Map(Arc::new(field.as_ref().into())), |
| 378 | + DataType::Dictionary(_, data_type) => data_type.as_ref().clone().into(), |
236 | 379 | DataType::RunEndEncoded(_, field) => field.data_type().clone().into(), |
237 | 380 | } |
238 | 381 | } |
|
0 commit comments