diff --git a/docs/source/status.rst b/docs/source/status.rst index c232aa280be..5e2c2cc19c8 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -119,6 +119,12 @@ Data Types +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Variable shape tensor | | | | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| JSON | | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| UUID | | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| 8-bit Boolean | ✓ | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ Notes: diff --git a/go/arrow/array/array_test.go b/go/arrow/array/array_test.go index 4d83766b4fa..4f0627c6000 100644 --- a/go/arrow/array/array_test.go +++ b/go/arrow/array/array_test.go @@ -21,9 +21,9 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/internal/testing/tools" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/stretchr/testify/assert" ) @@ -122,7 +122,7 @@ func TestMakeFromData(t *testing.T) { {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint64, ValueType: &testDataType{arrow.TIMESTAMP}}, dict: array.NewData(&testDataType{arrow.TIMESTAMP}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)}, {name: "extension", d: &testDataType{arrow.EXTENSION}, expPanic: true, expError: "arrow/array: DataType for ExtensionArray must implement arrow.ExtensionType"}, - {name: "extension", d: types.NewUUIDType()}, + {name: "extension", d: extensions.NewUUIDType()}, {name: "run end encoded", d: arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Int64), child: []arrow.ArrayData{ array.NewData(&testDataType{arrow.INT64}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */), diff --git a/go/arrow/array/diff_test.go b/go/arrow/array/diff_test.go index 65d212be118..9c9ce6a53ae 100644 --- a/go/arrow/array/diff_test.go +++ b/go/arrow/array/diff_test.go @@ -25,9 +25,9 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/json" - "github.com/apache/arrow/go/v18/internal/types" ) type diffTestCase struct { @@ -861,7 +861,7 @@ func TestEdits_UnifiedDiff(t *testing.T) { }, { name: "extensions", - dataType: types.NewUUIDType(), + dataType: extensions.NewUUIDType(), baseJSON: `["00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000001"]`, targetJSON: `["00000000-0000-0000-0000-000000000001", "00000000-0000-0000-0000-000000000002"]`, want: `@@ -0, +0 @@ diff --git a/go/arrow/array/extension_test.go b/go/arrow/array/extension_test.go index 71ea9f105af..26245cf015d 100644 --- a/go/arrow/array/extension_test.go +++ b/go/arrow/array/extension_test.go @@ -30,16 +30,6 @@ type ExtensionTypeTestSuite struct { suite.Suite } -func (e *ExtensionTypeTestSuite) SetupTest() { - e.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (e *ExtensionTypeTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - e.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (e *ExtensionTypeTestSuite) TestParametricEquals() { p1Type := types.NewParametric1Type(6) p2Type := types.NewParametric1Type(6) diff --git a/go/arrow/avro/reader_types.go b/go/arrow/avro/reader_types.go index e07cd380d51..dab2b33dce6 100644 --- a/go/arrow/avro/reader_types.go +++ b/go/arrow/avro/reader_types.go @@ -27,8 +27,8 @@ import ( "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" ) type dataLoader struct { @@ -436,7 +436,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } return nil } - case *types.UUIDBuilder: + case *extensions.UUIDBuilder: f.appendFunc = func(data interface{}) error { switch dt := data.(type) { case nil: diff --git a/go/arrow/avro/schema.go b/go/arrow/avro/schema.go index 007dad06c19..a6de3718d3c 100644 --- a/go/arrow/avro/schema.go +++ b/go/arrow/avro/schema.go @@ -24,7 +24,7 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/decimal128" - "github.com/apache/arrow/go/v18/internal/types" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/internal/utils" avro "github.com/hamba/avro/v2" ) @@ -349,7 +349,7 @@ func avroLogicalToArrowField(n *schemaNode) { // The uuid logical type represents a random generated universally unique identifier (UUID). // A uuid logical type annotates an Avro string. The string has to conform with RFC-4122 case "uuid": - dt = types.NewUUIDType() + dt = extensions.NewUUIDType() // The date logical type represents a date within the calendar, with no reference to a particular // time zone or time of day. diff --git a/go/arrow/compute/exec/span_test.go b/go/arrow/compute/exec/span_test.go index f5beb45ee14..018fbb7d623 100644 --- a/go/arrow/compute/exec/span_test.go +++ b/go/arrow/compute/exec/span_test.go @@ -29,6 +29,7 @@ import ( "github.com/apache/arrow/go/v18/arrow/compute/exec" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/endian" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/arrow/scalar" "github.com/apache/arrow/go/v18/internal/types" @@ -192,9 +193,6 @@ func TestArraySpan_NumBuffers(t *testing.T) { Children []exec.ArraySpan } - arrow.RegisterExtensionType(types.NewUUIDType()) - defer arrow.UnregisterExtensionType("uuid") - tests := []struct { name string fields fields @@ -207,7 +205,7 @@ func TestArraySpan_NumBuffers(t *testing.T) { {"large binary", fields{Type: arrow.BinaryTypes.LargeBinary}, 3}, {"string", fields{Type: arrow.BinaryTypes.String}, 3}, {"large string", fields{Type: arrow.BinaryTypes.LargeString}, 3}, - {"extension", fields{Type: types.NewUUIDType()}, 2}, + {"extension", fields{Type: extensions.NewUUIDType()}, 2}, {"int32", fields{Type: arrow.PrimitiveTypes.Int32}, 2}, } for _, tt := range tests { diff --git a/go/arrow/csv/reader_test.go b/go/arrow/csv/reader_test.go index b0775b9b11a..6a89d497042 100644 --- a/go/arrow/csv/reader_test.go +++ b/go/arrow/csv/reader_test.go @@ -30,8 +30,8 @@ import ( "github.com/apache/arrow/go/v18/arrow/csv" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -356,7 +356,7 @@ func testCSVReader(t *testing.T, filepath string, withHeader bool, stringsCanBeN {Name: "binary", Type: arrow.BinaryTypes.Binary}, {Name: "large_binary", Type: arrow.BinaryTypes.LargeBinary}, {Name: "fixed_size_binary", Type: &arrow.FixedSizeBinaryType{ByteWidth: 3}}, - {Name: "uuid", Type: types.NewUUIDType()}, + {Name: "uuid", Type: extensions.NewUUIDType()}, {Name: "date32", Type: arrow.PrimitiveTypes.Date32}, {Name: "date64", Type: arrow.PrimitiveTypes.Date64}, }, diff --git a/go/arrow/csv/writer_test.go b/go/arrow/csv/writer_test.go index be9ab961c3e..2ae01a6d490 100644 --- a/go/arrow/csv/writer_test.go +++ b/go/arrow/csv/writer_test.go @@ -31,9 +31,9 @@ import ( "github.com/apache/arrow/go/v18/arrow/csv" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/float16" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/google/uuid" ) @@ -230,7 +230,7 @@ func testCSVWriter(t *testing.T, data [][]string, writeHeader bool, fmtr func(bo {Name: "binary", Type: arrow.BinaryTypes.Binary}, {Name: "large_binary", Type: arrow.BinaryTypes.LargeBinary}, {Name: "fixed_size_binary", Type: &arrow.FixedSizeBinaryType{ByteWidth: 3}}, - {Name: "uuid", Type: types.NewUUIDType()}, + {Name: "uuid", Type: extensions.NewUUIDType()}, {Name: "null", Type: arrow.Null}, }, nil, @@ -285,7 +285,7 @@ func testCSVWriter(t *testing.T, data [][]string, writeHeader bool, fmtr func(bo b.Field(22).(*array.BinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) b.Field(23).(*array.BinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) b.Field(24).(*array.FixedSizeBinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) - b.Field(25).(*types.UUIDBuilder).AppendValues([]uuid.UUID{uuid.MustParse("00000000-0000-0000-0000-000000000001"), uuid.MustParse("00000000-0000-0000-0000-000000000002"), uuid.MustParse("00000000-0000-0000-0000-000000000003")}, nil) + b.Field(25).(*extensions.UUIDBuilder).AppendValues([]uuid.UUID{uuid.MustParse("00000000-0000-0000-0000-000000000001"), uuid.MustParse("00000000-0000-0000-0000-000000000002"), uuid.MustParse("00000000-0000-0000-0000-000000000003")}, nil) b.Field(26).(*array.NullBuilder).AppendEmptyValues(3) for _, field := range b.Fields() { diff --git a/go/arrow/datatype_extension_test.go b/go/arrow/datatype_extension_test.go index c3e595f523e..7244d377bd2 100644 --- a/go/arrow/datatype_extension_test.go +++ b/go/arrow/datatype_extension_test.go @@ -21,7 +21,7 @@ import ( "testing" "github.com/apache/arrow/go/v18/arrow" - "github.com/apache/arrow/go/v18/internal/types" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -50,24 +50,14 @@ type ExtensionTypeTestSuite struct { suite.Suite } -func (e *ExtensionTypeTestSuite) SetupTest() { - e.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (e *ExtensionTypeTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - e.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (e *ExtensionTypeTestSuite) TestExtensionType() { e.Nil(arrow.GetExtensionType("uuid-unknown")) - e.NotNil(arrow.GetExtensionType("uuid")) + e.NotNil(arrow.GetExtensionType("arrow.uuid")) - e.Error(arrow.RegisterExtensionType(types.NewUUIDType())) + e.Error(arrow.RegisterExtensionType(extensions.NewUUIDType())) e.Error(arrow.UnregisterExtensionType("uuid-unknown")) - typ := types.NewUUIDType() + typ := extensions.NewUUIDType() e.Implements((*arrow.ExtensionType)(nil), typ) e.Equal(arrow.EXTENSION, typ.ID()) e.Equal("extension", typ.Name()) diff --git a/go/arrow/extensions/bool8_test.go b/go/arrow/extensions/bool8_test.go index 9f7365d1555..ff129e24bc8 100644 --- a/go/arrow/extensions/bool8_test.go +++ b/go/arrow/extensions/bool8_test.go @@ -178,9 +178,6 @@ func TestReinterpretStorageEqualToValues(t *testing.T) { func TestBool8TypeBatchIPCRoundTrip(t *testing.T) { typ := extensions.NewBool8Type() - arrow.RegisterExtensionType(typ) - defer arrow.UnregisterExtensionType(typ.ExtensionName()) - storage, _, err := array.FromJSON(memory.DefaultAllocator, arrow.PrimitiveTypes.Int8, strings.NewReader(`[-1, 0, 1, 2, null]`)) require.NoError(t, err) diff --git a/go/arrow/extensions/extensions.go b/go/arrow/extensions/extensions.go new file mode 100644 index 00000000000..03c6923e95f --- /dev/null +++ b/go/arrow/extensions/extensions.go @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "github.com/apache/arrow/go/v18/arrow" +) + +var canonicalExtensionTypes = []arrow.ExtensionType{ + &Bool8Type{}, + &UUIDType{}, + &OpaqueType{}, + &JSONType{}, +} + +func init() { + for _, extType := range canonicalExtensionTypes { + if err := arrow.RegisterExtensionType(extType); err != nil { + panic(err) + } + } +} diff --git a/go/arrow/extensions/json.go b/go/arrow/extensions/json.go new file mode 100644 index 00000000000..12c49f9c0a7 --- /dev/null +++ b/go/arrow/extensions/json.go @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "fmt" + "reflect" + "slices" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/apache/arrow/go/v18/parquet/schema" +) + +var jsonSupportedStorageTypes = []arrow.DataType{ + arrow.BinaryTypes.String, + arrow.BinaryTypes.LargeString, + arrow.BinaryTypes.StringView, +} + +// JSONType represents a UTF-8 encoded JSON string as specified in RFC8259. +type JSONType struct { + arrow.ExtensionBase +} + +// ParquetLogicalType implements pqarrow.ExtensionCustomParquetType. +func (b *JSONType) ParquetLogicalType() schema.LogicalType { + return schema.JSONLogicalType{} +} + +// NewJSONType creates a new JSONType with the specified storage type. +// storageType must be one of String, LargeString, StringView. +func NewJSONType(storageType arrow.DataType) (*JSONType, error) { + if !slices.Contains(jsonSupportedStorageTypes, storageType) { + return nil, fmt.Errorf("unsupported storage type for JSON extension type: %s", storageType) + } + return &JSONType{ExtensionBase: arrow.ExtensionBase{Storage: storageType}}, nil +} + +func (b *JSONType) ArrayType() reflect.Type { return reflect.TypeOf(JSONArray{}) } + +func (b *JSONType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + if !(data == "" || data == "{}") { + return nil, fmt.Errorf("serialized metadata for JSON extension type must be '' or '{}', found: %s", data) + } + return NewJSONType(storageType) +} + +func (b *JSONType) ExtensionEquals(other arrow.ExtensionType) bool { + return b.ExtensionName() == other.ExtensionName() && arrow.TypeEqual(b.Storage, other.StorageType()) +} + +func (b *JSONType) ExtensionName() string { return "arrow.json" } + +func (b *JSONType) Serialize() string { return "" } + +func (b *JSONType) String() string { + return fmt.Sprintf("extension<%s[storage_type=%s]>", b.ExtensionName(), b.Storage) +} + +// JSONArray is logically an array of UTF-8 encoded JSON strings. +// Its values are unmarshaled to native Go values. +type JSONArray struct { + array.ExtensionArrayBase +} + +func (a *JSONArray) String() string { + b, err := a.MarshalJSON() + if err != nil { + panic(fmt.Sprintf("failed marshal JSONArray: %s", err)) + } + + return string(b) +} + +func (a *JSONArray) Value(i int) any { + val := a.ValueBytes(i) + + var res any + if err := json.Unmarshal(val, &res); err != nil { + panic(err) + } + + return res +} + +func (a *JSONArray) ValueStr(i int) string { + return string(a.ValueBytes(i)) +} + +func (a *JSONArray) ValueBytes(i int) []byte { + // convert to json.RawMessage, set to nil if elem isNull. + val := a.ValueJSON(i) + + // simply returns wrapped bytes, or null if val is nil. + b, err := val.MarshalJSON() + if err != nil { + panic(err) + } + + return b +} + +// ValueJSON wraps the underlying string value as a json.RawMessage, +// or returns nil if the array value is null. +func (a *JSONArray) ValueJSON(i int) json.RawMessage { + var val json.RawMessage + if a.IsValid(i) { + val = json.RawMessage(a.Storage().(array.StringLike).Value(i)) + } + return val +} + +// MarshalJSON implements json.Marshaler. +// Marshaling json.RawMessage is a no-op, except that nil values will +// be marshaled as a JSON null. +func (a *JSONArray) MarshalJSON() ([]byte, error) { + values := make([]json.RawMessage, a.Len()) + for i := 0; i < a.Len(); i++ { + values[i] = a.ValueJSON(i) + } + return json.Marshal(values) +} + +// GetOneForMarshal implements arrow.Array. +func (a *JSONArray) GetOneForMarshal(i int) interface{} { + return a.ValueJSON(i) +} + +var ( + _ arrow.ExtensionType = (*JSONType)(nil) + _ array.ExtensionArray = (*JSONArray)(nil) +) diff --git a/go/arrow/extensions/json_test.go b/go/arrow/extensions/json_test.go new file mode 100644 index 00000000000..21acc58f939 --- /dev/null +++ b/go/arrow/extensions/json_test.go @@ -0,0 +1,268 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/ipc" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONTypeBasics(t *testing.T) { + typ, err := extensions.NewJSONType(arrow.BinaryTypes.String) + require.NoError(t, err) + + typLarge, err := extensions.NewJSONType(arrow.BinaryTypes.LargeString) + require.NoError(t, err) + + typView, err := extensions.NewJSONType(arrow.BinaryTypes.StringView) + require.NoError(t, err) + + assert.Equal(t, "arrow.json", typ.ExtensionName()) + assert.Equal(t, "arrow.json", typLarge.ExtensionName()) + assert.Equal(t, "arrow.json", typView.ExtensionName()) + + assert.True(t, typ.ExtensionEquals(typ)) + assert.True(t, typLarge.ExtensionEquals(typLarge)) + assert.True(t, typView.ExtensionEquals(typView)) + + assert.False(t, arrow.TypeEqual(arrow.BinaryTypes.String, typ)) + assert.False(t, arrow.TypeEqual(typ, typLarge)) + assert.False(t, arrow.TypeEqual(typ, typView)) + assert.False(t, arrow.TypeEqual(typLarge, typView)) + + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.String, typ.StorageType())) + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.LargeString, typLarge.StorageType())) + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.StringView, typView.StorageType())) + + assert.Equal(t, "extension", typ.String()) + assert.Equal(t, "extension", typLarge.String()) + assert.Equal(t, "extension", typView.String()) +} + +var jsonTestCases = []struct { + Name string + StorageType arrow.DataType + StorageBuilder func(mem memory.Allocator) array.Builder +}{ + { + Name: "string", + StorageType: arrow.BinaryTypes.String, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewStringBuilder(mem) }, + }, + { + Name: "large_string", + StorageType: arrow.BinaryTypes.LargeString, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewLargeStringBuilder(mem) }, + }, + { + Name: "string_view", + StorageType: arrow.BinaryTypes.StringView, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewStringViewBuilder(mem) }, + }, +} + +func TestJSONTypeCreateFromArray(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + require.Equal(t, "foobar", jsonArr.Value(0)) + require.Equal(t, nil, jsonArr.Value(1)) + require.Equal(t, map[string]any{"foo": "bar"}, jsonArr.Value(2)) + require.Equal(t, float64(42), jsonArr.Value(3)) + require.Equal(t, true, jsonArr.Value(4)) + require.Equal(t, []any{float64(1), true, "3", nil, map[string]any{"five": float64(5)}}, jsonArr.Value(5)) + }) + } +} + +func TestJSONTypeBatchIPCRoundTrip(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "field", Type: typ, Nullable: true}}, nil), + []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.Record + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s, got: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s, got: %s", + batch, written) + }) + } +} + +func TestMarshallJSONArray(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + b, err := jsonArr.MarshalJSON() + require.NoError(t, err) + + expectedJSON := `["foobar",null,{"foo":"bar"},42,true,[1,true,"3",null,{"five":5}]]` + require.Equal(t, expectedJSON, string(b)) + require.Equal(t, expectedJSON, jsonArr.String()) + }) + } +} + +func TestJSONRecordToJSON(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + rec := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "json", Type: typ, Nullable: true}}, nil), []arrow.Array{jsonArr}, 6) + defer rec.Release() + + buf := bytes.NewBuffer([]byte("\n")) // expected output has leading newline for clearer formatting + require.NoError(t, array.RecordToJSON(rec, buf)) + + expectedJSON := ` + {"json":"foobar"} + {"json":null} + {"json":{"foo":"bar"}} + {"json":42} + {"json":true} + {"json":[1,true,"3",null,{"five":5}]} + ` + + expectedJSONLines := strings.Split(expectedJSON, "\n") + actualJSONLines := strings.Split(buf.String(), "\n") + + require.Equal(t, len(expectedJSONLines), len(actualJSONLines)) + for i := range expectedJSONLines { + if strings.TrimSpace(expectedJSONLines[i]) != "" { + require.JSONEq(t, expectedJSONLines[i], actualJSONLines[i]) + } + } + }) + } +} diff --git a/go/arrow/extensions/opaque_test.go b/go/arrow/extensions/opaque_test.go index b6686e97bc0..a0fc8962ce5 100644 --- a/go/arrow/extensions/opaque_test.go +++ b/go/arrow/extensions/opaque_test.go @@ -161,9 +161,6 @@ func TestOpaqueTypeMetadataRoundTrip(t *testing.T) { func TestOpaqueTypeBatchRoundTrip(t *testing.T) { typ := extensions.NewOpaqueType(arrow.BinaryTypes.String, "geometry", "adbc.postgresql") - arrow.RegisterExtensionType(typ) - defer arrow.UnregisterExtensionType(typ.ExtensionName()) - storage, _, err := array.FromJSON(memory.DefaultAllocator, arrow.BinaryTypes.String, strings.NewReader(`["foobar", null]`)) require.NoError(t, err) diff --git a/go/arrow/extensions/uuid.go b/go/arrow/extensions/uuid.go new file mode 100644 index 00000000000..422b9ea1188 --- /dev/null +++ b/go/arrow/extensions/uuid.go @@ -0,0 +1,265 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "bytes" + "fmt" + "reflect" + "strings" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/apache/arrow/go/v18/parquet/schema" + "github.com/google/uuid" +) + +type UUIDBuilder struct { + *array.ExtensionBuilder +} + +// NewUUIDBuilder creates a new UUIDBuilder, exposing a convenient and efficient interface +// for writing uuid.UUID (or [16]byte) values to the underlying FixedSizeBinary storage array. +func NewUUIDBuilder(mem memory.Allocator) *UUIDBuilder { + return &UUIDBuilder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewUUIDType())} +} + +func (b *UUIDBuilder) Append(v uuid.UUID) { + b.AppendBytes(v) +} + +func (b *UUIDBuilder) AppendBytes(v [16]byte) { + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).Append(v[:]) +} + +func (b *UUIDBuilder) UnsafeAppend(v uuid.UUID) { + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).UnsafeAppend(v[:]) +} + +func (b *UUIDBuilder) AppendValueFromString(s string) error { + if s == array.NullValueStr { + b.AppendNull() + return nil + } + + uid, err := uuid.Parse(s) + if err != nil { + return err + } + + b.Append(uid) + return nil +} + +func (b *UUIDBuilder) AppendValues(v []uuid.UUID, valid []bool) { + if len(v) != len(valid) && len(valid) != 0 { + panic("len(v) != len(valid) && len(valid) != 0") + } + + data := make([][]byte, len(v)) + for i := range v { + if len(valid) > 0 && !valid[i] { + continue + } + data[i] = v[i][:] + } + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).AppendValues(data, valid) +} + +func (b *UUIDBuilder) UnmarshalOne(dec *json.Decoder) error { + t, err := dec.Token() + if err != nil { + return err + } + + var val uuid.UUID + switch v := t.(type) { + case string: + val, err = uuid.Parse(v) + if err != nil { + return err + } + case []byte: + val, err = uuid.ParseBytes(v) + if err != nil { + return err + } + case nil: + b.AppendNull() + return nil + default: + return &json.UnmarshalTypeError{ + Value: fmt.Sprint(t), + Type: reflect.TypeOf([]byte{}), + Offset: dec.InputOffset(), + Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16), + } + } + + b.Append(val) + return nil +} + +func (b *UUIDBuilder) Unmarshal(dec *json.Decoder) error { + for dec.More() { + if err := b.UnmarshalOne(dec); err != nil { + return err + } + } + return nil +} + +func (b *UUIDBuilder) UnmarshalJSON(data []byte) error { + dec := json.NewDecoder(bytes.NewReader(data)) + t, err := dec.Token() + if err != nil { + return err + } + + if delim, ok := t.(json.Delim); !ok || delim != '[' { + return fmt.Errorf("uuid builder must unpack from json array, found %s", delim) + } + + return b.Unmarshal(dec) +} + +// UUIDArray is a simple array which is a FixedSizeBinary(16) +type UUIDArray struct { + array.ExtensionArrayBase +} + +func (a *UUIDArray) String() string { + arr := a.Storage().(*array.FixedSizeBinary) + o := new(strings.Builder) + o.WriteString("[") + for i := 0; i < arr.Len(); i++ { + if i > 0 { + o.WriteString(" ") + } + switch { + case a.IsNull(i): + o.WriteString(array.NullValueStr) + default: + fmt.Fprintf(o, "%q", a.Value(i)) + } + } + o.WriteString("]") + return o.String() +} + +func (a *UUIDArray) Value(i int) uuid.UUID { + if a.IsNull(i) { + return uuid.Nil + } + return uuid.Must(uuid.FromBytes(a.Storage().(*array.FixedSizeBinary).Value(i))) +} + +func (a *UUIDArray) Values() []uuid.UUID { + values := make([]uuid.UUID, a.Len()) + for i := range values { + values[i] = a.Value(i) + } + return values +} + +func (a *UUIDArray) ValueStr(i int) string { + switch { + case a.IsNull(i): + return array.NullValueStr + default: + return a.Value(i).String() + } +} + +func (a *UUIDArray) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := range vals { + vals[i] = a.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +func (a *UUIDArray) GetOneForMarshal(i int) interface{} { + if a.IsValid(i) { + return a.Value(i) + } + return nil +} + +// UUIDType is a simple extension type that represents a FixedSizeBinary(16) +// to be used for representing UUIDs +type UUIDType struct { + arrow.ExtensionBase +} + +// ParquetLogicalType implements pqarrow.ExtensionCustomParquetType. +func (e *UUIDType) ParquetLogicalType() schema.LogicalType { + return schema.UUIDLogicalType{} +} + +// NewUUIDType is a convenience function to create an instance of UUIDType +// with the correct storage type +func NewUUIDType() *UUIDType { + return &UUIDType{ExtensionBase: arrow.ExtensionBase{Storage: &arrow.FixedSizeBinaryType{ByteWidth: 16}}} +} + +// ArrayType returns TypeOf(UUIDArray{}) for constructing UUID arrays +func (*UUIDType) ArrayType() reflect.Type { + return reflect.TypeOf(UUIDArray{}) +} + +func (*UUIDType) ExtensionName() string { + return "arrow.uuid" +} + +func (e *UUIDType) String() string { + return fmt.Sprintf("extension<%s>", e.ExtensionName()) +} + +func (e *UUIDType) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`{"name":"%s","metadata":%s}`, e.ExtensionName(), e.Serialize())), nil +} + +func (*UUIDType) Serialize() string { + return "" +} + +// Deserialize expects storageType to be FixedSizeBinaryType{ByteWidth: 16} +func (*UUIDType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + if !arrow.TypeEqual(storageType, &arrow.FixedSizeBinaryType{ByteWidth: 16}) { + return nil, fmt.Errorf("invalid storage type for UUIDType: %s", storageType.Name()) + } + return NewUUIDType(), nil +} + +// ExtensionEquals returns true if both extensions have the same name +func (e *UUIDType) ExtensionEquals(other arrow.ExtensionType) bool { + return e.ExtensionName() == other.ExtensionName() +} + +func (*UUIDType) NewBuilder(mem memory.Allocator) array.Builder { + return NewUUIDBuilder(mem) +} + +var ( + _ arrow.ExtensionType = (*UUIDType)(nil) + _ array.CustomExtensionBuilder = (*UUIDType)(nil) + _ array.ExtensionArray = (*UUIDArray)(nil) + _ array.Builder = (*UUIDBuilder)(nil) +) diff --git a/go/arrow/extensions/uuid_test.go b/go/arrow/extensions/uuid_test.go new file mode 100644 index 00000000000..80c621db2a0 --- /dev/null +++ b/go/arrow/extensions/uuid_test.go @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/ipc" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testUUID = uuid.New() + +func TestUUIDExtensionBuilder(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + builder := extensions.NewUUIDBuilder(mem) + builder.Append(testUUID) + builder.AppendNull() + builder.AppendBytes(testUUID) + arr := builder.NewArray() + defer arr.Release() + arrStr := arr.String() + assert.Equal(t, fmt.Sprintf(`["%[1]s" (null) "%[1]s"]`, testUUID), arrStr) + jsonStr, err := json.Marshal(arr) + assert.NoError(t, err) + + arr1, _, err := array.FromJSON(mem, extensions.NewUUIDType(), bytes.NewReader(jsonStr)) + defer arr1.Release() + assert.NoError(t, err) + assert.True(t, array.Equal(arr1, arr)) + + require.NoError(t, json.Unmarshal(jsonStr, builder)) + arr2 := builder.NewArray() + defer arr2.Release() + assert.True(t, array.Equal(arr2, arr)) +} + +func TestUUIDExtensionRecordBuilder(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "uuid", Type: extensions.NewUUIDType()}, + }, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + builder.Field(0).(*extensions.UUIDBuilder).Append(testUUID) + builder.Field(0).(*extensions.UUIDBuilder).AppendNull() + builder.Field(0).(*extensions.UUIDBuilder).Append(testUUID) + record := builder.NewRecord() + b, err := record.MarshalJSON() + require.NoError(t, err) + require.Equal(t, "[{\"uuid\":\""+testUUID.String()+"\"}\n,{\"uuid\":null}\n,{\"uuid\":\""+testUUID.String()+"\"}\n]", string(b)) + record1, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(b)) + require.NoError(t, err) + require.Equal(t, record, record1) +} + +func TestUUIDStringRoundTrip(t *testing.T) { + // 1. create array + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + b := extensions.NewUUIDBuilder(mem) + b.Append(uuid.Nil) + b.AppendNull() + b.Append(uuid.NameSpaceURL) + b.AppendNull() + b.Append(testUUID) + + arr := b.NewArray() + defer arr.Release() + + // 2. create array via AppendValueFromString + b1 := extensions.NewUUIDBuilder(mem) + defer b1.Release() + + for i := 0; i < arr.Len(); i++ { + assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) + } + + arr1 := b1.NewArray() + defer arr1.Release() + + assert.True(t, array.Equal(arr, arr1)) +} + +func TestUUIDTypeBasics(t *testing.T) { + typ := extensions.NewUUIDType() + + assert.Equal(t, "arrow.uuid", typ.ExtensionName()) + assert.True(t, typ.ExtensionEquals(typ)) + + assert.True(t, arrow.TypeEqual(typ, typ)) + assert.False(t, arrow.TypeEqual(&arrow.FixedSizeBinaryType{ByteWidth: 16}, typ)) + assert.True(t, arrow.TypeEqual(&arrow.FixedSizeBinaryType{ByteWidth: 16}, typ.StorageType())) + + assert.Equal(t, "extension", typ.String()) +} + +func TestUUIDTypeCreateFromArray(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := array.NewFixedSizeBinaryBuilder(memory.DefaultAllocator, &arrow.FixedSizeBinaryType{ByteWidth: 16}) + defer bldr.Release() + + bldr.Append(testUUID[:]) + bldr.AppendNull() + bldr.Append(testUUID[:]) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + require.Equal(t, testUUID, uuidArr.Value(0)) + require.Equal(t, uuid.Nil, uuidArr.Value(1)) + require.Equal(t, testUUID, uuidArr.Value(2)) +} + +func TestUUIDTypeBatchIPCRoundTrip(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.Append(testUUID) + bldr.AppendNull() + bldr.AppendBytes(testUUID) + + arr := bldr.NewArray() + defer arr.Release() + + batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "field", Type: typ, Nullable: true}}, nil), + []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.Record + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s, got: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s, got: %s", + batch, written) +} + +func TestMarshallUUIDArray(t *testing.T) { + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.Append(testUUID) + bldr.AppendNull() + bldr.AppendBytes(testUUID) + + arr := bldr.NewArray() + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + b, err := uuidArr.MarshalJSON() + require.NoError(t, err) + + expectedJSON := fmt.Sprintf(`["%[1]s",null,"%[1]s"]`, testUUID) + require.Equal(t, expectedJSON, string(b)) +} + +func TestUUIDRecordToJSON(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + uuid1 := uuid.MustParse("8c607ed4-07b2-4b9c-b5eb-c0387357f9ae") + + bldr.Append(uuid1) + bldr.AppendNull() + + // c5f2cbd9-7094-491a-b267-167bb62efe02 + bldr.AppendBytes([16]byte{197, 242, 203, 217, 112, 148, 73, 26, 178, 103, 22, 123, 182, 46, 254, 2}) + + arr := bldr.NewArray() + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + rec := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "uuid", Type: typ, Nullable: true}}, nil), []arrow.Array{uuidArr}, 3) + defer rec.Release() + + buf := bytes.NewBuffer([]byte("\n")) // expected output has leading newline for clearer formatting + require.NoError(t, array.RecordToJSON(rec, buf)) + + expectedJSON := ` + {"uuid":"8c607ed4-07b2-4b9c-b5eb-c0387357f9ae"} + {"uuid":null} + {"uuid":"c5f2cbd9-7094-491a-b267-167bb62efe02"} + ` + + expectedJSONLines := strings.Split(expectedJSON, "\n") + actualJSONLines := strings.Split(buf.String(), "\n") + + require.Equal(t, len(expectedJSONLines), len(actualJSONLines)) + for i := range expectedJSONLines { + if strings.TrimSpace(expectedJSONLines[i]) != "" { + require.JSONEq(t, expectedJSONLines[i], actualJSONLines[i]) + } + } +} diff --git a/go/arrow/internal/flight_integration/scenario.go b/go/arrow/internal/flight_integration/scenario.go index 1528bb05d9d..b9535002a0a 100644 --- a/go/arrow/internal/flight_integration/scenario.go +++ b/go/arrow/internal/flight_integration/scenario.go @@ -40,7 +40,6 @@ import ( "github.com/apache/arrow/go/v18/arrow/internal/arrjson" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "golang.org/x/xerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -161,9 +160,6 @@ func (s *defaultIntegrationTester) RunClient(addr string, opts ...grpc.DialOptio ctx := context.Background() - arrow.RegisterExtensionType(types.NewUUIDType()) - defer arrow.UnregisterExtensionType("uuid") - descr := &flight.FlightDescriptor{ Type: flight.DescriptorPATH, Path: []string{s.path}, diff --git a/go/arrow/ipc/cmd/arrow-json-integration-test/main.go b/go/arrow/ipc/cmd/arrow-json-integration-test/main.go index b3e1dcac141..c47a091268b 100644 --- a/go/arrow/ipc/cmd/arrow-json-integration-test/main.go +++ b/go/arrow/ipc/cmd/arrow-json-integration-test/main.go @@ -22,12 +22,10 @@ import ( "log" "os" - "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/arrio" "github.com/apache/arrow/go/v18/arrow/internal/arrjson" "github.com/apache/arrow/go/v18/arrow/ipc" - "github.com/apache/arrow/go/v18/internal/types" ) func main() { @@ -50,8 +48,6 @@ func main() { } func runCommand(jsonName, arrowName, mode string, verbose bool) error { - arrow.RegisterExtensionType(types.NewUUIDType()) - if jsonName == "" { return fmt.Errorf("must specify json file name") } diff --git a/go/arrow/ipc/metadata_test.go b/go/arrow/ipc/metadata_test.go index 33bc63c2a00..14b8da2cf7c 100644 --- a/go/arrow/ipc/metadata_test.go +++ b/go/arrow/ipc/metadata_test.go @@ -23,10 +23,10 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/internal/dictutils" "github.com/apache/arrow/go/v18/arrow/internal/flatbuf" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" flatbuffers "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" ) @@ -169,7 +169,7 @@ func TestRWFooter(t *testing.T) { } func exampleUUID(mem memory.Allocator) arrow.Array { - extType := types.NewUUIDType() + extType := extensions.NewUUIDType() bldr := array.NewExtensionBuilder(mem, extType) defer bldr.Release() @@ -184,9 +184,6 @@ func TestUnrecognizedExtensionType(t *testing.T) { pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer pool.AssertSize(t, 0) - // register the uuid type - assert.NoError(t, arrow.RegisterExtensionType(types.NewUUIDType())) - extArr := exampleUUID(pool) defer extArr.Release() @@ -205,7 +202,9 @@ func TestUnrecognizedExtensionType(t *testing.T) { // unregister the uuid type before we read back the buffer so it is // unrecognized when reading back the record batch. - assert.NoError(t, arrow.UnregisterExtensionType("uuid")) + assert.NoError(t, arrow.UnregisterExtensionType("arrow.uuid")) + // re-register once the test is complete + defer arrow.RegisterExtensionType(extensions.NewUUIDType()) rdr, err := NewReader(&buf, WithAllocator(pool)) defer rdr.Release() diff --git a/go/internal/types/extension_types.go b/go/internal/types/extension_types.go index 85c64d86bff..33ada2d488f 100644 --- a/go/internal/types/extension_types.go +++ b/go/internal/types/extension_types.go @@ -18,238 +18,15 @@ package types import ( - "bytes" "encoding/binary" "fmt" "reflect" - "strings" "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" - "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/json" - "github.com/google/uuid" "golang.org/x/xerrors" ) -var UUID = NewUUIDType() - -type UUIDBuilder struct { - *array.ExtensionBuilder -} - -func NewUUIDBuilder(mem memory.Allocator) *UUIDBuilder { - return &UUIDBuilder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewUUIDType())} -} - -func (b *UUIDBuilder) Append(v uuid.UUID) { - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).Append(v[:]) -} - -func (b *UUIDBuilder) UnsafeAppend(v uuid.UUID) { - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).UnsafeAppend(v[:]) -} - -func (b *UUIDBuilder) AppendValueFromString(s string) error { - if s == array.NullValueStr { - b.AppendNull() - return nil - } - - uid, err := uuid.Parse(s) - if err != nil { - return err - } - - b.Append(uid) - return nil -} - -func (b *UUIDBuilder) AppendValues(v []uuid.UUID, valid []bool) { - if len(v) != len(valid) && len(valid) != 0 { - panic("len(v) != len(valid) && len(valid) != 0") - } - - data := make([][]byte, len(v)) - for i := range v { - if len(valid) > 0 && !valid[i] { - continue - } - data[i] = v[i][:] - } - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).AppendValues(data, valid) -} - -func (b *UUIDBuilder) UnmarshalOne(dec *json.Decoder) error { - t, err := dec.Token() - if err != nil { - return err - } - - var val uuid.UUID - switch v := t.(type) { - case string: - val, err = uuid.Parse(v) - if err != nil { - return err - } - case []byte: - val, err = uuid.ParseBytes(v) - if err != nil { - return err - } - case nil: - b.AppendNull() - return nil - default: - return &json.UnmarshalTypeError{ - Value: fmt.Sprint(t), - Type: reflect.TypeOf([]byte{}), - Offset: dec.InputOffset(), - Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16), - } - } - - b.Append(val) - return nil -} - -func (b *UUIDBuilder) Unmarshal(dec *json.Decoder) error { - for dec.More() { - if err := b.UnmarshalOne(dec); err != nil { - return err - } - } - return nil -} - -func (b *UUIDBuilder) UnmarshalJSON(data []byte) error { - dec := json.NewDecoder(bytes.NewReader(data)) - t, err := dec.Token() - if err != nil { - return err - } - - if delim, ok := t.(json.Delim); !ok || delim != '[' { - return fmt.Errorf("uuid builder must unpack from json array, found %s", delim) - } - - return b.Unmarshal(dec) -} - -// UUIDArray is a simple array which is a FixedSizeBinary(16) -type UUIDArray struct { - array.ExtensionArrayBase -} - -func (a *UUIDArray) String() string { - arr := a.Storage().(*array.FixedSizeBinary) - o := new(strings.Builder) - o.WriteString("[") - for i := 0; i < arr.Len(); i++ { - if i > 0 { - o.WriteString(" ") - } - switch { - case a.IsNull(i): - o.WriteString(array.NullValueStr) - default: - fmt.Fprintf(o, "%q", a.Value(i)) - } - } - o.WriteString("]") - return o.String() -} - -func (a *UUIDArray) Value(i int) uuid.UUID { - if a.IsNull(i) { - return uuid.Nil - } - return uuid.Must(uuid.FromBytes(a.Storage().(*array.FixedSizeBinary).Value(i))) -} - -func (a *UUIDArray) ValueStr(i int) string { - switch { - case a.IsNull(i): - return array.NullValueStr - default: - return a.Value(i).String() - } -} - -func (a *UUIDArray) MarshalJSON() ([]byte, error) { - arr := a.Storage().(*array.FixedSizeBinary) - values := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - if a.IsValid(i) { - values[i] = uuid.Must(uuid.FromBytes(arr.Value(i))).String() - } - } - return json.Marshal(values) -} - -func (a *UUIDArray) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - return a.Value(i) -} - -// UUIDType is a simple extension type that represents a FixedSizeBinary(16) -// to be used for representing UUIDs -type UUIDType struct { - arrow.ExtensionBase -} - -// NewUUIDType is a convenience function to create an instance of UUIDType -// with the correct storage type -func NewUUIDType() *UUIDType { - return &UUIDType{ExtensionBase: arrow.ExtensionBase{Storage: &arrow.FixedSizeBinaryType{ByteWidth: 16}}} -} - -// ArrayType returns TypeOf(UUIDArray{}) for constructing UUID arrays -func (*UUIDType) ArrayType() reflect.Type { - return reflect.TypeOf(UUIDArray{}) -} - -func (*UUIDType) ExtensionName() string { - return "uuid" -} - -func (e *UUIDType) String() string { - return fmt.Sprintf("extension_type", e.Storage) -} - -func (e *UUIDType) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf(`{"name":"%s","metadata":%s}`, e.ExtensionName(), e.Serialize())), nil -} - -// Serialize returns "uuid-serialized" for testing proper metadata passing -func (*UUIDType) Serialize() string { - return "uuid-serialized" -} - -// Deserialize expects storageType to be FixedSizeBinaryType{ByteWidth: 16} and the data to be -// "uuid-serialized" in order to correctly create a UUIDType for testing deserialize. -func (*UUIDType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { - if data != "uuid-serialized" { - return nil, fmt.Errorf("type identifier did not match: '%s'", data) - } - if !arrow.TypeEqual(storageType, &arrow.FixedSizeBinaryType{ByteWidth: 16}) { - return nil, fmt.Errorf("invalid storage type for UUIDType: %s", storageType.Name()) - } - return NewUUIDType(), nil -} - -// ExtensionEquals returns true if both extensions have the same name -func (e *UUIDType) ExtensionEquals(other arrow.ExtensionType) bool { - return e.ExtensionName() == other.ExtensionName() -} - -func (*UUIDType) NewBuilder(mem memory.Allocator) array.Builder { - return NewUUIDBuilder(mem) -} - // Parametric1Array is a simple int32 array for use with the Parametric1Type // in testing a parameterized user-defined extension type. type Parametric1Array struct { @@ -518,14 +295,14 @@ func (SmallintType) ArrayType() reflect.Type { return reflect.TypeOf(SmallintArr func (SmallintType) ExtensionName() string { return "smallint" } -func (SmallintType) Serialize() string { return "smallint" } +func (SmallintType) Serialize() string { return "smallint-serialized" } func (s *SmallintType) ExtensionEquals(other arrow.ExtensionType) bool { return s.Name() == other.Name() } func (SmallintType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { - if data != "smallint" { + if data != "smallint-serialized" { return nil, fmt.Errorf("type identifier did not match: '%s'", data) } if !arrow.TypeEqual(storageType, arrow.PrimitiveTypes.Int16) { diff --git a/go/internal/types/extension_types_test.go b/go/internal/types/extension_types_test.go deleted file mode 100644 index 65f6353d01b..00000000000 --- a/go/internal/types/extension_types_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package types_test - -import ( - "bytes" - "testing" - - "github.com/apache/arrow/go/v18/arrow" - "github.com/apache/arrow/go/v18/arrow/array" - "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/json" - "github.com/apache/arrow/go/v18/internal/types" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var testUUID = uuid.New() - -func TestUUIDExtensionBuilder(t *testing.T) { - mem := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer mem.AssertSize(t, 0) - builder := types.NewUUIDBuilder(mem) - builder.Append(testUUID) - arr := builder.NewArray() - defer arr.Release() - arrStr := arr.String() - assert.Equal(t, "[\""+testUUID.String()+"\"]", arrStr) - jsonStr, err := json.Marshal(arr) - assert.NoError(t, err) - - arr1, _, err := array.FromJSON(mem, types.NewUUIDType(), bytes.NewReader(jsonStr)) - defer arr1.Release() - assert.NoError(t, err) - assert.Equal(t, arr, arr1) -} - -func TestUUIDExtensionRecordBuilder(t *testing.T) { - schema := arrow.NewSchema([]arrow.Field{ - {Name: "uuid", Type: types.NewUUIDType()}, - }, nil) - builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) - builder.Field(0).(*types.UUIDBuilder).Append(testUUID) - record := builder.NewRecord() - b, err := record.MarshalJSON() - require.NoError(t, err) - require.Equal(t, "[{\"uuid\":\""+testUUID.String()+"\"}\n]", string(b)) - record1, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(b)) - require.NoError(t, err) - require.Equal(t, record, record1) -} - -func TestUUIDStringRoundTrip(t *testing.T) { - // 1. create array - mem := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer mem.AssertSize(t, 0) - - b := types.NewUUIDBuilder(mem) - b.Append(uuid.Nil) - b.AppendNull() - b.Append(uuid.NameSpaceURL) - b.AppendNull() - b.Append(testUUID) - - arr := b.NewArray() - defer arr.Release() - - // 2. create array via AppendValueFromString - b1 := types.NewUUIDBuilder(mem) - defer b1.Release() - - for i := 0; i < arr.Len(); i++ { - assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) - } - - arr1 := b1.NewArray() - defer arr1.Release() - - assert.True(t, array.Equal(arr, arr1)) -} diff --git a/go/parquet/cmd/parquet_reader/main.go b/go/parquet/cmd/parquet_reader/main.go index 6e04f4254f9..4e480aeb866 100644 --- a/go/parquet/cmd/parquet_reader/main.go +++ b/go/parquet/cmd/parquet_reader/main.go @@ -154,7 +154,7 @@ func main() { if descr.ConvertedType() != schema.ConvertedTypes.None { fmt.Printf("/%s", descr.ConvertedType()) if descr.ConvertedType() == schema.ConvertedTypes.Decimal { - dec := descr.LogicalType().(*schema.DecimalLogicalType) + dec := descr.LogicalType().(schema.DecimalLogicalType) fmt.Printf("(%d,%d)", dec.Precision(), dec.Scale()) } } diff --git a/go/parquet/metadata/app_version.go b/go/parquet/metadata/app_version.go index 887ed79343a..345e9d440a1 100644 --- a/go/parquet/metadata/app_version.go +++ b/go/parquet/metadata/app_version.go @@ -164,7 +164,7 @@ func (v AppVersion) HasCorrectStatistics(coltype parquet.Type, logicalType schem // parquet-cpp-arrow version 4.0.0 fixed Decimal comparisons for creating min/max stats // parquet-cpp also becomes parquet-cpp-arrow as of version 4.0.0 if v.App == "parquet-cpp" || (v.App == "parquet-cpp-arrow" && v.LessThan(parquet1655FixedVersion)) { - if _, ok := logicalType.(*schema.DecimalLogicalType); ok && coltype == parquet.Types.FixedLenByteArray { + if _, ok := logicalType.(schema.DecimalLogicalType); ok && coltype == parquet.Types.FixedLenByteArray { return false } } diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go index 16282173a68..a238a78133e 100644 --- a/go/parquet/pqarrow/encode_arrow_test.go +++ b/go/parquet/pqarrow/encode_arrow_test.go @@ -30,6 +30,7 @@ import ( "github.com/apache/arrow/go/v18/arrow/bitutil" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/types" @@ -715,16 +716,6 @@ type ParquetIOTestSuite struct { suite.Suite } -func (ps *ParquetIOTestSuite) SetupTest() { - ps.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (ps *ParquetIOTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - ps.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (ps *ParquetIOTestSuite) makeSimpleSchema(typ arrow.DataType, rep parquet.Repetition) *schema.GroupNode { byteWidth := int32(-1) @@ -2053,7 +2044,7 @@ func (ps *ParquetIOTestSuite) TestArrowExtensionTypeRoundTrip() { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(ps.T(), 0) - builder := types.NewUUIDBuilder(mem) + builder := extensions.NewUUIDBuilder(mem) builder.Append(uuid.New()) arr := builder.NewArray() defer arr.Release() @@ -2076,22 +2067,23 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { { // Prepare `written` table with the extension type registered. - extType := types.NewUUIDType() + extType := types.NewSmallintType() bldr := array.NewExtensionBuilder(mem, extType) defer bldr.Release() - bldr.Builder.(*array.FixedSizeBinaryBuilder).AppendValues( - [][]byte{nil, []byte("abcdefghijklmno0"), []byte("abcdefghijklmno1"), []byte("abcdefghijklmno2")}, + bldr.Builder.(*array.Int16Builder).AppendValues( + []int16{0, 0, 1, 2}, []bool{false, true, true, true}) arr := bldr.NewArray() defer arr.Release() - if arrow.GetExtensionType("uuid") != nil { - ps.NoError(arrow.UnregisterExtensionType("uuid")) + if arrow.GetExtensionType("smallint") != nil { + ps.NoError(arrow.UnregisterExtensionType("smallint")) + defer arrow.RegisterExtensionType(extType) } - fld := arrow.Field{Name: "uuid", Type: arr.DataType(), Nullable: true} + fld := arrow.Field{Name: "smallint", Type: arr.DataType(), Nullable: true} cnk := arrow.NewChunked(arr.DataType(), []arrow.Array{arr}) defer arr.Release() // NewChunked written = array.NewTable(arrow.NewSchema([]arrow.Field{fld}, nil), []arrow.Column{*arrow.NewColumn(fld, cnk)}, -1) @@ -2101,16 +2093,16 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { { // Prepare `expected` table with the extension type unregistered in the underlying type. - bldr := array.NewFixedSizeBinaryBuilder(mem, &arrow.FixedSizeBinaryType{ByteWidth: 16}) + bldr := array.NewInt16Builder(mem) defer bldr.Release() bldr.AppendValues( - [][]byte{nil, []byte("abcdefghijklmno0"), []byte("abcdefghijklmno1"), []byte("abcdefghijklmno2")}, + []int16{0, 0, 1, 2}, []bool{false, true, true, true}) arr := bldr.NewArray() defer arr.Release() - fld := arrow.Field{Name: "uuid", Type: arr.DataType(), Nullable: true} + fld := arrow.Field{Name: "smallint", Type: arr.DataType(), Nullable: true} cnk := arrow.NewChunked(arr.DataType(), []arrow.Array{arr}) defer arr.Release() // NewChunked expected = array.NewTable(arrow.NewSchema([]arrow.Field{fld}, nil), []arrow.Column{*arrow.NewColumn(fld, cnk)}, -1) @@ -2147,13 +2139,55 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { ps.Truef(array.Equal(exc, tbc), "expected: %T %s\ngot: %T %s", exc, exc, tbc, tbc) expectedMd := arrow.MetadataFrom(map[string]string{ - ipc.ExtensionTypeKeyName: "uuid", - ipc.ExtensionMetadataKeyName: "uuid-serialized", + ipc.ExtensionTypeKeyName: "smallint", + ipc.ExtensionMetadataKeyName: "smallint-serialized", "PARQUET:field_id": "-1", }) ps.Truef(expectedMd.Equal(tbl.Column(0).Field().Metadata), "expected: %v\ngot: %v", expectedMd, tbl.Column(0).Field().Metadata) } +func (ps *ParquetIOTestSuite) TestArrowExtensionTypeLogicalType() { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(ps.T(), 0) + + jsonType, err := extensions.NewJSONType(arrow.BinaryTypes.String) + ps.NoError(err) + + sch := arrow.NewSchema([]arrow.Field{ + {Name: "uuid", Type: extensions.NewUUIDType()}, + {Name: "json", Type: jsonType}, + }, + nil, + ) + bldr := array.NewRecordBuilder(mem, sch) + defer bldr.Release() + + bldr.Field(0).(*extensions.UUIDBuilder).Append(uuid.New()) + bldr.Field(1).(*array.ExtensionBuilder).AppendValueFromString(`{"hello": ["world", 2, true], "world": null}`) + rec := bldr.NewRecord() + defer rec.Release() + + var buf bytes.Buffer + wr, err := pqarrow.NewFileWriter( + sch, + &buf, + parquet.NewWriterProperties(), + pqarrow.DefaultWriterProps(), + ) + ps.Require().NoError(err) + + ps.Require().NoError(wr.Write(rec)) + ps.Require().NoError(wr.Close()) + + rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes())) + ps.Require().NoError(err) + defer rdr.Close() + + pqSchema := rdr.MetaData().Schema + ps.True(pqSchema.Column(0).LogicalType().Equals(schema.UUIDLogicalType{})) + ps.True(pqSchema.Column(1).LogicalType().Equals(schema.JSONLogicalType{})) +} + func TestWriteTableMemoryAllocation(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) sc := arrow.NewSchema([]arrow.Field{ @@ -2163,7 +2197,7 @@ func TestWriteTableMemoryAllocation(t *testing.T) { arrow.Field{Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, arrow.Field{Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true})}, {Name: "arr_i64", Type: arrow.ListOf(arrow.PrimitiveTypes.Int64)}, - {Name: "uuid", Type: types.NewUUIDType(), Nullable: true}, + {Name: "uuid", Type: extensions.NewUUIDType(), Nullable: true}, }, nil) bld := array.NewRecordBuilder(mem, sc) @@ -2176,7 +2210,7 @@ func TestWriteTableMemoryAllocation(t *testing.T) { abld := bld.Field(3).(*array.ListBuilder) abld.Append(true) abld.ValueBuilder().(*array.Int64Builder).Append(2) - bld.Field(4).(*types.UUIDBuilder).Append(uuid.MustParse("00000000-0000-0000-0000-000000000001")) + bld.Field(4).(*extensions.UUIDBuilder).Append(uuid.MustParse("00000000-0000-0000-0000-000000000001")) rec := bld.NewRecord() bld.Release() diff --git a/go/parquet/pqarrow/path_builder_test.go b/go/parquet/pqarrow/path_builder_test.go index 9bbae426b8a..364f836d0bb 100644 --- a/go/parquet/pqarrow/path_builder_test.go +++ b/go/parquet/pqarrow/path_builder_test.go @@ -22,8 +22,8 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -364,12 +364,12 @@ func TestNestedExtensionListsWithSomeNulls(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) - listType := arrow.ListOf(types.NewUUIDType()) + listType := arrow.ListOf(extensions.NewUUIDType()) bldr := array.NewListBuilder(mem, listType) defer bldr.Release() nestedBldr := bldr.ValueBuilder().(*array.ListBuilder) - vb := nestedBldr.ValueBuilder().(*types.UUIDBuilder) + vb := nestedBldr.ValueBuilder().(*extensions.UUIDBuilder) uuid1 := uuid.New() uuid3 := uuid.New() diff --git a/go/parquet/pqarrow/schema.go b/go/parquet/pqarrow/schema.go index ce5cc6f9050..4882077671f 100644 --- a/go/parquet/pqarrow/schema.go +++ b/go/parquet/pqarrow/schema.go @@ -25,7 +25,6 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/flight" - "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/parquet" "github.com/apache/arrow/go/v18/parquet/file" @@ -120,6 +119,15 @@ func (sm *SchemaManifest) GetFieldIndices(indices []int) ([]int, error) { return ret, nil } +// ExtensionCustomParquetType is an interface that Arrow ExtensionTypes may implement +// to specify the target LogicalType to use when converting to Parquet. +// +// The PrimitiveType is not configurable, and is determined by a fixed mapping from +// the extension's StorageType to a Parquet type (see getParquetType in pqarrow source). +type ExtensionCustomParquetType interface { + ParquetLogicalType() schema.LogicalType +} + func isDictionaryReadSupported(dt arrow.DataType) bool { return arrow.IsBinaryLike(dt.ID()) } @@ -250,104 +258,14 @@ func structToNode(typ *arrow.StructType, name string, nullable bool, props *parq } func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { - var ( - logicalType schema.LogicalType = schema.NoLogicalType{} - typ parquet.Type - repType = repFromNullable(field.Nullable) - length = -1 - precision = -1 - scale = -1 - err error - ) + repType := repFromNullable(field.Nullable) + // Handle complex types i.e. GroupNodes switch field.Type.ID() { case arrow.NULL: - typ = parquet.Types.Int32 - logicalType = &schema.NullLogicalType{} if repType != parquet.Repetitions.Optional { return nil, xerrors.New("nulltype arrow field must be nullable") } - case arrow.BOOL: - typ = parquet.Types.Boolean - case arrow.UINT8: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(8, false) - case arrow.INT8: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(8, true) - case arrow.UINT16: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(16, false) - case arrow.INT16: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(16, true) - case arrow.UINT32: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(32, false) - case arrow.INT32: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(32, true) - case arrow.UINT64: - typ = parquet.Types.Int64 - logicalType = schema.NewIntLogicalType(64, false) - case arrow.INT64: - typ = parquet.Types.Int64 - logicalType = schema.NewIntLogicalType(64, true) - case arrow.FLOAT32: - typ = parquet.Types.Float - case arrow.FLOAT64: - typ = parquet.Types.Double - case arrow.STRING, arrow.LARGE_STRING: - logicalType = schema.StringLogicalType{} - fallthrough - case arrow.BINARY, arrow.LARGE_BINARY: - typ = parquet.Types.ByteArray - case arrow.FIXED_SIZE_BINARY: - typ = parquet.Types.FixedLenByteArray - length = field.Type.(*arrow.FixedSizeBinaryType).ByteWidth - case arrow.DECIMAL, arrow.DECIMAL256: - dectype := field.Type.(arrow.DecimalType) - precision = int(dectype.GetPrecision()) - scale = int(dectype.GetScale()) - - if props.StoreDecimalAsInteger() && 1 <= precision && precision <= 18 { - if precision <= 9 { - typ = parquet.Types.Int32 - } else { - typ = parquet.Types.Int64 - } - } else { - typ = parquet.Types.FixedLenByteArray - length = int(DecimalSize(int32(precision))) - } - - logicalType = schema.NewDecimalLogicalType(int32(precision), int32(scale)) - case arrow.DATE32: - typ = parquet.Types.Int32 - logicalType = schema.DateLogicalType{} - case arrow.DATE64: - typ = parquet.Types.Int32 - logicalType = schema.DateLogicalType{} - case arrow.TIMESTAMP: - typ, logicalType, err = getTimestampMeta(field.Type.(*arrow.TimestampType), props, arrprops) - if err != nil { - return nil, err - } - case arrow.TIME32: - typ = parquet.Types.Int32 - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMillis) - case arrow.TIME64: - typ = parquet.Types.Int64 - timeType := field.Type.(*arrow.Time64Type) - if timeType.Unit == arrow.Nanosecond { - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitNanos) - } else { - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMicros) - } - case arrow.FLOAT16: - typ = parquet.Types.FixedLenByteArray - length = arrow.Float16SizeBytes - logicalType = schema.Float16LogicalType{} case arrow.STRUCT: return structToNode(field.Type.(*arrow.StructType), field.Name, field.Nullable, props, arrprops) case arrow.FIXED_SIZE_LIST, arrow.LIST: @@ -369,16 +287,6 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties dictType := field.Type.(*arrow.DictionaryType) return fieldToNode(name, arrow.Field{Name: name, Type: dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata}, props, arrprops) - case arrow.EXTENSION: - return fieldToNode(name, arrow.Field{ - Name: name, - Type: field.Type.(arrow.ExtensionType).StorageType(), - Nullable: field.Nullable, - Metadata: arrow.MetadataFrom(map[string]string{ - ipc.ExtensionTypeKeyName: field.Type.(arrow.ExtensionType).ExtensionName(), - ipc.ExtensionMetadataKeyName: field.Type.(arrow.ExtensionType).Serialize(), - }), - }, props, arrprops) case arrow.MAP: mapType := field.Type.(*arrow.MapType) keyNode, err := fieldToNode("key", mapType.KeyField(), props, arrprops) @@ -402,8 +310,12 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties }, -1) } return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), -1) - default: - return nil, fmt.Errorf("%w: support for %s", arrow.ErrNotImplemented, field.Type.ID()) + } + + // Not a GroupNode + typ, logicalType, length, err := getParquetType(field.Type, props, arrprops) + if err != nil { + return nil, err } return schema.NewPrimitiveNodeLogical(name, repType, logicalType, typ, length, fieldIDFromMeta(field.Metadata)) @@ -472,7 +384,7 @@ func (s schemaTree) RecordLeaf(leaf *SchemaField) { s.manifest.ColIndexToField[leaf.ColIndex] = leaf } -func arrowInt(log *schema.IntLogicalType) (arrow.DataType, error) { +func arrowInt(log schema.IntLogicalType) (arrow.DataType, error) { switch log.BitWidth() { case 8: if log.IsSigned() { @@ -499,7 +411,7 @@ func arrowInt(log *schema.IntLogicalType) (arrow.DataType, error) { } } -func arrowTime32(logical *schema.TimeLogicalType) (arrow.DataType, error) { +func arrowTime32(logical schema.TimeLogicalType) (arrow.DataType, error) { if logical.TimeUnit() == schema.TimeUnitMillis { return arrow.FixedWidthTypes.Time32ms, nil } @@ -507,7 +419,7 @@ func arrowTime32(logical *schema.TimeLogicalType) (arrow.DataType, error) { return nil, xerrors.New(logical.String() + " cannot annotate a time32") } -func arrowTime64(logical *schema.TimeLogicalType) (arrow.DataType, error) { +func arrowTime64(logical schema.TimeLogicalType) (arrow.DataType, error) { switch logical.TimeUnit() { case schema.TimeUnitMicros: return arrow.FixedWidthTypes.Time64us, nil @@ -518,7 +430,7 @@ func arrowTime64(logical *schema.TimeLogicalType) (arrow.DataType, error) { } } -func arrowTimestamp(logical *schema.TimestampLogicalType) (arrow.DataType, error) { +func arrowTimestamp(logical schema.TimestampLogicalType) (arrow.DataType, error) { tz := "" // ConvertedTypes are adjusted to UTC per backward compatibility guidelines @@ -539,7 +451,7 @@ func arrowTimestamp(logical *schema.TimestampLogicalType) (arrow.DataType, error } } -func arrowDecimal(logical *schema.DecimalLogicalType) arrow.DataType { +func arrowDecimal(logical schema.DecimalLogicalType) arrow.DataType { if logical.Precision() <= decimal128.MaxPrecision { return &arrow.Decimal128Type{Precision: logical.Precision(), Scale: logical.Scale()} } @@ -550,11 +462,11 @@ func arrowFromInt32(logical schema.LogicalType) (arrow.DataType, error) { switch logtype := logical.(type) { case schema.NoLogicalType: return arrow.PrimitiveTypes.Int32, nil - case *schema.TimeLogicalType: + case schema.TimeLogicalType: return arrowTime32(logtype) - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil - case *schema.IntLogicalType: + case schema.IntLogicalType: return arrowInt(logtype) case schema.DateLogicalType: return arrow.FixedWidthTypes.Date32, nil @@ -569,13 +481,13 @@ func arrowFromInt64(logical schema.LogicalType) (arrow.DataType, error) { } switch logtype := logical.(type) { - case *schema.IntLogicalType: + case schema.IntLogicalType: return arrowInt(logtype) - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil - case *schema.TimeLogicalType: + case schema.TimeLogicalType: return arrowTime64(logtype) - case *schema.TimestampLogicalType: + case schema.TimestampLogicalType: return arrowTimestamp(logtype) default: return nil, xerrors.New(logical.String() + " cannot annotate int64") @@ -586,7 +498,7 @@ func arrowFromByteArray(logical schema.LogicalType) (arrow.DataType, error) { switch logtype := logical.(type) { case schema.StringLogicalType: return arrow.BinaryTypes.String, nil - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil case schema.NoLogicalType, schema.EnumLogicalType, @@ -600,7 +512,7 @@ func arrowFromByteArray(logical schema.LogicalType) (arrow.DataType, error) { func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, error) { switch logtype := logical.(type) { - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil case schema.NoLogicalType, schema.IntervalLogicalType, schema.UUIDLogicalType: return &arrow.FixedSizeBinaryType{ByteWidth: int(length)}, nil @@ -611,6 +523,84 @@ func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, erro } } +func getParquetType(typ arrow.DataType, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (parquet.Type, schema.LogicalType, int, error) { + switch typ.ID() { + case arrow.NULL: + return parquet.Types.Int32, schema.NullLogicalType{}, -1, nil + case arrow.BOOL: + return parquet.Types.Boolean, schema.NoLogicalType{}, -1, nil + case arrow.UINT8: + return parquet.Types.Int32, schema.NewIntLogicalType(8, false), -1, nil + case arrow.INT8: + return parquet.Types.Int32, schema.NewIntLogicalType(8, true), -1, nil + case arrow.UINT16: + return parquet.Types.Int32, schema.NewIntLogicalType(16, false), -1, nil + case arrow.INT16: + return parquet.Types.Int32, schema.NewIntLogicalType(16, true), -1, nil + case arrow.UINT32: + return parquet.Types.Int32, schema.NewIntLogicalType(32, false), -1, nil + case arrow.INT32: + return parquet.Types.Int32, schema.NewIntLogicalType(32, true), -1, nil + case arrow.UINT64: + return parquet.Types.Int64, schema.NewIntLogicalType(64, false), -1, nil + case arrow.INT64: + return parquet.Types.Int64, schema.NewIntLogicalType(64, true), -1, nil + case arrow.FLOAT32: + return parquet.Types.Float, schema.NoLogicalType{}, -1, nil + case arrow.FLOAT64: + return parquet.Types.Double, schema.NoLogicalType{}, -1, nil + case arrow.STRING, arrow.LARGE_STRING: + return parquet.Types.ByteArray, schema.StringLogicalType{}, -1, nil + case arrow.BINARY, arrow.LARGE_BINARY: + return parquet.Types.ByteArray, schema.NoLogicalType{}, -1, nil + case arrow.FIXED_SIZE_BINARY: + return parquet.Types.FixedLenByteArray, schema.NoLogicalType{}, typ.(*arrow.FixedSizeBinaryType).ByteWidth, nil + case arrow.DECIMAL, arrow.DECIMAL256: + dectype := typ.(arrow.DecimalType) + precision := int(dectype.GetPrecision()) + scale := int(dectype.GetScale()) + + if !props.StoreDecimalAsInteger() || precision > 18 { + return parquet.Types.FixedLenByteArray, schema.NewDecimalLogicalType(int32(precision), int32(scale)), int(DecimalSize(int32(precision))), nil + } + + pqType := parquet.Types.Int32 + if precision > 9 { + pqType = parquet.Types.Int64 + } + + return pqType, schema.NoLogicalType{}, -1, nil + case arrow.DATE32: + return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil + case arrow.DATE64: + return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil + case arrow.TIMESTAMP: + pqType, logicalType, err := getTimestampMeta(typ.(*arrow.TimestampType), props, arrprops) + return pqType, logicalType, -1, err + case arrow.TIME32: + return parquet.Types.Int32, schema.NewTimeLogicalType(true, schema.TimeUnitMillis), -1, nil + case arrow.TIME64: + pqTimeUnit := schema.TimeUnitMicros + if typ.(*arrow.Time64Type).Unit == arrow.Nanosecond { + pqTimeUnit = schema.TimeUnitNanos + } + + return parquet.Types.Int64, schema.NewTimeLogicalType(true, pqTimeUnit), -1, nil + case arrow.FLOAT16: + return parquet.Types.FixedLenByteArray, schema.Float16LogicalType{}, arrow.Float16SizeBytes, nil + case arrow.EXTENSION: + storageType := typ.(arrow.ExtensionType).StorageType() + pqType, logicalType, length, err := getParquetType(storageType, props, arrprops) + if withCustomType, ok := typ.(ExtensionCustomParquetType); ok { + logicalType = withCustomType.ParquetLogicalType() + } + + return pqType, logicalType, length, err + default: + return parquet.Type(0), nil, 0, fmt.Errorf("%w: support for %s", arrow.ErrNotImplemented, typ.ID()) + } +} + func getArrowType(physical parquet.Type, logical schema.LogicalType, typeLen int) (arrow.DataType, error) { if !logical.IsValid() || logical.Equals(schema.NullLogicalType{}) { return arrow.Null, nil diff --git a/go/parquet/pqarrow/schema_test.go b/go/parquet/pqarrow/schema_test.go index 24b031c174b..528200fd0e7 100644 --- a/go/parquet/pqarrow/schema_test.go +++ b/go/parquet/pqarrow/schema_test.go @@ -21,10 +21,10 @@ import ( "testing" "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/flight" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/apache/arrow/go/v18/parquet" "github.com/apache/arrow/go/v18/parquet/metadata" "github.com/apache/arrow/go/v18/parquet/pqarrow" @@ -34,7 +34,7 @@ import ( ) func TestGetOriginSchemaBase64(t *testing.T) { - uuidType := types.NewUUIDType() + uuidType := extensions.NewUUIDType() md := arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"-1"}) extMd := arrow.NewMetadata([]string{ipc.ExtensionMetadataKeyName, ipc.ExtensionTypeKeyName, "PARQUET:field_id"}, []string{uuidType.Serialize(), uuidType.ExtensionName(), "-1"}) origArrSc := arrow.NewSchema([]arrow.Field{ @@ -44,10 +44,6 @@ func TestGetOriginSchemaBase64(t *testing.T) { }, nil) arrSerializedSc := flight.SerializeSchema(origArrSc, memory.DefaultAllocator) - if err := arrow.RegisterExtensionType(uuidType); err != nil { - t.Fatal(err) - } - defer arrow.UnregisterExtensionType(uuidType.ExtensionName()) pqschema, err := pqarrow.ToParquet(origArrSc, nil, pqarrow.DefaultWriterProps()) require.NoError(t, err) @@ -71,11 +67,7 @@ func TestGetOriginSchemaBase64(t *testing.T) { } func TestGetOriginSchemaUnregisteredExtension(t *testing.T) { - uuidType := types.NewUUIDType() - if err := arrow.RegisterExtensionType(uuidType); err != nil { - t.Fatal(err) - } - + uuidType := extensions.NewUUIDType() md := arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"-1"}) origArrSc := arrow.NewSchema([]arrow.Field{ {Name: "f1", Type: arrow.BinaryTypes.String, Metadata: md}, @@ -90,6 +82,7 @@ func TestGetOriginSchemaUnregisteredExtension(t *testing.T) { kv.Append("ARROW:schema", base64.StdEncoding.EncodeToString(arrSerializedSc)) arrow.UnregisterExtensionType(uuidType.ExtensionName()) + defer arrow.RegisterExtensionType(uuidType) arrsc, err := pqarrow.FromParquet(pqschema, nil, kv) require.NoError(t, err) diff --git a/go/parquet/schema/converted_types.go b/go/parquet/schema/converted_types.go index 5fc10f61ceb..b2b6f50cbf6 100644 --- a/go/parquet/schema/converted_types.go +++ b/go/parquet/schema/converted_types.go @@ -113,13 +113,9 @@ func (p ConvertedType) ToLogicalType(convertedDecimal DecimalMetadata) LogicalTy case ConvertedTypes.TimeMicros: return NewTimeLogicalType(true /* adjustedToUTC */, TimeUnitMicros) case ConvertedTypes.TimestampMillis: - t := NewTimestampLogicalType(true /* adjustedToUTC */, TimeUnitMillis) - t.(*TimestampLogicalType).fromConverted = true - return t + return NewTimestampLogicalTypeWithOpts(WithTSIsAdjustedToUTC(), WithTSTimeUnitType(TimeUnitMillis), WithTSFromConverted()) case ConvertedTypes.TimestampMicros: - t := NewTimestampLogicalType(true /* adjustedToUTC */, TimeUnitMicros) - t.(*TimestampLogicalType).fromConverted = true - return t + return NewTimestampLogicalTypeWithOpts(WithTSIsAdjustedToUTC(), WithTSTimeUnitType(TimeUnitMicros), WithTSFromConverted()) case ConvertedTypes.Interval: return IntervalLogicalType{} case ConvertedTypes.Int8: diff --git a/go/parquet/schema/logical_types.go b/go/parquet/schema/logical_types.go index e8adce1ca14..fa46ea0172f 100644 --- a/go/parquet/schema/logical_types.go +++ b/go/parquet/schema/logical_types.go @@ -45,21 +45,21 @@ func getLogicalType(l *format.LogicalType) LogicalType { case l.IsSetENUM(): return EnumLogicalType{} case l.IsSetDECIMAL(): - return &DecimalLogicalType{typ: l.DECIMAL} + return DecimalLogicalType{typ: l.DECIMAL} case l.IsSetDATE(): return DateLogicalType{} case l.IsSetTIME(): if timeUnitFromThrift(l.TIME.Unit) == TimeUnitUnknown { panic("parquet: TimeUnit must be one of MILLIS, MICROS, or NANOS for Time logical type") } - return &TimeLogicalType{typ: l.TIME} + return TimeLogicalType{typ: l.TIME} case l.IsSetTIMESTAMP(): if timeUnitFromThrift(l.TIMESTAMP.Unit) == TimeUnitUnknown { panic("parquet: TimeUnit must be one of MILLIS, MICROS, or NANOS for Timestamp logical type") } - return &TimestampLogicalType{typ: l.TIMESTAMP} + return TimestampLogicalType{typ: l.TIMESTAMP} case l.IsSetINTEGER(): - return &IntLogicalType{typ: l.INTEGER} + return IntLogicalType{typ: l.INTEGER} case l.IsSetUNKNOWN(): return NullLogicalType{} case l.IsSetJSON(): @@ -344,7 +344,7 @@ func NewDecimalLogicalType(precision int32, scale int32) LogicalType { if scale < 0 || scale > precision { panic("parquet: scale must be a non-negative integer that does not exceed precision for decimal logical type") } - return &DecimalLogicalType{typ: &format.DecimalType{Precision: precision, Scale: scale}} + return DecimalLogicalType{typ: &format.DecimalType{Precision: precision, Scale: scale}} } // DecimalLogicalType is used to represent a decimal value of a given @@ -405,7 +405,7 @@ func (t DecimalLogicalType) toThrift() *format.LogicalType { } func (t DecimalLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*DecimalLogicalType) + other, ok := rhs.(DecimalLogicalType) if !ok { return false } @@ -509,7 +509,7 @@ func createTimeUnit(unit TimeUnitType) *format.TimeUnit { // NewTimeLogicalType returns a time type of the given unit. func NewTimeLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimeLogicalType{typ: &format.TimeType{ + return TimeLogicalType{typ: &format.TimeType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), }} @@ -584,7 +584,7 @@ func (t TimeLogicalType) toThrift() *format.LogicalType { } func (t TimeLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*TimeLogicalType) + other, ok := rhs.(TimeLogicalType) if !ok { return false } @@ -595,7 +595,7 @@ func (t TimeLogicalType) Equals(rhs LogicalType) bool { // NewTimestampLogicalType returns a logical timestamp type with "forceConverted" // set to false func NewTimestampLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimestampLogicalType{ + return TimestampLogicalType{ typ: &format.TimestampType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), @@ -608,7 +608,7 @@ func NewTimestampLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalTyp // NewTimestampLogicalTypeForce returns a timestamp logical type with // "forceConverted" set to true func NewTimestampLogicalTypeForce(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimestampLogicalType{ + return TimestampLogicalType{ typ: &format.TimestampType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), @@ -654,14 +654,14 @@ func WithTSFromConverted() TimestampOpt { // // TimestampType Unit defaults to milliseconds (TimeUnitMillis) func NewTimestampLogicalTypeWithOpts(opts ...TimestampOpt) LogicalType { - ts := &TimestampLogicalType{ + ts := TimestampLogicalType{ typ: &format.TimestampType{ Unit: createTimeUnit(TimeUnitMillis), // default to milliseconds }, } for _, o := range opts { - o(ts) + o(&ts) } return ts @@ -760,7 +760,7 @@ func (t TimestampLogicalType) toThrift() *format.LogicalType { } func (t TimestampLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*TimestampLogicalType) + other, ok := rhs.(TimestampLogicalType) if !ok { return false } @@ -778,7 +778,7 @@ func NewIntLogicalType(bitWidth int8, signed bool) LogicalType { default: panic("parquet: bit width must be exactly 8, 16, 32, or 64 for Int logical type") } - return &IntLogicalType{ + return IntLogicalType{ typ: &format.IntType{ BitWidth: bitWidth, IsSigned: signed, @@ -864,7 +864,7 @@ func (t IntLogicalType) toThrift() *format.LogicalType { } func (t IntLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*IntLogicalType) + other, ok := rhs.(IntLogicalType) if !ok { return false } diff --git a/go/parquet/schema/logical_types_test.go b/go/parquet/schema/logical_types_test.go index e33925966e1..395d1504182 100644 --- a/go/parquet/schema/logical_types_test.go +++ b/go/parquet/schema/logical_types_test.go @@ -38,18 +38,18 @@ func TestConvertedLogicalEquivalences(t *testing.T) { {"list", schema.ConvertedTypes.List, schema.NewListLogicalType(), schema.NewListLogicalType()}, {"enum", schema.ConvertedTypes.Enum, schema.EnumLogicalType{}, schema.EnumLogicalType{}}, {"date", schema.ConvertedTypes.Date, schema.DateLogicalType{}, schema.DateLogicalType{}}, - {"timemilli", schema.ConvertedTypes.TimeMillis, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimeLogicalType{}}, - {"timemicro", schema.ConvertedTypes.TimeMicros, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimeLogicalType{}}, - {"timestampmilli", schema.ConvertedTypes.TimestampMillis, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimestampLogicalType{}}, - {"timestampmicro", schema.ConvertedTypes.TimestampMicros, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimestampLogicalType{}}, - {"uint8", schema.ConvertedTypes.Uint8, schema.NewIntLogicalType(8 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint16", schema.ConvertedTypes.Uint16, schema.NewIntLogicalType(16 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint32", schema.ConvertedTypes.Uint32, schema.NewIntLogicalType(32 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint64", schema.ConvertedTypes.Uint64, schema.NewIntLogicalType(64 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"int8", schema.ConvertedTypes.Int8, schema.NewIntLogicalType(8 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int16", schema.ConvertedTypes.Int16, schema.NewIntLogicalType(16 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int32", schema.ConvertedTypes.Int32, schema.NewIntLogicalType(32 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int64", schema.ConvertedTypes.Int64, schema.NewIntLogicalType(64 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, + {"timemilli", schema.ConvertedTypes.TimeMillis, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimeLogicalType{}}, + {"timemicro", schema.ConvertedTypes.TimeMicros, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimeLogicalType{}}, + {"timestampmilli", schema.ConvertedTypes.TimestampMillis, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimestampLogicalType{}}, + {"timestampmicro", schema.ConvertedTypes.TimestampMicros, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimestampLogicalType{}}, + {"uint8", schema.ConvertedTypes.Uint8, schema.NewIntLogicalType(8 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint16", schema.ConvertedTypes.Uint16, schema.NewIntLogicalType(16 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint32", schema.ConvertedTypes.Uint32, schema.NewIntLogicalType(32 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint64", schema.ConvertedTypes.Uint64, schema.NewIntLogicalType(64 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"int8", schema.ConvertedTypes.Int8, schema.NewIntLogicalType(8 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int16", schema.ConvertedTypes.Int16, schema.NewIntLogicalType(16 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int32", schema.ConvertedTypes.Int32, schema.NewIntLogicalType(32 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int64", schema.ConvertedTypes.Int64, schema.NewIntLogicalType(64 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, {"json", schema.ConvertedTypes.JSON, schema.JSONLogicalType{}, schema.JSONLogicalType{}}, {"bson", schema.ConvertedTypes.BSON, schema.BSONLogicalType{}, schema.BSONLogicalType{}}, {"interval", schema.ConvertedTypes.Interval, schema.IntervalLogicalType{}, schema.IntervalLogicalType{}}, @@ -72,8 +72,8 @@ func TestConvertedLogicalEquivalences(t *testing.T) { fromMake := schema.NewDecimalLogicalType(10, 4) assert.IsType(t, fromMake, fromConverted) assert.True(t, fromConverted.Equals(fromMake)) - assert.IsType(t, &schema.DecimalLogicalType{}, fromConverted) - assert.IsType(t, &schema.DecimalLogicalType{}, fromMake) + assert.IsType(t, schema.DecimalLogicalType{}, fromConverted) + assert.IsType(t, schema.DecimalLogicalType{}, fromMake) assert.True(t, schema.NewDecimalLogicalType(16, 0).Equals(schema.NewDecimalLogicalType(16, 0))) }) } @@ -160,12 +160,12 @@ func TestNewTypeIncompatibility(t *testing.T) { {"uuid", schema.UUIDLogicalType{}, schema.UUIDLogicalType{}}, {"float16", schema.Float16LogicalType{}, schema.Float16LogicalType{}}, {"null", schema.NullLogicalType{}, schema.NullLogicalType{}}, - {"not-utc-time_milli", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimeLogicalType{}}, - {"not-utc-time-micro", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimeLogicalType{}}, - {"not-utc-time-nano", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimeLogicalType{}}, - {"utc-time-nano", schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimeLogicalType{}}, - {"not-utc-timestamp-nano", schema.NewTimestampLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimestampLogicalType{}}, - {"utc-timestamp-nano", schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimestampLogicalType{}}, + {"not-utc-time_milli", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimeLogicalType{}}, + {"not-utc-time-micro", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimeLogicalType{}}, + {"not-utc-time-nano", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimeLogicalType{}}, + {"utc-time-nano", schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimeLogicalType{}}, + {"not-utc-timestamp-nano", schema.NewTimestampLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimestampLogicalType{}}, + {"utc-timestamp-nano", schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimestampLogicalType{}}, } for _, tt := range tests { diff --git a/go/parquet/schema/schema_element_test.go b/go/parquet/schema/schema_element_test.go index 7da55ce93ab..e427ba6485e 100644 --- a/go/parquet/schema/schema_element_test.go +++ b/go/parquet/schema/schema_element_test.go @@ -192,7 +192,7 @@ func (s *SchemaElementConstructionSuite) TestSimple() { func (s *SchemaElementConstructionSuite) reconstructDecimal(c schemaElementConstructArgs) *decimalSchemaElementConstruction { ret := s.reconstruct(c) - dec := c.logical.(*DecimalLogicalType) + dec := c.logical.(DecimalLogicalType) return &decimalSchemaElementConstruction{*ret, int(dec.Precision()), int(dec.Scale())} } @@ -359,7 +359,7 @@ func (s *SchemaElementConstructionSuite) TestTemporal() { func (s *SchemaElementConstructionSuite) reconstructInteger(c schemaElementConstructArgs) *intSchemaElementConstruction { base := s.reconstruct(c) - l := c.logical.(*IntLogicalType) + l := c.logical.(IntLogicalType) return &intSchemaElementConstruction{ *base, l.BitWidth(),