Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cpp/src/arrow/extension_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,7 @@ std::shared_ptr<ExtensionType> GetExtensionType(const std::string& type_name) {
return registry->GetType(type_name);
}

extern const char kExtensionTypeKeyName[] = "ARROW:extension:name";
extern const char kExtensionMetadataKeyName[] = "ARROW:extension:metadata";

} // namespace arrow
3 changes: 3 additions & 0 deletions cpp/src/arrow/extension_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,7 @@ Status UnregisterExtensionType(const std::string& type_name);
ARROW_EXPORT
std::shared_ptr<ExtensionType> GetExtensionType(const std::string& type_name);

ARROW_EXPORT extern const char kExtensionTypeKeyName[];
ARROW_EXPORT extern const char kExtensionMetadataKeyName[];

} // namespace arrow
3 changes: 0 additions & 3 deletions cpp/src/arrow/ipc/metadata_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ using Offset = flatbuffers::Offset<void>;
using FBString = flatbuffers::Offset<flatbuffers::String>;
using KVVector = flatbuffers::Vector<KeyValueOffset>;

static const char kExtensionTypeKeyName[] = "ARROW:extension:name";
static const char kExtensionMetadataKeyName[] = "ARROW:extension:metadata";

MetadataVersion GetMetadataVersion(flatbuf::MetadataVersion version) {
switch (version) {
case flatbuf::MetadataVersion_V1:
Expand Down
44 changes: 44 additions & 0 deletions cpp/src/parquet/arrow/reader_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "arrow/array.h"
#include "arrow/builder.h"
#include "arrow/compute/kernel.h"
#include "arrow/extension_type.h"
#include "arrow/io/memory.h"
#include "arrow/ipc/reader.h"
#include "arrow/status.h"
Expand Down Expand Up @@ -620,6 +621,27 @@ Status ApplyOriginalMetadata(std::shared_ptr<Field> field, const Field& origin_f
field = field->WithType(
::arrow::dictionary(::arrow::int32(), field->type(), dict_origin_type.ordered()));
}
// restore field metadata
std::shared_ptr<const KeyValueMetadata> field_metadata = origin_field.metadata();
if (field_metadata != nullptr) {
field = field->WithMetadata(field_metadata);

// extension type
int name_index = field_metadata->FindKey(::arrow::kExtensionTypeKeyName);
if (name_index != -1) {
std::string type_name = field_metadata->value(name_index);
int data_index = field_metadata->FindKey(::arrow::kExtensionMetadataKeyName);
std::string type_data = data_index == -1 ? "" : field_metadata->value(data_index);

std::shared_ptr<::arrow::ExtensionType> ext_type =
::arrow::GetExtensionType(type_name);
if (ext_type != nullptr) {
std::shared_ptr<DataType> deserialized;
RETURN_NOT_OK(ext_type->Deserialize(field->type(), type_data, &deserialized));
field = field->WithType(deserialized);
}
}
}
*out = field;
return Status::OK();
}
Expand Down Expand Up @@ -1098,6 +1120,25 @@ Status TransferDecimal(RecordReader* reader, MemoryPool* pool,
return Status::OK();
}

Status TransferExtension(RecordReader* reader, std::shared_ptr<DataType> value_type,
const ColumnDescriptor* descr, MemoryPool* pool, Datum* out) {
std::shared_ptr<ChunkedArray> result;
auto ext_type = std::static_pointer_cast<::arrow::ExtensionType>(value_type);
auto storage_type = ext_type->storage_type();
RETURN_NOT_OK(TransferColumnData(reader, storage_type, descr, pool, &result));

::arrow::ArrayVector out_chunks(result->num_chunks());
for (int i = 0; i < result->num_chunks(); i++) {
auto chunk = result->chunk(i);
auto ext_data = chunk->data()->Copy();
ext_data->type = ext_type;
auto ext_result = ext_type->MakeArray(ext_data);
out_chunks[i] = ext_result;
}
*out = std::make_shared<ChunkedArray>(out_chunks);
return Status::OK();
}

#define TRANSFER_INT32(ENUM, ArrowType) \
case ::arrow::Type::ENUM: { \
Status s = TransferInt<ArrowType, Int32Type>(reader, pool, value_type, &result); \
Expand Down Expand Up @@ -1194,6 +1235,9 @@ Status TransferColumnData(internal::RecordReader* reader,
return Status::NotImplemented("TimeUnit not supported");
}
} break;
case ::arrow::Type::EXTENSION: {
RETURN_NOT_OK(TransferExtension(reader, value_type, descr, pool, &result));
} break;
default:
return Status::NotImplemented("No support for reading columns of type ",
value_type->ToString());
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/parquet/arrow/schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <string>
#include <vector>

#include "arrow/extension_type.h"
#include "arrow/type.h"
#include "arrow/util/checked_cast.h"

Expand Down Expand Up @@ -308,6 +309,12 @@ Status FieldToNode(const std::shared_ptr<Field>& field,
field->name(), dict_type.value_type(), field->nullable(), field->metadata());
return FieldToNode(unpacked_field, properties, arrow_properties, out);
}
case ArrowTypeId::EXTENSION: {
auto ext_type = std::static_pointer_cast<::arrow::ExtensionType>(field->type());
std::shared_ptr<::arrow::Field> storage_field = ::arrow::field(
field->name(), ext_type->storage_type(), field->nullable(), field->metadata());
return FieldToNode(storage_field, properties, arrow_properties, out);
}
default: {
// TODO: DENSE_UNION, SPARE_UNION, JSON_SCALAR, DECIMAL_TEXT, VARCHAR
return Status::NotImplemented(
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/parquet/arrow/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "arrow/array.h"
#include "arrow/buffer_builder.h"
#include "arrow/extension_type.h"
#include "arrow/ipc/writer.h"
#include "arrow/table.h"
#include "arrow/type.h"
Expand All @@ -48,6 +49,7 @@ using arrow::DictionaryArray;
using arrow::Field;
using arrow::FixedSizeBinaryArray;
using Int16BufferBuilder = arrow::TypedBufferBuilder<int16_t>;
using arrow::ExtensionArray;
using arrow::ListArray;
using arrow::MemoryPool;
using arrow::NumericArray;
Expand Down Expand Up @@ -115,6 +117,8 @@ class LevelBuilder {
return VisitInline(*array.values());
}

Status Visit(const ExtensionArray& array) { return VisitInline(*array.storage()); }

#define NOT_IMPLEMENTED_VISIT(ArrowTypePrefix) \
Status Visit(const ::arrow::ArrowTypePrefix##Array& array) { \
return Status::NotImplemented("Level generation for " #ArrowTypePrefix \
Expand All @@ -126,7 +130,6 @@ class LevelBuilder {
NOT_IMPLEMENTED_VISIT(FixedSizeList)
NOT_IMPLEMENTED_VISIT(Struct)
NOT_IMPLEMENTED_VISIT(Union)
NOT_IMPLEMENTED_VISIT(Extension)

#undef NOT_IMPLEMENTED_VISIT

Expand Down
33 changes: 33 additions & 0 deletions python/pyarrow/tests/test_extension_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,36 @@ def test_generic_ext_type_register(registered_period_type):
period_type = PeriodType('D')
with pytest.raises(KeyError):
pa.register_extension_type(period_type)


@pytest.mark.parquet
def test_parquet(tmpdir, registered_period_type):
# parquet support for extension types
period_type = PeriodType('D')
storage = pa.array([1, 2, 3, 4], pa.int64())
arr = pa.ExtensionArray.from_storage(period_type, storage)
table = pa.table([arr], names=["ext"])

import pyarrow.parquet as pq

filename = tmpdir / 'extension_type.parquet'
pq.write_table(table, filename)

# stored in parquet as storage type but with extension metadata saved
# in the serialized arrow schema
meta = pq.read_metadata(filename)
assert meta.schema.column(0).physical_type == "INT64"
assert b"ARROW:schema" in meta.metadata
schema = pa.read_schema(pa.BufferReader(meta.metadata[b"ARROW:schema"]))
assert schema.field("ext").metadata == {
b'ARROW:extension:metadata': b'freq=D',
b'ARROW:extension:name': b'pandas.period'}

# when reading in, properly create extension type if it is registered
result = pq.read_table(filename)
assert result.column("ext").type == period_type

# when the type is not registered, read in as storage type
pa.unregister_extension_type(period_type.extension_name)
result = pq.read_table(filename)
assert result.column("ext").type == pa.int64()