diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index 0b1e09e70f2..42ed178b012 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -138,4 +138,7 @@ std::shared_ptr GetExtensionType(const std::string& type_name) { return registry->GetType(type_name); } +extern const char kExtensionTypeKeyName[] = "ARROW:extension:name"; +extern const char kExtensionMetadataKeyName[] = "ARROW:extension:metadata"; + } // namespace arrow diff --git a/cpp/src/arrow/extension_type.h b/cpp/src/arrow/extension_type.h index 32ab7c6cc45..560121eaa1f 100644 --- a/cpp/src/arrow/extension_type.h +++ b/cpp/src/arrow/extension_type.h @@ -142,4 +142,7 @@ Status UnregisterExtensionType(const std::string& type_name); ARROW_EXPORT std::shared_ptr GetExtensionType(const std::string& type_name); +ARROW_EXPORT extern const char kExtensionTypeKeyName[]; +ARROW_EXPORT extern const char kExtensionMetadataKeyName[]; + } // namespace arrow diff --git a/cpp/src/arrow/ipc/metadata_internal.cc b/cpp/src/arrow/ipc/metadata_internal.cc index dff3369a27f..d4ed8b7e0da 100644 --- a/cpp/src/arrow/ipc/metadata_internal.cc +++ b/cpp/src/arrow/ipc/metadata_internal.cc @@ -62,9 +62,6 @@ using Offset = flatbuffers::Offset; using FBString = flatbuffers::Offset; using KVVector = flatbuffers::Vector; -static const char kExtensionTypeKeyName[] = "ARROW:extension:name"; -static const char kExtensionMetadataKeyName[] = "ARROW:extension:metadata"; - MetadataVersion GetMetadataVersion(flatbuf::MetadataVersion version) { switch (version) { case flatbuf::MetadataVersion_V1: diff --git a/cpp/src/parquet/arrow/reader_internal.cc b/cpp/src/parquet/arrow/reader_internal.cc index b234291eb84..f8307c54d0d 100644 --- a/cpp/src/parquet/arrow/reader_internal.cc +++ b/cpp/src/parquet/arrow/reader_internal.cc @@ -32,6 +32,7 @@ #include "arrow/array.h" #include "arrow/builder.h" #include "arrow/compute/kernel.h" +#include "arrow/extension_type.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/status.h" @@ -620,6 +621,27 @@ Status ApplyOriginalMetadata(std::shared_ptr field, const Field& origin_f field = field->WithType( ::arrow::dictionary(::arrow::int32(), field->type(), dict_origin_type.ordered())); } + // restore field metadata + std::shared_ptr field_metadata = origin_field.metadata(); + if (field_metadata != nullptr) { + field = field->WithMetadata(field_metadata); + + // extension type + int name_index = field_metadata->FindKey(::arrow::kExtensionTypeKeyName); + if (name_index != -1) { + std::string type_name = field_metadata->value(name_index); + int data_index = field_metadata->FindKey(::arrow::kExtensionMetadataKeyName); + std::string type_data = data_index == -1 ? "" : field_metadata->value(data_index); + + std::shared_ptr<::arrow::ExtensionType> ext_type = + ::arrow::GetExtensionType(type_name); + if (ext_type != nullptr) { + std::shared_ptr deserialized; + RETURN_NOT_OK(ext_type->Deserialize(field->type(), type_data, &deserialized)); + field = field->WithType(deserialized); + } + } + } *out = field; return Status::OK(); } @@ -1098,6 +1120,25 @@ Status TransferDecimal(RecordReader* reader, MemoryPool* pool, return Status::OK(); } +Status TransferExtension(RecordReader* reader, std::shared_ptr value_type, + const ColumnDescriptor* descr, MemoryPool* pool, Datum* out) { + std::shared_ptr result; + auto ext_type = std::static_pointer_cast<::arrow::ExtensionType>(value_type); + auto storage_type = ext_type->storage_type(); + RETURN_NOT_OK(TransferColumnData(reader, storage_type, descr, pool, &result)); + + ::arrow::ArrayVector out_chunks(result->num_chunks()); + for (int i = 0; i < result->num_chunks(); i++) { + auto chunk = result->chunk(i); + auto ext_data = chunk->data()->Copy(); + ext_data->type = ext_type; + auto ext_result = ext_type->MakeArray(ext_data); + out_chunks[i] = ext_result; + } + *out = std::make_shared(out_chunks); + return Status::OK(); +} + #define TRANSFER_INT32(ENUM, ArrowType) \ case ::arrow::Type::ENUM: { \ Status s = TransferInt(reader, pool, value_type, &result); \ @@ -1194,6 +1235,9 @@ Status TransferColumnData(internal::RecordReader* reader, return Status::NotImplemented("TimeUnit not supported"); } } break; + case ::arrow::Type::EXTENSION: { + RETURN_NOT_OK(TransferExtension(reader, value_type, descr, pool, &result)); + } break; default: return Status::NotImplemented("No support for reading columns of type ", value_type->ToString()); diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc index 49d0dd98808..2a8893f271e 100644 --- a/cpp/src/parquet/arrow/schema.cc +++ b/cpp/src/parquet/arrow/schema.cc @@ -20,6 +20,7 @@ #include #include +#include "arrow/extension_type.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" @@ -308,6 +309,12 @@ Status FieldToNode(const std::shared_ptr& field, field->name(), dict_type.value_type(), field->nullable(), field->metadata()); return FieldToNode(unpacked_field, properties, arrow_properties, out); } + case ArrowTypeId::EXTENSION: { + auto ext_type = std::static_pointer_cast<::arrow::ExtensionType>(field->type()); + std::shared_ptr<::arrow::Field> storage_field = ::arrow::field( + field->name(), ext_type->storage_type(), field->nullable(), field->metadata()); + return FieldToNode(storage_field, properties, arrow_properties, out); + } default: { // TODO: DENSE_UNION, SPARE_UNION, JSON_SCALAR, DECIMAL_TEXT, VARCHAR return Status::NotImplemented( diff --git a/cpp/src/parquet/arrow/writer.cc b/cpp/src/parquet/arrow/writer.cc index 950e3de721f..cfd58b2a452 100644 --- a/cpp/src/parquet/arrow/writer.cc +++ b/cpp/src/parquet/arrow/writer.cc @@ -26,6 +26,7 @@ #include "arrow/array.h" #include "arrow/buffer_builder.h" +#include "arrow/extension_type.h" #include "arrow/ipc/writer.h" #include "arrow/table.h" #include "arrow/type.h" @@ -48,6 +49,7 @@ using arrow::DictionaryArray; using arrow::Field; using arrow::FixedSizeBinaryArray; using Int16BufferBuilder = arrow::TypedBufferBuilder; +using arrow::ExtensionArray; using arrow::ListArray; using arrow::MemoryPool; using arrow::NumericArray; @@ -115,6 +117,8 @@ class LevelBuilder { return VisitInline(*array.values()); } + Status Visit(const ExtensionArray& array) { return VisitInline(*array.storage()); } + #define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \ Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \ return Status::NotImplemented("Level generation for " #ArrowTypePrefix \ @@ -126,7 +130,6 @@ class LevelBuilder { NOT_IMPLEMENTED_VISIT(FixedSizeList) NOT_IMPLEMENTED_VISIT(Struct) NOT_IMPLEMENTED_VISIT(Union) - NOT_IMPLEMENTED_VISIT(Extension) #undef NOT_IMPLEMENTED_VISIT diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 3c03c4c31ec..dd9208549f9 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -352,3 +352,36 @@ def test_generic_ext_type_register(registered_period_type): period_type = PeriodType('D') with pytest.raises(KeyError): pa.register_extension_type(period_type) + + +@pytest.mark.parquet +def test_parquet(tmpdir, registered_period_type): + # parquet support for extension types + period_type = PeriodType('D') + storage = pa.array([1, 2, 3, 4], pa.int64()) + arr = pa.ExtensionArray.from_storage(period_type, storage) + table = pa.table([arr], names=["ext"]) + + import pyarrow.parquet as pq + + filename = tmpdir / 'extension_type.parquet' + pq.write_table(table, filename) + + # stored in parquet as storage type but with extension metadata saved + # in the serialized arrow schema + meta = pq.read_metadata(filename) + assert meta.schema.column(0).physical_type == "INT64" + assert b"ARROW:schema" in meta.metadata + schema = pa.read_schema(pa.BufferReader(meta.metadata[b"ARROW:schema"])) + assert schema.field("ext").metadata == { + b'ARROW:extension:metadata': b'freq=D', + b'ARROW:extension:name': b'pandas.period'} + + # when reading in, properly create extension type if it is registered + result = pq.read_table(filename) + assert result.column("ext").type == period_type + + # when the type is not registered, read in as storage type + pa.unregister_extension_type(period_type.extension_name) + result = pq.read_table(filename) + assert result.column("ext").type == pa.int64()