From 39fe6fc45d14accf63b7aefed5a8f1225f6b552a Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Mon, 12 Aug 2024 16:49:57 -0400 Subject: [PATCH 001/157] GH-17682: [Go] Bool8 Extension Type Implementation (#43323) ### Rationale for this change Go implementation of #43234 ### What changes are included in this PR? - Go implementation of the `Bool8` extension type - Minor refactor of existing extension builder interfaces ### Are these changes tested? Yes, unit tests and basic read/write benchmarks are included. ### Are there any user-facing changes? - A new extension type is added - Custom extension builders no longer need another builder created and released separately. * GitHub Issue: #17682 Authored-by: Joel Lubinitsky Signed-off-by: Joel Lubinitsky --- go/arrow/array/builder.go | 11 +- go/arrow/array/extension_builder.go | 10 +- go/arrow/extensions/bool8.go | 216 +++++++++++++++ go/arrow/extensions/bool8_test.go | 319 ++++++++++++++++++++++ go/arrow/extensions/extensions_test.go | 105 +++++++ go/internal/types/extension_types.go | 9 +- go/internal/types/extension_types_test.go | 16 +- go/parquet/pqarrow/encode_arrow_test.go | 4 +- 8 files changed, 663 insertions(+), 27 deletions(-) create mode 100644 go/arrow/extensions/bool8.go create mode 100644 go/arrow/extensions/bool8_test.go create mode 100644 go/arrow/extensions/extensions_test.go diff --git a/go/arrow/array/builder.go b/go/arrow/array/builder.go index 6c8ea877a2f..1f4d0ea9635 100644 --- a/go/arrow/array/builder.go +++ b/go/arrow/array/builder.go @@ -349,12 +349,13 @@ func NewBuilder(mem memory.Allocator, dtype arrow.DataType) Builder { typ := dtype.(*arrow.LargeListViewType) return NewLargeListViewBuilderWithField(mem, typ.ElemField()) case arrow.EXTENSION: - typ := dtype.(arrow.ExtensionType) - bldr := NewExtensionBuilder(mem, typ) - if custom, ok := typ.(ExtensionBuilderWrapper); ok { - return custom.NewBuilder(bldr) + if custom, ok := dtype.(CustomExtensionBuilder); ok { + return custom.NewBuilder(mem) } - return bldr + if typ, ok := dtype.(arrow.ExtensionType); ok { + return NewExtensionBuilder(mem, typ) + } + panic(fmt.Errorf("arrow/array: invalid extension type: %T", dtype)) case arrow.FIXED_SIZE_LIST: typ := dtype.(*arrow.FixedSizeListType) return NewFixedSizeListBuilderWithField(mem, typ.Len(), typ.ElemField()) diff --git a/go/arrow/array/extension_builder.go b/go/arrow/array/extension_builder.go index a71287faf0e..9c2ee880564 100644 --- a/go/arrow/array/extension_builder.go +++ b/go/arrow/array/extension_builder.go @@ -16,8 +16,10 @@ package array -// ExtensionBuilderWrapper is an interface that you need to implement in your custom extension type if you want to provide a customer builder as well. -// See example in ./arrow/internal/testing/types/extension_types.go -type ExtensionBuilderWrapper interface { - NewBuilder(bldr *ExtensionBuilder) Builder +import "github.com/apache/arrow/go/v18/arrow/memory" + +// CustomExtensionBuilder is an interface that custom extension types may implement to provide a custom builder +// instead of the underlying storage type's builder when array.NewBuilder is called with that type. +type CustomExtensionBuilder interface { + NewBuilder(memory.Allocator) Builder } diff --git a/go/arrow/extensions/bool8.go b/go/arrow/extensions/bool8.go new file mode 100644 index 00000000000..20ab024a2a2 --- /dev/null +++ b/go/arrow/extensions/bool8.go @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "fmt" + "reflect" + "strconv" + "strings" + "unsafe" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" +) + +// Bool8Type represents a logical boolean that is stored using 8 bits. +type Bool8Type struct { + arrow.ExtensionBase +} + +// NewBool8Type creates a new Bool8Type with the underlying storage type set correctly to Int8. +func NewBool8Type() *Bool8Type { + return &Bool8Type{ExtensionBase: arrow.ExtensionBase{Storage: arrow.PrimitiveTypes.Int8}} +} + +func (b *Bool8Type) ArrayType() reflect.Type { return reflect.TypeOf(Bool8Array{}) } + +func (b *Bool8Type) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + if !arrow.TypeEqual(storageType, arrow.PrimitiveTypes.Int8) { + return nil, fmt.Errorf("invalid storage type for Bool8Type: %s", storageType.Name()) + } + return NewBool8Type(), nil +} + +func (b *Bool8Type) ExtensionEquals(other arrow.ExtensionType) bool { + return b.ExtensionName() == other.ExtensionName() +} + +func (b *Bool8Type) ExtensionName() string { return "arrow.bool8" } + +func (b *Bool8Type) Serialize() string { return "" } + +func (b *Bool8Type) String() string { return fmt.Sprintf("extension<%s>", b.ExtensionName()) } + +func (*Bool8Type) NewBuilder(mem memory.Allocator) array.Builder { + return NewBool8Builder(mem) +} + +// Bool8Array is logically an array of boolean values but uses +// 8 bits to store values instead of 1 bit as in the native BooleanArray. +type Bool8Array struct { + array.ExtensionArrayBase +} + +func (a *Bool8Array) String() string { + var o strings.Builder + o.WriteString("[") + for i := 0; i < a.Len(); i++ { + if i > 0 { + o.WriteString(" ") + } + switch { + case a.IsNull(i): + o.WriteString(array.NullValueStr) + default: + fmt.Fprintf(&o, "%v", a.Value(i)) + } + } + o.WriteString("]") + return o.String() +} + +func (a *Bool8Array) Value(i int) bool { + return a.Storage().(*array.Int8).Value(i) != 0 +} + +func (a *Bool8Array) BoolValues() []bool { + int8s := a.Storage().(*array.Int8).Int8Values() + return unsafe.Slice((*bool)(unsafe.Pointer(unsafe.SliceData(int8s))), len(int8s)) +} + +func (a *Bool8Array) ValueStr(i int) string { + switch { + case a.IsNull(i): + return array.NullValueStr + default: + return fmt.Sprint(a.Value(i)) + } +} + +func (a *Bool8Array) MarshalJSON() ([]byte, error) { + values := make([]interface{}, a.Len()) + for i := 0; i < a.Len(); i++ { + if a.IsValid(i) { + values[i] = a.Value(i) + } + } + return json.Marshal(values) +} + +func (a *Bool8Array) GetOneForMarshal(i int) interface{} { + if a.IsNull(i) { + return nil + } + return a.Value(i) +} + +// boolToInt8 performs the simple scalar conversion of bool to the canonical int8 +// value for the Bool8Type. +func boolToInt8(v bool) int8 { + var res int8 + if v { + res = 1 + } + return res +} + +// Bool8Builder is a convenience builder for the Bool8 extension type, +// allowing arrays to be built with boolean values rather than the underlying storage type. +type Bool8Builder struct { + *array.ExtensionBuilder +} + +// NewBool8Builder creates a new Bool8Builder, exposing a convenient and efficient interface +// for writing boolean values to the underlying int8 storage array. +func NewBool8Builder(mem memory.Allocator) *Bool8Builder { + return &Bool8Builder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewBool8Type())} +} + +func (b *Bool8Builder) Append(v bool) { + b.ExtensionBuilder.Builder.(*array.Int8Builder).Append(boolToInt8(v)) +} + +func (b *Bool8Builder) UnsafeAppend(v bool) { + b.ExtensionBuilder.Builder.(*array.Int8Builder).UnsafeAppend(boolToInt8(v)) +} + +func (b *Bool8Builder) AppendValueFromString(s string) error { + if s == array.NullValueStr { + b.AppendNull() + return nil + } + + val, err := strconv.ParseBool(s) + if err != nil { + return err + } + + b.Append(val) + return nil +} + +func (b *Bool8Builder) AppendValues(v []bool, valid []bool) { + boolsAsInt8s := unsafe.Slice((*int8)(unsafe.Pointer(unsafe.SliceData(v))), len(v)) + b.ExtensionBuilder.Builder.(*array.Int8Builder).AppendValues(boolsAsInt8s, valid) +} + +func (b *Bool8Builder) UnmarshalOne(dec *json.Decoder) error { + t, err := dec.Token() + if err != nil { + return err + } + + switch v := t.(type) { + case bool: + b.Append(v) + return nil + case string: + return b.AppendValueFromString(v) + case int8: + b.ExtensionBuilder.Builder.(*array.Int8Builder).Append(v) + return nil + case nil: + b.AppendNull() + return nil + default: + return &json.UnmarshalTypeError{ + Value: fmt.Sprint(t), + Type: reflect.TypeOf([]byte{}), + Offset: dec.InputOffset(), + Struct: "Bool8Builder", + } + } +} + +func (b *Bool8Builder) Unmarshal(dec *json.Decoder) error { + for dec.More() { + if err := b.UnmarshalOne(dec); err != nil { + return err + } + } + return nil +} + +var ( + _ arrow.ExtensionType = (*Bool8Type)(nil) + _ array.CustomExtensionBuilder = (*Bool8Type)(nil) + _ array.ExtensionArray = (*Bool8Array)(nil) + _ array.Builder = (*Bool8Builder)(nil) +) diff --git a/go/arrow/extensions/bool8_test.go b/go/arrow/extensions/bool8_test.go new file mode 100644 index 00000000000..9f7365d1555 --- /dev/null +++ b/go/arrow/extensions/bool8_test.go @@ -0,0 +1,319 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/ipc" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + MINSIZE = 1024 + MAXSIZE = 65536 +) + +func TestBool8ExtensionBuilder(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + builder := extensions.NewBool8Builder(mem) + defer builder.Release() + + builder.Append(true) + builder.AppendNull() + builder.Append(false) + arr := builder.NewArray() + defer arr.Release() + + arrStr := arr.String() + require.Equal(t, "[true (null) false]", arrStr) + + jsonStr, err := json.Marshal(arr) + require.NoError(t, err) + + arr1, _, err := array.FromJSON(mem, extensions.NewBool8Type(), bytes.NewReader(jsonStr)) + require.NoError(t, err) + defer arr1.Release() + + require.Equal(t, arr, arr1) +} + +func TestBool8ExtensionRecordBuilder(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "bool8", Type: extensions.NewBool8Type()}, + }, nil) + + builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + defer builder.Release() + + builder.Field(0).(*extensions.Bool8Builder).Append(true) + record := builder.NewRecord() + defer record.Release() + + b, err := record.MarshalJSON() + require.NoError(t, err) + require.Equal(t, "[{\"bool8\":true}\n]", string(b)) + + record1, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(b)) + require.NoError(t, err) + defer record1.Release() + + require.Equal(t, record, record1) + + require.NoError(t, builder.UnmarshalJSON([]byte(`{"bool8":true}`))) + record = builder.NewRecord() + defer record.Release() + + require.Equal(t, schema, record.Schema()) + require.Equal(t, true, record.Column(0).(*extensions.Bool8Array).Value(0)) +} + +func TestBool8StringRoundTrip(t *testing.T) { + // 1. create array + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + b := extensions.NewBool8Builder(mem) + b.Append(true) + b.AppendNull() + b.Append(false) + b.AppendNull() + b.Append(true) + + arr := b.NewArray() + defer arr.Release() + + // 2. create array via AppendValueFromString + b1 := extensions.NewBool8Builder(mem) + defer b1.Release() + + for i := 0; i < arr.Len(); i++ { + assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) + } + + arr1 := b1.NewArray() + defer arr1.Release() + + assert.True(t, array.Equal(arr, arr1)) +} + +func TestCompareBool8AndBoolean(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + bool8bldr := extensions.NewBool8Builder(mem) + defer bool8bldr.Release() + + boolbldr := array.NewBooleanBuilder(mem) + defer boolbldr.Release() + + inputVals := []bool{true, false, false, false, true} + inputValidity := []bool{true, false, true, false, true} + + bool8bldr.AppendValues(inputVals, inputValidity) + bool8Arr := bool8bldr.NewExtensionArray().(*extensions.Bool8Array) + defer bool8Arr.Release() + + boolbldr.AppendValues(inputVals, inputValidity) + boolArr := boolbldr.NewBooleanArray() + defer boolArr.Release() + + require.Equal(t, boolArr.Len(), bool8Arr.Len()) + for i := 0; i < boolArr.Len(); i++ { + require.Equal(t, boolArr.Value(i), bool8Arr.Value(i)) + } +} + +func TestReinterpretStorageEqualToValues(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + bool8bldr := extensions.NewBool8Builder(mem) + defer bool8bldr.Release() + + inputVals := []bool{true, false, false, false, true} + inputValidity := []bool{true, false, true, false, true} + + bool8bldr.AppendValues(inputVals, inputValidity) + bool8Arr := bool8bldr.NewExtensionArray().(*extensions.Bool8Array) + defer bool8Arr.Release() + + boolValsCopy := make([]bool, bool8Arr.Len()) + for i := 0; i < bool8Arr.Len(); i++ { + boolValsCopy[i] = bool8Arr.Value(i) + } + + boolValsZeroCopy := bool8Arr.BoolValues() + + require.Equal(t, len(boolValsZeroCopy), len(boolValsCopy)) + for i := range boolValsCopy { + require.Equal(t, boolValsZeroCopy[i], boolValsCopy[i]) + } +} + +func TestBool8TypeBatchIPCRoundTrip(t *testing.T) { + typ := extensions.NewBool8Type() + arrow.RegisterExtensionType(typ) + defer arrow.UnregisterExtensionType(typ.ExtensionName()) + + storage, _, err := array.FromJSON(memory.DefaultAllocator, arrow.PrimitiveTypes.Int8, + strings.NewReader(`[-1, 0, 1, 2, null]`)) + require.NoError(t, err) + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "field", Type: typ, Nullable: true}}, nil), + []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.Record + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s, got: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s, got: %s", + batch, written) +} + +func BenchmarkWriteBool8Array(b *testing.B) { + bool8bldr := extensions.NewBool8Builder(memory.DefaultAllocator) + defer bool8bldr.Release() + + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + + values := make([]bool, sz) + for idx := range values { + values[idx] = true + } + + b.ResetTimer() + b.SetBytes(int64(sz)) + for n := 0; n < b.N; n++ { + bool8bldr.AppendValues(values, nil) + bool8bldr.NewArray() + } + }) + } +} + +func BenchmarkWriteBooleanArray(b *testing.B) { + boolbldr := array.NewBooleanBuilder(memory.DefaultAllocator) + defer boolbldr.Release() + + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + + values := make([]bool, sz) + for idx := range values { + values[idx] = true + } + + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + boolbldr.AppendValues(values, nil) + boolbldr.NewArray() + } + }) + } +} + +// storage benchmark result at package level to prevent compiler from eliminating the function call +var result []bool + +func BenchmarkReadBool8Array(b *testing.B) { + bool8bldr := extensions.NewBool8Builder(memory.DefaultAllocator) + defer bool8bldr.Release() + + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + + values := make([]bool, sz) + for idx := range values { + values[idx] = true + } + + bool8bldr.AppendValues(values, nil) + bool8Arr := bool8bldr.NewArray().(*extensions.Bool8Array) + defer bool8Arr.Release() + + var r []bool + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + r = bool8Arr.BoolValues() + } + result = r + }) + } +} + +func BenchmarkReadBooleanArray(b *testing.B) { + boolbldr := array.NewBooleanBuilder(memory.DefaultAllocator) + defer boolbldr.Release() + + for sz := MINSIZE; sz < MAXSIZE+1; sz *= 2 { + b.Run(fmt.Sprintf("len %d", sz), func(b *testing.B) { + + values := make([]bool, sz) + output := make([]bool, sz) + for idx := range values { + values[idx] = true + } + + boolbldr.AppendValues(values, nil) + boolArr := boolbldr.NewArray().(*array.Boolean) + defer boolArr.Release() + + b.ResetTimer() + b.SetBytes(int64(len(values))) + for n := 0; n < b.N; n++ { + for i := 0; i < boolArr.Len(); i++ { + output[i] = boolArr.Value(i) + } + } + }) + } +} diff --git a/go/arrow/extensions/extensions_test.go b/go/arrow/extensions/extensions_test.go new file mode 100644 index 00000000000..f56fed5e132 --- /dev/null +++ b/go/arrow/extensions/extensions_test.go @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "fmt" + "reflect" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/stretchr/testify/require" +) + +// testBool8Type minimally implements arrow.ExtensionType, but importantly does not implement array.CustomExtensionBuilder +// so it will fall back to the storage type's default builder. +type testBool8Type struct { + arrow.ExtensionBase +} + +func newTestBool8Type() *testBool8Type { + return &testBool8Type{ExtensionBase: arrow.ExtensionBase{Storage: arrow.PrimitiveTypes.Int8}} +} + +func (t *testBool8Type) ArrayType() reflect.Type { return reflect.TypeOf(testBool8Array{}) } +func (t *testBool8Type) ExtensionEquals(arrow.ExtensionType) bool { panic("unimplemented") } +func (t *testBool8Type) ExtensionName() string { panic("unimplemented") } +func (t *testBool8Type) Serialize() string { panic("unimplemented") } +func (t *testBool8Type) Deserialize(arrow.DataType, string) (arrow.ExtensionType, error) { + panic("unimplemented") +} + +type testBool8Array struct { + array.ExtensionArrayBase +} + +func TestUnmarshalExtensionTypes(t *testing.T) { + logicalJSON := `[true,null,false,null,true]` + storageJSON := `[1,null,0,null,1]` + + // extensions.Bool8Type implements array.CustomExtensionBuilder so we expect the array to be built with the custom builder + arrCustomBuilder, _, err := array.FromJSON(memory.DefaultAllocator, extensions.NewBool8Type(), bytes.NewBufferString(logicalJSON)) + require.NoError(t, err) + defer arrCustomBuilder.Release() + require.Equal(t, 5, arrCustomBuilder.Len()) + + // testBoolType falls back to the default builder for the storage type, so it cannot deserialize native booleans + _, _, err = array.FromJSON(memory.DefaultAllocator, newTestBool8Type(), bytes.NewBufferString(logicalJSON)) + require.ErrorContains(t, err, "cannot unmarshal true into Go value of type int8") + + // testBoolType must build the array with the native storage type: Int8 + arrDefaultBuilder, _, err := array.FromJSON(memory.DefaultAllocator, newTestBool8Type(), bytes.NewBufferString(storageJSON)) + require.NoError(t, err) + defer arrDefaultBuilder.Release() + require.Equal(t, 5, arrDefaultBuilder.Len()) + + arrBool8, ok := arrCustomBuilder.(*extensions.Bool8Array) + require.True(t, ok) + + arrExt, ok := arrDefaultBuilder.(array.ExtensionArray) + require.True(t, ok) + + // The physical layout of both arrays is identical + require.True(t, array.Equal(arrBool8.Storage(), arrExt.Storage())) +} + +// invalidExtensionType does not fully implement the arrow.ExtensionType interface, even though it embeds arrow.ExtensionBase +type invalidExtensionType struct { + arrow.ExtensionBase +} + +func newInvalidExtensionType() *invalidExtensionType { + return &invalidExtensionType{ExtensionBase: arrow.ExtensionBase{Storage: arrow.BinaryTypes.String}} +} + +func TestInvalidExtensionType(t *testing.T) { + jsonStr := `["one","two","three"]` + typ := newInvalidExtensionType() + + require.PanicsWithError(t, fmt.Sprintf("arrow/array: invalid extension type: %T", typ), func() { + array.FromJSON(memory.DefaultAllocator, typ, bytes.NewBufferString(jsonStr)) + }) +} + +var ( + _ arrow.ExtensionType = (*testBool8Type)(nil) + _ array.ExtensionArray = (*testBool8Array)(nil) +) diff --git a/go/internal/types/extension_types.go b/go/internal/types/extension_types.go index 3c63b368746..85c64d86bff 100644 --- a/go/internal/types/extension_types.go +++ b/go/internal/types/extension_types.go @@ -26,6 +26,7 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/json" "github.com/google/uuid" "golang.org/x/xerrors" @@ -37,8 +38,8 @@ type UUIDBuilder struct { *array.ExtensionBuilder } -func NewUUIDBuilder(builder *array.ExtensionBuilder) *UUIDBuilder { - return &UUIDBuilder{ExtensionBuilder: builder} +func NewUUIDBuilder(mem memory.Allocator) *UUIDBuilder { + return &UUIDBuilder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewUUIDType())} } func (b *UUIDBuilder) Append(v uuid.UUID) { @@ -245,8 +246,8 @@ func (e *UUIDType) ExtensionEquals(other arrow.ExtensionType) bool { return e.ExtensionName() == other.ExtensionName() } -func (*UUIDType) NewBuilder(bldr *array.ExtensionBuilder) array.Builder { - return NewUUIDBuilder(bldr) +func (*UUIDType) NewBuilder(mem memory.Allocator) array.Builder { + return NewUUIDBuilder(mem) } // Parametric1Array is a simple int32 array for use with the Parametric1Type diff --git a/go/internal/types/extension_types_test.go b/go/internal/types/extension_types_test.go index 50abaae3a9e..65f6353d01b 100644 --- a/go/internal/types/extension_types_test.go +++ b/go/internal/types/extension_types_test.go @@ -32,12 +32,10 @@ import ( var testUUID = uuid.New() -func TestExtensionBuilder(t *testing.T) { +func TestUUIDExtensionBuilder(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) - extBuilder := array.NewExtensionBuilder(mem, types.NewUUIDType()) - defer extBuilder.Release() - builder := types.NewUUIDBuilder(extBuilder) + builder := types.NewUUIDBuilder(mem) builder.Append(testUUID) arr := builder.NewArray() defer arr.Release() @@ -52,7 +50,7 @@ func TestExtensionBuilder(t *testing.T) { assert.Equal(t, arr, arr1) } -func TestExtensionRecordBuilder(t *testing.T) { +func TestUUIDExtensionRecordBuilder(t *testing.T) { schema := arrow.NewSchema([]arrow.Field{ {Name: "uuid", Type: types.NewUUIDType()}, }, nil) @@ -72,9 +70,7 @@ func TestUUIDStringRoundTrip(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) - extBuilder := array.NewExtensionBuilder(mem, types.NewUUIDType()) - defer extBuilder.Release() - b := types.NewUUIDBuilder(extBuilder) + b := types.NewUUIDBuilder(mem) b.Append(uuid.Nil) b.AppendNull() b.Append(uuid.NameSpaceURL) @@ -85,9 +81,7 @@ func TestUUIDStringRoundTrip(t *testing.T) { defer arr.Release() // 2. create array via AppendValueFromString - extBuilder1 := array.NewExtensionBuilder(mem, types.NewUUIDType()) - defer extBuilder1.Release() - b1 := types.NewUUIDBuilder(extBuilder1) + b1 := types.NewUUIDBuilder(mem) defer b1.Release() for i := 0; i < arr.Len(); i++ { diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go index 9b3419988d6..16282173a68 100644 --- a/go/parquet/pqarrow/encode_arrow_test.go +++ b/go/parquet/pqarrow/encode_arrow_test.go @@ -2053,9 +2053,7 @@ func (ps *ParquetIOTestSuite) TestArrowExtensionTypeRoundTrip() { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(ps.T(), 0) - extBuilder := array.NewExtensionBuilder(mem, types.NewUUIDType()) - defer extBuilder.Release() - builder := types.NewUUIDBuilder(extBuilder) + builder := types.NewUUIDBuilder(mem) builder.Append(uuid.New()) arr := builder.NewArray() defer arr.Release() From 483bc7b6d10d62e3bb83c167569cde84e2912744 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Tue, 13 Aug 2024 07:45:11 +0530 Subject: [PATCH 002/157] GH-43638: [Java] LargeListViewVector RangeEqualVisitor and TypeEqualVisitor integration (#43642) ### Rationale for this change LargeListViewVector requires `RangeEqualVisitor` and `TypeEqualVisitor` to support the C Data interface. ### What changes are included in this PR? Adding `RangeEqualVisitor`, `TypeEqualVisitor` and the corresponding test cases. ### Are these changes tested? Yes. ### Are there any user-facing changes? No * GitHub Issue: #43638 Authored-by: Vibhatha Abeykoon Signed-off-by: David Li --- .../vector/compare/RangeEqualsVisitor.java | 57 ++++++++++ .../vector/compare/TypeEqualsVisitor.java | 6 ++ .../arrow/vector/compare/VectorVisitor.java | 6 ++ .../vector/complex/LargeListViewVector.java | 2 +- .../apache/arrow/vector/TestValueVector.java | 95 ++++++++++++++++ .../compare/TestRangeEqualsVisitor.java | 102 ++++++++++++++++++ .../vector/compare/TestTypeEqualsVisitor.java | 17 +++ 7 files changed, 284 insertions(+), 1 deletion(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java index fbc28a3609c..9aa1bffb846 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java @@ -31,11 +31,13 @@ import org.apache.arrow.vector.ExtensionTypeVector; import org.apache.arrow.vector.NullVector; import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.BaseLargeRepeatedValueViewVector; import org.apache.arrow.vector.complex.BaseRepeatedValueVector; import org.apache.arrow.vector.complex.BaseRepeatedValueViewVector; import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.NonNullableStructVector; @@ -244,6 +246,14 @@ public Boolean visit(ListViewVector left, Range range) { return compareListViewVectors(range); } + @Override + public Boolean visit(LargeListViewVector left, Range range) { + if (!validate(left)) { + return false; + } + return compareLargeListViewVectors(range); + } + protected RangeEqualsVisitor createInnerVisitor( ValueVector leftInner, ValueVector rightInner, @@ -759,4 +769,51 @@ protected boolean compareListViewVectors(Range range) { } return true; } + + protected boolean compareLargeListViewVectors(Range range) { + LargeListViewVector leftVector = (LargeListViewVector) left; + LargeListViewVector rightVector = (LargeListViewVector) right; + + RangeEqualsVisitor innerVisitor = + createInnerVisitor( + leftVector.getDataVector(), rightVector.getDataVector(), /*type comparator*/ null); + Range innerRange = new Range(); + + for (int i = 0; i < range.getLength(); i++) { + int leftIndex = range.getLeftStart() + i; + int rightIndex = range.getRightStart() + i; + + boolean isNull = leftVector.isNull(leftIndex); + if (isNull != rightVector.isNull(rightIndex)) { + return false; + } + + int offsetWidth = BaseLargeRepeatedValueViewVector.OFFSET_WIDTH; + int sizeWidth = BaseLargeRepeatedValueViewVector.SIZE_WIDTH; + + if (!isNull) { + final int startIndexLeft = + leftVector.getOffsetBuffer().getInt((long) leftIndex * offsetWidth); + final int leftSize = leftVector.getSizeBuffer().getInt((long) leftIndex * sizeWidth); + + final int startIndexRight = + rightVector.getOffsetBuffer().getInt((long) rightIndex * offsetWidth); + final int rightSize = rightVector.getSizeBuffer().getInt((long) rightIndex * sizeWidth); + + if (leftSize != rightSize) { + return false; + } + + innerRange = + innerRange + .setRightStart(startIndexRight) + .setLeftStart(startIndexLeft) + .setLength(leftSize); + if (!innerVisitor.rangeEquals(innerRange)) { + return false; + } + } + } + return true; + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/TypeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/TypeEqualsVisitor.java index 6e15d6a83e7..ce92b22ef61 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/TypeEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/TypeEqualsVisitor.java @@ -28,6 +28,7 @@ import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.NonNullableStructVector; @@ -130,6 +131,11 @@ public Boolean visit(ListViewVector left, Void value) { return compareField(left.getField(), right.getField()); } + @Override + public Boolean visit(LargeListViewVector left, Void value) { + return compareField(left.getField(), right.getField()); + } + private boolean compareField(Field leftField, Field rightField) { if (leftField == rightField) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorVisitor.java index c912359d4af..e20f8cd9cfb 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorVisitor.java @@ -25,6 +25,7 @@ import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.NonNullableStructVector; @@ -65,4 +66,9 @@ public interface VectorVisitor { default OUT visit(ListViewVector left, IN value) { throw new UnsupportedOperationException("VectorVisitor for ListViewVector is not supported."); } + + default OUT visit(LargeListViewVector left, IN value) { + throw new UnsupportedOperationException( + "VectorVisitor for LargeListViewVector is not supported."); + } } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index 1bb24a53fc2..17ccdbf0eae 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -449,7 +449,7 @@ public int hashCode(int index, ArrowBufHasher hasher) { @Override public OUT accept(VectorVisitor visitor, IN value) { - throw new UnsupportedOperationException(); + return visitor.visit(this, value); } @Override diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java index 4dd55afdb8b..83e470ae258 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java @@ -46,11 +46,13 @@ import org.apache.arrow.vector.compare.VectorEqualsVisitor; import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.complex.impl.NullableStructWriter; +import org.apache.arrow.vector.complex.impl.UnionLargeListViewWriter; import org.apache.arrow.vector.complex.impl.UnionListViewWriter; import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.holders.NullableIntHolder; @@ -2910,6 +2912,35 @@ public void testListViewVectorEqualsWithNull() { } } + @Test + public void testLargeListViewVectorEqualsWithNull() { + try (final LargeListViewVector vector1 = LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector vector2 = + LargeListViewVector.empty("largelistview", allocator); ) { + + UnionLargeListViewWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + // set some values + writeLargeListViewVector(writer1, new int[] {1, 2}); + writeLargeListViewVector(writer1, new int[] {3, 4}); + writeLargeListViewVector(writer1, new int[] {}); + writer1.setValueCount(3); + + UnionLargeListViewWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + // set some values + writeLargeListViewVector(writer2, new int[] {1, 2}); + writeLargeListViewVector(writer2, new int[] {3, 4}); + writer2.setValueCount(3); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(); + + assertFalse(visitor.vectorEquals(vector1, vector2)); + } + } + @Test public void testListVectorEquals() { try (final ListVector vector1 = ListVector.empty("list", allocator); @@ -2974,6 +3005,39 @@ public void testListViewVectorEquals() { } } + @Test + public void testLargeListViewVectorEquals() { + try (final LargeListViewVector vector1 = LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector vector2 = + LargeListViewVector.empty("largelistview", allocator); ) { + + UnionLargeListViewWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + // set some values + writeLargeListViewVector(writer1, new int[] {1, 2}); + writeLargeListViewVector(writer1, new int[] {3, 4}); + writeLargeListViewVector(writer1, new int[] {5, 6}); + writer1.setValueCount(3); + + UnionLargeListViewWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + // set some values + writeLargeListViewVector(writer2, new int[] {1, 2}); + writeLargeListViewVector(writer2, new int[] {3, 4}); + writer2.setValueCount(2); + + VectorEqualsVisitor visitor = new VectorEqualsVisitor(); + assertFalse(visitor.vectorEquals(vector1, vector2)); + + writeLargeListViewVector(writer2, new int[] {5, 6}); + writer2.setValueCount(3); + + assertTrue(visitor.vectorEquals(vector1, vector2)); + } + } + @Test public void testListVectorSetNull() { try (final ListVector vector = ListVector.empty("list", allocator)) { @@ -3020,6 +3084,29 @@ public void testListViewVectorSetNull() { } } + @Test + public void testLargeListViewVectorSetNull() { + try (final LargeListViewVector vector = LargeListViewVector.empty("largelistview", allocator)) { + UnionLargeListViewWriter writer = vector.getWriter(); + writer.allocate(); + + writeLargeListViewVector(writer, new int[] {1, 2}); + writeLargeListViewVector(writer, new int[] {3, 4}); + writeLargeListViewVector(writer, new int[] {5, 6}); + vector.setNull(3); + vector.setNull(4); + vector.setNull(5); + writer.setValueCount(6); + + assertEquals(vector.getObject(0), Arrays.asList(1, 2)); + assertEquals(vector.getObject(1), Arrays.asList(3, 4)); + assertEquals(vector.getObject(2), Arrays.asList(5, 6)); + assertTrue(vector.isNull(3)); + assertTrue(vector.isNull(4)); + assertTrue(vector.isNull(5)); + } + } + @Test public void testStructVectorEqualsWithNull() { @@ -3359,6 +3446,14 @@ private void writeListViewVector(UnionListViewWriter writer, int[] values) { writer.endListView(); } + private void writeLargeListViewVector(UnionLargeListViewWriter writer, int[] values) { + writer.startListView(); + for (int v : values) { + writer.integer().writeInt(v); + } + writer.endListView(); + } + @Test public void testVariableVectorGetEndOffset() { try (final VarCharVector vector1 = new VarCharVector("v1", allocator); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java index 7e91b760430..eca5c2d9b2a 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java @@ -36,12 +36,14 @@ import org.apache.arrow.vector.compare.util.ValueEpsilonEqualizers; import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.complex.impl.NullableStructWriter; import org.apache.arrow.vector.complex.impl.UnionFixedSizeListWriter; +import org.apache.arrow.vector.complex.impl.UnionLargeListViewWriter; import org.apache.arrow.vector.complex.impl.UnionListViewWriter; import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.holders.NullableBigIntHolder; @@ -221,6 +223,25 @@ public void testListViewVectorWithDifferentChild() { } } + @Test + public void testLargeListViewVectorWithDifferentChild() { + try (final LargeListViewVector vector1 = LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector vector2 = + LargeListViewVector.empty("largelistview", allocator); ) { + + vector1.allocateNew(); + vector1.initializeChildrenFromFields( + Arrays.asList(Field.nullable("child", new ArrowType.Int(32, true)))); + + vector2.allocateNew(); + vector2.initializeChildrenFromFields( + Arrays.asList(Field.nullable("child", new ArrowType.Int(64, true)))); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector1, vector2); + assertFalse(visitor.rangeEquals(new Range(0, 0, 0))); + } + } + @Test public void testListVectorRangeEquals() { try (final ListVector vector1 = ListVector.empty("list", allocator); @@ -285,6 +306,39 @@ public void testListViewVectorRangeEquals() { } } + @Test + public void testLargeListViewVectorRangeEquals() { + try (final LargeListViewVector vector1 = LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector vector2 = + LargeListViewVector.empty("largelistview", allocator); ) { + + UnionLargeListViewWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + // set some values + writeLargeListViewVector(writer1, new int[] {1, 2}); + writeLargeListViewVector(writer1, new int[] {3, 4}); + writeLargeListViewVector(writer1, new int[] {5, 6}); + writeLargeListViewVector(writer1, new int[] {7, 8}); + writeLargeListViewVector(writer1, new int[] {9, 10}); + writer1.setValueCount(5); + + UnionLargeListViewWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + // set some values + writeLargeListViewVector(writer2, new int[] {0, 0}); + writeLargeListViewVector(writer2, new int[] {3, 4}); + writeLargeListViewVector(writer2, new int[] {5, 6}); + writeLargeListViewVector(writer2, new int[] {7, 8}); + writeLargeListViewVector(writer2, new int[] {0, 0}); + writer2.setValueCount(5); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector1, vector2); + assertTrue(visitor.rangeEquals(new Range(1, 1, 3))); + } + } + @Test public void testBitVectorRangeEquals() { try (final BitVector vector1 = new BitVector("v1", allocator); @@ -903,6 +957,38 @@ public void testListViewVectorApproxEquals() { } } + @Test + public void testLargeListViewVectorApproxEquals() { + try (final LargeListViewVector right = LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector left1 = LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector left2 = LargeListViewVector.empty("largelistview", allocator); ) { + + final float epsilon = 1.0E-6f; + + UnionLargeListViewWriter rightWriter = right.getWriter(); + rightWriter.allocate(); + writeLargeListViewVector(rightWriter, new double[] {1, 2}); + writeLargeListViewVector(rightWriter, new double[] {1.01, 2.02}); + rightWriter.setValueCount(2); + + UnionLargeListViewWriter leftWriter1 = left1.getWriter(); + leftWriter1.allocate(); + writeLargeListViewVector(leftWriter1, new double[] {1, 2}); + writeLargeListViewVector(leftWriter1, new double[] {1.01 + epsilon / 2, 2.02 - epsilon / 2}); + leftWriter1.setValueCount(2); + + UnionLargeListViewWriter leftWriter2 = left2.getWriter(); + leftWriter2.allocate(); + writeLargeListViewVector(leftWriter2, new double[] {1, 2}); + writeLargeListViewVector(leftWriter2, new double[] {1.01 + epsilon * 2, 2.02 - epsilon * 2}); + leftWriter2.setValueCount(2); + + Range range = new Range(0, 0, right.getValueCount()); + assertTrue(new ApproxEqualsVisitor(left1, right, epsilon, epsilon).rangeEquals(range)); + assertFalse(new ApproxEqualsVisitor(left2, right, epsilon, epsilon).rangeEquals(range)); + } + } + private void writeStructVector(NullableStructWriter writer, int value1, long value2) { writer.start(); writer.integer("f0").writeInt(value1); @@ -933,6 +1019,14 @@ private void writeListViewVector(UnionListViewWriter writer, int[] values) { writer.endListView(); } + private void writeLargeListViewVector(UnionLargeListViewWriter writer, int[] values) { + writer.startListView(); + for (int v : values) { + writer.integer().writeInt(v); + } + writer.endListView(); + } + private void writeFixedSizeListVector(UnionFixedSizeListWriter writer, int[] values) { writer.startList(); for (int v : values) { @@ -956,4 +1050,12 @@ private void writeListViewVector(UnionListViewWriter writer, double[] values) { } writer.endListView(); } + + private void writeLargeListViewVector(UnionLargeListViewWriter writer, double[] values) { + writer.startListView(); + for (double v : values) { + writer.float8().writeFloat8(v); + } + writer.endListView(); + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestTypeEqualsVisitor.java b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestTypeEqualsVisitor.java index d65096205fd..ce029493473 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestTypeEqualsVisitor.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestTypeEqualsVisitor.java @@ -32,6 +32,7 @@ import org.apache.arrow.vector.ViewVarBinaryVector; import org.apache.arrow.vector.ViewVarCharVector; import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.StructVector; @@ -121,6 +122,22 @@ public void testListViewTypeEquals() { } } + @Test + public void testLargeListViewTypeEquals() { + try (final LargeListViewVector right = LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector left1 = LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector left2 = LargeListViewVector.empty("largelistview", allocator)) { + + right.addOrGetVector(FieldType.nullable(new ArrowType.Utf8())); + left1.addOrGetVector(FieldType.nullable(new ArrowType.Utf8())); + left2.addOrGetVector(FieldType.nullable(new ArrowType.FixedSizeBinary(2))); + + TypeEqualsVisitor visitor = new TypeEqualsVisitor(right); + assertTrue(visitor.equals(left1)); + assertFalse(visitor.equals(left2)); + } + } + @Test public void testStructTypeEquals() { try (final StructVector right = StructVector.empty("struct", allocator); From e8e9d1ac2b9761b40eb0e041127285b55655e49c Mon Sep 17 00:00:00 2001 From: Lysandros Nikolaou Date: Wed, 14 Aug 2024 01:46:02 +0200 Subject: [PATCH 003/157] GH-43536: [Python] Declare support for free-threading in Cython (#43606) ### Rationale for this change This is done by passing an extra flag when building the Cython extension modules. It is needed so that the GIL is not dynamically reenabled when importing `pyarrow.lib`. ### What changes are included in this PR? Changes to CMake so that the extra flag is passed when building Cython extension modules. * GitHub Issue: #43536 Lead-authored-by: Lysandros Nikolaou Co-authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- cpp/cmake_modules/UseCython.cmake | 5 +++++ python/CMakeLists.txt | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/cpp/cmake_modules/UseCython.cmake b/cpp/cmake_modules/UseCython.cmake index e15ac59490c..7d88daa4fad 100644 --- a/cpp/cmake_modules/UseCython.cmake +++ b/cpp/cmake_modules/UseCython.cmake @@ -184,4 +184,9 @@ function(cython_add_module _name pyx_target_name generated_files) add_dependencies(${_name} ${pyx_target_name}) endfunction() +execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "from Cython.Compiler.Version import version; print(version)" + OUTPUT_VARIABLE CYTHON_VERSION_OUTPUT + OUTPUT_STRIP_TRAILING_WHITESPACE) +set(CYTHON_VERSION "${CYTHON_VERSION_OUTPUT}") + include(CMakeParseArguments) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index a90dee70584..5d5eeaf8157 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -260,6 +260,7 @@ message(STATUS "Found NumPy version: ${Python3_NumPy_VERSION}") message(STATUS "NumPy include dir: ${NUMPY_INCLUDE_DIRS}") include(UseCython) +message(STATUS "Found Cython version: ${CYTHON_VERSION}") # Arrow C++ and set default PyArrow build options include(GNUInstallDirs) @@ -855,6 +856,10 @@ set(CYTHON_FLAGS "${CYTHON_FLAGS}" "--warning-errors") # undocumented Cython feature. set(CYTHON_FLAGS "${CYTHON_FLAGS}" "--no-c-in-traceback") +if(CYTHON_VERSION VERSION_GREATER_EQUAL "3.1.0a0") + list(APPEND CYTHON_FLAGS "-Xfreethreading_compatible=True") +endif() + foreach(module ${CYTHON_EXTENSIONS}) string(REPLACE "." ";" directories ${module}) list(GET directories -1 module_name) From fc80d7d8b9f80152415fc333e0850358bf217db9 Mon Sep 17 00:00:00 2001 From: Dane Pitkin Date: Tue, 13 Aug 2024 19:48:47 -0400 Subject: [PATCH 004/157] GH-43378: [Java][CI] Don't configure multithreading when building javadocs (#43674) ### Rationale for this change Apparently some maven plugins are not thread safe and started throwing errors in the `test-debian-12-docs` CI job when building javadocs. ### What changes are included in this PR? * Remove multithreading config when building javadocs ### Are these changes tested? CI ### Are there any user-facing changes? No * GitHub Issue: #43378 Authored-by: Dane Pitkin Signed-off-by: Sutou Kouhei --- ci/scripts/java_build.sh | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ci/scripts/java_build.sh b/ci/scripts/java_build.sh index 0fa1edab429..212ec6eb114 100755 --- a/ci/scripts/java_build.sh +++ b/ci/scripts/java_build.sh @@ -72,9 +72,6 @@ if [ $ARROW_JAVA_SKIP_GIT_PLUGIN ]; then mvn="${mvn} -Dmaven.gitcommitid.skip=true" fi -# Use `2 * ncores` threads -mvn="${mvn} -T 2C" - # https://github.com/apache/arrow/issues/41429 # TODO: We want to out-of-source build. This is a workaround. We copy # all needed files to the build directory from the source directory @@ -98,10 +95,12 @@ if [ "${ARROW_JAVA_JNI}" = "ON" ]; then mvn="${mvn} -Darrow.cpp.build.dir=${java_jni_dist_dir} -Parrow-jni" fi -${mvn} clean install +# Use `2 * ncores` threads +${mvn} -T 2C clean install if [ "${BUILD_DOCS_JAVA}" == "ON" ]; then # HTTP pooling is turned of to avoid download issues https://issues.apache.org/jira/browse/ARROW-11633 + # GH-43378: Maven site plugins not compatible with multithreading mkdir -p ${build_dir}/docs/java/reference ${mvn} -Dcheckstyle.skip=true -Dhttp.keepAlive=false -Dmaven.wagon.http.pool=false clean install site rsync -a target/site/apidocs/ ${build_dir}/docs/java/reference From 88e8140ad7902435b5d1ac29205dda7517f2cc79 Mon Sep 17 00:00:00 2001 From: Oliver Layer Date: Wed, 14 Aug 2024 02:16:54 +0200 Subject: [PATCH 005/157] GH-43097: [C++] Implement `PathFromUri` support for Azure file system (#43098) ### Rationale for this change See #43097. ### What changes are included in this PR? Implements `AzureFS::PathFromUri` using existing URI parsing and path extraction inside the `AzureOptions`. ### Are these changes tested? Yes, added a unit test. ### Are there any user-facing changes? No, but calling `PathFromUri` will now work instead of throwing due to no implementation provided. * GitHub Issue: #43097 Authored-by: Oliver Layer Signed-off-by: Sutou Kouhei --- cpp/src/arrow/filesystem/azurefs.cc | 27 ++++++++++++++++++++++++ cpp/src/arrow/filesystem/azurefs.h | 2 ++ cpp/src/arrow/filesystem/azurefs_test.cc | 9 ++++++++ 3 files changed, 38 insertions(+) diff --git a/cpp/src/arrow/filesystem/azurefs.cc b/cpp/src/arrow/filesystem/azurefs.cc index a3aa2c8e837..9b3c0c0c1d7 100644 --- a/cpp/src/arrow/filesystem/azurefs.cc +++ b/cpp/src/arrow/filesystem/azurefs.cc @@ -3199,4 +3199,31 @@ Result> AzureFileSystem::OpenAppendStream( return impl_->OpenAppendStream(location, metadata, false, this); } +Result AzureFileSystem::PathFromUri(const std::string& uri_string) const { + /// We can not use `internal::PathFromUriHelper` here because for Azure we have to + /// support different URI schemes where the authority is handled differently. + /// Example (both should yield the same path `container/some/path`): + /// - (1) abfss://storageacc.blob.core.windows.net/container/some/path + /// - (2) abfss://acc:pw@container/some/path + /// The authority handling is different with these two URIs. (1) requires no prepending + /// of the authority to the path, while (2) requires to preprend the authority to the + /// path. + std::string path; + Uri uri; + RETURN_NOT_OK(uri.Parse(uri_string)); + RETURN_NOT_OK(AzureOptions::FromUri(uri, &path)); + + std::vector supported_schemes = {"abfs", "abfss"}; + const auto scheme = uri.scheme(); + if (std::find(supported_schemes.begin(), supported_schemes.end(), scheme) == + supported_schemes.end()) { + std::string expected_schemes = + ::arrow::internal::JoinStrings(supported_schemes, ", "); + return Status::Invalid("The filesystem expected a URI with one of the schemes (", + expected_schemes, ") but received ", uri_string); + } + + return path; +} + } // namespace arrow::fs diff --git a/cpp/src/arrow/filesystem/azurefs.h b/cpp/src/arrow/filesystem/azurefs.h index 93d6ec2f945..072b061eeb2 100644 --- a/cpp/src/arrow/filesystem/azurefs.h +++ b/cpp/src/arrow/filesystem/azurefs.h @@ -367,6 +367,8 @@ class ARROW_EXPORT AzureFileSystem : public FileSystem { Result> OpenAppendStream( const std::string& path, const std::shared_ptr& metadata) override; + + Result PathFromUri(const std::string& uri_string) const override; }; } // namespace arrow::fs diff --git a/cpp/src/arrow/filesystem/azurefs_test.cc b/cpp/src/arrow/filesystem/azurefs_test.cc index 9a11a6f2499..36646f417cb 100644 --- a/cpp/src/arrow/filesystem/azurefs_test.cc +++ b/cpp/src/arrow/filesystem/azurefs_test.cc @@ -2958,5 +2958,14 @@ TEST_F(TestAzuriteFileSystem, OpenInputFileClosed) { ASSERT_RAISES(Invalid, stream->ReadAt(1, 1)); ASSERT_RAISES(Invalid, stream->Seek(2)); } + +TEST_F(TestAzuriteFileSystem, PathFromUri) { + ASSERT_EQ( + "container/some/path", + fs()->PathFromUri("abfss://storageacc.blob.core.windows.net/container/some/path")); + ASSERT_EQ("container/some/path", + fs()->PathFromUri("abfss://acc:pw@container/some/path")); + ASSERT_RAISES(Invalid, fs()->PathFromUri("http://acc:pw@container/some/path")); +} } // namespace fs } // namespace arrow From 01fd7fc18ca737edf0afbcc6afa349206b055a09 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 14 Aug 2024 09:26:27 +0900 Subject: [PATCH 006/157] MINOR: [Go] Bump github.com/substrait-io/substrait-go from 0.5.0 to 0.6.0 in /go (#43647) Bumps [github.com/substrait-io/substrait-go](https://github.com/substrait-io/substrait-go) from 0.5.0 to 0.6.0.
Release notes

Sourced from github.com/substrait-io/substrait-go's releases.

v0.6.0 (2024-08-11)

Features

  • type add support for type PrecisionTimestamp and PrecisionTimestampTz (#41) (5040d09)
  • substrait Update to Substrait v0.53.0 (#40) (0ea5482)
    • Update substrait dependency to v0.53.0
    • Accommodate UserDefined Literal changes where literal value became oneof in proto instead of direct value
    • Fix AdvanceExtension interface to accommodate breaking change in AdvanceExtensionProto
    • Add linter to ignore internal use of deprecated methods.
Commits
  • 5040d09 feat(type): add support for type PrecisionTimestamp and PrecisionTimestampTz ...
  • 0ea5482 feat(substrait): Update to Substrait v0.53.0 (#40)
  • 2fc8f58 ci(build-test): Use grep to exclude protobuf from coverage report (#38)
  • b3aa515 ci(build-test): Update codecov to ignore protobuf files
  • 15314a8 ci(build-test): Add codecov and release branch action badges. (#36)
  • 663c26d ci(build-test): Add codecov reports (#35)
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/substrait-io/substrait-go&package-manager=go_modules&previous-version=0.5.0&new-version=0.6.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Sutou Kouhei --- go/go.mod | 2 +- go/go.sum | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/go/go.mod b/go/go.mod index 09869b7a383..9f4222a541b 100644 --- a/go/go.mod +++ b/go/go.mod @@ -49,7 +49,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hamba/avro/v2 v2.24.1 github.com/huandu/xstrings v1.4.0 - github.com/substrait-io/substrait-go v0.5.0 + github.com/substrait-io/substrait-go v0.6.0 github.com/tidwall/sjson v1.2.5 ) diff --git a/go/go.sum b/go/go.sum index 2e89a769024..c7eb3a66dee 100644 --- a/go/go.sum +++ b/go/go.sum @@ -24,8 +24,8 @@ github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8c github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= -github.com/go-playground/validator/v10 v10.4.1 h1:pH2c5ADXtd66mxoE0Zm9SUhxE20r7aM3F26W0hOn+GE= -github.com/go-playground/validator/v10 v10.4.1/go.mod h1:nlOn6nFhuKACm19sB/8EGNn9GlaMV7XkbRSipzJ0Ii4= +github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= +github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.11.0 h1:n7Z+zx8S9f9KgzG6KtQKf+kwqXZlLNR2F6018Dgau54= @@ -99,8 +99,8 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/substrait-io/substrait-go v0.5.0 h1:8sYsoqcrzoNpThPyot1CQpwF6OokxvplLUQJTGlKws4= -github.com/substrait-io/substrait-go v0.5.0/go.mod h1:Co7ko6iIjdqCGcN3LfkKWPVlxONkNZem9omWAGIaOrQ= +github.com/substrait-io/substrait-go v0.6.0 h1:n2G/SGmrn7U5Q39VA8WeM2UfVL5Y/6HX8WAP9uJLNk4= +github.com/substrait-io/substrait-go v0.6.0/go.mod h1:cl8Wsc7aBPDfcHp9+OrUqGpjkgrYlhcDsH/lMP6KUZA= github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= From 69bce8f0cd02297ecc31caef22db67e654c16e28 Mon Sep 17 00:00:00 2001 From: Felipe Oliveira Carvalho Date: Tue, 13 Aug 2024 21:27:36 -0300 Subject: [PATCH 007/157] GH-43677: [C++][FlightRPC] Move the FlightTestServer to its own .cc and .h files (#43678) ### Rationale for this change One way of learning about a codebase is reading the tests. As it is now, it's hard to see the minimal `FlightServerBase` sub-class in `flight/test_util.cc`, so I moved it to its own file. ### What changes are included in this PR? - Renaming `FlightTestServer` to `TestFlightServer` - Moving the class to `test_flight_server.{h,cc}` - Bonus: Moving the server and client auth handlers to `test_auth_handlers.{h,cc}` ### Are these changes tested? By existing tests. ### Are there any user-facing changes? `ExampleTestServer` is removed from the testing library in favor of `FlightTestServer::Make`. * GitHub Issue: #43677 Authored-by: Felipe Oliveira Carvalho Signed-off-by: Felipe Oliveira Carvalho --- cpp/src/arrow/flight/CMakeLists.txt | 2 + cpp/src/arrow/flight/flight_test.cc | 8 +- .../integration_tests/test_integration.cc | 1 + cpp/src/arrow/flight/test_auth_handlers.cc | 141 +++++ cpp/src/arrow/flight/test_auth_handlers.h | 89 ++++ cpp/src/arrow/flight/test_definitions.cc | 15 +- cpp/src/arrow/flight/test_flight_server.cc | 417 +++++++++++++++ cpp/src/arrow/flight/test_flight_server.h | 92 ++++ cpp/src/arrow/flight/test_server.cc | 3 +- cpp/src/arrow/flight/test_util.cc | 486 +----------------- cpp/src/arrow/flight/test_util.h | 65 --- 11 files changed, 759 insertions(+), 560 deletions(-) create mode 100644 cpp/src/arrow/flight/test_auth_handlers.cc create mode 100644 cpp/src/arrow/flight/test_auth_handlers.h create mode 100644 cpp/src/arrow/flight/test_flight_server.cc create mode 100644 cpp/src/arrow/flight/test_flight_server.h diff --git a/cpp/src/arrow/flight/CMakeLists.txt b/cpp/src/arrow/flight/CMakeLists.txt index 43ac48b8767..98f93705f6f 100644 --- a/cpp/src/arrow/flight/CMakeLists.txt +++ b/cpp/src/arrow/flight/CMakeLists.txt @@ -262,7 +262,9 @@ if(ARROW_TESTING) OUTPUTS ARROW_FLIGHT_TESTING_LIBRARIES SOURCES + test_auth_handlers.cc test_definitions.cc + test_flight_server.cc test_util.cc DEPENDENCIES flight_grpc_gen diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 101bb06b212..3d52bc3f5ae 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -52,7 +52,9 @@ // Include before test_util.h (boost), contains Windows fixes #include "arrow/flight/platform.h" #include "arrow/flight/serialization_internal.h" +#include "arrow/flight/test_auth_handlers.h" #include "arrow/flight/test_definitions.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" // OTel includes must come after any gRPC includes, and // client_header_internal.h includes gRPC. See: @@ -247,7 +249,7 @@ TEST(TestFlight, ConnectUriUnix) { // CI environments don't have an IPv6 interface configured TEST(TestFlight, DISABLED_IpV6Port) { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("[::1]", 0)); FlightServerOptions options(location); @@ -261,7 +263,7 @@ TEST(TestFlight, DISABLED_IpV6Port) { } TEST(TestFlight, ServerCallContextIncomingHeaders) { - auto server = ExampleTestServer(); + auto server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -290,7 +292,7 @@ TEST(TestFlight, ServerCallContextIncomingHeaders) { class TestFlightClient : public ::testing::Test { public: void SetUp() { - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); FlightServerOptions options(location); diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 665c1f1ba03..da6fcf81eb7 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -36,6 +36,7 @@ #include "arrow/flight/sql/server.h" #include "arrow/flight/sql/server_session_middleware.h" #include "arrow/flight/sql/types.h" +#include "arrow/flight/test_auth_handlers.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" diff --git a/cpp/src/arrow/flight/test_auth_handlers.cc b/cpp/src/arrow/flight/test_auth_handlers.cc new file mode 100644 index 00000000000..856ccf0f2b2 --- /dev/null +++ b/cpp/src/arrow/flight/test_auth_handlers.cc @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/flight/client_auth.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_auth.h" +#include "arrow/flight/test_auth_handlers.h" +#include "arrow/flight/types.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +namespace arrow::flight { + +// TestServerAuthHandler + +TestServerAuthHandler::TestServerAuthHandler(const std::string& username, + const std::string& password) + : username_(username), password_(password) {} + +TestServerAuthHandler::~TestServerAuthHandler() {} + +Status TestServerAuthHandler::Authenticate(const ServerCallContext& context, + ServerAuthSender* outgoing, + ServerAuthReader* incoming) { + std::string token; + RETURN_NOT_OK(incoming->Read(&token)); + if (token != password_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + RETURN_NOT_OK(outgoing->Write(username_)); + return Status::OK(); +} + +Status TestServerAuthHandler::IsValid(const ServerCallContext& context, + const std::string& token, + std::string* peer_identity) { + if (token != password_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + *peer_identity = username_; + return Status::OK(); +} + +// TestServerBasicAuthHandler + +TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username, + const std::string& password) { + basic_auth_.username = username; + basic_auth_.password = password; +} + +TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {} + +Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext& context, + ServerAuthSender* outgoing, + ServerAuthReader* incoming) { + std::string token; + RETURN_NOT_OK(incoming->Read(&token)); + ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth, BasicAuth::Deserialize(token)); + if (incoming_auth.username != basic_auth_.username || + incoming_auth.password != basic_auth_.password) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + RETURN_NOT_OK(outgoing->Write(basic_auth_.username)); + return Status::OK(); +} + +Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context, + const std::string& token, + std::string* peer_identity) { + if (token != basic_auth_.username) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + *peer_identity = basic_auth_.username; + return Status::OK(); +} + +// TestClientAuthHandler + +TestClientAuthHandler::TestClientAuthHandler(const std::string& username, + const std::string& password) + : username_(username), password_(password) {} + +TestClientAuthHandler::~TestClientAuthHandler() {} + +Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing, + ClientAuthReader* incoming) { + RETURN_NOT_OK(outgoing->Write(password_)); + std::string username; + RETURN_NOT_OK(incoming->Read(&username)); + if (username != username_) { + return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); + } + return Status::OK(); +} + +Status TestClientAuthHandler::GetToken(std::string* token) { + *token = password_; + return Status::OK(); +} + +// TestClientBasicAuthHandler + +TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username, + const std::string& password) { + basic_auth_.username = username; + basic_auth_.password = password; +} + +TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {} + +Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing, + ClientAuthReader* incoming) { + ARROW_ASSIGN_OR_RAISE(std::string pb_result, basic_auth_.SerializeToString()); + RETURN_NOT_OK(outgoing->Write(pb_result)); + RETURN_NOT_OK(incoming->Read(&token_)); + return Status::OK(); +} + +Status TestClientBasicAuthHandler::GetToken(std::string* token) { + *token = token_; + return Status::OK(); +} + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_auth_handlers.h b/cpp/src/arrow/flight/test_auth_handlers.h new file mode 100644 index 00000000000..74f48798f3b --- /dev/null +++ b/cpp/src/arrow/flight/test_auth_handlers.h @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/client_auth.h" +#include "arrow/flight/server.h" +#include "arrow/flight/server_auth.h" +#include "arrow/flight/types.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +// A pair of authentication handlers that check for a predefined password +// and set the peer identity to a predefined username. + +namespace arrow::flight { + +class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler { + public: + explicit TestServerAuthHandler(const std::string& username, + const std::string& password); + ~TestServerAuthHandler() override; + Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, + ServerAuthReader* incoming) override; + Status IsValid(const ServerCallContext& context, const std::string& token, + std::string* peer_identity) override; + + private: + std::string username_; + std::string password_; +}; + +class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler { + public: + explicit TestServerBasicAuthHandler(const std::string& username, + const std::string& password); + ~TestServerBasicAuthHandler() override; + Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, + ServerAuthReader* incoming) override; + Status IsValid(const ServerCallContext& context, const std::string& token, + std::string* peer_identity) override; + + private: + BasicAuth basic_auth_; +}; + +class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler { + public: + explicit TestClientAuthHandler(const std::string& username, + const std::string& password); + ~TestClientAuthHandler() override; + Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; + Status GetToken(std::string* token) override; + + private: + std::string username_; + std::string password_; +}; + +class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler { + public: + explicit TestClientBasicAuthHandler(const std::string& username, + const std::string& password); + ~TestClientBasicAuthHandler() override; + Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; + Status GetToken(std::string* token) override; + + private: + BasicAuth basic_auth_; + std::string token_; +}; + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_definitions.cc b/cpp/src/arrow/flight/test_definitions.cc index c43b693d84a..273d394c288 100644 --- a/cpp/src/arrow/flight/test_definitions.cc +++ b/cpp/src/arrow/flight/test_definitions.cc @@ -27,6 +27,7 @@ #include "arrow/array/util.h" #include "arrow/flight/api.h" #include "arrow/flight/client_middleware.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/flight/types_async.h" @@ -53,7 +54,7 @@ using arrow::internal::checked_cast; // Tests of initialization/shutdown void ConnectivityTest::TestGetPort() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -61,7 +62,7 @@ void ConnectivityTest::TestGetPort() { ASSERT_GT(server->port(), 0); } void ConnectivityTest::TestBuilderHook() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -80,7 +81,7 @@ void ConnectivityTest::TestShutdown() { constexpr int kIterations = 10; ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); for (int i = 0; i < kIterations; i++) { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -92,7 +93,7 @@ void ConnectivityTest::TestShutdown() { } } void ConnectivityTest::TestShutdownWithDeadline() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -105,7 +106,7 @@ void ConnectivityTest::TestShutdownWithDeadline() { ASSERT_OK(server->Wait()); } void ConnectivityTest::TestBrokenConnection() { - std::unique_ptr server = ExampleTestServer(); + std::unique_ptr server = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); ASSERT_OK(server->Init(options)); @@ -151,7 +152,7 @@ class GetFlightInfoListener : public AsyncListener { } // namespace void DataTest::SetUpTest() { - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); FlightServerOptions options(location); @@ -1822,7 +1823,7 @@ void AsyncClientTest::SetUpTest() { ASSERT_OK_AND_ASSIGN(auto location, Location::ForScheme(transport(), "127.0.0.1", 0)); - server_ = ExampleTestServer(); + server_ = TestFlightServer::Make(); FlightServerOptions server_options(location); ASSERT_OK(server_->Init(server_options)); diff --git a/cpp/src/arrow/flight/test_flight_server.cc b/cpp/src/arrow/flight/test_flight_server.cc new file mode 100644 index 00000000000..0ea95ebd15b --- /dev/null +++ b/cpp/src/arrow/flight/test_flight_server.cc @@ -0,0 +1,417 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/flight/test_flight_server.h" + +#include "arrow/array/array_base.h" +#include "arrow/array/array_primitive.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/flight/server.h" +#include "arrow/flight/test_util.h" +#include "arrow/flight/type_fwd.h" +#include "arrow/status.h" + +namespace arrow::flight { +namespace { + +class ErrorRecordBatchReader : public RecordBatchReader { + public: + ErrorRecordBatchReader() : schema_(arrow::schema({})) {} + + std::shared_ptr schema() const override { return schema_; } + + Status ReadNext(std::shared_ptr* out) override { + *out = nullptr; + return Status::OK(); + } + + Status Close() override { + // This should be propagated over DoGet to the client + return Status::IOError("Expected error"); + } + + private: + std::shared_ptr schema_; +}; + +Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { + if (ticket.ticket == "ticket-ints-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-floats-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleFloatBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-dicts-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleDictBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else if (ticket.ticket == "ticket-large-batch-1") { + RecordBatchVector batches; + RETURN_NOT_OK(ExampleLargeBatches(&batches)); + ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); + return Status::OK(); + } else { + return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket); + } +} + +} // namespace + +std::unique_ptr TestFlightServer::Make() { + return std::make_unique(); +} + +Status TestFlightServer::ListFlights(const ServerCallContext& context, + const Criteria* criteria, + std::unique_ptr* listings) { + std::vector flights = ExampleFlightInfo(); + if (criteria && criteria->expression != "") { + // For test purposes, if we get criteria, return no results + flights.clear(); + } + *listings = std::make_unique(flights); + return Status::OK(); +} + +Status TestFlightServer::GetFlightInfo(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* out) { + // Test that Arrow-C++ status codes make it through the transport + if (request.type == FlightDescriptor::DescriptorType::CMD && + request.cmd == "status-outofmemory") { + return Status::OutOfMemory("Sentinel"); + } + + std::vector flights = ExampleFlightInfo(); + + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *out = std::make_unique(info); + return Status::OK(); + } + } + return Status::Invalid("Flight not found: ", request.ToString()); +} + +Status TestFlightServer::DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) { + // Test for ARROW-5095 + if (request.ticket == "ARROW-5095-fail") { + return Status::UnknownError("Server-side error"); + } + if (request.ticket == "ARROW-5095-success") { + return Status::OK(); + } + if (request.ticket == "ARROW-13253-DoGet-Batch") { + // Make batch > 2GiB in size + ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); + ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); + *data_stream = std::make_unique(std::move(reader)); + return Status::OK(); + } + if (request.ticket == "ticket-stream-error") { + auto reader = std::make_shared(); + *data_stream = std::make_unique(std::move(reader)); + return Status::OK(); + } + + std::shared_ptr batch_reader; + RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader)); + + *data_stream = std::make_unique(batch_reader); + return Status::OK(); +} + +Status TestFlightServer::DoPut(const ServerCallContext&, + std::unique_ptr reader, + std::unique_ptr writer) { + return reader->ToRecordBatches().status(); +} + +Status TestFlightServer::DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) { + // Test various scenarios for a DoExchange + if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) { + return Status::Invalid("Must provide a command descriptor"); + } + + const std::string& cmd = reader->descriptor().cmd; + if (cmd == "error") { + // Immediately return an error to the client. + return Status::NotImplemented("Expected error"); + } else if (cmd == "get") { + return RunExchangeGet(std::move(reader), std::move(writer)); + } else if (cmd == "put") { + return RunExchangePut(std::move(reader), std::move(writer)); + } else if (cmd == "counter") { + return RunExchangeCounter(std::move(reader), std::move(writer)); + } else if (cmd == "total") { + return RunExchangeTotal(std::move(reader), std::move(writer)); + } else if (cmd == "echo") { + return RunExchangeEcho(std::move(reader), std::move(writer)); + } else if (cmd == "large_batch") { + return RunExchangeLargeBatch(std::move(reader), std::move(writer)); + } else if (cmd == "TestUndrained") { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + return Status::OK(); + } else { + return Status::NotImplemented("Scenario not implemented: ", cmd); + } +} + +// A simple example - act like DoGet. +Status TestFlightServer::RunExchangeGet(std::unique_ptr reader, + std::unique_ptr writer) { + RETURN_NOT_OK(writer->Begin(ExampleIntSchema())); + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + for (const auto& batch : batches) { + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + } + return Status::OK(); +} + +// A simple example - act like DoPut +Status TestFlightServer::RunExchangePut(std::unique_ptr reader, + std::unique_ptr writer) { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + if (!schema->Equals(ExampleIntSchema(), false)) { + return Status::Invalid("Schema is not as expected"); + } + RecordBatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + FlightStreamChunk chunk; + for (const auto& batch : batches) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data) { + return Status::Invalid("Expected another batch"); + } + if (!batch->Equals(*chunk.data)) { + return Status::Invalid("Batch does not match"); + } + } + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (chunk.data || chunk.app_metadata) { + return Status::Invalid("Too many batches"); + } + + RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done"))); + return Status::OK(); +} + +// Read some number of record batches from the client, send a +// metadata message back with the count, then echo the batches back. +Status TestFlightServer::RunExchangeCounter(std::unique_ptr reader, + std::unique_ptr writer) { + std::vector> batches; + FlightStreamChunk chunk; + int chunks = 0; + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (chunk.data) { + batches.push_back(chunk.data); + chunks++; + } + } + + // Echo back the number of record batches read. + std::shared_ptr buf = Buffer::FromString(std::to_string(chunks)); + RETURN_NOT_OK(writer->WriteMetadata(buf)); + // Echo the record batches themselves. + if (chunks > 0) { + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + RETURN_NOT_OK(writer->Begin(schema)); + + for (const auto& batch : batches) { + RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); + } + } + + return Status::OK(); +} + +// Read int64 batches from the client, each time sending back a +// batch with a running sum of columns. +Status TestFlightServer::RunExchangeTotal(std::unique_ptr reader, + std::unique_ptr writer) { + FlightStreamChunk chunk{}; + ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); + // Ensure the schema contains only int64 columns + for (const auto& field : schema->fields()) { + if (field->type()->id() != Type::type::INT64) { + return Status::Invalid("Field is not INT64: ", field->name()); + } + } + std::vector sums(schema->num_fields()); + std::vector> columns(schema->num_fields()); + RETURN_NOT_OK(writer->Begin(schema)); + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (chunk.data) { + if (!chunk.data->schema()->Equals(schema, false)) { + // A compliant client implementation would make this impossible + return Status::Invalid("Schemas are incompatible"); + } + + // Update the running totals + auto builder = std::make_shared(); + int col_index = 0; + for (const auto& column : chunk.data->columns()) { + auto arr = std::dynamic_pointer_cast(column); + if (!arr) { + return MakeFlightError(FlightStatusCode::Internal, "Could not cast array"); + } + for (int row = 0; row < column->length(); row++) { + if (!arr->IsNull(row)) { + sums[col_index] += arr->Value(row); + } + } + + builder->Reset(); + RETURN_NOT_OK(builder->Append(sums[col_index])); + RETURN_NOT_OK(builder->Finish(&columns[col_index])); + + col_index++; + } + + // Echo the totals to the client + auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns); + RETURN_NOT_OK(writer->WriteRecordBatch(*response)); + } + } + return Status::OK(); +} + +// Echo the client's messages back. +Status TestFlightServer::RunExchangeEcho(std::unique_ptr reader, + std::unique_ptr writer) { + FlightStreamChunk chunk; + bool begun = false; + while (true) { + ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); + if (!chunk.data && !chunk.app_metadata) { + break; + } + if (!begun && chunk.data) { + begun = true; + RETURN_NOT_OK(writer->Begin(chunk.data->schema())); + } + if (chunk.data && chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata)); + } else if (chunk.data) { + RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); + } else if (chunk.app_metadata) { + RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata)); + } + } + return Status::OK(); +} + +// Regression test for ARROW-13253 +Status TestFlightServer::RunExchangeLargeBatch( + std::unique_ptr, std::unique_ptr writer) { + ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); + RETURN_NOT_OK(writer->Begin(batch->schema())); + return writer->WriteRecordBatch(*batch); +} + +Status TestFlightServer::RunAction1(const Action& action, + std::unique_ptr* out) { + std::vector results; + for (int i = 0; i < 3; ++i) { + Result result; + std::string value = action.body->ToString() + "-part" + std::to_string(i); + result.body = Buffer::FromString(std::move(value)); + results.push_back(result); + } + *out = std::make_unique(std::move(results)); + return Status::OK(); +} + +Status TestFlightServer::RunAction2(std::unique_ptr* out) { + // Empty + *out = std::make_unique(std::vector{}); + return Status::OK(); +} + +Status TestFlightServer::ListIncomingHeaders(const ServerCallContext& context, + const Action& action, + std::unique_ptr* out) { + std::vector results; + std::string_view prefix(*action.body); + for (const auto& header : context.incoming_headers()) { + if (header.first.substr(0, prefix.size()) != prefix) { + continue; + } + Result result; + result.body = + Buffer::FromString(std::string(header.first) + ": " + std::string(header.second)); + results.push_back(result); + } + *out = std::make_unique(std::move(results)); + return Status::OK(); +} + +Status TestFlightServer::DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* out) { + if (action.type == "action1") { + return RunAction1(action, out); + } else if (action.type == "action2") { + return RunAction2(out); + } else if (action.type == "list-incoming-headers") { + return ListIncomingHeaders(context, action, out); + } else { + return Status::NotImplemented(action.type); + } +} + +Status TestFlightServer::ListActions(const ServerCallContext& context, + std::vector* out) { + std::vector actions = ExampleActionTypes(); + *out = std::move(actions); + return Status::OK(); +} + +Status TestFlightServer::GetSchema(const ServerCallContext& context, + const FlightDescriptor& request, + std::unique_ptr* schema) { + std::vector flights = ExampleFlightInfo(); + + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *schema = std::make_unique(info.serialized_schema()); + return Status::OK(); + } + } + return Status::Invalid("Flight not found: ", request.ToString()); +} + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_flight_server.h b/cpp/src/arrow/flight/test_flight_server.h new file mode 100644 index 00000000000..794dd834c01 --- /dev/null +++ b/cpp/src/arrow/flight/test_flight_server.h @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/server.h" +#include "arrow/flight/type_fwd.h" +#include "arrow/flight/visibility.h" +#include "arrow/status.h" + +namespace arrow::flight { + +class ARROW_FLIGHT_EXPORT TestFlightServer : public FlightServerBase { + public: + static std::unique_ptr Make(); + + Status ListFlights(const ServerCallContext& context, const Criteria* criteria, + std::unique_ptr* listings) override; + + Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* out) override; + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override; + + Status DoPut(const ServerCallContext&, std::unique_ptr reader, + std::unique_ptr writer) override; + + Status DoExchange(const ServerCallContext& context, + std::unique_ptr reader, + std::unique_ptr writer) override; + + // A simple example - act like DoGet. + Status RunExchangeGet(std::unique_ptr reader, + std::unique_ptr writer); + + // A simple example - act like DoPut + Status RunExchangePut(std::unique_ptr reader, + std::unique_ptr writer); + + // Read some number of record batches from the client, send a + // metadata message back with the count, then echo the batches back. + Status RunExchangeCounter(std::unique_ptr reader, + std::unique_ptr writer); + + // Read int64 batches from the client, each time sending back a + // batch with a running sum of columns. + Status RunExchangeTotal(std::unique_ptr reader, + std::unique_ptr writer); + + // Echo the client's messages back. + Status RunExchangeEcho(std::unique_ptr reader, + std::unique_ptr writer); + + // Regression test for ARROW-13253 + Status RunExchangeLargeBatch(std::unique_ptr, + std::unique_ptr writer); + + Status RunAction1(const Action& action, std::unique_ptr* out); + + Status RunAction2(std::unique_ptr* out); + + Status ListIncomingHeaders(const ServerCallContext& context, const Action& action, + std::unique_ptr* out); + + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* out) override; + + Status ListActions(const ServerCallContext& context, + std::vector* out) override; + + Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* schema) override; +}; + +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_server.cc b/cpp/src/arrow/flight/test_server.cc index 18bf2b41359..ba84b8f532e 100644 --- a/cpp/src/arrow/flight/test_server.cc +++ b/cpp/src/arrow/flight/test_server.cc @@ -26,6 +26,7 @@ #include #include "arrow/flight/server.h" +#include "arrow/flight/test_flight_server.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/util/logging.h" @@ -38,7 +39,7 @@ std::unique_ptr g_server; int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - g_server = arrow::flight::ExampleTestServer(); + g_server = arrow::flight::TestFlightServer::Make(); arrow::flight::Location location; if (FLAGS_unix.empty()) { diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 8b4245e74e8..127827ff38c 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -49,8 +49,7 @@ #include "arrow/flight/api.h" #include "arrow/flight/serialization_internal.h" -namespace arrow { -namespace flight { +namespace arrow::flight { namespace bp = boost::process; namespace fs = boost::filesystem; @@ -90,25 +89,6 @@ Status ResolveCurrentExecutable(fs::path* out) { } } -class ErrorRecordBatchReader : public RecordBatchReader { - public: - ErrorRecordBatchReader() : schema_(arrow::schema({})) {} - - std::shared_ptr schema() const override { return schema_; } - - Status ReadNext(std::shared_ptr* out) override { - *out = nullptr; - return Status::OK(); - } - - Status Close() override { - // This should be propagated over DoGet to the client - return Status::IOError("Expected error"); - } - - private: - std::shared_ptr schema_; -}; } // namespace void TestServer::Start(const std::vector& extra_args) { @@ -171,364 +151,6 @@ int TestServer::port() const { return port_; } const std::string& TestServer::unix_sock() const { return unix_sock_; } -Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { - if (ticket.ticket == "ticket-ints-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-floats-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleFloatBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-dicts-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleDictBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else if (ticket.ticket == "ticket-large-batch-1") { - RecordBatchVector batches; - RETURN_NOT_OK(ExampleLargeBatches(&batches)); - ARROW_ASSIGN_OR_RAISE(*out, RecordBatchReader::Make(batches)); - return Status::OK(); - } else { - return Status::NotImplemented("no stream implemented for ticket: " + ticket.ticket); - } -} - -class FlightTestServer : public FlightServerBase { - Status ListFlights(const ServerCallContext& context, const Criteria* criteria, - std::unique_ptr* listings) override { - std::vector flights = ExampleFlightInfo(); - if (criteria && criteria->expression != "") { - // For test purposes, if we get criteria, return no results - flights.clear(); - } - *listings = std::make_unique(flights); - return Status::OK(); - } - - Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, - std::unique_ptr* out) override { - // Test that Arrow-C++ status codes make it through the transport - if (request.type == FlightDescriptor::DescriptorType::CMD && - request.cmd == "status-outofmemory") { - return Status::OutOfMemory("Sentinel"); - } - - std::vector flights = ExampleFlightInfo(); - - for (const auto& info : flights) { - if (info.descriptor().Equals(request)) { - *out = std::make_unique(info); - return Status::OK(); - } - } - return Status::Invalid("Flight not found: ", request.ToString()); - } - - Status DoGet(const ServerCallContext& context, const Ticket& request, - std::unique_ptr* data_stream) override { - // Test for ARROW-5095 - if (request.ticket == "ARROW-5095-fail") { - return Status::UnknownError("Server-side error"); - } - if (request.ticket == "ARROW-5095-success") { - return Status::OK(); - } - if (request.ticket == "ARROW-13253-DoGet-Batch") { - // Make batch > 2GiB in size - ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); - ARROW_ASSIGN_OR_RAISE(auto reader, RecordBatchReader::Make({batch})); - *data_stream = std::make_unique(std::move(reader)); - return Status::OK(); - } - if (request.ticket == "ticket-stream-error") { - auto reader = std::make_shared(); - *data_stream = std::make_unique(std::move(reader)); - return Status::OK(); - } - - std::shared_ptr batch_reader; - RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader)); - - *data_stream = std::make_unique(batch_reader); - return Status::OK(); - } - - Status DoPut(const ServerCallContext&, std::unique_ptr reader, - std::unique_ptr writer) override { - return reader->ToRecordBatches().status(); - } - - Status DoExchange(const ServerCallContext& context, - std::unique_ptr reader, - std::unique_ptr writer) override { - // Test various scenarios for a DoExchange - if (reader->descriptor().type != FlightDescriptor::DescriptorType::CMD) { - return Status::Invalid("Must provide a command descriptor"); - } - - const std::string& cmd = reader->descriptor().cmd; - if (cmd == "error") { - // Immediately return an error to the client. - return Status::NotImplemented("Expected error"); - } else if (cmd == "get") { - return RunExchangeGet(std::move(reader), std::move(writer)); - } else if (cmd == "put") { - return RunExchangePut(std::move(reader), std::move(writer)); - } else if (cmd == "counter") { - return RunExchangeCounter(std::move(reader), std::move(writer)); - } else if (cmd == "total") { - return RunExchangeTotal(std::move(reader), std::move(writer)); - } else if (cmd == "echo") { - return RunExchangeEcho(std::move(reader), std::move(writer)); - } else if (cmd == "large_batch") { - return RunExchangeLargeBatch(std::move(reader), std::move(writer)); - } else if (cmd == "TestUndrained") { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - return Status::OK(); - } else { - return Status::NotImplemented("Scenario not implemented: ", cmd); - } - } - - // A simple example - act like DoGet. - Status RunExchangeGet(std::unique_ptr reader, - std::unique_ptr writer) { - RETURN_NOT_OK(writer->Begin(ExampleIntSchema())); - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - for (const auto& batch : batches) { - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - } - return Status::OK(); - } - - // A simple example - act like DoPut - Status RunExchangePut(std::unique_ptr reader, - std::unique_ptr writer) { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - if (!schema->Equals(ExampleIntSchema(), false)) { - return Status::Invalid("Schema is not as expected"); - } - RecordBatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - FlightStreamChunk chunk; - for (const auto& batch : batches) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data) { - return Status::Invalid("Expected another batch"); - } - if (!batch->Equals(*chunk.data)) { - return Status::Invalid("Batch does not match"); - } - } - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (chunk.data || chunk.app_metadata) { - return Status::Invalid("Too many batches"); - } - - RETURN_NOT_OK(writer->WriteMetadata(Buffer::FromString("done"))); - return Status::OK(); - } - - // Read some number of record batches from the client, send a - // metadata message back with the count, then echo the batches back. - Status RunExchangeCounter(std::unique_ptr reader, - std::unique_ptr writer) { - std::vector> batches; - FlightStreamChunk chunk; - int chunks = 0; - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (chunk.data) { - batches.push_back(chunk.data); - chunks++; - } - } - - // Echo back the number of record batches read. - std::shared_ptr buf = Buffer::FromString(std::to_string(chunks)); - RETURN_NOT_OK(writer->WriteMetadata(buf)); - // Echo the record batches themselves. - if (chunks > 0) { - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - RETURN_NOT_OK(writer->Begin(schema)); - - for (const auto& batch : batches) { - RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); - } - } - - return Status::OK(); - } - - // Read int64 batches from the client, each time sending back a - // batch with a running sum of columns. - Status RunExchangeTotal(std::unique_ptr reader, - std::unique_ptr writer) { - FlightStreamChunk chunk{}; - ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); - // Ensure the schema contains only int64 columns - for (const auto& field : schema->fields()) { - if (field->type()->id() != Type::type::INT64) { - return Status::Invalid("Field is not INT64: ", field->name()); - } - } - std::vector sums(schema->num_fields()); - std::vector> columns(schema->num_fields()); - RETURN_NOT_OK(writer->Begin(schema)); - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (chunk.data) { - if (!chunk.data->schema()->Equals(schema, false)) { - // A compliant client implementation would make this impossible - return Status::Invalid("Schemas are incompatible"); - } - - // Update the running totals - auto builder = std::make_shared(); - int col_index = 0; - for (const auto& column : chunk.data->columns()) { - auto arr = std::dynamic_pointer_cast(column); - if (!arr) { - return MakeFlightError(FlightStatusCode::Internal, "Could not cast array"); - } - for (int row = 0; row < column->length(); row++) { - if (!arr->IsNull(row)) { - sums[col_index] += arr->Value(row); - } - } - - builder->Reset(); - RETURN_NOT_OK(builder->Append(sums[col_index])); - RETURN_NOT_OK(builder->Finish(&columns[col_index])); - - col_index++; - } - - // Echo the totals to the client - auto response = RecordBatch::Make(schema, /* num_rows */ 1, columns); - RETURN_NOT_OK(writer->WriteRecordBatch(*response)); - } - } - return Status::OK(); - } - - // Echo the client's messages back. - Status RunExchangeEcho(std::unique_ptr reader, - std::unique_ptr writer) { - FlightStreamChunk chunk; - bool begun = false; - while (true) { - ARROW_ASSIGN_OR_RAISE(chunk, reader->Next()); - if (!chunk.data && !chunk.app_metadata) { - break; - } - if (!begun && chunk.data) { - begun = true; - RETURN_NOT_OK(writer->Begin(chunk.data->schema())); - } - if (chunk.data && chunk.app_metadata) { - RETURN_NOT_OK(writer->WriteWithMetadata(*chunk.data, chunk.app_metadata)); - } else if (chunk.data) { - RETURN_NOT_OK(writer->WriteRecordBatch(*chunk.data)); - } else if (chunk.app_metadata) { - RETURN_NOT_OK(writer->WriteMetadata(chunk.app_metadata)); - } - } - return Status::OK(); - } - - // Regression test for ARROW-13253 - Status RunExchangeLargeBatch(std::unique_ptr, - std::unique_ptr writer) { - ARROW_ASSIGN_OR_RAISE(auto batch, VeryLargeBatch()); - RETURN_NOT_OK(writer->Begin(batch->schema())); - return writer->WriteRecordBatch(*batch); - } - - Status RunAction1(const Action& action, std::unique_ptr* out) { - std::vector results; - for (int i = 0; i < 3; ++i) { - Result result; - std::string value = action.body->ToString() + "-part" + std::to_string(i); - result.body = Buffer::FromString(std::move(value)); - results.push_back(result); - } - *out = std::make_unique(std::move(results)); - return Status::OK(); - } - - Status RunAction2(std::unique_ptr* out) { - // Empty - *out = std::make_unique(std::vector{}); - return Status::OK(); - } - - Status ListIncomingHeaders(const ServerCallContext& context, const Action& action, - std::unique_ptr* out) { - std::vector results; - std::string_view prefix(*action.body); - for (const auto& header : context.incoming_headers()) { - if (header.first.substr(0, prefix.size()) != prefix) { - continue; - } - Result result; - result.body = Buffer::FromString(std::string(header.first) + ": " + - std::string(header.second)); - results.push_back(result); - } - *out = std::make_unique(std::move(results)); - return Status::OK(); - } - - Status DoAction(const ServerCallContext& context, const Action& action, - std::unique_ptr* out) override { - if (action.type == "action1") { - return RunAction1(action, out); - } else if (action.type == "action2") { - return RunAction2(out); - } else if (action.type == "list-incoming-headers") { - return ListIncomingHeaders(context, action, out); - } else { - return Status::NotImplemented(action.type); - } - } - - Status ListActions(const ServerCallContext& context, - std::vector* out) override { - std::vector actions = ExampleActionTypes(); - *out = std::move(actions); - return Status::OK(); - } - - Status GetSchema(const ServerCallContext& context, const FlightDescriptor& request, - std::unique_ptr* schema) override { - std::vector flights = ExampleFlightInfo(); - - for (const auto& info : flights) { - if (info.descriptor().Equals(request)) { - *schema = std::make_unique(info.serialized_schema()); - return Status::OK(); - } - } - return Status::Invalid("Flight not found: ", request.ToString()); - } -}; - -std::unique_ptr ExampleTestServer() { - return std::make_unique(); -} - FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, const std::vector& endpoints, int64_t total_records, int64_t total_bytes, bool ordered, @@ -701,109 +323,6 @@ std::vector ExampleActionTypes() { return {{"drop", "drop a dataset"}, {"cache", "cache a dataset"}}; } -TestServerAuthHandler::TestServerAuthHandler(const std::string& username, - const std::string& password) - : username_(username), password_(password) {} - -TestServerAuthHandler::~TestServerAuthHandler() {} - -Status TestServerAuthHandler::Authenticate(const ServerCallContext& context, - ServerAuthSender* outgoing, - ServerAuthReader* incoming) { - std::string token; - RETURN_NOT_OK(incoming->Read(&token)); - if (token != password_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - RETURN_NOT_OK(outgoing->Write(username_)); - return Status::OK(); -} - -Status TestServerAuthHandler::IsValid(const ServerCallContext& context, - const std::string& token, - std::string* peer_identity) { - if (token != password_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - *peer_identity = username_; - return Status::OK(); -} - -TestServerBasicAuthHandler::TestServerBasicAuthHandler(const std::string& username, - const std::string& password) { - basic_auth_.username = username; - basic_auth_.password = password; -} - -TestServerBasicAuthHandler::~TestServerBasicAuthHandler() {} - -Status TestServerBasicAuthHandler::Authenticate(const ServerCallContext& context, - ServerAuthSender* outgoing, - ServerAuthReader* incoming) { - std::string token; - RETURN_NOT_OK(incoming->Read(&token)); - ARROW_ASSIGN_OR_RAISE(BasicAuth incoming_auth, BasicAuth::Deserialize(token)); - if (incoming_auth.username != basic_auth_.username || - incoming_auth.password != basic_auth_.password) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - RETURN_NOT_OK(outgoing->Write(basic_auth_.username)); - return Status::OK(); -} - -Status TestServerBasicAuthHandler::IsValid(const ServerCallContext& context, - const std::string& token, - std::string* peer_identity) { - if (token != basic_auth_.username) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - *peer_identity = basic_auth_.username; - return Status::OK(); -} - -TestClientAuthHandler::TestClientAuthHandler(const std::string& username, - const std::string& password) - : username_(username), password_(password) {} - -TestClientAuthHandler::~TestClientAuthHandler() {} - -Status TestClientAuthHandler::Authenticate(ClientAuthSender* outgoing, - ClientAuthReader* incoming) { - RETURN_NOT_OK(outgoing->Write(password_)); - std::string username; - RETURN_NOT_OK(incoming->Read(&username)); - if (username != username_) { - return MakeFlightError(FlightStatusCode::Unauthenticated, "Invalid token"); - } - return Status::OK(); -} - -Status TestClientAuthHandler::GetToken(std::string* token) { - *token = password_; - return Status::OK(); -} - -TestClientBasicAuthHandler::TestClientBasicAuthHandler(const std::string& username, - const std::string& password) { - basic_auth_.username = username; - basic_auth_.password = password; -} - -TestClientBasicAuthHandler::~TestClientBasicAuthHandler() {} - -Status TestClientBasicAuthHandler::Authenticate(ClientAuthSender* outgoing, - ClientAuthReader* incoming) { - ARROW_ASSIGN_OR_RAISE(std::string pb_result, basic_auth_.SerializeToString()); - RETURN_NOT_OK(outgoing->Write(pb_result)); - RETURN_NOT_OK(incoming->Read(&token_)); - return Status::OK(); -} - -Status TestClientBasicAuthHandler::GetToken(std::string* token) { - *token = token_; - return Status::OK(); -} - Status ExampleTlsCertificates(std::vector* out) { std::string root; RETURN_NOT_OK(GetTestResourceRoot(&root)); @@ -860,5 +379,4 @@ Status ExampleTlsCertificateRoot(CertKeyPair* out) { } } -} // namespace flight -} // namespace arrow +} // namespace arrow::flight diff --git a/cpp/src/arrow/flight/test_util.h b/cpp/src/arrow/flight/test_util.h index c0b42d9b90c..15ba6145ecd 100644 --- a/cpp/src/arrow/flight/test_util.h +++ b/cpp/src/arrow/flight/test_util.h @@ -32,9 +32,7 @@ #include "arrow/testing/util.h" #include "arrow/flight/client.h" -#include "arrow/flight/client_auth.h" #include "arrow/flight/server.h" -#include "arrow/flight/server_auth.h" #include "arrow/flight/types.h" #include "arrow/flight/visibility.h" @@ -95,10 +93,6 @@ class ARROW_FLIGHT_EXPORT TestServer { std::shared_ptr<::boost::process::child> server_process_; }; -/// \brief Create a simple Flight server for testing -ARROW_FLIGHT_EXPORT -std::unique_ptr ExampleTestServer(); - // Helper to initialize a server and matching client with callbacks to // populate options. template @@ -195,65 +189,6 @@ FlightInfo MakeFlightInfo(const Schema& schema, const FlightDescriptor& descript int64_t total_records, int64_t total_bytes, bool ordered, std::string app_metadata); -// ---------------------------------------------------------------------- -// A pair of authentication handlers that check for a predefined password -// and set the peer identity to a predefined username. - -class ARROW_FLIGHT_EXPORT TestServerAuthHandler : public ServerAuthHandler { - public: - explicit TestServerAuthHandler(const std::string& username, - const std::string& password); - ~TestServerAuthHandler() override; - Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, - ServerAuthReader* incoming) override; - Status IsValid(const ServerCallContext& context, const std::string& token, - std::string* peer_identity) override; - - private: - std::string username_; - std::string password_; -}; - -class ARROW_FLIGHT_EXPORT TestServerBasicAuthHandler : public ServerAuthHandler { - public: - explicit TestServerBasicAuthHandler(const std::string& username, - const std::string& password); - ~TestServerBasicAuthHandler() override; - Status Authenticate(const ServerCallContext& context, ServerAuthSender* outgoing, - ServerAuthReader* incoming) override; - Status IsValid(const ServerCallContext& context, const std::string& token, - std::string* peer_identity) override; - - private: - BasicAuth basic_auth_; -}; - -class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler { - public: - explicit TestClientAuthHandler(const std::string& username, - const std::string& password); - ~TestClientAuthHandler() override; - Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; - Status GetToken(std::string* token) override; - - private: - std::string username_; - std::string password_; -}; - -class ARROW_FLIGHT_EXPORT TestClientBasicAuthHandler : public ClientAuthHandler { - public: - explicit TestClientBasicAuthHandler(const std::string& username, - const std::string& password); - ~TestClientBasicAuthHandler() override; - Status Authenticate(ClientAuthSender* outgoing, ClientAuthReader* incoming) override; - Status GetToken(std::string* token) override; - - private: - BasicAuth basic_auth_; - std::string token_; -}; - ARROW_FLIGHT_EXPORT Status ExampleTlsCertificates(std::vector* out); From 4d200dc17daf268863df6f0d7c458cb460904a7c Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 14 Aug 2024 06:31:47 +0530 Subject: [PATCH 008/157] GH-43577: [Java] getBuffers method needs correction on clear flag usage (#43583) ### Rationale for this change `getBuffers` method provides the capability to clear the buffers in the vector, this has not been properly tested while clear flag is not properly used in the implementation across various types of vectors. ### What changes are included in this PR? Updating the vector `getBuffers` method to use `clear` flag as expected and adding corresponding test cases. ### Are these changes tested? Yes, via existing test cases and new test cases. ### Are there any user-facing changes? Yes * GitHub Issue: #43577 Authored-by: Vibhatha Abeykoon Signed-off-by: David Li --- .../arrow/vector/complex/AbstractStructVector.java | 11 +++++++++++ .../vector/complex/BaseRepeatedValueVector.java | 11 +++++++++++ .../arrow/vector/complex/FixedSizeListVector.java | 11 +++++++++++ .../arrow/vector/complex/LargeListVector.java | 9 +++++---- .../arrow/vector/complex/LargeListViewVector.java | 5 +++-- .../org/apache/arrow/vector/complex/ListVector.java | 9 +++++---- .../apache/arrow/vector/complex/ListViewVector.java | 3 ++- .../apache/arrow/vector/complex/StructVector.java | 9 +++++---- .../org/apache/arrow/vector/TestVectorReset.java | 13 ++++++++++++- 9 files changed, 65 insertions(+), 16 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java index feb7edfec94..2921e43cb64 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractStructVector.java @@ -382,6 +382,17 @@ public VectorWithOrdinal getChildVectorWithOrdinal(String name) { return new VectorWithOrdinal(vector, ordinal); } + /** + * Return the underlying buffers associated with this vector. Note that this doesn't impact the + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it + * (unless they change it). + * + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. + * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. + */ @Override public ArrowBuf[] getBuffers(boolean clear) { final List buffers = new ArrayList<>(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java index 1cdb87eba03..fbe83bad52c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java @@ -271,6 +271,17 @@ public void reset() { valueCount = 0; } + /** + * Return the underlying buffers associated with this vector. Note that this doesn't impact the + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it + * (unless they change it). + * + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. + * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. + */ @Override public ArrowBuf[] getBuffers(boolean clear) { final ArrowBuf[] buffers; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java index cb455084808..c762eb51725 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java @@ -360,6 +360,17 @@ public void reset() { valueCount = 0; } + /** + * Return the underlying buffers associated with this vector. Note that this doesn't impact the + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it + * (unless they change it). + * + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. + * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. + */ @Override public ArrowBuf[] getBuffers(boolean clear) { setReaderAndWriterIndex(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java index b5b32c8032d..ed075352c93 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListVector.java @@ -882,12 +882,13 @@ public void reset() { /** * Return the underlying buffers associated with this vector. Note that this doesn't impact the - * reference counts for this buffer so it only should be used for in-context access. Also note - * that this buffer changes regularly thus external classes shouldn't hold a reference to it + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it * (unless they change it). * - * @param clear Whether to clear vector before returning; the buffers will still be refcounted but - * the returned array will be the only reference to them + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index 17ccdbf0eae..f6b3de88b77 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -546,7 +546,8 @@ public void reset() { * (unless they change it). * * @param clear Whether to clear vector before returning, the buffers will still be refcounted but - * the returned array will be the only reference to them + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override @@ -561,7 +562,7 @@ public ArrowBuf[] getBuffers(boolean clear) { list.add(validityBuffer); list.add(offsetBuffer); list.add(sizeBuffer); - list.addAll(Arrays.asList(vector.getBuffers(clear))); + list.addAll(Arrays.asList(vector.getBuffers(false))); buffers = list.toArray(new ArrowBuf[list.size()]); } if (clear) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index a1e18210fc6..76682c28fe6 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -726,12 +726,13 @@ public void reset() { /** * Return the underlying buffers associated with this vector. Note that this doesn't impact the - * reference counts for this buffer so it only should be used for in-context access. Also note - * that this buffer changes regularly thus external classes shouldn't hold a reference to it + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it * (unless they change it). * - * @param clear Whether to clear vector before returning; the buffers will still be refcounted but - * the returned array will be the only reference to them + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java index 6ced66d81ec..7f6d92f3be9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java @@ -704,7 +704,8 @@ public void reset() { * (unless they change it). * * @param clear Whether to clear vector before returning, the buffers will still be refcounted but - * the returned array will be the only reference to them + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java index dda9b6547f7..ca5f572034c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/StructVector.java @@ -396,12 +396,13 @@ public int getValueCapacity() { /** * Return the underlying buffers associated with this vector. Note that this doesn't impact the - * reference counts for this buffer so it only should be used for in-context access. Also note - * that this buffer changes regularly thus external classes shouldn't hold a reference to it + * reference counts for this buffer, so it only should be used for in-context access. Also note + * that this buffer changes regularly, thus external classes shouldn't hold a reference to it * (unless they change it). * - * @param clear Whether to clear vector before returning; the buffers will still be refcounted but - * the returned array will be the only reference to them + * @param clear Whether to clear vector before returning, the buffers will still be refcounted but + * the returned array will be the only reference to them. Also, this won't clear the child + * buffers. * @return The underlying {@link ArrowBuf buffers} that is used by this vector instance. */ @Override diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReset.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReset.java index 48cf78a4c2e..28d73a8fdff 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReset.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReset.java @@ -25,6 +25,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.NonNullableStructVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; @@ -122,7 +123,10 @@ public void testListTypeReset() { "VarList", allocator, FieldType.nullable(MinorType.INT.getType()), null); final FixedSizeListVector fixedList = new FixedSizeListVector( - "FixedList", allocator, FieldType.nullable(new FixedSizeList(2)), null)) { + "FixedList", allocator, FieldType.nullable(new FixedSizeList(2)), null); + final ListViewVector variableViewList = + new ListViewVector( + "VarListView", allocator, FieldType.nullable(MinorType.INT.getType()), null)) { // ListVector variableList.allocateNewSafe(); variableList.startNewValue(0); @@ -136,6 +140,13 @@ public void testListTypeReset() { fixedList.setNull(0); fixedList.setValueCount(1); resetVectorAndVerify(fixedList, fixedList.getBuffers(false)); + + // ListViewVector + variableViewList.allocateNewSafe(); + variableViewList.startNewValue(0); + variableViewList.endValue(0, 0); + variableViewList.setValueCount(1); + resetVectorAndVerify(variableViewList, variableViewList.getBuffers(false)); } } From 6e7125b61f2ff587a09dbe45ab05d2f28632a702 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 14 Aug 2024 10:41:38 +0900 Subject: [PATCH 009/157] GH-43454: [C++][Python] Add Opaque canonical extension type (#43458) ### Rationale for this change Add the newly ratified extension type. ### What changes are included in this PR? The C++/Python implementation only. ### Are these changes tested? Yes ### Are there any user-facing changes? No. * GitHub Issue: #43454 Lead-authored-by: David Li Co-authored-by: Weston Pace Signed-off-by: David Li --- cpp/src/arrow/CMakeLists.txt | 1 + .../compute/kernels/scalar_cast_numeric.cc | 23 ++ cpp/src/arrow/extension/CMakeLists.txt | 6 + cpp/src/arrow/extension/opaque.cc | 109 ++++++++++ cpp/src/arrow/extension/opaque.h | 69 ++++++ cpp/src/arrow/extension/opaque_test.cc | 197 ++++++++++++++++++ docs/source/python/api/arrays.rst | 3 + docs/source/python/api/datatypes.rst | 10 + python/pyarrow/__init__.py | 8 +- python/pyarrow/array.pxi | 28 +++ python/pyarrow/includes/libarrow.pxd | 13 ++ python/pyarrow/lib.pxd | 5 + python/pyarrow/public-api.pxi | 2 + python/pyarrow/scalar.pxi | 6 + python/pyarrow/tests/test_extension_type.py | 46 ++++ python/pyarrow/tests/test_misc.py | 3 + python/pyarrow/types.pxi | 101 +++++++++ 17 files changed, 627 insertions(+), 3 deletions(-) create mode 100644 cpp/src/arrow/extension/opaque.cc create mode 100644 cpp/src/arrow/extension/opaque.h create mode 100644 cpp/src/arrow/extension/opaque_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 9c66a58c542..67d2c19f98a 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -907,6 +907,7 @@ endif() if(ARROW_JSON) arrow_add_object_library(ARROW_JSON extension/fixed_shape_tensor.cc + extension/opaque.cc json/options.cc json/chunked_builder.cc json/chunker.cc diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc index 3df86e7d693..bd9be3e8a95 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_numeric.cc @@ -865,6 +865,25 @@ std::shared_ptr GetCastToHalfFloat() { return func; } +struct NullExtensionTypeMatcher : public TypeMatcher { + ~NullExtensionTypeMatcher() override = default; + + bool Matches(const DataType& type) const override { + return type.id() == Type::EXTENSION && + checked_cast(type).storage_id() == Type::NA; + } + + std::string ToString() const override { return "extension"; } + + bool Equals(const TypeMatcher& other) const override { + if (this == &other) { + return true; + } + auto casted = dynamic_cast(&other); + return casted != nullptr; + } +}; + } // namespace std::vector> GetNumericCasts() { @@ -875,6 +894,10 @@ std::vector> GetNumericCasts() { auto cast_null = std::make_shared("cast_null", Type::NA); DCHECK_OK(cast_null->AddKernel(Type::DICTIONARY, {InputType(Type::DICTIONARY)}, null(), OutputAllNull)); + // Explicitly allow casting extension type with null backing array to null + DCHECK_OK(cast_null->AddKernel( + Type::EXTENSION, {InputType(std::make_shared())}, null(), + OutputAllNull)); functions.push_back(cast_null); functions.push_back(GetCastToInteger("cast_int8")); diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index c15c42874d4..6741ab602f5 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -21,4 +21,10 @@ add_arrow_test(test PREFIX "arrow-fixed-shape-tensor") +add_arrow_test(test + SOURCES + opaque_test.cc + PREFIX + "arrow-extension-opaque") + arrow_install_all_headers("arrow/extension") diff --git a/cpp/src/arrow/extension/opaque.cc b/cpp/src/arrow/extension/opaque.cc new file mode 100644 index 00000000000..c430bb5d2ea --- /dev/null +++ b/cpp/src/arrow/extension/opaque.cc @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/opaque.h" + +#include + +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep +#include "arrow/util/logging.h" + +#include +#include +#include + +namespace arrow::extension { + +std::string OpaqueType::ToString(bool show_metadata) const { + std::stringstream ss; + ss << "extension<" << this->extension_name() + << "[storage_type=" << storage_type_->ToString(show_metadata) + << ", type_name=" << type_name_ << ", vendor_name=" << vendor_name_ << "]>"; + return ss.str(); +} + +bool OpaqueType::ExtensionEquals(const ExtensionType& other) const { + if (extension_name() != other.extension_name()) { + return false; + } + const auto& opaque = internal::checked_cast(other); + return storage_type()->Equals(*opaque.storage_type()) && + type_name() == opaque.type_name() && vendor_name() == opaque.vendor_name(); +} + +std::string OpaqueType::Serialize() const { + rapidjson::Document document; + document.SetObject(); + rapidjson::Document::AllocatorType& allocator = document.GetAllocator(); + + rapidjson::Value type_name(rapidjson::StringRef(type_name_)); + document.AddMember(rapidjson::Value("type_name", allocator), type_name, allocator); + rapidjson::Value vendor_name(rapidjson::StringRef(vendor_name_)); + document.AddMember(rapidjson::Value("vendor_name", allocator), vendor_name, allocator); + + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + document.Accept(writer); + return buffer.GetString(); +} + +Result> OpaqueType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized_data) const { + rapidjson::Document document; + const auto& parsed = document.Parse(serialized_data.data(), serialized_data.length()); + if (parsed.HasParseError()) { + return Status::Invalid("Invalid serialized JSON data for OpaqueType: ", + rapidjson::GetParseError_En(parsed.GetParseError()), ": ", + serialized_data); + } else if (!document.IsObject()) { + return Status::Invalid("Invalid serialized JSON data for OpaqueType: not an object"); + } + if (!document.HasMember("type_name")) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: missing type_name"); + } else if (!document.HasMember("vendor_name")) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: missing vendor_name"); + } + + const auto& type_name = document["type_name"]; + const auto& vendor_name = document["vendor_name"]; + if (!type_name.IsString()) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: type_name is not a string"); + } else if (!vendor_name.IsString()) { + return Status::Invalid( + "Invalid serialized JSON data for OpaqueType: vendor_name is not a string"); + } + + return opaque(std::move(storage_type), type_name.GetString(), vendor_name.GetString()); +} + +std::shared_ptr OpaqueType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.opaque", + internal::checked_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +std::shared_ptr opaque(std::shared_ptr storage_type, + std::string type_name, std::string vendor_name) { + return std::make_shared(std::move(storage_type), std::move(type_name), + std::move(vendor_name)); +} + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/opaque.h b/cpp/src/arrow/extension/opaque.h new file mode 100644 index 00000000000..9814b391cba --- /dev/null +++ b/cpp/src/arrow/extension/opaque.h @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension_type.h" +#include "arrow/type.h" + +namespace arrow::extension { + +/// \brief Opaque is a placeholder for a type from an external (usually +/// non-Arrow) system that could not be interpreted. +class ARROW_EXPORT OpaqueType : public ExtensionType { + public: + /// \brief Construct an OpaqueType. + /// + /// \param[in] storage_type The underlying storage type. Should be + /// arrow::null if there is no data. + /// \param[in] type_name The name of the type in the external system. + /// \param[in] vendor_name The name of the external system. + explicit OpaqueType(std::shared_ptr storage_type, std::string type_name, + std::string vendor_name) + : ExtensionType(std::move(storage_type)), + type_name_(std::move(type_name)), + vendor_name_(std::move(vendor_name)) {} + + std::string extension_name() const override { return "arrow.opaque"; } + std::string ToString(bool show_metadata) const override; + bool ExtensionEquals(const ExtensionType& other) const override; + std::string Serialize() const override; + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const override; + /// Create an OpaqueArray from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + std::string_view type_name() const { return type_name_; } + std::string_view vendor_name() const { return vendor_name_; } + + private: + std::string type_name_; + std::string vendor_name_; +}; + +/// \brief Opaque is a wrapper for (usually binary) data from an external +/// (often non-Arrow) system that could not be interpreted. +class ARROW_EXPORT OpaqueArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Return an OpaqueType instance. +ARROW_EXPORT std::shared_ptr opaque(std::shared_ptr storage_type, + std::string type_name, + std::string vendor_name); + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/opaque_test.cc b/cpp/src/arrow/extension/opaque_test.cc new file mode 100644 index 00000000000..1629cdb3965 --- /dev/null +++ b/cpp/src/arrow/extension/opaque_test.cc @@ -0,0 +1,197 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/extension/fixed_shape_tensor.h" +#include "arrow/extension/opaque.h" +#include "arrow/extension_type.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/testing/extension_type.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { + +TEST(OpaqueType, Basics) { + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type2 = internal::checked_pointer_cast( + extension::opaque(null(), "type2", "vendor")); + ASSERT_EQ("arrow.opaque", type->extension_name()); + ASSERT_EQ(*type, *type); + ASSERT_NE(*arrow::null(), *type); + ASSERT_NE(*type, *type2); + ASSERT_EQ(*arrow::null(), *type->storage_type()); + ASSERT_THAT(type->Serialize(), ::testing::Not(::testing::IsEmpty())); + ASSERT_EQ(R"({"type_name":"type","vendor_name":"vendor"})", type->Serialize()); + ASSERT_EQ("type", type->type_name()); + ASSERT_EQ("vendor", type->vendor_name()); + ASSERT_EQ( + "extension", + type->ToString(false)); +} + +TEST(OpaqueType, Equals) { + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type2 = internal::checked_pointer_cast( + extension::opaque(null(), "type2", "vendor")); + auto type3 = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor2")); + auto type4 = internal::checked_pointer_cast( + extension::opaque(int64(), "type", "vendor")); + auto type5 = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + auto type6 = internal::checked_pointer_cast( + extension::fixed_shape_tensor(float64(), {1})); + + ASSERT_EQ(*type, *type); + ASSERT_EQ(*type2, *type2); + ASSERT_EQ(*type3, *type3); + ASSERT_EQ(*type4, *type4); + ASSERT_EQ(*type5, *type5); + + ASSERT_EQ(*type, *type5); + + ASSERT_NE(*type, *type2); + ASSERT_NE(*type, *type3); + ASSERT_NE(*type, *type4); + ASSERT_NE(*type, *type6); + + ASSERT_NE(*type2, *type); + ASSERT_NE(*type2, *type3); + ASSERT_NE(*type2, *type4); + ASSERT_NE(*type2, *type6); + + ASSERT_NE(*type3, *type); + ASSERT_NE(*type3, *type2); + ASSERT_NE(*type3, *type4); + ASSERT_NE(*type3, *type6); + + ASSERT_NE(*type4, *type); + ASSERT_NE(*type4, *type2); + ASSERT_NE(*type4, *type3); + ASSERT_NE(*type4, *type6); + ASSERT_NE(*type6, *type4); +} + +TEST(OpaqueType, CreateFromArray) { + auto type = internal::checked_pointer_cast( + extension::opaque(binary(), "geometry", "adbc.postgresql")); + auto storage = ArrayFromJSON(binary(), R"(["foobar", null])"); + auto array = ExtensionType::WrapArray(type, storage); + ASSERT_EQ(2, array->length()); + ASSERT_EQ(1, array->null_count()); +} + +void CheckDeserialize(const std::string& serialized, + const std::shared_ptr& expected) { + auto type = internal::checked_pointer_cast(expected); + ASSERT_OK_AND_ASSIGN(auto deserialized, + type->Deserialize(type->storage_type(), serialized)); + ASSERT_EQ(*expected, *deserialized); +} + +TEST(OpaqueType, Deserialize) { + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "type", "vendor_name": "vendor"})", + extension::opaque(null(), "type", "vendor"))); + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "long name", "vendor_name": "long name"})", + extension::opaque(null(), "long name", "long name"))); + ASSERT_NO_FATAL_FAILURE( + CheckDeserialize(R"({"type_name": "名前", "vendor_name": "名字"})", + extension::opaque(null(), "名前", "名字"))); + ASSERT_NO_FATAL_FAILURE(CheckDeserialize( + R"({"type_name": "type", "vendor_name": "vendor", "extra_field": 2})", + extension::opaque(null(), "type", "vendor"))); + + auto type = internal::checked_pointer_cast( + extension::opaque(null(), "type", "vendor")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("The document is empty"), + type->Deserialize(null(), R"()")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + testing::HasSubstr("Missing a name for object member"), + type->Deserialize(null(), R"({)")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("not an object"), + type->Deserialize(null(), R"([])")); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, testing::HasSubstr("missing type_name"), + type->Deserialize(null(), R"({})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("type_name is not a string"), + type->Deserialize(null(), R"({"type_name": 2, "vendor_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("type_name is not a string"), + type->Deserialize(null(), R"({"type_name": null, "vendor_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("vendor_name is not a string"), + type->Deserialize(null(), R"({"vendor_name": 2, "type_name": ""})")); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("vendor_name is not a string"), + type->Deserialize(null(), R"({"vendor_name": null, "type_name": ""})")); +} + +TEST(OpaqueType, MetadataRoundTrip) { + for (const auto& type : { + extension::opaque(null(), "foo", "bar"), + extension::opaque(binary(), "geometry", "postgis"), + extension::opaque(fixed_size_list(int64(), 4), "foo", "bar"), + extension::opaque(utf8(), "foo", "bar"), + }) { + auto opaque = internal::checked_pointer_cast(type); + std::string serialized = opaque->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + opaque->Deserialize(opaque->storage_type(), serialized)); + ASSERT_EQ(*type, *deserialized); + } +} + +TEST(OpaqueType, BatchRoundTrip) { + auto type = internal::checked_pointer_cast( + extension::opaque(binary(), "geometry", "adbc.postgresql")); + ExtensionTypeGuard guard(type); + + auto storage = ArrayFromJSON(binary(), R"(["foobar", null])"); + auto array = ExtensionType::WrapArray(type, storage); + auto batch = + RecordBatch::Make(schema({field("field", type)}), array->length(), {array}); + + std::shared_ptr written; + { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(&written)); + } + + ASSERT_EQ(*batch->schema(), *written->schema()); + ASSERT_BATCHES_EQUAL(*batch, *written); +} + +} // namespace arrow diff --git a/docs/source/python/api/arrays.rst b/docs/source/python/api/arrays.rst index aefed00b3d2..4ad35b190cd 100644 --- a/docs/source/python/api/arrays.rst +++ b/docs/source/python/api/arrays.rst @@ -85,6 +85,7 @@ may expose data type-specific methods or properties. UnionArray ExtensionArray FixedShapeTensorArray + OpaqueArray .. _api.scalar: @@ -143,3 +144,5 @@ classes may expose data type-specific methods or properties. StructScalar UnionScalar ExtensionScalar + FixedShapeTensorScalar + OpaqueScalar diff --git a/docs/source/python/api/datatypes.rst b/docs/source/python/api/datatypes.rst index 7edb4e16154..a43c5299eae 100644 --- a/docs/source/python/api/datatypes.rst +++ b/docs/source/python/api/datatypes.rst @@ -67,6 +67,8 @@ These should be used to create Arrow data types and schemas. struct dictionary run_end_encoded + fixed_shape_tensor + opaque field schema from_numpy_dtype @@ -117,6 +119,14 @@ Specific classes and functions for extension types. register_extension_type unregister_extension_type +:doc:`Canonical extension types <../../format/CanonicalExtensions>` +implemented by PyArrow. + +.. autosummary:: + :toctree: ../generated/ + + FixedShapeTensorType + OpaqueType .. _api.types.checking: .. currentmodule:: pyarrow.types diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index e52e0d242be..aa7bab9f97e 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -173,6 +173,7 @@ def print_entry(label, value): dictionary, run_end_encoded, fixed_shape_tensor, + opaque, field, type_for_alias, DataType, DictionaryType, StructType, @@ -182,7 +183,7 @@ def print_entry(label, value): TimestampType, Time32Type, Time64Type, DurationType, FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, - RunEndEncodedType, FixedShapeTensorType, + RunEndEncodedType, FixedShapeTensorType, OpaqueType, PyExtensionType, UnknownExtensionType, register_extension_type, unregister_extension_type, DictionaryMemo, @@ -216,7 +217,7 @@ def print_entry(label, value): Time32Array, Time64Array, DurationArray, MonthDayNanoIntervalArray, Decimal128Array, Decimal256Array, StructArray, ExtensionArray, - RunEndEncodedArray, FixedShapeTensorArray, + RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray, scalar, NA, _NULL as NULL, Scalar, NullScalar, BooleanScalar, Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, @@ -233,7 +234,8 @@ def print_entry(label, value): StringScalar, LargeStringScalar, StringViewScalar, FixedSizeBinaryScalar, DictionaryScalar, MapScalar, StructScalar, UnionScalar, - RunEndEncodedScalar, ExtensionScalar) + RunEndEncodedScalar, ExtensionScalar, + FixedShapeTensorScalar, OpaqueScalar) # Buffers, allocation from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 997f208a5de..6c40a21db96 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -4448,6 +4448,34 @@ cdef class FixedShapeTensorArray(ExtensionArray): ) +cdef class OpaqueArray(ExtensionArray): + """ + Concrete class for opaque extension arrays. + + Examples + -------- + Define the extension type for an opaque array + + >>> import pyarrow as pa + >>> opaque_type = pa.opaque( + ... pa.binary(), + ... type_name="geometry", + ... vendor_name="postgis", + ... ) + + Create an extension array + + >>> arr = [None, b"data"] + >>> storage = pa.array(arr, pa.binary()) + >>> pa.ExtensionArray.from_storage(opaque_type, storage) + + [ + null, + 64617461 + ] + """ + + cdef dict _array_classes = { _Type_NA: NullArray, _Type_BOOL: BooleanArray, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 0d871f411b1..9b008d150f1 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2882,6 +2882,19 @@ cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extens " arrow::extension::FixedShapeTensorArray"(CExtensionArray): const CResult[shared_ptr[CTensor]] ToTensor() const + +cdef extern from "arrow/extension/opaque.h" namespace "arrow::extension" nogil: + cdef cppclass COpaqueType \ + " arrow::extension::OpaqueType"(CExtensionType): + + c_string type_name() + c_string vendor_name() + + cdef cppclass COpaqueArray \ + " arrow::extension::OpaqueArray"(CExtensionArray): + pass + + cdef extern from "arrow/util/compression.h" namespace "arrow" nogil: cdef enum CCompressionType" arrow::Compression::type": CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED" diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 082d8470cdb..2cb302d20a8 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -215,6 +215,11 @@ cdef class FixedShapeTensorType(BaseExtensionType): const CFixedShapeTensorType* tensor_ext_type +cdef class OpaqueType(BaseExtensionType): + cdef: + const COpaqueType* opaque_ext_type + + cdef class PyExtensionType(ExtensionType): pass diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 966273b4bea..2f9fc1c5542 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -124,6 +124,8 @@ cdef api object pyarrow_wrap_data_type( return cpy_ext_type.GetInstance() elif ext_type.extension_name() == b"arrow.fixed_shape_tensor": out = FixedShapeTensorType.__new__(FixedShapeTensorType) + elif ext_type.extension_name() == b"arrow.opaque": + out = OpaqueType.__new__(OpaqueType) else: out = BaseExtensionType.__new__(BaseExtensionType) else: diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 41bfde39adb..12a99c2aece 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -1085,6 +1085,12 @@ cdef class FixedShapeTensorScalar(ExtensionScalar): return pyarrow_wrap_tensor(ctensor) +cdef class OpaqueScalar(ExtensionScalar): + """ + Concrete class for opaque extension scalar. + """ + + cdef dict _scalar_classes = { _Type_BOOL: BooleanScalar, _Type_UINT8: UInt8Scalar, diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 1c4d0175a2d..58c54189f22 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1661,3 +1661,49 @@ def test_legacy_int_type(): batch = ipc_read_batch(buf) assert isinstance(batch.column(0).type, LegacyIntType) assert batch.column(0) == ext_arr + + +@pytest.mark.parametrize("storage_type,storage", [ + (pa.null(), [None] * 4), + (pa.int64(), [1, 2, None, 4]), + (pa.binary(), [None, b"foobar"]), + (pa.list_(pa.int64()), [[], [1, 2], None, [3, None]]), +]) +def test_opaque_type(pickle_module, storage_type, storage): + opaque_type = pa.opaque(storage_type, "type", "vendor") + assert opaque_type.extension_name == "arrow.opaque" + assert opaque_type.storage_type == storage_type + assert opaque_type.type_name == "type" + assert opaque_type.vendor_name == "vendor" + assert "arrow.opaque" in str(opaque_type) + + assert opaque_type == opaque_type + assert opaque_type != storage_type + assert opaque_type != pa.opaque(storage_type, "type2", "vendor") + assert opaque_type != pa.opaque(storage_type, "type", "vendor2") + assert opaque_type != pa.opaque(pa.decimal128(12, 3), "type", "vendor") + + # Pickle roundtrip + result = pickle_module.loads(pickle_module.dumps(opaque_type)) + assert result == opaque_type + + # IPC roundtrip + opaque_arr_class = opaque_type.__arrow_ext_class__() + storage = pa.array(storage, storage_type) + arr = pa.ExtensionArray.from_storage(opaque_type, storage) + assert isinstance(arr, opaque_arr_class) + + with registered_extension_type(opaque_type): + buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) + batch = ipc_read_batch(buf) + + assert batch.column(0).type.extension_name == "arrow.opaque" + assert isinstance(batch.column(0), opaque_arr_class) + + # cast storage -> extension type + result = storage.cast(opaque_type) + assert result == arr + + # cast extension type -> storage type + inner = arr.cast(storage_type) + assert inner == storage diff --git a/python/pyarrow/tests/test_misc.py b/python/pyarrow/tests/test_misc.py index c42e4fbdfc2..9a55a38177f 100644 --- a/python/pyarrow/tests/test_misc.py +++ b/python/pyarrow/tests/test_misc.py @@ -247,6 +247,9 @@ def test_set_timezone_db_path_non_windows(): pa.ProxyMemoryPool, pa.Device, pa.MemoryManager, + pa.OpaqueArray, + pa.OpaqueScalar, + pa.OpaqueType, ]) def test_extension_type_constructor_errors(klass): # ARROW-2638: prevent calling extension class constructors directly diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 039870accdd..93d68fb8478 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1837,6 +1837,50 @@ cdef class FixedShapeTensorType(BaseExtensionType): return FixedShapeTensorScalar +cdef class OpaqueType(BaseExtensionType): + """ + Concrete class for opaque extension type. + + Opaque is a placeholder for a type from an external (often non-Arrow) + system that could not be interpreted. + + Examples + -------- + Create an instance of opaque extension type: + + >>> import pyarrow as pa + >>> pa.opaque(pa.int32(), "geometry", "postgis") + OpaqueType(extension) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.opaque_ext_type = type.get() + + @property + def type_name(self): + """ + The name of the type in the external system. + """ + return frombytes(c_string(self.opaque_ext_type.type_name())) + + @property + def vendor_name(self): + """ + The name of the external system. + """ + return frombytes(c_string(self.opaque_ext_type.vendor_name())) + + def __arrow_ext_class__(self): + return OpaqueArray + + def __reduce__(self): + return opaque, (self.storage_type, self.type_name, self.vendor_name) + + def __arrow_ext_scalar_class__(self): + return OpaqueScalar + + _py_extension_type_auto_load = False @@ -5234,6 +5278,63 @@ def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=N return out +def opaque(DataType storage_type, str type_name not None, str vendor_name not None): + """ + Create instance of opaque extension type. + + Parameters + ---------- + storage_type : DataType + The underlying data type. + type_name : str + The name of the type in the external system. + vendor_name : str + The name of the external system. + + Examples + -------- + Create an instance of an opaque extension type: + + >>> import pyarrow as pa + >>> type = pa.opaque(pa.binary(), "other", "jdbc") + >>> type + OpaqueType(extension) + + Inspect the data type: + + >>> type.storage_type + DataType(binary) + >>> type.type_name + 'other' + >>> type.vendor_name + 'jdbc' + + Create a table with an opaque array: + + >>> arr = [None, b"foobar"] + >>> storage = pa.array(arr, pa.binary()) + >>> other = pa.ExtensionArray.from_storage(type, storage) + >>> pa.table([other], names=["unknown_col"]) + pyarrow.Table + unknown_col: extension + ---- + unknown_col: [[null,666F6F626172]] + + Returns + ------- + type : OpaqueType + """ + + cdef: + c_string c_type_name = tobytes(type_name) + c_string c_vendor_name = tobytes(vendor_name) + shared_ptr[CDataType] c_type = make_shared[COpaqueType]( + storage_type.sp_type, c_type_name, c_vendor_name) + OpaqueType out = OpaqueType.__new__(OpaqueType) + out.init(c_type) + return out + + cdef dict _type_aliases = { 'null': null, 'bool': bool_, From ce251a6721cfcd27ed76bbaa5cb1c824a5f23a94 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 14 Aug 2024 08:58:15 +0530 Subject: [PATCH 010/157] GH-41291: [Java] LargeListViewVector Implementation transferPair implementation (#43637) ### Rationale for this change Integrating the `transferPair` and `copyFrom` functionality to `LargeListViewVector` - [X] https://github.com/apache/arrow/issues/41292 ### What changes are included in this PR? This PR includes the `TransferPairImpl`, corresponding functions and test cases. ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #41291 Authored-by: Vibhatha Abeykoon Signed-off-by: David Li --- .../BaseLargeRepeatedValueViewVector.java | 2 +- .../vector/complex/LargeListViewVector.java | 163 ++++++- .../arrow/vector/TestLargeListViewVector.java | 456 ++++++++++++++++++ .../arrow/vector/TestSplitAndTransfer.java | 20 + 4 files changed, 634 insertions(+), 7 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java index 26079cbee95..f643306cfdc 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java @@ -102,7 +102,7 @@ private void allocateBuffers() { sizeBuffer = allocateBuffers(sizeAllocationSizeInBytes); } - private ArrowBuf allocateBuffers(final long size) { + protected ArrowBuf allocateBuffers(final long size) { final int curSize = (int) size; ArrowBuf buffer = allocator.buffer(curSize); buffer.readerIndex(0); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index f6b3de88b77..2c61f799a4c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -39,6 +39,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueIterableVector; import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.compare.VectorVisitor; import org.apache.arrow.vector.complex.impl.UnionLargeListViewReader; import org.apache.arrow.vector.complex.impl.UnionLargeListViewWriter; @@ -361,20 +362,17 @@ public TransferPair getTransferPair(Field field, BufferAllocator allocator) { @Override public TransferPair getTransferPair(String ref, BufferAllocator allocator, CallBack callBack) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support getTransferPair(String, BufferAllocator, CallBack) yet"); + return new TransferImpl(ref, allocator, callBack); } @Override public TransferPair getTransferPair(Field field, BufferAllocator allocator, CallBack callBack) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support getTransferPair(Field, BufferAllocator, CallBack) yet"); + return new TransferImpl(field, allocator, callBack); } @Override public TransferPair makeTransferPair(ValueVector target) { - throw new UnsupportedOperationException( - "LargeListViewVector does not support makeTransferPair(ValueVector) yet"); + return new TransferImpl((LargeListViewVector) target); } @Override @@ -452,6 +450,159 @@ public OUT accept(VectorVisitor visitor, IN value) { return visitor.visit(this, value); } + private class TransferImpl implements TransferPair { + + LargeListViewVector to; + TransferPair dataTransferPair; + + public TransferImpl(String name, BufferAllocator allocator, CallBack callBack) { + this(new LargeListViewVector(name, allocator, field.getFieldType(), callBack)); + } + + public TransferImpl(Field field, BufferAllocator allocator, CallBack callBack) { + this(new LargeListViewVector(field, allocator, callBack)); + } + + public TransferImpl(LargeListViewVector to) { + this.to = to; + to.addOrGetVector(vector.getField().getFieldType()); + if (to.getDataVector() instanceof ZeroVector) { + to.addOrGetVector(vector.getField().getFieldType()); + } + dataTransferPair = getDataVector().makeTransferPair(to.getDataVector()); + } + + @Override + public void transfer() { + to.clear(); + dataTransferPair.transfer(); + to.validityBuffer = transferBuffer(validityBuffer, to.allocator); + to.offsetBuffer = transferBuffer(offsetBuffer, to.allocator); + to.sizeBuffer = transferBuffer(sizeBuffer, to.allocator); + if (valueCount > 0) { + to.setValueCount(valueCount); + } + clear(); + } + + @Override + public void splitAndTransfer(int startIndex, int length) { + Preconditions.checkArgument( + startIndex >= 0 && length >= 0 && startIndex + length <= valueCount, + "Invalid parameters startIndex: %s, length: %s for valueCount: %s", + startIndex, + length, + valueCount); + to.clear(); + if (length > 0) { + // we have to scan by index since there are out-of-order offsets + to.offsetBuffer = to.allocateBuffers((long) length * OFFSET_WIDTH); + to.sizeBuffer = to.allocateBuffers((long) length * SIZE_WIDTH); + + /* splitAndTransfer the size buffer */ + int maxOffsetAndSizeSum = Integer.MIN_VALUE; + int minOffsetValue = Integer.MAX_VALUE; + for (int i = 0; i < length; i++) { + final int offsetValue = offsetBuffer.getInt((long) (startIndex + i) * OFFSET_WIDTH); + final int sizeValue = sizeBuffer.getInt((long) (startIndex + i) * SIZE_WIDTH); + to.sizeBuffer.setInt((long) i * SIZE_WIDTH, sizeValue); + maxOffsetAndSizeSum = Math.max(maxOffsetAndSizeSum, offsetValue + sizeValue); + minOffsetValue = Math.min(minOffsetValue, offsetValue); + } + + /* splitAndTransfer the offset buffer */ + for (int i = 0; i < length; i++) { + final int offsetValue = offsetBuffer.getInt((long) (startIndex + i) * OFFSET_WIDTH); + final int relativeOffset = offsetValue - minOffsetValue; + to.offsetBuffer.setInt((long) i * OFFSET_WIDTH, relativeOffset); + } + + /* splitAndTransfer the validity buffer */ + splitAndTransferValidityBuffer(startIndex, length, to); + + /* splitAndTransfer the data buffer */ + final int childSliceLength = maxOffsetAndSizeSum - minOffsetValue; + dataTransferPair.splitAndTransfer(minOffsetValue, childSliceLength); + to.setValueCount(length); + } + } + + /* + * transfer the validity. + */ + private void splitAndTransferValidityBuffer( + int startIndex, int length, LargeListViewVector target) { + int firstByteSource = BitVectorHelper.byteIndex(startIndex); + int lastByteSource = BitVectorHelper.byteIndex(valueCount - 1); + int byteSizeTarget = getValidityBufferSizeFromCount(length); + int offset = startIndex % 8; + + if (length > 0) { + if (offset == 0) { + // slice + if (target.validityBuffer != null) { + target.validityBuffer.getReferenceManager().release(); + } + target.validityBuffer = validityBuffer.slice(firstByteSource, byteSizeTarget); + target.validityBuffer.getReferenceManager().retain(1); + } else { + /* Copy data + * When the first bit starts from the middle of a byte (offset != 0), + * copy data from src BitVector. + * Each byte in the target is composed by a part in i-th byte, + * another part in (i+1)-th byte. + */ + target.allocateValidityBuffer(byteSizeTarget); + + for (int i = 0; i < byteSizeTarget - 1; i++) { + byte b1 = + BitVectorHelper.getBitsFromCurrentByte(validityBuffer, firstByteSource + i, offset); + byte b2 = + BitVectorHelper.getBitsFromNextByte( + validityBuffer, firstByteSource + i + 1, offset); + + target.validityBuffer.setByte(i, (b1 + b2)); + } + + /* Copying the last piece is done in the following manner: + * if the source vector has 1 or more bytes remaining, we copy + * the last piece as a byte formed by shifting data + * from the current byte and the next byte. + * + * if the source vector has no more bytes remaining + * (we are at the last byte), we copy the last piece as a byte + * by shifting data from the current byte. + */ + if ((firstByteSource + byteSizeTarget - 1) < lastByteSource) { + byte b1 = + BitVectorHelper.getBitsFromCurrentByte( + validityBuffer, firstByteSource + byteSizeTarget - 1, offset); + byte b2 = + BitVectorHelper.getBitsFromNextByte( + validityBuffer, firstByteSource + byteSizeTarget, offset); + + target.validityBuffer.setByte(byteSizeTarget - 1, b1 + b2); + } else { + byte b1 = + BitVectorHelper.getBitsFromCurrentByte( + validityBuffer, firstByteSource + byteSizeTarget - 1, offset); + target.validityBuffer.setByte(byteSizeTarget - 1, b1); + } + } + } + } + + @Override + public ValueVector getTo() { + return to; + } + + @Override + public void copyValueSafe(int from, int to) { + this.to.copyFrom(from, to, LargeListViewVector.this); + } + } + @Override protected FieldReader getReaderImpl() { throw new UnsupportedOperationException( diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java index 563ac811c4f..2ed8d4d7005 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java @@ -18,6 +18,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.ArrayList; @@ -32,6 +33,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.TransferPair; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -1639,6 +1641,460 @@ public void testOutOfOrderOffset1() { } } + private int validateSizeBufferAndCalculateMinOffset( + int start, + int splitLength, + ArrowBuf fromOffsetBuffer, + ArrowBuf fromSizeBuffer, + ArrowBuf toSizeBuffer) { + int minOffset = fromOffsetBuffer.getInt((long) start * LargeListViewVector.OFFSET_WIDTH); + int fromDataLength; + int toDataLength; + + for (int i = 0; i < splitLength; i++) { + fromDataLength = fromSizeBuffer.getInt((long) (start + i) * LargeListViewVector.SIZE_WIDTH); + toDataLength = toSizeBuffer.getInt((long) (i) * LargeListViewVector.SIZE_WIDTH); + + /* validate size */ + assertEquals( + fromDataLength, + toDataLength, + "Different data lengths at index: " + i + " and start: " + start); + + /* calculate minimum offset */ + int currentOffset = + fromOffsetBuffer.getInt((long) (start + i) * LargeListViewVector.OFFSET_WIDTH); + if (currentOffset < minOffset) { + minOffset = currentOffset; + } + } + + return minOffset; + } + + private void validateOffsetBuffer( + int start, + int splitLength, + ArrowBuf fromOffsetBuffer, + ArrowBuf toOffsetBuffer, + int minOffset) { + int offset1; + int offset2; + + for (int i = 0; i < splitLength; i++) { + offset1 = fromOffsetBuffer.getInt((long) (start + i) * LargeListViewVector.OFFSET_WIDTH); + offset2 = toOffsetBuffer.getInt((long) (i) * LargeListViewVector.OFFSET_WIDTH); + assertEquals( + offset1 - minOffset, + offset2, + "Different offset values at index: " + i + " and start: " + start); + } + } + + private void validateDataBuffer( + int start, + int splitLength, + ArrowBuf fromOffsetBuffer, + ArrowBuf fromSizeBuffer, + BigIntVector fromDataVector, + ArrowBuf toOffsetBuffer, + BigIntVector toDataVector) { + int dataLength; + Long fromValue; + for (int i = 0; i < splitLength; i++) { + dataLength = fromSizeBuffer.getInt((long) (start + i) * LargeListViewVector.SIZE_WIDTH); + for (int j = 0; j < dataLength; j++) { + fromValue = + fromDataVector.getObject( + (fromOffsetBuffer.getInt((long) (start + i) * LargeListViewVector.OFFSET_WIDTH) + + j)); + Long toValue = + toDataVector.getObject( + (toOffsetBuffer.getInt((long) i * LargeListViewVector.OFFSET_WIDTH) + j)); + assertEquals( + fromValue, toValue, "Different data values at index: " + i + " and start: " + start); + } + } + } + + /** + * Validate split and transfer of data from fromVector to toVector. Note that this method assumes + * that the child vector is BigIntVector. + * + * @param start start index + * @param splitLength length of data to split and transfer + * @param fromVector fromVector + * @param toVector toVector + */ + private void validateSplitAndTransfer( + TransferPair transferPair, + int start, + int splitLength, + LargeListViewVector fromVector, + LargeListViewVector toVector) { + + transferPair.splitAndTransfer(start, splitLength); + + /* get offsetBuffer of toVector */ + final ArrowBuf toOffsetBuffer = toVector.getOffsetBuffer(); + + /* get sizeBuffer of toVector */ + final ArrowBuf toSizeBuffer = toVector.getSizeBuffer(); + + /* get dataVector of toVector */ + BigIntVector toDataVector = (BigIntVector) toVector.getDataVector(); + + /* get offsetBuffer of toVector */ + final ArrowBuf fromOffsetBuffer = fromVector.getOffsetBuffer(); + + /* get sizeBuffer of toVector */ + final ArrowBuf fromSizeBuffer = fromVector.getSizeBuffer(); + + /* get dataVector of toVector */ + BigIntVector fromDataVector = (BigIntVector) fromVector.getDataVector(); + + /* validate size buffers */ + int minOffset = + validateSizeBufferAndCalculateMinOffset( + start, splitLength, fromOffsetBuffer, fromSizeBuffer, toSizeBuffer); + /* validate offset buffers */ + validateOffsetBuffer(start, splitLength, fromOffsetBuffer, toOffsetBuffer, minOffset); + /* validate data */ + validateDataBuffer( + start, + splitLength, + fromOffsetBuffer, + fromSizeBuffer, + fromDataVector, + toOffsetBuffer, + toDataVector); + } + + @Test + public void testSplitAndTransfer() throws Exception { + try (LargeListViewVector fromVector = LargeListViewVector.empty("sourceVector", allocator)) { + + /* Explicitly add the dataVector */ + MinorType type = MinorType.BIGINT; + fromVector.addOrGetVector(FieldType.nullable(type.getType())); + + UnionLargeListViewWriter listViewWriter = fromVector.getWriter(); + + /* allocate memory */ + listViewWriter.allocate(); + + /* populate data */ + listViewWriter.setPosition(0); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(10); + listViewWriter.bigInt().writeBigInt(11); + listViewWriter.bigInt().writeBigInt(12); + listViewWriter.endListView(); + + listViewWriter.setPosition(1); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(13); + listViewWriter.bigInt().writeBigInt(14); + listViewWriter.endListView(); + + listViewWriter.setPosition(2); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(15); + listViewWriter.bigInt().writeBigInt(16); + listViewWriter.bigInt().writeBigInt(17); + listViewWriter.bigInt().writeBigInt(18); + listViewWriter.endListView(); + + listViewWriter.setPosition(3); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(19); + listViewWriter.endListView(); + + listViewWriter.setPosition(4); + listViewWriter.startListView(); + listViewWriter.bigInt().writeBigInt(20); + listViewWriter.bigInt().writeBigInt(21); + listViewWriter.bigInt().writeBigInt(22); + listViewWriter.bigInt().writeBigInt(23); + listViewWriter.endListView(); + + fromVector.setValueCount(5); + + /* get offset buffer */ + final ArrowBuf offsetBuffer = fromVector.getOffsetBuffer(); + + /* get size buffer */ + final ArrowBuf sizeBuffer = fromVector.getSizeBuffer(); + + /* get dataVector */ + BigIntVector dataVector = (BigIntVector) fromVector.getDataVector(); + + /* check the vector output */ + + int index = 0; + int offset; + int size = 0; + Long actual; + + /* index 0 */ + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(0), Integer.toString(offset)); + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(10), actual); + offset++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(11), actual); + offset++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(12), actual); + assertEquals( + Integer.toString(3), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* index 1 */ + index++; + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(3), Integer.toString(offset)); + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(13), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(14), actual); + size++; + assertEquals( + Integer.toString(size), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* index 2 */ + size = 0; + index++; + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(5), Integer.toString(offset)); + size++; + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(15), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(16), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(17), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(18), actual); + assertEquals( + Integer.toString(size), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* index 3 */ + size = 0; + index++; + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(9), Integer.toString(offset)); + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(19), actual); + size++; + assertEquals( + Integer.toString(size), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* index 4 */ + size = 0; + index++; + assertFalse(fromVector.isNull(index)); + offset = offsetBuffer.getInt(index * LargeListViewVector.OFFSET_WIDTH); + assertEquals(Integer.toString(10), Integer.toString(offset)); + + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(20), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(21), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(22), actual); + offset++; + size++; + actual = dataVector.getObject(offset); + assertEquals(Long.valueOf(23), actual); + size++; + assertEquals( + Integer.toString(size), + Integer.toString(sizeBuffer.getInt(index * LargeListViewVector.SIZE_WIDTH))); + + /* do split and transfer */ + try (LargeListViewVector toVector = LargeListViewVector.empty("toVector", allocator)) { + int[][] transferLengths = {{0, 2}, {3, 1}, {4, 1}}; + TransferPair transferPair = fromVector.makeTransferPair(toVector); + + for (final int[] transferLength : transferLengths) { + int start = transferLength[0]; + int splitLength = transferLength[1]; + validateSplitAndTransfer(transferPair, start, splitLength, fromVector, toVector); + } + } + } + } + + @Test + public void testGetTransferPairWithField() throws Exception { + try (final LargeListViewVector fromVector = LargeListViewVector.empty("listview", allocator)) { + + UnionLargeListViewWriter writer = fromVector.getWriter(); + writer.allocate(); + + // set some values + writer.startListView(); + writer.integer().writeInt(1); + writer.integer().writeInt(2); + writer.endListView(); + fromVector.setValueCount(2); + + final TransferPair transferPair = + fromVector.getTransferPair(fromVector.getField(), allocator); + final LargeListViewVector toVector = (LargeListViewVector) transferPair.getTo(); + // Field inside a new vector created by reusing a field should be the same in memory as the + // original field. + assertSame(toVector.getField(), fromVector.getField()); + } + } + + @Test + public void testOutOfOrderOffsetSplitAndTransfer() { + // [[12, -7, 25], null, [0, -127, 127, 50], [], [50, 12]] + try (LargeListViewVector fromVector = LargeListViewVector.empty("fromVector", allocator)) { + // Allocate buffers in LargeListViewVector by calling `allocateNew` method. + fromVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(64, true), null, null); + Field field = new Field("child-vector", fieldType, null); + fromVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = fromVector.getDataVector(); + fieldVector.clear(); + + BigIntVector childVector = (BigIntVector) fieldVector; + + childVector.allocateNew(7); + + childVector.set(0, 0); + childVector.set(1, -127); + childVector.set(2, 127); + childVector.set(3, 50); + childVector.set(4, 12); + childVector.set(5, -7); + childVector.set(6, 25); + + childVector.setValueCount(7); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + fromVector.setValidity(0, 1); + fromVector.setValidity(1, 0); + fromVector.setValidity(2, 1); + fromVector.setValidity(3, 1); + fromVector.setValidity(4, 1); + + fromVector.setOffset(0, 4); + fromVector.setOffset(1, 7); + fromVector.setOffset(2, 0); + fromVector.setOffset(3, 0); + fromVector.setOffset(4, 3); + + fromVector.setSize(0, 3); + fromVector.setSize(1, 0); + fromVector.setSize(2, 4); + fromVector.setSize(3, 0); + fromVector.setSize(4, 2); + + // Set value count using `setValueCount` method. + fromVector.setValueCount(5); + + final ArrowBuf offSetBuffer = fromVector.getOffsetBuffer(); + final ArrowBuf sizeBuffer = fromVector.getSizeBuffer(); + + // check offset buffer + assertEquals(4, offSetBuffer.getInt(0 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + assertEquals(7, offSetBuffer.getInt(1 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + assertEquals(0, offSetBuffer.getInt(2 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + assertEquals(0, offSetBuffer.getInt(3 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + assertEquals(3, offSetBuffer.getInt(4 * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + + // check size buffer + assertEquals(3, sizeBuffer.getInt(0 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + assertEquals(0, sizeBuffer.getInt(1 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + assertEquals(4, sizeBuffer.getInt(2 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + assertEquals(0, sizeBuffer.getInt(3 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + assertEquals(2, sizeBuffer.getInt(4 * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + + // check child vector + assertEquals(0, ((BigIntVector) fromVector.getDataVector()).get(0)); + assertEquals(-127, ((BigIntVector) fromVector.getDataVector()).get(1)); + assertEquals(127, ((BigIntVector) fromVector.getDataVector()).get(2)); + assertEquals(50, ((BigIntVector) fromVector.getDataVector()).get(3)); + assertEquals(12, ((BigIntVector) fromVector.getDataVector()).get(4)); + assertEquals(-7, ((BigIntVector) fromVector.getDataVector()).get(5)); + assertEquals(25, ((BigIntVector) fromVector.getDataVector()).get(6)); + + // check values + Object result = fromVector.getObject(0); + ArrayList resultSet = (ArrayList) result; + assertEquals(3, resultSet.size()); + assertEquals(Long.valueOf(12), resultSet.get(0)); + assertEquals(Long.valueOf(-7), resultSet.get(1)); + assertEquals(Long.valueOf(25), resultSet.get(2)); + + assertTrue(fromVector.isNull(1)); + + result = fromVector.getObject(2); + resultSet = (ArrayList) result; + assertEquals(4, resultSet.size()); + assertEquals(Long.valueOf(0), resultSet.get(0)); + assertEquals(Long.valueOf(-127), resultSet.get(1)); + assertEquals(Long.valueOf(127), resultSet.get(2)); + assertEquals(Long.valueOf(50), resultSet.get(3)); + + assertTrue(fromVector.isEmpty(3)); + + result = fromVector.getObject(4); + resultSet = (ArrayList) result; + assertEquals(2, resultSet.size()); + assertEquals(Long.valueOf(50), resultSet.get(0)); + assertEquals(Long.valueOf(12), resultSet.get(1)); + + fromVector.validate(); + + /* do split and transfer */ + try (LargeListViewVector toVector = LargeListViewVector.empty("toVector", allocator)) { + int[][] transferLengths = {{2, 3}, {0, 1}, {0, 3}}; + TransferPair transferPair = fromVector.makeTransferPair(toVector); + + for (final int[] transferLength : transferLengths) { + int start = transferLength[0]; + int splitLength = transferLength[1]; + validateSplitAndTransfer(transferPair, start, splitLength, fromVector, toVector); + } + } + } + } + private void writeIntValues(UnionLargeListViewWriter writer, int[] values) { writer.startListView(); for (int v : values) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java b/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java index d20dc3348b1..a3f25bc5207 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestSplitAndTransfer.java @@ -29,6 +29,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; @@ -852,6 +853,25 @@ public void testListVectorZeroStartIndexAndLength() { } } + @Test + public void testLargeListViewVectorZeroStartIndexAndLength() { + try (final LargeListViewVector listVector = + LargeListViewVector.empty("largelistview", allocator); + final LargeListViewVector newListVector = LargeListViewVector.empty("newList", allocator)) { + + listVector.allocateNew(); + final int valueCount = 0; + listVector.setValueCount(valueCount); + + final TransferPair tp = listVector.makeTransferPair(newListVector); + + tp.splitAndTransfer(0, 0); + assertEquals(valueCount, newListVector.getValueCount()); + + newListVector.clear(); + } + } + @Test public void testStructVectorZeroStartIndexAndLength() { Map metadata = new HashMap<>(); From 712cfe6d84bd344cfe57a1e4c791f8a4d052c76d Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 14 Aug 2024 09:09:19 +0530 Subject: [PATCH 011/157] GH-43643: [Java] LargeListViewVector IPC Integration (#43681) ### Rationale for this change Newly introduced `LargeListViewVector` requires the IPC integration for C Data integration tests while mainly supporting IPC format to include this type. ### What changes are included in this PR? Includes the `JsonFileWriter` and `JsonFileReader` along with the corresponding test cases. ### Are these changes tested? Yes, using existing tests but adding new configurations. ### Are there any user-facing changes? No * GitHub Issue: #43643 Authored-by: Vibhatha Abeykoon Signed-off-by: David Li --- .../apache/arrow/vector/ipc/JsonFileReader.java | 9 +++++++-- .../apache/arrow/vector/ipc/JsonFileWriter.java | 16 ++++++++++++++-- .../apache/arrow/vector/ipc/TestJSONFile.java | 8 ++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java index 626619a9483..5668325a87e 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java @@ -73,6 +73,7 @@ import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.ArrowType.LargeListView; import org.apache.arrow.vector.types.pojo.ArrowType.ListView; import org.apache.arrow.vector.types.pojo.ArrowType.Union; import org.apache.arrow.vector.types.pojo.Field; @@ -729,7 +730,8 @@ private List readIntoBuffer( } else if (bufferType.equals(OFFSET) || bufferType.equals(SIZE)) { if (type == MinorType.LARGELIST || type == MinorType.LARGEVARCHAR - || type == MinorType.LARGEVARBINARY) { + || type == MinorType.LARGEVARBINARY + || type == MinorType.LARGELISTVIEW) { reader = helper.INT8; } else { reader = helper.INT4; @@ -890,7 +892,10 @@ private void readFromJsonIntoVector(Field field, FieldVector vector) throws IOEx BufferType bufferType = vectorTypes.get(v); nextFieldIs(bufferType.getName()); int innerBufferValueCount = valueCount; - if (bufferType.equals(OFFSET) && !(type instanceof Union) && !(type instanceof ListView)) { + if (bufferType.equals(OFFSET) + && !(type instanceof Union) + && !(type instanceof ListView) + && !(type instanceof LargeListView)) { /* offset buffer has 1 additional value capacity except for dense unions and ListView */ innerBufferValueCount = valueCount + 1; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java index 929c8c97c05..68700fe6afd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java @@ -73,6 +73,7 @@ import org.apache.arrow.vector.UInt4Vector; import org.apache.arrow.vector.UInt8Vector; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.BaseLargeRepeatedValueViewVector; import org.apache.arrow.vector.complex.BaseRepeatedValueViewVector; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; @@ -232,7 +233,8 @@ private void writeFromVectorIntoJson(Field field, FieldVector vector) throws IOE final int bufferValueCount = (bufferType.equals(OFFSET) && vector.getMinorType() != MinorType.DENSEUNION - && vector.getMinorType() != MinorType.LISTVIEW) + && vector.getMinorType() != MinorType.LISTVIEW + && vector.getMinorType() != MinorType.LARGELISTVIEW) ? valueCount + 1 : valueCount; for (int i = 0; i < bufferValueCount; i++) { @@ -274,6 +276,7 @@ private void writeFromVectorIntoJson(Field field, FieldVector vector) throws IOE } else if (bufferType.equals(OFFSET) && vector.getValueCount() == 0 && (vector.getMinorType() == MinorType.LARGELIST + || vector.getMinorType() == MinorType.LARGELISTVIEW || vector.getMinorType() == MinorType.LARGEVARBINARY || vector.getMinorType() == MinorType.LARGEVARCHAR)) { // Empty vectors may not have allocated an offsets buffer @@ -427,6 +430,10 @@ private void writeValueToGenerator( generator.writeNumber( buffer.getInt((long) index * BaseRepeatedValueViewVector.OFFSET_WIDTH)); break; + case LARGELISTVIEW: + generator.writeNumber( + buffer.getInt((long) index * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH)); + break; case LARGELIST: case LARGEVARBINARY: case LARGEVARCHAR: @@ -582,7 +589,12 @@ private void writeValueToGenerator( throw new UnsupportedOperationException("minor type: " + vector.getMinorType()); } } else if (bufferType.equals(SIZE)) { - generator.writeNumber(buffer.getInt((long) index * BaseRepeatedValueViewVector.SIZE_WIDTH)); + if (vector.getMinorType() == MinorType.LISTVIEW) { + generator.writeNumber(buffer.getInt((long) index * BaseRepeatedValueViewVector.SIZE_WIDTH)); + } else { + generator.writeNumber( + buffer.getInt((long) index * BaseLargeRepeatedValueViewVector.SIZE_WIDTH)); + } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java index c69a3bfbc1e..8037212aaea 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java @@ -437,10 +437,18 @@ public void testRoundtripEmptyVector() throws Exception { "list", FieldType.nullable(ArrowType.List.INSTANCE), Collections.singletonList(Field.nullable("items", new ArrowType.Int(32, true)))), + new Field( + "listview", + FieldType.nullable(ArrowType.ListView.INSTANCE), + Collections.singletonList(Field.nullable("items", new ArrowType.Int(32, true)))), new Field( "largelist", FieldType.nullable(ArrowType.LargeList.INSTANCE), Collections.singletonList(Field.nullable("items", new ArrowType.Int(32, true)))), + new Field( + "largelistview", + FieldType.nullable(ArrowType.LargeListView.INSTANCE), + Collections.singletonList(Field.nullable("items", new ArrowType.Int(32, true)))), new Field( "map", FieldType.nullable(new ArrowType.Map(/*keyssorted*/ false)), From 7c8909a144f2e8d593dc8ad363ac95b2865b04ca Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 14 Aug 2024 14:27:07 +0200 Subject: [PATCH 012/157] MINOR: [Dev][C++] Allow ubuntu-cpp-thread-sanitizer Docker build with Ubuntu 24.04 (#43619) Install the clang-rt libraries that are necessary to link Thread Sanitizer-enabled binaries. Also fix use of deprecated `BufferReader` constructor in some tests, so that compilation with CLang 18 succeeds. Note that the C++ test suite still fails on Flight tests, as tracked in GH-36552. Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- ci/docker/ubuntu-24.04-cpp.dockerfile | 1 + cpp/src/arrow/dataset/dataset_writer_test.cc | 2 +- cpp/src/arrow/io/compressed_test.cc | 2 +- cpp/src/arrow/io/memory_test.cc | 6 +++--- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/ci/docker/ubuntu-24.04-cpp.dockerfile b/ci/docker/ubuntu-24.04-cpp.dockerfile index ecfb5e2f509..7d0772c33a2 100644 --- a/ci/docker/ubuntu-24.04-cpp.dockerfile +++ b/ci/docker/ubuntu-24.04-cpp.dockerfile @@ -57,6 +57,7 @@ RUN latest_system_llvm=18 && \ clang-${llvm} \ clang-format-${clang_tools} \ clang-tidy-${clang_tools} \ + libclang-rt-${llvm}-dev \ llvm-${llvm}-dev && \ apt-get clean && \ rm -rf /var/lib/apt/lists* diff --git a/cpp/src/arrow/dataset/dataset_writer_test.cc b/cpp/src/arrow/dataset/dataset_writer_test.cc index 871b6ef6f55..32ae8d7ee12 100644 --- a/cpp/src/arrow/dataset/dataset_writer_test.cc +++ b/cpp/src/arrow/dataset/dataset_writer_test.cc @@ -157,7 +157,7 @@ class DatasetWriterTestFixture : public testing::Test { std::shared_ptr ReadAsBatch(std::string_view data, int* num_batches) { std::shared_ptr in_stream = - std::make_shared(data); + std::make_shared(std::make_shared(data)); EXPECT_OK_AND_ASSIGN(std::shared_ptr reader, ipc::RecordBatchFileReader::Open(in_stream)); RecordBatchVector batches; diff --git a/cpp/src/arrow/io/compressed_test.cc b/cpp/src/arrow/io/compressed_test.cc index 12d116e3395..7724c65e9dd 100644 --- a/cpp/src/arrow/io/compressed_test.cc +++ b/cpp/src/arrow/io/compressed_test.cc @@ -262,7 +262,7 @@ TEST_P(CompressedOutputStreamTest, RandomData) { TEST(TestSnappyInputStream, NotImplemented) { std::unique_ptr codec; ASSERT_OK_AND_ASSIGN(codec, Codec::Create(Compression::SNAPPY)); - std::shared_ptr stream = std::make_shared(""); + std::shared_ptr stream = BufferReader::FromString(""); ASSERT_RAISES(NotImplemented, CompressedInputStream::Make(codec.get(), stream)); } diff --git a/cpp/src/arrow/io/memory_test.cc b/cpp/src/arrow/io/memory_test.cc index bd898f17181..58f51ffa8d0 100644 --- a/cpp/src/arrow/io/memory_test.cc +++ b/cpp/src/arrow/io/memory_test.cc @@ -404,7 +404,7 @@ template void TestSlowInputStream() { using clock = std::chrono::high_resolution_clock; - auto stream = std::make_shared(std::string_view("abcdefghijkl")); + std::shared_ptr stream = BufferReader::FromString("abcdefghijkl"); const double latency = 0.6; auto slow = std::make_shared(stream, latency); @@ -519,7 +519,7 @@ class TestTransformInputStream : public ::testing::Test { TransformInputStream::TransformFunc transform() const { return T(); } void TestEmptyStream() { - auto wrapped = std::make_shared(std::string_view()); + std::shared_ptr wrapped = BufferReader::FromString({}); auto stream = std::make_shared(wrapped, transform()); ASSERT_OK_AND_EQ(0, stream->Tell()); @@ -797,7 +797,7 @@ TEST(RangeReadCache, Basics) { TEST(RangeReadCache, Concurrency) { std::string data = "abcdefghijklmnopqrstuvwxyz"; - auto file = std::make_shared(Buffer(data)); + auto file = std::make_shared(std::make_shared(data)); std::vector ranges{{1, 2}, {3, 2}, {8, 2}, {20, 2}, {25, 0}, {10, 4}, {14, 0}, {15, 4}}; From ab432b1362208696e60824b45a5599a4e91e6301 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Wed, 14 Aug 2024 07:50:04 -0700 Subject: [PATCH 013/157] GH-43627: [R] Fix summarize() performance regression (pushdown) (#43649) ### Rationale for this change See https://github.com/apache/arrow/issues/43627#issuecomment-2284259559 ### What changes are included in this PR? An extra `dplyr::select()` ### Are these changes tested? Conbench should show that the performance is much better ### Are there any user-facing changes? Not slow * GitHub Issue: #43627 --- r/R/dplyr-summarize.R | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index f4fda0f13aa..a9ad750de7c 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -43,6 +43,15 @@ do_arrow_summarize <- function(.data, ..., .groups = NULL) { hash = length(.data$group_by_vars) > 0 ) + # Do a projection here to keep only the columns we need in summarize(). + # If possible, this will push down the column selection into the SourceNode, + # saving lots of wasted processing for columns we don't need. (GH-43627) + vars_to_keep <- unique(c( + unlist(lapply(exprs, all.vars)), # vars referenced in summarize + dplyr::group_vars(.data) # vars needed for grouping + )) + .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data))) + # nolint start # summarize() is complicated because you can do a mixture of scalar operations # and aggregations, but that's not how Acero works. For example, for us to do From f518d6beb0c70f00688d08a3e70deff0d3c24c86 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 15 Aug 2024 10:41:08 +0200 Subject: [PATCH 014/157] GH-38041: [C++][CI] Improve IPC fuzzing seed corpus (#43621) 1. Add fuzz seeds with newer datatypes such as Run-End Encoded and String Views 2. Add fuzz seeds with buffer compression 3. Build seed corpus generation utilities even when fuzzing isn't enabled, for convenience * GitHub Issue: #38041 Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/ipc/CMakeLists.txt | 7 ++- cpp/src/arrow/ipc/generate_fuzz_corpus.cc | 44 +++++++++++++------ .../arrow/ipc/generate_tensor_fuzz_corpus.cc | 2 +- 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/cpp/src/arrow/ipc/CMakeLists.txt b/cpp/src/arrow/ipc/CMakeLists.txt index 2fc9b145ccc..9e0b1d723b9 100644 --- a/cpp/src/arrow/ipc/CMakeLists.txt +++ b/cpp/src/arrow/ipc/CMakeLists.txt @@ -71,7 +71,12 @@ endif() add_arrow_benchmark(read_write_benchmark PREFIX "arrow-ipc") -if(ARROW_FUZZING) +if(ARROW_FUZZING + OR (ARROW_BUILD_UTILITIES + AND ARROW_TESTING + AND ARROW_WITH_LZ4 + AND ARROW_WITH_ZSTD + )) add_executable(arrow-ipc-generate-fuzz-corpus generate_fuzz_corpus.cc) target_link_libraries(arrow-ipc-generate-fuzz-corpus ${ARROW_UTIL_LIB} ${ARROW_TEST_LINK_LIBS}) diff --git a/cpp/src/arrow/ipc/generate_fuzz_corpus.cc b/cpp/src/arrow/ipc/generate_fuzz_corpus.cc index 682c352132a..6ccf1155d12 100644 --- a/cpp/src/arrow/ipc/generate_fuzz_corpus.cc +++ b/cpp/src/arrow/ipc/generate_fuzz_corpus.cc @@ -33,11 +33,11 @@ #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/testing/extension_type.h" +#include "arrow/util/compression.h" #include "arrow/util/io_util.h" #include "arrow/util/key_value_metadata.h" -namespace arrow { -namespace ipc { +namespace arrow::ipc { using ::arrow::internal::CreateDir; using ::arrow::internal::PlatformFilename; @@ -88,6 +88,13 @@ Result>> Batches() { batches.push_back(batch); RETURN_NOT_OK(test::MakeFixedSizeListRecordBatch(&batch)); batches.push_back(batch); + RETURN_NOT_OK(test::MakeStringTypesRecordBatch(&batch)); + batches.push_back(batch); + RETURN_NOT_OK(test::MakeUuid(&batch)); + batches.push_back(batch); + RETURN_NOT_OK(test::MakeRunEndEncoded(&batch)); + batches.push_back(batch); + ARROW_ASSIGN_OR_RAISE(batch, MakeExtensionBatch()); batches.push_back(batch); ARROW_ASSIGN_OR_RAISE(batch, MakeMapBatch()); @@ -97,13 +104,14 @@ Result>> Batches() { } Result> SerializeRecordBatch( - const std::shared_ptr& batch, bool is_stream_format) { + const std::shared_ptr& batch, const IpcWriteOptions& options, + bool is_stream_format) { ARROW_ASSIGN_OR_RAISE(auto sink, io::BufferOutputStream::Create(1024)); std::shared_ptr writer; if (is_stream_format) { - ARROW_ASSIGN_OR_RAISE(writer, MakeStreamWriter(sink, batch->schema())); + ARROW_ASSIGN_OR_RAISE(writer, MakeStreamWriter(sink, batch->schema(), options)); } else { - ARROW_ASSIGN_OR_RAISE(writer, MakeFileWriter(sink, batch->schema())); + ARROW_ASSIGN_OR_RAISE(writer, MakeFileWriter(sink, batch->schema(), options)); } RETURN_NOT_OK(writer->WriteRecordBatch(*batch)); RETURN_NOT_OK(writer->Close()); @@ -119,16 +127,27 @@ Status DoMain(bool is_stream_format, const std::string& out_dir) { return "batch-" + std::to_string(sample_num++); }; + // codec 0 is uncompressed + std::vector> codecs(3, nullptr); + ARROW_ASSIGN_OR_RAISE(codecs[1], util::Codec::Create(Compression::LZ4_FRAME)); + ARROW_ASSIGN_OR_RAISE(codecs[2], util::Codec::Create(Compression::ZSTD)); + ARROW_ASSIGN_OR_RAISE(auto batches, Batches()); + // Emit a separate file for each (batch, codec) pair for (const auto& batch : batches) { RETURN_NOT_OK(batch->ValidateFull()); - ARROW_ASSIGN_OR_RAISE(auto buf, SerializeRecordBatch(batch, is_stream_format)); - ARROW_ASSIGN_OR_RAISE(auto sample_fn, dir_fn.Join(sample_name())); - std::cerr << sample_fn.ToString() << std::endl; - ARROW_ASSIGN_OR_RAISE(auto file, io::FileOutputStream::Open(sample_fn.ToString())); - RETURN_NOT_OK(file->Write(buf)); - RETURN_NOT_OK(file->Close()); + for (const auto& codec : codecs) { + IpcWriteOptions options = IpcWriteOptions::Defaults(); + options.codec = codec; + ARROW_ASSIGN_OR_RAISE(auto buf, + SerializeRecordBatch(batch, options, is_stream_format)); + ARROW_ASSIGN_OR_RAISE(auto sample_fn, dir_fn.Join(sample_name())); + std::cerr << sample_fn.ToString() << std::endl; + ARROW_ASSIGN_OR_RAISE(auto file, io::FileOutputStream::Open(sample_fn.ToString())); + RETURN_NOT_OK(file->Write(buf)); + RETURN_NOT_OK(file->Close()); + } } return Status::OK(); } @@ -157,7 +176,6 @@ int Main(int argc, char** argv) { return 0; } -} // namespace ipc -} // namespace arrow +} // namespace arrow::ipc int main(int argc, char** argv) { return arrow::ipc::Main(argc, argv); } diff --git a/cpp/src/arrow/ipc/generate_tensor_fuzz_corpus.cc b/cpp/src/arrow/ipc/generate_tensor_fuzz_corpus.cc index dd40ef0ab2f..870f4586708 100644 --- a/cpp/src/arrow/ipc/generate_tensor_fuzz_corpus.cc +++ b/cpp/src/arrow/ipc/generate_tensor_fuzz_corpus.cc @@ -41,7 +41,7 @@ using ::arrow::internal::PlatformFilename; Result PrepareDirectory(const std::string& dir) { ARROW_ASSIGN_OR_RAISE(auto dir_fn, PlatformFilename::FromString(dir)); RETURN_NOT_OK(::arrow::internal::CreateDir(dir_fn)); - return std::move(dir_fn); + return dir_fn; } Result> MakeSerializedBuffer( From 894f72f735c7074a40908bbc4d04bc4d07cbc3ea Mon Sep 17 00:00:00 2001 From: Lysandros Nikolaou Date: Thu, 15 Aug 2024 14:24:35 +0200 Subject: [PATCH 015/157] GH-43536: [Python] Do not use borrowed references APIs (#43540) ### Rationale for this change For better reference safety under Python free-threaded builds (i.e. with the GIL removed), we should be using `Py(List|Dict)_GetItemRef` that return strong references and are implemented in a thread-safe manner. ### What changes are included in this PR? - Vendor a copy of https://github.com/python/pythoncapi-compat - Port to strong reference APIs for lists and dicts ### Are these changes tested? I ran the tests with the free-threaded build before and after, and there's the same expected failures. * GitHub Issue: #43536 Lead-authored-by: Lysandros Nikolaou Co-authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- dev/release/rat_exclude_files.txt | 1 + .../pyarrow/src/arrow/python/CMakeLists.txt | 1 + .../pyarrow/src/arrow/python/deserialize.cc | 14 +- .../src/arrow/python/numpy_to_arrow.cc | 7 +- .../src/arrow/python/python_to_arrow.cc | 17 +- .../src/arrow/python/vendored/CMakeLists.txt | 18 + .../arrow/python/vendored/pythoncapi_compat.h | 1519 +++++++++++++++++ 7 files changed, 1566 insertions(+), 11 deletions(-) create mode 100644 python/pyarrow/src/arrow/python/vendored/CMakeLists.txt create mode 100644 python/pyarrow/src/arrow/python/vendored/pythoncapi_compat.h diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt index ef325090f2f..e149c179813 100644 --- a/dev/release/rat_exclude_files.txt +++ b/dev/release/rat_exclude_files.txt @@ -95,6 +95,7 @@ python/manylinux1/.dockerignore python/pyarrow/includes/__init__.pxd python/pyarrow/tests/__init__.py python/pyarrow/vendored/* +python/pyarrow/src/arrow/python/vendored/* python/requirements*.txt pax_global_header MANIFEST.in diff --git a/python/pyarrow/src/arrow/python/CMakeLists.txt b/python/pyarrow/src/arrow/python/CMakeLists.txt index ff355e46a4b..67508982eab 100644 --- a/python/pyarrow/src/arrow/python/CMakeLists.txt +++ b/python/pyarrow/src/arrow/python/CMakeLists.txt @@ -16,3 +16,4 @@ # under the License. arrow_install_all_headers("arrow/python") +add_subdirectory(vendored) diff --git a/python/pyarrow/src/arrow/python/deserialize.cc b/python/pyarrow/src/arrow/python/deserialize.cc index 961a1686e0a..ab300a182fa 100644 --- a/python/pyarrow/src/arrow/python/deserialize.cc +++ b/python/pyarrow/src/arrow/python/deserialize.cc @@ -46,6 +46,7 @@ #include "arrow/python/numpy_convert.h" #include "arrow/python/pyarrow.h" #include "arrow/python/serialize.h" +#include "arrow/python/vendored/pythoncapi_compat.h" namespace arrow { @@ -88,8 +89,13 @@ Status DeserializeDict(PyObject* context, const Array& array, int64_t start_idx, // The latter two steal references whereas PyDict_SetItem does not. So we need // to make sure the reference count is decremented by letting the OwnedRef // go out of scope at the end. - int ret = PyDict_SetItem(result.obj(), PyList_GET_ITEM(keys.obj(), i - start_idx), - PyList_GET_ITEM(vals.obj(), i - start_idx)); + PyObject* key = PyList_GetItemRef(keys.obj(), i - start_idx); + RETURN_IF_PYERROR(); + OwnedRef keyref(key); + PyObject* val = PyList_GetItemRef(vals.obj(), i - start_idx); + RETURN_IF_PYERROR(); + OwnedRef valref(val); + int ret = PyDict_SetItem(result.obj(), key, val); if (ret != 0) { return ConvertPyError(); } @@ -398,7 +404,9 @@ Status GetSerializedFromComponents(int num_tensors, auto GetBuffer = [&data](Py_ssize_t index, std::shared_ptr* out) { ARROW_CHECK_LE(index, PyList_Size(data)); - PyObject* py_buf = PyList_GET_ITEM(data, index); + PyObject* py_buf = PyList_GetItemRef(data, index); + RETURN_IF_PYERROR(); + OwnedRef py_buf_ref(py_buf); return unwrap_buffer(py_buf).Value(out); }; diff --git a/python/pyarrow/src/arrow/python/numpy_to_arrow.cc b/python/pyarrow/src/arrow/python/numpy_to_arrow.cc index 460b1d0ce3f..e78a301bce3 100644 --- a/python/pyarrow/src/arrow/python/numpy_to_arrow.cc +++ b/python/pyarrow/src/arrow/python/numpy_to_arrow.cc @@ -57,6 +57,7 @@ #include "arrow/python/numpy_internal.h" #include "arrow/python/python_to_arrow.h" #include "arrow/python/type_traits.h" +#include "arrow/python/vendored/pythoncapi_compat.h" namespace arrow { @@ -757,8 +758,10 @@ Status NumPyConverter::Visit(const StructType& type) { } for (auto field : type.fields()) { - PyObject* tup = - PyDict_GetItemString(PyDataType_FIELDS(dtype_), field->name().c_str()); + PyObject* tup; + PyDict_GetItemStringRef(PyDataType_FIELDS(dtype_), field->name().c_str(), &tup); + RETURN_IF_PYERROR(); + OwnedRef tupref(tup); if (tup == NULL) { return Status::Invalid("Missing field '", field->name(), "' in struct array"); } diff --git a/python/pyarrow/src/arrow/python/python_to_arrow.cc b/python/pyarrow/src/arrow/python/python_to_arrow.cc index a2a325fde8d..ce9e15c894c 100644 --- a/python/pyarrow/src/arrow/python/python_to_arrow.cc +++ b/python/pyarrow/src/arrow/python/python_to_arrow.cc @@ -54,6 +54,7 @@ #include "arrow/python/iterators.h" #include "arrow/python/numpy_convert.h" #include "arrow/python/type_traits.h" +#include "arrow/python/vendored/pythoncapi_compat.h" #include "arrow/visit_type_inline.h" namespace arrow { @@ -1107,11 +1108,13 @@ class PyStructConverter : public StructConverter Status AppendDict(PyObject* dict, PyObject* field_names) { // NOTE we're ignoring any extraneous dict items for (int i = 0; i < num_fields_; i++) { - PyObject* name = PyList_GET_ITEM(field_names, i); // borrowed - PyObject* value = PyDict_GetItem(dict, name); // borrowed - if (value == NULL) { - RETURN_IF_PYERROR(); - } + PyObject* name = PyList_GetItemRef(field_names, i); + RETURN_IF_PYERROR(); + OwnedRef nameref(name); + PyObject* value; + PyDict_GetItemRef(dict, name, &value); + RETURN_IF_PYERROR(); + OwnedRef valueref(value); RETURN_NOT_OK(this->children_[i]->Append(value ? value : Py_None)); } return Status::OK(); @@ -1141,7 +1144,9 @@ class PyStructConverter : public StructConverter ARROW_ASSIGN_OR_RAISE(auto pair, GetKeyValuePair(items, i)); // validate that the key and the field name are equal - PyObject* name = PyList_GET_ITEM(field_names, i); + PyObject* name = PyList_GetItemRef(field_names, i); + RETURN_IF_PYERROR(); + OwnedRef nameref(name); bool are_equal = PyObject_RichCompareBool(pair.first, name, Py_EQ); RETURN_IF_PYERROR(); diff --git a/python/pyarrow/src/arrow/python/vendored/CMakeLists.txt b/python/pyarrow/src/arrow/python/vendored/CMakeLists.txt new file mode 100644 index 00000000000..6190072c0d3 --- /dev/null +++ b/python/pyarrow/src/arrow/python/vendored/CMakeLists.txt @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +arrow_install_all_headers("arrow/python/vendored") diff --git a/python/pyarrow/src/arrow/python/vendored/pythoncapi_compat.h b/python/pyarrow/src/arrow/python/vendored/pythoncapi_compat.h new file mode 100644 index 00000000000..4baa7b34a93 --- /dev/null +++ b/python/pyarrow/src/arrow/python/vendored/pythoncapi_compat.h @@ -0,0 +1,1519 @@ +// Header file providing new C API functions to old Python versions. +// +// File distributed under the Zero Clause BSD (0BSD) license. +// Copyright Contributors to the pythoncapi_compat project. +// +// Homepage: +// https://github.com/python/pythoncapi_compat +// +// Latest version: +// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h +// +// Vendored from git revision: +// 39e2663e6acc0b68d5dd75bdaad0af33152552ae +// https://raw.githubusercontent.com/python/pythoncapi-compat/39e2663e6acc0b68d5dd75bdaad0af33152552ae/pythoncapi_compat.h +// +// SPDX-License-Identifier: 0BSD + +/* clang-format off */ + +#ifndef PYTHONCAPI_COMPAT +#define PYTHONCAPI_COMPAT + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +// Python 3.11.0b4 added PyFrame_Back() to Python.h +#if PY_VERSION_HEX < 0x030b00B4 && !defined(PYPY_VERSION) +# include "frameobject.h" // PyFrameObject, PyFrame_GetBack() +#endif + + +#ifndef _Py_CAST +# define _Py_CAST(type, expr) ((type)(expr)) +#endif + +// Static inline functions should use _Py_NULL rather than using directly NULL +// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, +// _Py_NULL is defined as nullptr. +#if (defined (__STDC_VERSION__) && __STDC_VERSION__ > 201710L) \ + || (defined(__cplusplus) && __cplusplus >= 201103) +# define _Py_NULL nullptr +#else +# define _Py_NULL NULL +#endif + +// Cast argument to PyObject* type. +#ifndef _PyObject_CAST +# define _PyObject_CAST(op) _Py_CAST(PyObject*, op) +#endif + + +// bpo-42262 added Py_NewRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_NewRef) +static inline PyObject* _Py_NewRef(PyObject *obj) +{ + Py_INCREF(obj); + return obj; +} +#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) +#endif + + +// bpo-42262 added Py_XNewRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 && !defined(Py_XNewRef) +static inline PyObject* _Py_XNewRef(PyObject *obj) +{ + Py_XINCREF(obj); + return obj; +} +#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj)) +#endif + + +// bpo-39573 added Py_SET_REFCNT() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_REFCNT) +static inline void _Py_SET_REFCNT(PyObject *ob, Py_ssize_t refcnt) +{ + ob->ob_refcnt = refcnt; +} +#define Py_SET_REFCNT(ob, refcnt) _Py_SET_REFCNT(_PyObject_CAST(ob), refcnt) +#endif + + +// Py_SETREF() and Py_XSETREF() were added to Python 3.5.2. +// It is excluded from the limited C API. +#if (PY_VERSION_HEX < 0x03050200 && !defined(Py_SETREF)) && !defined(Py_LIMITED_API) +#define Py_SETREF(dst, src) \ + do { \ + PyObject **_tmp_dst_ptr = _Py_CAST(PyObject**, &(dst)); \ + PyObject *_tmp_dst = (*_tmp_dst_ptr); \ + *_tmp_dst_ptr = _PyObject_CAST(src); \ + Py_DECREF(_tmp_dst); \ + } while (0) + +#define Py_XSETREF(dst, src) \ + do { \ + PyObject **_tmp_dst_ptr = _Py_CAST(PyObject**, &(dst)); \ + PyObject *_tmp_dst = (*_tmp_dst_ptr); \ + *_tmp_dst_ptr = _PyObject_CAST(src); \ + Py_XDECREF(_tmp_dst); \ + } while (0) +#endif + + +// bpo-43753 added Py_Is(), Py_IsNone(), Py_IsTrue() and Py_IsFalse() +// to Python 3.10.0b1. +#if PY_VERSION_HEX < 0x030A00B1 && !defined(Py_Is) +# define Py_Is(x, y) ((x) == (y)) +#endif +#if PY_VERSION_HEX < 0x030A00B1 && !defined(Py_IsNone) +# define Py_IsNone(x) Py_Is(x, Py_None) +#endif +#if (PY_VERSION_HEX < 0x030A00B1 || defined(PYPY_VERSION)) && !defined(Py_IsTrue) +# define Py_IsTrue(x) Py_Is(x, Py_True) +#endif +#if (PY_VERSION_HEX < 0x030A00B1 || defined(PYPY_VERSION)) && !defined(Py_IsFalse) +# define Py_IsFalse(x) Py_Is(x, Py_False) +#endif + + +// bpo-39573 added Py_SET_TYPE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_TYPE) +static inline void _Py_SET_TYPE(PyObject *ob, PyTypeObject *type) +{ + ob->ob_type = type; +} +#define Py_SET_TYPE(ob, type) _Py_SET_TYPE(_PyObject_CAST(ob), type) +#endif + + +// bpo-39573 added Py_SET_SIZE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_SET_SIZE) +static inline void _Py_SET_SIZE(PyVarObject *ob, Py_ssize_t size) +{ + ob->ob_size = size; +} +#define Py_SET_SIZE(ob, size) _Py_SET_SIZE((PyVarObject*)(ob), size) +#endif + + +// bpo-40421 added PyFrame_GetCode() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 || defined(PYPY_VERSION) +static inline PyCodeObject* PyFrame_GetCode(PyFrameObject *frame) +{ + assert(frame != _Py_NULL); + assert(frame->f_code != _Py_NULL); + return _Py_CAST(PyCodeObject*, Py_NewRef(frame->f_code)); +} +#endif + +static inline PyCodeObject* _PyFrame_GetCodeBorrow(PyFrameObject *frame) +{ + PyCodeObject *code = PyFrame_GetCode(frame); + Py_DECREF(code); + return code; +} + + +// bpo-40421 added PyFrame_GetBack() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject* PyFrame_GetBack(PyFrameObject *frame) +{ + assert(frame != _Py_NULL); + return _Py_CAST(PyFrameObject*, Py_XNewRef(frame->f_back)); +} +#endif + +#if !defined(PYPY_VERSION) +static inline PyFrameObject* _PyFrame_GetBackBorrow(PyFrameObject *frame) +{ + PyFrameObject *back = PyFrame_GetBack(frame); + Py_XDECREF(back); + return back; +} +#endif + + +// bpo-40421 added PyFrame_GetLocals() to Python 3.11.0a7 +#if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) +static inline PyObject* PyFrame_GetLocals(PyFrameObject *frame) +{ +#if PY_VERSION_HEX >= 0x030400B1 + if (PyFrame_FastToLocalsWithError(frame) < 0) { + return NULL; + } +#else + PyFrame_FastToLocals(frame); +#endif + return Py_NewRef(frame->f_locals); +} +#endif + + +// bpo-40421 added PyFrame_GetGlobals() to Python 3.11.0a7 +#if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) +static inline PyObject* PyFrame_GetGlobals(PyFrameObject *frame) +{ + return Py_NewRef(frame->f_globals); +} +#endif + + +// bpo-40421 added PyFrame_GetBuiltins() to Python 3.11.0a7 +#if PY_VERSION_HEX < 0x030B00A7 && !defined(PYPY_VERSION) +static inline PyObject* PyFrame_GetBuiltins(PyFrameObject *frame) +{ + return Py_NewRef(frame->f_builtins); +} +#endif + + +// bpo-40421 added PyFrame_GetLasti() to Python 3.11.0b1 +#if PY_VERSION_HEX < 0x030B00B1 && !defined(PYPY_VERSION) +static inline int PyFrame_GetLasti(PyFrameObject *frame) +{ +#if PY_VERSION_HEX >= 0x030A00A7 + // bpo-27129: Since Python 3.10.0a7, f_lasti is an instruction offset, + // not a bytes offset anymore. Python uses 16-bit "wordcode" (2 bytes) + // instructions. + if (frame->f_lasti < 0) { + return -1; + } + return frame->f_lasti * 2; +#else + return frame->f_lasti; +#endif +} +#endif + + +// gh-91248 added PyFrame_GetVar() to Python 3.12.0a2 +#if PY_VERSION_HEX < 0x030C00A2 && !defined(PYPY_VERSION) +static inline PyObject* PyFrame_GetVar(PyFrameObject *frame, PyObject *name) +{ + PyObject *locals, *value; + + locals = PyFrame_GetLocals(frame); + if (locals == NULL) { + return NULL; + } +#if PY_VERSION_HEX >= 0x03000000 + value = PyDict_GetItemWithError(locals, name); +#else + value = _PyDict_GetItemWithError(locals, name); +#endif + Py_DECREF(locals); + + if (value == NULL) { + if (PyErr_Occurred()) { + return NULL; + } +#if PY_VERSION_HEX >= 0x03000000 + PyErr_Format(PyExc_NameError, "variable %R does not exist", name); +#else + PyErr_SetString(PyExc_NameError, "variable does not exist"); +#endif + return NULL; + } + return Py_NewRef(value); +} +#endif + + +// gh-91248 added PyFrame_GetVarString() to Python 3.12.0a2 +#if PY_VERSION_HEX < 0x030C00A2 && !defined(PYPY_VERSION) +static inline PyObject* +PyFrame_GetVarString(PyFrameObject *frame, const char *name) +{ + PyObject *name_obj, *value; +#if PY_VERSION_HEX >= 0x03000000 + name_obj = PyUnicode_FromString(name); +#else + name_obj = PyString_FromString(name); +#endif + if (name_obj == NULL) { + return NULL; + } + value = PyFrame_GetVar(frame, name_obj); + Py_DECREF(name_obj); + return value; +} +#endif + + +// bpo-39947 added PyThreadState_GetInterpreter() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 || defined(PYPY_VERSION) +static inline PyInterpreterState * +PyThreadState_GetInterpreter(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return tstate->interp; +} +#endif + + +// bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject* PyThreadState_GetFrame(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); +} +#endif + +#if !defined(PYPY_VERSION) +static inline PyFrameObject* +_PyThreadState_GetFrameBorrow(PyThreadState *tstate) +{ + PyFrameObject *frame = PyThreadState_GetFrame(tstate); + Py_XDECREF(frame); + return frame; +} +#endif + + +// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 || defined(PYPY_VERSION) +static inline PyInterpreterState* PyInterpreterState_Get(void) +{ + PyThreadState *tstate; + PyInterpreterState *interp; + + tstate = PyThreadState_GET(); + if (tstate == _Py_NULL) { + Py_FatalError("GIL released (tstate is NULL)"); + } + interp = tstate->interp; + if (interp == _Py_NULL) { + Py_FatalError("no current interpreter"); + } + return interp; +} +#endif + + +// bpo-39947 added PyInterpreterState_Get() to Python 3.9.0a6 +#if 0x030700A1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) +static inline uint64_t PyThreadState_GetID(PyThreadState *tstate) +{ + assert(tstate != _Py_NULL); + return tstate->id; +} +#endif + +// bpo-43760 added PyThreadState_EnterTracing() to Python 3.11.0a2 +#if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) +static inline void PyThreadState_EnterTracing(PyThreadState *tstate) +{ + tstate->tracing++; +#if PY_VERSION_HEX >= 0x030A00A1 + tstate->cframe->use_tracing = 0; +#else + tstate->use_tracing = 0; +#endif +} +#endif + +// bpo-43760 added PyThreadState_LeaveTracing() to Python 3.11.0a2 +#if PY_VERSION_HEX < 0x030B00A2 && !defined(PYPY_VERSION) +static inline void PyThreadState_LeaveTracing(PyThreadState *tstate) +{ + int use_tracing = (tstate->c_tracefunc != _Py_NULL + || tstate->c_profilefunc != _Py_NULL); + tstate->tracing--; +#if PY_VERSION_HEX >= 0x030A00A1 + tstate->cframe->use_tracing = use_tracing; +#else + tstate->use_tracing = use_tracing; +#endif +} +#endif + + +// bpo-37194 added PyObject_CallNoArgs() to Python 3.9.0a1 +// PyObject_CallNoArgs() added to PyPy 3.9.16-v7.3.11 +#if !defined(PyObject_CallNoArgs) && PY_VERSION_HEX < 0x030900A1 +static inline PyObject* PyObject_CallNoArgs(PyObject *func) +{ + return PyObject_CallFunctionObjArgs(func, NULL); +} +#endif + + +// bpo-39245 made PyObject_CallOneArg() public (previously called +// _PyObject_CallOneArg) in Python 3.9.0a4 +// PyObject_CallOneArg() added to PyPy 3.9.16-v7.3.11 +#if !defined(PyObject_CallOneArg) && PY_VERSION_HEX < 0x030900A4 +static inline PyObject* PyObject_CallOneArg(PyObject *func, PyObject *arg) +{ + return PyObject_CallFunctionObjArgs(func, arg, NULL); +} +#endif + + +// bpo-1635741 added PyModule_AddObjectRef() to Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 +static inline int +PyModule_AddObjectRef(PyObject *module, const char *name, PyObject *value) +{ + int res; + + if (!value && !PyErr_Occurred()) { + // PyModule_AddObject() raises TypeError in this case + PyErr_SetString(PyExc_SystemError, + "PyModule_AddObjectRef() must be called " + "with an exception raised if value is NULL"); + return -1; + } + + Py_XINCREF(value); + res = PyModule_AddObject(module, name, value); + if (res < 0) { + Py_XDECREF(value); + } + return res; +} +#endif + + +// bpo-40024 added PyModule_AddType() to Python 3.9.0a5 +#if PY_VERSION_HEX < 0x030900A5 +static inline int PyModule_AddType(PyObject *module, PyTypeObject *type) +{ + const char *name, *dot; + + if (PyType_Ready(type) < 0) { + return -1; + } + + // inline _PyType_Name() + name = type->tp_name; + assert(name != _Py_NULL); + dot = strrchr(name, '.'); + if (dot != _Py_NULL) { + name = dot + 1; + } + + return PyModule_AddObjectRef(module, name, _PyObject_CAST(type)); +} +#endif + + +// bpo-40241 added PyObject_GC_IsTracked() to Python 3.9.0a6. +// bpo-4688 added _PyObject_GC_IS_TRACKED() to Python 2.7.0a2. +#if PY_VERSION_HEX < 0x030900A6 && !defined(PYPY_VERSION) +static inline int PyObject_GC_IsTracked(PyObject* obj) +{ + return (PyObject_IS_GC(obj) && _PyObject_GC_IS_TRACKED(obj)); +} +#endif + +// bpo-40241 added PyObject_GC_IsFinalized() to Python 3.9.0a6. +// bpo-18112 added _PyGCHead_FINALIZED() to Python 3.4.0 final. +#if PY_VERSION_HEX < 0x030900A6 && PY_VERSION_HEX >= 0x030400F0 && !defined(PYPY_VERSION) +static inline int PyObject_GC_IsFinalized(PyObject *obj) +{ + PyGC_Head *gc = _Py_CAST(PyGC_Head*, obj) - 1; + return (PyObject_IS_GC(obj) && _PyGCHead_FINALIZED(gc)); +} +#endif + + +// bpo-39573 added Py_IS_TYPE() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 && !defined(Py_IS_TYPE) +static inline int _Py_IS_TYPE(PyObject *ob, PyTypeObject *type) { + return Py_TYPE(ob) == type; +} +#define Py_IS_TYPE(ob, type) _Py_IS_TYPE(_PyObject_CAST(ob), type) +#endif + + +// bpo-46906 added PyFloat_Pack2() and PyFloat_Unpack2() to Python 3.11a7. +// bpo-11734 added _PyFloat_Pack2() and _PyFloat_Unpack2() to Python 3.6.0b1. +// Python 3.11a2 moved _PyFloat_Pack2() and _PyFloat_Unpack2() to the internal +// C API: Python 3.11a2-3.11a6 versions are not supported. +#if 0x030600B1 <= PY_VERSION_HEX && PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION) +static inline int PyFloat_Pack2(double x, char *p, int le) +{ return _PyFloat_Pack2(x, (unsigned char*)p, le); } + +static inline double PyFloat_Unpack2(const char *p, int le) +{ return _PyFloat_Unpack2((const unsigned char *)p, le); } +#endif + + +// bpo-46906 added PyFloat_Pack4(), PyFloat_Pack8(), PyFloat_Unpack4() and +// PyFloat_Unpack8() to Python 3.11a7. +// Python 3.11a2 moved _PyFloat_Pack4(), _PyFloat_Pack8(), _PyFloat_Unpack4() +// and _PyFloat_Unpack8() to the internal C API: Python 3.11a2-3.11a6 versions +// are not supported. +#if PY_VERSION_HEX <= 0x030B00A1 && !defined(PYPY_VERSION) +static inline int PyFloat_Pack4(double x, char *p, int le) +{ return _PyFloat_Pack4(x, (unsigned char*)p, le); } + +static inline int PyFloat_Pack8(double x, char *p, int le) +{ return _PyFloat_Pack8(x, (unsigned char*)p, le); } + +static inline double PyFloat_Unpack4(const char *p, int le) +{ return _PyFloat_Unpack4((const unsigned char *)p, le); } + +static inline double PyFloat_Unpack8(const char *p, int le) +{ return _PyFloat_Unpack8((const unsigned char *)p, le); } +#endif + + +// gh-92154 added PyCode_GetCode() to Python 3.11.0b1 +#if PY_VERSION_HEX < 0x030B00B1 && !defined(PYPY_VERSION) +static inline PyObject* PyCode_GetCode(PyCodeObject *code) +{ + return Py_NewRef(code->co_code); +} +#endif + + +// gh-95008 added PyCode_GetVarnames() to Python 3.11.0rc1 +#if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) +static inline PyObject* PyCode_GetVarnames(PyCodeObject *code) +{ + return Py_NewRef(code->co_varnames); +} +#endif + +// gh-95008 added PyCode_GetFreevars() to Python 3.11.0rc1 +#if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) +static inline PyObject* PyCode_GetFreevars(PyCodeObject *code) +{ + return Py_NewRef(code->co_freevars); +} +#endif + +// gh-95008 added PyCode_GetCellvars() to Python 3.11.0rc1 +#if PY_VERSION_HEX < 0x030B00C1 && !defined(PYPY_VERSION) +static inline PyObject* PyCode_GetCellvars(PyCodeObject *code) +{ + return Py_NewRef(code->co_cellvars); +} +#endif + + +// Py_UNUSED() was added to Python 3.4.0b2. +#if PY_VERSION_HEX < 0x030400B2 && !defined(Py_UNUSED) +# if defined(__GNUC__) || defined(__clang__) +# define Py_UNUSED(name) _unused_ ## name __attribute__((unused)) +# else +# define Py_UNUSED(name) _unused_ ## name +# endif +#endif + + +// gh-105922 added PyImport_AddModuleRef() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A0 +static inline PyObject* PyImport_AddModuleRef(const char *name) +{ + return Py_XNewRef(PyImport_AddModule(name)); +} +#endif + + +// gh-105927 added PyWeakref_GetRef() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D0000 +static inline int PyWeakref_GetRef(PyObject *ref, PyObject **pobj) +{ + PyObject *obj; + if (ref != NULL && !PyWeakref_Check(ref)) { + *pobj = NULL; + PyErr_SetString(PyExc_TypeError, "expected a weakref"); + return -1; + } + obj = PyWeakref_GetObject(ref); + if (obj == NULL) { + // SystemError if ref is NULL + *pobj = NULL; + return -1; + } + if (obj == Py_None) { + *pobj = NULL; + return 0; + } + *pobj = Py_NewRef(obj); + return (*pobj != NULL); +} +#endif + + +// bpo-36974 added PY_VECTORCALL_ARGUMENTS_OFFSET to Python 3.8b1 +#ifndef PY_VECTORCALL_ARGUMENTS_OFFSET +# define PY_VECTORCALL_ARGUMENTS_OFFSET (_Py_CAST(size_t, 1) << (8 * sizeof(size_t) - 1)) +#endif + +// bpo-36974 added PyVectorcall_NARGS() to Python 3.8b1 +#if PY_VERSION_HEX < 0x030800B1 +static inline Py_ssize_t PyVectorcall_NARGS(size_t n) +{ + return n & ~PY_VECTORCALL_ARGUMENTS_OFFSET; +} +#endif + + +// gh-105922 added PyObject_Vectorcall() to Python 3.9.0a4 +#if PY_VERSION_HEX < 0x030900A4 +static inline PyObject* +PyObject_Vectorcall(PyObject *callable, PyObject *const *args, + size_t nargsf, PyObject *kwnames) +{ +#if PY_VERSION_HEX >= 0x030800B1 && !defined(PYPY_VERSION) + // bpo-36974 added _PyObject_Vectorcall() to Python 3.8.0b1 + return _PyObject_Vectorcall(callable, args, nargsf, kwnames); +#else + PyObject *posargs = NULL, *kwargs = NULL; + PyObject *res; + Py_ssize_t nposargs, nkwargs, i; + + if (nargsf != 0 && args == NULL) { + PyErr_BadInternalCall(); + goto error; + } + if (kwnames != NULL && !PyTuple_Check(kwnames)) { + PyErr_BadInternalCall(); + goto error; + } + + nposargs = (Py_ssize_t)PyVectorcall_NARGS(nargsf); + if (kwnames) { + nkwargs = PyTuple_GET_SIZE(kwnames); + } + else { + nkwargs = 0; + } + + posargs = PyTuple_New(nposargs); + if (posargs == NULL) { + goto error; + } + if (nposargs) { + for (i=0; i < nposargs; i++) { + PyTuple_SET_ITEM(posargs, i, Py_NewRef(*args)); + args++; + } + } + + if (nkwargs) { + kwargs = PyDict_New(); + if (kwargs == NULL) { + goto error; + } + + for (i = 0; i < nkwargs; i++) { + PyObject *key = PyTuple_GET_ITEM(kwnames, i); + PyObject *value = *args; + args++; + if (PyDict_SetItem(kwargs, key, value) < 0) { + goto error; + } + } + } + else { + kwargs = NULL; + } + + res = PyObject_Call(callable, posargs, kwargs); + Py_DECREF(posargs); + Py_XDECREF(kwargs); + return res; + +error: + Py_DECREF(posargs); + Py_XDECREF(kwargs); + return NULL; +#endif +} +#endif + + +// gh-106521 added PyObject_GetOptionalAttr() and +// PyObject_GetOptionalAttrString() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyObject_GetOptionalAttr(PyObject *obj, PyObject *attr_name, PyObject **result) +{ + // bpo-32571 added _PyObject_LookupAttr() to Python 3.7.0b1 +#if PY_VERSION_HEX >= 0x030700B1 && !defined(PYPY_VERSION) + return _PyObject_LookupAttr(obj, attr_name, result); +#else + *result = PyObject_GetAttr(obj, attr_name); + if (*result != NULL) { + return 1; + } + if (!PyErr_Occurred()) { + return 0; + } + if (PyErr_ExceptionMatches(PyExc_AttributeError)) { + PyErr_Clear(); + return 0; + } + return -1; +#endif +} + +static inline int +PyObject_GetOptionalAttrString(PyObject *obj, const char *attr_name, PyObject **result) +{ + PyObject *name_obj; + int rc; +#if PY_VERSION_HEX >= 0x03000000 + name_obj = PyUnicode_FromString(attr_name); +#else + name_obj = PyString_FromString(attr_name); +#endif + if (name_obj == NULL) { + *result = NULL; + return -1; + } + rc = PyObject_GetOptionalAttr(obj, name_obj, result); + Py_DECREF(name_obj); + return rc; +} +#endif + + +// gh-106307 added PyObject_GetOptionalAttr() and +// PyMapping_GetOptionalItemString() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyMapping_GetOptionalItem(PyObject *obj, PyObject *key, PyObject **result) +{ + *result = PyObject_GetItem(obj, key); + if (*result) { + return 1; + } + if (!PyErr_ExceptionMatches(PyExc_KeyError)) { + return -1; + } + PyErr_Clear(); + return 0; +} + +static inline int +PyMapping_GetOptionalItemString(PyObject *obj, const char *key, PyObject **result) +{ + PyObject *key_obj; + int rc; +#if PY_VERSION_HEX >= 0x03000000 + key_obj = PyUnicode_FromString(key); +#else + key_obj = PyString_FromString(key); +#endif + if (key_obj == NULL) { + *result = NULL; + return -1; + } + rc = PyMapping_GetOptionalItem(obj, key_obj, result); + Py_DECREF(key_obj); + return rc; +} +#endif + +// gh-108511 added PyMapping_HasKeyWithError() and +// PyMapping_HasKeyStringWithError() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyMapping_HasKeyWithError(PyObject *obj, PyObject *key) +{ + PyObject *res; + int rc = PyMapping_GetOptionalItem(obj, key, &res); + Py_XDECREF(res); + return rc; +} + +static inline int +PyMapping_HasKeyStringWithError(PyObject *obj, const char *key) +{ + PyObject *res; + int rc = PyMapping_GetOptionalItemString(obj, key, &res); + Py_XDECREF(res); + return rc; +} +#endif + + +// gh-108511 added PyObject_HasAttrWithError() and +// PyObject_HasAttrStringWithError() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyObject_HasAttrWithError(PyObject *obj, PyObject *attr) +{ + PyObject *res; + int rc = PyObject_GetOptionalAttr(obj, attr, &res); + Py_XDECREF(res); + return rc; +} + +static inline int +PyObject_HasAttrStringWithError(PyObject *obj, const char *attr) +{ + PyObject *res; + int rc = PyObject_GetOptionalAttrString(obj, attr, &res); + Py_XDECREF(res); + return rc; +} +#endif + + +// gh-106004 added PyDict_GetItemRef() and PyDict_GetItemStringRef() +// to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyDict_GetItemRef(PyObject *mp, PyObject *key, PyObject **result) +{ +#if PY_VERSION_HEX >= 0x03000000 + PyObject *item = PyDict_GetItemWithError(mp, key); +#else + PyObject *item = _PyDict_GetItemWithError(mp, key); +#endif + if (item != NULL) { + *result = Py_NewRef(item); + return 1; // found + } + if (!PyErr_Occurred()) { + *result = NULL; + return 0; // not found + } + *result = NULL; + return -1; +} + +static inline int +PyDict_GetItemStringRef(PyObject *mp, const char *key, PyObject **result) +{ + int res; +#if PY_VERSION_HEX >= 0x03000000 + PyObject *key_obj = PyUnicode_FromString(key); +#else + PyObject *key_obj = PyString_FromString(key); +#endif + if (key_obj == NULL) { + *result = NULL; + return -1; + } + res = PyDict_GetItemRef(mp, key_obj, result); + Py_DECREF(key_obj); + return res; +} +#endif + + +// gh-106307 added PyModule_Add() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyModule_Add(PyObject *mod, const char *name, PyObject *value) +{ + int res = PyModule_AddObjectRef(mod, name, value); + Py_XDECREF(value); + return res; +} +#endif + + +// gh-108014 added Py_IsFinalizing() to Python 3.13.0a1 +// bpo-1856 added _Py_Finalizing to Python 3.2.1b1. +// _Py_IsFinalizing() was added to PyPy 7.3.0. +#if (0x030201B1 <= PY_VERSION_HEX && PY_VERSION_HEX < 0x030D00A1) \ + && (!defined(PYPY_VERSION_NUM) || PYPY_VERSION_NUM >= 0x7030000) +static inline int Py_IsFinalizing(void) +{ +#if PY_VERSION_HEX >= 0x030700A1 + // _Py_IsFinalizing() was added to Python 3.7.0a1. + return _Py_IsFinalizing(); +#else + return (_Py_Finalizing != NULL); +#endif +} +#endif + + +// gh-108323 added PyDict_ContainsString() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int PyDict_ContainsString(PyObject *op, const char *key) +{ + PyObject *key_obj = PyUnicode_FromString(key); + if (key_obj == NULL) { + return -1; + } + int res = PyDict_Contains(op, key_obj); + Py_DECREF(key_obj); + return res; +} +#endif + + +// gh-108445 added PyLong_AsInt() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int PyLong_AsInt(PyObject *obj) +{ +#ifdef PYPY_VERSION + long value = PyLong_AsLong(obj); + if (value == -1 && PyErr_Occurred()) { + return -1; + } + if (value < (long)INT_MIN || (long)INT_MAX < value) { + PyErr_SetString(PyExc_OverflowError, + "Python int too large to convert to C int"); + return -1; + } + return (int)value; +#else + return _PyLong_AsInt(obj); +#endif +} +#endif + + +// gh-107073 added PyObject_VisitManagedDict() to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyObject_VisitManagedDict(PyObject *obj, visitproc visit, void *arg) +{ + PyObject **dict = _PyObject_GetDictPtr(obj); + if (*dict == NULL) { + return -1; + } + Py_VISIT(*dict); + return 0; +} + +static inline void +PyObject_ClearManagedDict(PyObject *obj) +{ + PyObject **dict = _PyObject_GetDictPtr(obj); + if (*dict == NULL) { + return; + } + Py_CLEAR(*dict); +} +#endif + +// gh-108867 added PyThreadState_GetUnchecked() to Python 3.13.0a1 +// Python 3.5.2 added _PyThreadState_UncheckedGet(). +#if PY_VERSION_HEX >= 0x03050200 && PY_VERSION_HEX < 0x030D00A1 +static inline PyThreadState* +PyThreadState_GetUnchecked(void) +{ + return _PyThreadState_UncheckedGet(); +} +#endif + +// gh-110289 added PyUnicode_EqualToUTF8() and PyUnicode_EqualToUTF8AndSize() +// to Python 3.13.0a1 +#if PY_VERSION_HEX < 0x030D00A1 +static inline int +PyUnicode_EqualToUTF8AndSize(PyObject *unicode, const char *str, Py_ssize_t str_len) +{ + Py_ssize_t len; + const void *utf8; + PyObject *exc_type, *exc_value, *exc_tb; + int res; + + // API cannot report errors so save/restore the exception + PyErr_Fetch(&exc_type, &exc_value, &exc_tb); + + // Python 3.3.0a1 added PyUnicode_AsUTF8AndSize() +#if PY_VERSION_HEX >= 0x030300A1 + if (PyUnicode_IS_ASCII(unicode)) { + utf8 = PyUnicode_DATA(unicode); + len = PyUnicode_GET_LENGTH(unicode); + } + else { + utf8 = PyUnicode_AsUTF8AndSize(unicode, &len); + if (utf8 == NULL) { + // Memory allocation failure. The API cannot report error, + // so ignore the exception and return 0. + res = 0; + goto done; + } + } + + if (len != str_len) { + res = 0; + goto done; + } + res = (memcmp(utf8, str, (size_t)len) == 0); +#else + PyObject *bytes = PyUnicode_AsUTF8String(unicode); + if (bytes == NULL) { + // Memory allocation failure. The API cannot report error, + // so ignore the exception and return 0. + res = 0; + goto done; + } + +#if PY_VERSION_HEX >= 0x03000000 + len = PyBytes_GET_SIZE(bytes); + utf8 = PyBytes_AS_STRING(bytes); +#else + len = PyString_GET_SIZE(bytes); + utf8 = PyString_AS_STRING(bytes); +#endif + if (len != str_len) { + Py_DECREF(bytes); + res = 0; + goto done; + } + + res = (memcmp(utf8, str, (size_t)len) == 0); + Py_DECREF(bytes); +#endif + +done: + PyErr_Restore(exc_type, exc_value, exc_tb); + return res; +} + +static inline int +PyUnicode_EqualToUTF8(PyObject *unicode, const char *str) +{ + return PyUnicode_EqualToUTF8AndSize(unicode, str, (Py_ssize_t)strlen(str)); +} +#endif + + +// gh-111138 added PyList_Extend() and PyList_Clear() to Python 3.13.0a2 +#if PY_VERSION_HEX < 0x030D00A2 +static inline int +PyList_Extend(PyObject *list, PyObject *iterable) +{ + return PyList_SetSlice(list, PY_SSIZE_T_MAX, PY_SSIZE_T_MAX, iterable); +} + +static inline int +PyList_Clear(PyObject *list) +{ + return PyList_SetSlice(list, 0, PY_SSIZE_T_MAX, NULL); +} +#endif + +// gh-111262 added PyDict_Pop() and PyDict_PopString() to Python 3.13.0a2 +#if PY_VERSION_HEX < 0x030D00A2 +static inline int +PyDict_Pop(PyObject *dict, PyObject *key, PyObject **result) +{ + PyObject *value; + + if (!PyDict_Check(dict)) { + PyErr_BadInternalCall(); + if (result) { + *result = NULL; + } + return -1; + } + + // bpo-16991 added _PyDict_Pop() to Python 3.5.0b2. + // Python 3.6.0b3 changed _PyDict_Pop() first argument type to PyObject*. + // Python 3.13.0a1 removed _PyDict_Pop(). +#if defined(PYPY_VERSION) || PY_VERSION_HEX < 0x030500b2 || PY_VERSION_HEX >= 0x030D0000 + value = PyObject_CallMethod(dict, "pop", "O", key); +#elif PY_VERSION_HEX < 0x030600b3 + value = _PyDict_Pop(_Py_CAST(PyDictObject*, dict), key, NULL); +#else + value = _PyDict_Pop(dict, key, NULL); +#endif + if (value == NULL) { + if (result) { + *result = NULL; + } + if (PyErr_Occurred() && !PyErr_ExceptionMatches(PyExc_KeyError)) { + return -1; + } + PyErr_Clear(); + return 0; + } + if (result) { + *result = value; + } + else { + Py_DECREF(value); + } + return 1; +} + +static inline int +PyDict_PopString(PyObject *dict, const char *key, PyObject **result) +{ + PyObject *key_obj = PyUnicode_FromString(key); + if (key_obj == NULL) { + if (result != NULL) { + *result = NULL; + } + return -1; + } + + int res = PyDict_Pop(dict, key_obj, result); + Py_DECREF(key_obj); + return res; +} +#endif + + +#if PY_VERSION_HEX < 0x030200A4 +// Python 3.2.0a4 added Py_hash_t type +typedef Py_ssize_t Py_hash_t; +#endif + + +// gh-111545 added Py_HashPointer() to Python 3.13.0a3 +#if PY_VERSION_HEX < 0x030D00A3 +static inline Py_hash_t Py_HashPointer(const void *ptr) +{ +#if PY_VERSION_HEX >= 0x030900A4 && !defined(PYPY_VERSION) + return _Py_HashPointer(ptr); +#else + return _Py_HashPointer(_Py_CAST(void*, ptr)); +#endif +} +#endif + + +// Python 3.13a4 added a PyTime API. +// Use the private API added to Python 3.5. +#if PY_VERSION_HEX < 0x030D00A4 && PY_VERSION_HEX >= 0x03050000 +typedef _PyTime_t PyTime_t; +#define PyTime_MIN _PyTime_MIN +#define PyTime_MAX _PyTime_MAX + +static inline double PyTime_AsSecondsDouble(PyTime_t t) +{ return _PyTime_AsSecondsDouble(t); } + +static inline int PyTime_Monotonic(PyTime_t *result) +{ return _PyTime_GetMonotonicClockWithInfo(result, NULL); } + +static inline int PyTime_Time(PyTime_t *result) +{ return _PyTime_GetSystemClockWithInfo(result, NULL); } + +static inline int PyTime_PerfCounter(PyTime_t *result) +{ +#if PY_VERSION_HEX >= 0x03070000 && !defined(PYPY_VERSION) + return _PyTime_GetPerfCounterWithInfo(result, NULL); +#elif PY_VERSION_HEX >= 0x03070000 + // Call time.perf_counter_ns() and convert Python int object to PyTime_t. + // Cache time.perf_counter_ns() function for best performance. + static PyObject *func = NULL; + if (func == NULL) { + PyObject *mod = PyImport_ImportModule("time"); + if (mod == NULL) { + return -1; + } + + func = PyObject_GetAttrString(mod, "perf_counter_ns"); + Py_DECREF(mod); + if (func == NULL) { + return -1; + } + } + + PyObject *res = PyObject_CallNoArgs(func); + if (res == NULL) { + return -1; + } + long long value = PyLong_AsLongLong(res); + Py_DECREF(res); + + if (value == -1 && PyErr_Occurred()) { + return -1; + } + + Py_BUILD_ASSERT(sizeof(value) >= sizeof(PyTime_t)); + *result = (PyTime_t)value; + return 0; +#else + // Call time.perf_counter() and convert C double to PyTime_t. + // Cache time.perf_counter() function for best performance. + static PyObject *func = NULL; + if (func == NULL) { + PyObject *mod = PyImport_ImportModule("time"); + if (mod == NULL) { + return -1; + } + + func = PyObject_GetAttrString(mod, "perf_counter"); + Py_DECREF(mod); + if (func == NULL) { + return -1; + } + } + + PyObject *res = PyObject_CallNoArgs(func); + if (res == NULL) { + return -1; + } + double d = PyFloat_AsDouble(res); + Py_DECREF(res); + + if (d == -1.0 && PyErr_Occurred()) { + return -1; + } + + // Avoid floor() to avoid having to link to libm + *result = (PyTime_t)(d * 1e9); + return 0; +#endif +} + +#endif + +// gh-111389 added hash constants to Python 3.13.0a5. These constants were +// added first as private macros to Python 3.4.0b1 and PyPy 7.3.9. +#if (!defined(PyHASH_BITS) \ + && ((!defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x030400B1) \ + || (defined(PYPY_VERSION) && PY_VERSION_HEX >= 0x03070000 \ + && PYPY_VERSION_NUM >= 0x07090000))) +# define PyHASH_BITS _PyHASH_BITS +# define PyHASH_MODULUS _PyHASH_MODULUS +# define PyHASH_INF _PyHASH_INF +# define PyHASH_IMAG _PyHASH_IMAG +#endif + + +// gh-111545 added Py_GetConstant() and Py_GetConstantBorrowed() +// to Python 3.13.0a6 +#if PY_VERSION_HEX < 0x030D00A6 && !defined(Py_CONSTANT_NONE) + +#define Py_CONSTANT_NONE 0 +#define Py_CONSTANT_FALSE 1 +#define Py_CONSTANT_TRUE 2 +#define Py_CONSTANT_ELLIPSIS 3 +#define Py_CONSTANT_NOT_IMPLEMENTED 4 +#define Py_CONSTANT_ZERO 5 +#define Py_CONSTANT_ONE 6 +#define Py_CONSTANT_EMPTY_STR 7 +#define Py_CONSTANT_EMPTY_BYTES 8 +#define Py_CONSTANT_EMPTY_TUPLE 9 + +static inline PyObject* Py_GetConstant(unsigned int constant_id) +{ + static PyObject* constants[Py_CONSTANT_EMPTY_TUPLE + 1] = {NULL}; + + if (constants[Py_CONSTANT_NONE] == NULL) { + constants[Py_CONSTANT_NONE] = Py_None; + constants[Py_CONSTANT_FALSE] = Py_False; + constants[Py_CONSTANT_TRUE] = Py_True; + constants[Py_CONSTANT_ELLIPSIS] = Py_Ellipsis; + constants[Py_CONSTANT_NOT_IMPLEMENTED] = Py_NotImplemented; + + constants[Py_CONSTANT_ZERO] = PyLong_FromLong(0); + if (constants[Py_CONSTANT_ZERO] == NULL) { + goto fatal_error; + } + + constants[Py_CONSTANT_ONE] = PyLong_FromLong(1); + if (constants[Py_CONSTANT_ONE] == NULL) { + goto fatal_error; + } + + constants[Py_CONSTANT_EMPTY_STR] = PyUnicode_FromStringAndSize("", 0); + if (constants[Py_CONSTANT_EMPTY_STR] == NULL) { + goto fatal_error; + } + + constants[Py_CONSTANT_EMPTY_BYTES] = PyBytes_FromStringAndSize("", 0); + if (constants[Py_CONSTANT_EMPTY_BYTES] == NULL) { + goto fatal_error; + } + + constants[Py_CONSTANT_EMPTY_TUPLE] = PyTuple_New(0); + if (constants[Py_CONSTANT_EMPTY_TUPLE] == NULL) { + goto fatal_error; + } + // goto dance to avoid compiler warnings about Py_FatalError() + goto init_done; + +fatal_error: + // This case should never happen + Py_FatalError("Py_GetConstant() failed to get constants"); + } + +init_done: + if (constant_id <= Py_CONSTANT_EMPTY_TUPLE) { + return Py_NewRef(constants[constant_id]); + } + else { + PyErr_BadInternalCall(); + return NULL; + } +} + +static inline PyObject* Py_GetConstantBorrowed(unsigned int constant_id) +{ + PyObject *obj = Py_GetConstant(constant_id); + Py_XDECREF(obj); + return obj; +} +#endif + + +// gh-114329 added PyList_GetItemRef() to Python 3.13.0a4 +#if PY_VERSION_HEX < 0x030D00A4 +static inline PyObject * +PyList_GetItemRef(PyObject *op, Py_ssize_t index) +{ + PyObject *item = PyList_GetItem(op, index); + Py_XINCREF(item); + return item; +} +#endif + + +// gh-114329 added PyList_GetItemRef() to Python 3.13.0a4 +#if PY_VERSION_HEX < 0x030D00A4 +static inline int +PyDict_SetDefaultRef(PyObject *d, PyObject *key, PyObject *default_value, + PyObject **result) +{ + PyObject *value; + if (PyDict_GetItemRef(d, key, &value) < 0) { + // get error + if (result) { + *result = NULL; + } + return -1; + } + if (value != NULL) { + // present + if (result) { + *result = value; + } + else { + Py_DECREF(value); + } + return 1; + } + + // missing: set the item + if (PyDict_SetItem(d, key, default_value) < 0) { + // set error + if (result) { + *result = NULL; + } + return -1; + } + if (result) { + *result = Py_NewRef(default_value); + } + return 0; +} +#endif + +#if PY_VERSION_HEX < 0x030E0000 && PY_VERSION_HEX >= 0x03060000 && !defined(PYPY_VERSION) +typedef struct PyUnicodeWriter PyUnicodeWriter; + +static inline void PyUnicodeWriter_Discard(PyUnicodeWriter *writer) +{ + _PyUnicodeWriter_Dealloc((_PyUnicodeWriter*)writer); + PyMem_Free(writer); +} + +static inline PyUnicodeWriter* PyUnicodeWriter_Create(Py_ssize_t length) +{ + if (length < 0) { + PyErr_SetString(PyExc_ValueError, + "length must be positive"); + return NULL; + } + + const size_t size = sizeof(_PyUnicodeWriter); + PyUnicodeWriter *pub_writer = (PyUnicodeWriter *)PyMem_Malloc(size); + if (pub_writer == _Py_NULL) { + PyErr_NoMemory(); + return _Py_NULL; + } + _PyUnicodeWriter *writer = (_PyUnicodeWriter *)pub_writer; + + _PyUnicodeWriter_Init(writer); + if (_PyUnicodeWriter_Prepare(writer, length, 127) < 0) { + PyUnicodeWriter_Discard(pub_writer); + return NULL; + } + writer->overallocate = 1; + return pub_writer; +} + +static inline PyObject* PyUnicodeWriter_Finish(PyUnicodeWriter *writer) +{ + PyObject *str = _PyUnicodeWriter_Finish((_PyUnicodeWriter*)writer); + assert(((_PyUnicodeWriter*)writer)->buffer == NULL); + PyMem_Free(writer); + return str; +} + +static inline int +PyUnicodeWriter_WriteChar(PyUnicodeWriter *writer, Py_UCS4 ch) +{ + if (ch > 0x10ffff) { + PyErr_SetString(PyExc_ValueError, + "character must be in range(0x110000)"); + return -1; + } + + return _PyUnicodeWriter_WriteChar((_PyUnicodeWriter*)writer, ch); +} + +static inline int +PyUnicodeWriter_WriteStr(PyUnicodeWriter *writer, PyObject *obj) +{ + PyObject *str = PyObject_Str(obj); + if (str == NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); + Py_DECREF(str); + return res; +} + +static inline int +PyUnicodeWriter_WriteRepr(PyUnicodeWriter *writer, PyObject *obj) +{ + PyObject *str = PyObject_Repr(obj); + if (str == NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); + Py_DECREF(str); + return res; +} + +static inline int +PyUnicodeWriter_WriteUTF8(PyUnicodeWriter *writer, + const char *str, Py_ssize_t size) +{ + if (size < 0) { + size = (Py_ssize_t)strlen(str); + } + + PyObject *str_obj = PyUnicode_FromStringAndSize(str, size); + if (str_obj == _Py_NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str_obj); + Py_DECREF(str_obj); + return res; +} + +static inline int +PyUnicodeWriter_WriteWideChar(PyUnicodeWriter *writer, + const wchar_t *str, Py_ssize_t size) +{ + if (size < 0) { + size = (Py_ssize_t)wcslen(str); + } + + PyObject *str_obj = PyUnicode_FromWideChar(str, size); + if (str_obj == _Py_NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str_obj); + Py_DECREF(str_obj); + return res; +} + +static inline int +PyUnicodeWriter_WriteSubstring(PyUnicodeWriter *writer, PyObject *str, + Py_ssize_t start, Py_ssize_t end) +{ + if (!PyUnicode_Check(str)) { + PyErr_Format(PyExc_TypeError, "expect str, not %T", str); + return -1; + } + if (start < 0 || start > end) { + PyErr_Format(PyExc_ValueError, "invalid start argument"); + return -1; + } + if (end > PyUnicode_GET_LENGTH(str)) { + PyErr_Format(PyExc_ValueError, "invalid end argument"); + return -1; + } + + return _PyUnicodeWriter_WriteSubstring((_PyUnicodeWriter*)writer, str, + start, end); +} + +static inline int +PyUnicodeWriter_Format(PyUnicodeWriter *writer, const char *format, ...) +{ + va_list vargs; + va_start(vargs, format); + PyObject *str = PyUnicode_FromFormatV(format, vargs); + va_end(vargs); + if (str == _Py_NULL) { + return -1; + } + + int res = _PyUnicodeWriter_WriteStr((_PyUnicodeWriter*)writer, str); + Py_DECREF(str); + return res; +} +#endif // PY_VERSION_HEX < 0x030E0000 + +// gh-116560 added PyLong_GetSign() to Python 3.14.0a0 +#if PY_VERSION_HEX < 0x030E00A0 +static inline int PyLong_GetSign(PyObject *obj, int *sign) +{ + if (!PyLong_Check(obj)) { + PyErr_Format(PyExc_TypeError, "expect int, got %s", Py_TYPE(obj)->tp_name); + return -1; + } + + *sign = _PyLong_Sign(obj); + return 0; +} +#endif + + +#ifdef __cplusplus +} +#endif +#endif // PYTHONCAPI_COMPAT From fb202ee66d73572f46035c5b2f21ac22f74ba951 Mon Sep 17 00:00:00 2001 From: mwish Date: Thu, 15 Aug 2024 21:04:39 +0800 Subject: [PATCH 016/157] GH-43703: [C++][Parquet][CI] Parquet: Introducing more bad_data for testing (#43708) ### Rationale for this change Introducing more bad_data for testing ### What changes are included in this PR? * Upgrade parquet-testing * Introduce more bad_data * Update fuzz generation ### Are these changes tested? They're tests :-) ### Are there any user-facing changes? no * GitHub Issue: #43703 Authored-by: mwish Signed-off-by: Antoine Pitrou --- cpp/build-support/fuzzing/generate_corpuses.sh | 1 + cpp/src/parquet/arrow/arrow_reader_writer_test.cc | 12 +++++++++--- cpp/submodules/parquet-testing | 2 +- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/cpp/build-support/fuzzing/generate_corpuses.sh b/cpp/build-support/fuzzing/generate_corpuses.sh index e3f00e64782..ffd5c54e443 100755 --- a/cpp/build-support/fuzzing/generate_corpuses.sh +++ b/cpp/build-support/fuzzing/generate_corpuses.sh @@ -56,4 +56,5 @@ rm -rf ${CORPUS_DIR} ${OUT}/parquet-arrow-generate-fuzz-corpus ${CORPUS_DIR} # Add Parquet testing examples cp ${ARROW_CPP}/submodules/parquet-testing/data/*.parquet ${CORPUS_DIR} +cp ${ARROW_CPP}/submodules/parquet-testing/bad_data/*.parquet ${CORPUS_DIR} ${ARROW_CPP}/build-support/fuzzing/pack_corpus.py ${CORPUS_DIR} ${OUT}/parquet-arrow-fuzz_seed_corpus.zip diff --git a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc index aad1e933c4f..64030e0f90d 100644 --- a/cpp/src/parquet/arrow/arrow_reader_writer_test.cc +++ b/cpp/src/parquet/arrow/arrow_reader_writer_test.cc @@ -5298,14 +5298,20 @@ TEST(TestArrowReadWrite, MultithreadedWrite) { TEST(TestArrowReadWrite, FuzzReader) { constexpr size_t kMaxFileSize = 1024 * 1024 * 1; - { - auto path = test::get_data_file("PARQUET-1481.parquet", /*is_good=*/false); + auto check_bad_file = [&](const std::string& file_name) { + SCOPED_TRACE(file_name); + auto path = test::get_data_file(file_name, /*is_good=*/false); PARQUET_ASSIGN_OR_THROW(auto source, ::arrow::io::MemoryMappedFile::Open( path, ::arrow::io::FileMode::READ)); PARQUET_ASSIGN_OR_THROW(auto buffer, source->Read(kMaxFileSize)); auto s = internal::FuzzReader(buffer->data(), buffer->size()); ASSERT_NOT_OK(s); - } + }; + check_bad_file("PARQUET-1481.parquet"); + check_bad_file("ARROW-GH-41317.parquet"); + check_bad_file("ARROW-GH-41321.parquet"); + check_bad_file("ARROW-RS-GH-6229-LEVELS.parquet"); + check_bad_file("ARROW-RS-GH-6229-DICTHEADER.parquet"); { auto path = test::get_data_file("alltypes_plain.parquet", /*is_good=*/true); PARQUET_ASSIGN_OR_THROW(auto source, ::arrow::io::MemoryMappedFile::Open( diff --git a/cpp/submodules/parquet-testing b/cpp/submodules/parquet-testing index 74278bc4a11..cb7a9674142 160000 --- a/cpp/submodules/parquet-testing +++ b/cpp/submodules/parquet-testing @@ -1 +1 @@ -Subproject commit 74278bc4a1122d74945969e6dec405abd1533ec3 +Subproject commit cb7a9674142c137367bf75a01b79c6e214a73199 From dfe6c50cf81a6893e44b1e2056301bfdfc2be48b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 07:46:35 -0700 Subject: [PATCH 017/157] MINOR: [C#] Bump BenchmarkDotNet and System.Runtime.CompilerServices.Unsafe in /csharp (#43651) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [BenchmarkDotNet](https://github.com/dotnet/BenchmarkDotNet) and [System.Runtime.CompilerServices.Unsafe](https://github.com/dotnet/runtime). These dependencies needed to be updated together. Updates `BenchmarkDotNet` from 0.13.12 to 0.14.0
Release notes

Sourced from BenchmarkDotNet's releases.

0.14.0

Full changelog: https://benchmarkdotnet.org/changelog/v0.14.0.html

Highlights

  • Introduce BenchmarkDotNet.Diagnostics.dotMemory #2549: memory allocation profile of your benchmarks using dotMemory, see @​BenchmarkDotNet.Samples.IntroDotMemoryDiagnoser
  • Introduce BenchmarkDotNet.Exporters.Plotting #2560: plotting via ScottPlot (initial version)
  • Multiple bugfixes
  • The default build toolchains have been updated to pass IntermediateOutputPath, OutputPath, and OutDir properties to the dotnet build command. This change forces all build outputs to be placed in a new directory generated by BenchmarkDotNet, and fixes many issues that have been reported with builds. You can also access these paths in your own .csproj and .props from those properties if you need to copy custom files to the output.

Bug fixes

  • Fixed multiple build-related bugs including passing MsBuildArguments and .Net 8's UseArtifactsOutput.

Breaking Changes

  • DotNetCliBuilder removed retryFailedBuildWithNoDeps constructor option.
  • DotNetCliCommand removed RetryFailedBuildWithNoDeps property and BuildNoRestoreNoDependencies() and PublishNoBuildAndNoRestore() methods (replaced with PublishNoRestore()).
Commits
  • cf882d3 Add macOS Sequoia in OsBrandStringHelper
  • 17cf3b0 [docs] Prepare v0.14.0 changelog
  • b3fbe7c Set next BenchmarkDotNet version: 0.14.0
  • 23e6c52 Fix InvalidOperationException in DotMemoryDiagnoser
  • 3d34edb Bump JetBrains.Profiler.SelfApi: 2.5.2->2.5.9
  • bf0a49d fix(CI): Deprecation issues (#2605)
  • 0275649 Fixed crash from TaskbarProgress when BuiltInComInteropSupport is disabled. ...
  • 15200d4 [build] Add BenchmarkDotNet.Exporters.Plotting.Tests to unit-tests
  • 834417a Improve logging in ScottPlotExporterTests
  • f8082a2 Fix IntroSummaryStyle compilation
  • Additional commits viewable in compare view

Updates `System.Runtime.CompilerServices.Unsafe` from 4.7.1 to 5.0.0
Release notes

Sourced from System.Runtime.CompilerServices.Unsafe's releases.

.NET 5

Release Notes Install Instructions

Repo

Commits

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Curt Hagenlocher --- .../test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj b/csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj index f735f01b022..5bf51f5c305 100644 --- a/csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj +++ b/csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj @@ -6,7 +6,7 @@ - + From 8b634ad2998b6a670cc7d4d3ef0e43dea3b7aca1 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 15 Aug 2024 08:07:56 -0700 Subject: [PATCH 018/157] MINOR: [C#] Bump BenchmarkDotNet.Diagnostics.Windows and System.Runtime.CompilerServices.Unsafe in /csharp (#43711) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [BenchmarkDotNet.Diagnostics.Windows](https://github.com/dotnet/BenchmarkDotNet) and [System.Runtime.CompilerServices.Unsafe](https://github.com/dotnet/runtime). These dependencies needed to be updated together. Updates `BenchmarkDotNet.Diagnostics.Windows` from 0.13.12 to 0.14.0
Release notes

Sourced from BenchmarkDotNet.Diagnostics.Windows's releases.

0.14.0

Full changelog: https://benchmarkdotnet.org/changelog/v0.14.0.html

Highlights

  • Introduce BenchmarkDotNet.Diagnostics.dotMemory #2549: memory allocation profile of your benchmarks using dotMemory, see @​BenchmarkDotNet.Samples.IntroDotMemoryDiagnoser
  • Introduce BenchmarkDotNet.Exporters.Plotting #2560: plotting via ScottPlot (initial version)
  • Multiple bugfixes
  • The default build toolchains have been updated to pass IntermediateOutputPath, OutputPath, and OutDir properties to the dotnet build command. This change forces all build outputs to be placed in a new directory generated by BenchmarkDotNet, and fixes many issues that have been reported with builds. You can also access these paths in your own .csproj and .props from those properties if you need to copy custom files to the output.

Bug fixes

  • Fixed multiple build-related bugs including passing MsBuildArguments and .Net 8's UseArtifactsOutput.

Breaking Changes

  • DotNetCliBuilder removed retryFailedBuildWithNoDeps constructor option.
  • DotNetCliCommand removed RetryFailedBuildWithNoDeps property and BuildNoRestoreNoDependencies() and PublishNoBuildAndNoRestore() methods (replaced with PublishNoRestore()).
Commits
  • cf882d3 Add macOS Sequoia in OsBrandStringHelper
  • 17cf3b0 [docs] Prepare v0.14.0 changelog
  • b3fbe7c Set next BenchmarkDotNet version: 0.14.0
  • 23e6c52 Fix InvalidOperationException in DotMemoryDiagnoser
  • 3d34edb Bump JetBrains.Profiler.SelfApi: 2.5.2->2.5.9
  • bf0a49d fix(CI): Deprecation issues (#2605)
  • 0275649 Fixed crash from TaskbarProgress when BuiltInComInteropSupport is disabled. ...
  • 15200d4 [build] Add BenchmarkDotNet.Exporters.Plotting.Tests to unit-tests
  • 834417a Improve logging in ScottPlotExporterTests
  • f8082a2 Fix IntroSummaryStyle compilation
  • Additional commits viewable in compare view

Updates `System.Runtime.CompilerServices.Unsafe` from 4.7.1 to 5.0.0
Commits

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Curt Hagenlocher --- .../test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj b/csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj index 5bf51f5c305..0a3e3341041 100644 --- a/csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj +++ b/csharp/test/Apache.Arrow.Benchmarks/Apache.Arrow.Benchmarks.csproj @@ -7,7 +7,7 @@ - + From 2e434dad9b0cdcc57524dc2a0cc7f7b3ed23ccc4 Mon Sep 17 00:00:00 2001 From: mwish Date: Fri, 16 Aug 2024 00:10:51 +0800 Subject: [PATCH 019/157] GH-43687: [C++] Compute: fix register kernel SimdLevel for AddMinMax512AggKernels (#43704) ### Rationale for this change See https://github.com/apache/arrow/issues/43687 ### What changes are included in this PR? Change Registered AVX2 to AVX512 ### Are these changes tested? No ### Are there any user-facing changes? maybe bugfix * GitHub Issue: #43687 Authored-by: mwish Signed-off-by: mwish --- cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc index 0d66ed2ec3e..05356e0aa5e 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_avx512.cc @@ -80,8 +80,8 @@ void AddMinMaxAvx512AggKernels(ScalarAggregateFunction* func) { AddMinMaxKernels(MinMaxInitAvx512, {int32(), uint32(), int64(), uint64()}, func, SimdLevel::AVX512); AddMinMaxKernels(MinMaxInitAvx512, TemporalTypes(), func, SimdLevel::AVX512); - AddMinMaxKernels(MinMaxInitAvx512, BaseBinaryTypes(), func, SimdLevel::AVX2); - AddMinMaxKernel(MinMaxInitAvx512, Type::FIXED_SIZE_BINARY, func, SimdLevel::AVX2); + AddMinMaxKernels(MinMaxInitAvx512, BaseBinaryTypes(), func, SimdLevel::AVX512); + AddMinMaxKernel(MinMaxInitAvx512, Type::FIXED_SIZE_BINARY, func, SimdLevel::AVX512); AddMinMaxKernel(MinMaxInitAvx512, Type::INTERVAL_MONTHS, func, SimdLevel::AVX512); } From 2767dc55cb41377af6895141f717475d73b2892d Mon Sep 17 00:00:00 2001 From: Chungmin Lee Date: Thu, 15 Aug 2024 09:32:22 -0700 Subject: [PATCH 020/157] GH-41579: [C++][Python][Parquet] Support reading/writing key-value metadata from/to ColumnChunkMetaData (#41580) ### Rationale for this change Parquet standard allows reading/writing key-value metadata from/to ColumnChunkMetaData, but there is no way to do that with Parquet C++. ### What changes are included in this PR? Support reading/writing key-value metadata from/to ColumnChunkMetaData with Parquet C++ reader/writer. Support reading key-value metadata from ColumnChunkMetaData with pyarrow.parquet. ### Are these changes tested? Yes, unit tests are added ### Are there any user-facing changes? Yes. - Users can read or write key-value metadata for column chunks with Parquet C++. - Users can read key-value metadata for column chunks with PyArrow. - parquet-reader tool prints key-value metadata in column chunks when `--print-key-value-metadata` option is used. * GitHub Issue: #41579 Lead-authored-by: Chungmin Lee Co-authored-by: mwish Signed-off-by: mwish --- cpp/src/parquet/column_writer.cc | 24 ++++++ cpp/src/parquet/column_writer.h | 12 +++ cpp/src/parquet/column_writer_test.cc | 69 +++++++++++++++ cpp/src/parquet/metadata.cc | 84 ++++++++++++++----- cpp/src/parquet/metadata.h | 5 ++ cpp/src/parquet/printer.cc | 32 +++++-- python/pyarrow/_parquet.pxd | 1 + python/pyarrow/_parquet.pyx | 13 +++ python/pyarrow/tests/parquet/conftest.py | 12 +++ python/pyarrow/tests/parquet/test_metadata.py | 9 ++ 10 files changed, 235 insertions(+), 26 deletions(-) diff --git a/cpp/src/parquet/column_writer.cc b/cpp/src/parquet/column_writer.cc index f859ec9653f..40d19d38e10 100644 --- a/cpp/src/parquet/column_writer.cc +++ b/cpp/src/parquet/column_writer.cc @@ -40,6 +40,7 @@ #include "arrow/util/crc32.h" #include "arrow/util/endian.h" #include "arrow/util/float16.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/rle_encoding_internal.h" #include "arrow/util/type_traits.h" @@ -832,6 +833,9 @@ class ColumnWriterImpl { void FlushBufferedDataPages(); ColumnChunkMetaDataBuilder* metadata_; + // key_value_metadata_ for the column chunk + // It would be nullptr if there is no KeyValueMetadata set. + std::shared_ptr key_value_metadata_; const ColumnDescriptor* descr_; // scratch buffer if validity bits need to be recalculated. std::shared_ptr bits_buffer_; @@ -1100,6 +1104,7 @@ int64_t ColumnWriterImpl::Close() { if (rows_written_ > 0 && chunk_statistics.is_set()) { metadata_->SetStatistics(chunk_statistics); } + metadata_->SetKeyValueMetadata(key_value_metadata_); pager_->Close(has_dictionary_, fallback_); } @@ -1397,6 +1402,25 @@ class TypedColumnWriterImpl : public ColumnWriterImpl, public TypedColumnWriter< return pages_change_on_record_boundaries_; } + void AddKeyValueMetadata( + const std::shared_ptr& key_value_metadata) override { + if (closed_) { + throw ParquetException("Cannot add key-value metadata to closed column"); + } + if (key_value_metadata_ == nullptr) { + key_value_metadata_ = key_value_metadata; + } else if (key_value_metadata != nullptr) { + key_value_metadata_ = key_value_metadata_->Merge(*key_value_metadata); + } + } + + void ResetKeyValueMetadata() override { + if (closed_) { + throw ParquetException("Cannot add key-value metadata to closed column"); + } + key_value_metadata_ = nullptr; + } + private: using ValueEncoderType = typename EncodingTraits::Encoder; using TypedStats = TypedStatistics; diff --git a/cpp/src/parquet/column_writer.h b/cpp/src/parquet/column_writer.h index a278670fa81..845bf9aa896 100644 --- a/cpp/src/parquet/column_writer.h +++ b/cpp/src/parquet/column_writer.h @@ -21,6 +21,7 @@ #include #include +#include "arrow/type_fwd.h" #include "arrow/util/compression.h" #include "parquet/exception.h" #include "parquet/platform.h" @@ -181,6 +182,17 @@ class PARQUET_EXPORT ColumnWriter { /// \brief The file-level writer properties virtual const WriterProperties* properties() = 0; + /// \brief Add key-value metadata to the ColumnChunk. + /// \param[in] key_value_metadata the metadata to add. + /// \note This will overwrite any existing metadata with the same key. + /// \throw ParquetException if Close() has been called. + virtual void AddKeyValueMetadata( + const std::shared_ptr& key_value_metadata) = 0; + + /// \brief Reset the ColumnChunk key-value metadata. + /// \throw ParquetException if Close() has been called. + virtual void ResetKeyValueMetadata() = 0; + /// \brief Write Apache Arrow columnar data directly to ColumnWriter. Returns /// error status if the array data type is not compatible with the concrete /// writer type. diff --git a/cpp/src/parquet/column_writer_test.cc b/cpp/src/parquet/column_writer_test.cc index c99efd17961..d2b3aa0dff0 100644 --- a/cpp/src/parquet/column_writer_test.cc +++ b/cpp/src/parquet/column_writer_test.cc @@ -23,10 +23,12 @@ #include #include "arrow/io/buffered.h" +#include "arrow/io/file.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_builders.h" #include "arrow/util/config.h" +#include "arrow/util/key_value_metadata.h" #include "parquet/column_page.h" #include "parquet/column_reader.h" @@ -51,6 +53,9 @@ using schema::PrimitiveNode; namespace test { +using ::testing::IsNull; +using ::testing::NotNull; + // The default size used in most tests. const int SMALL_SIZE = 100; #ifdef PARQUET_VALGRIND @@ -385,6 +390,15 @@ class TestPrimitiveWriter : public PrimitiveTypedTest { return metadata_accessor->encoding_stats(); } + std::shared_ptr metadata_key_value_metadata() { + // Metadata accessor must be created lazily. + // This is because the ColumnChunkMetaData semantics dictate the metadata object is + // complete (no changes to the metadata buffer can be made after instantiation) + auto metadata_accessor = + ColumnChunkMetaData::Make(metadata_->contents(), this->descr_); + return metadata_accessor->key_value_metadata(); + } + protected: int64_t values_read_; // Keep the reader alive as for ByteArray the lifetime of the ByteArray @@ -1705,5 +1719,60 @@ TEST(TestColumnWriter, WriteDataPageV2HeaderNullCount) { } } +using TestInt32Writer = TestPrimitiveWriter; + +TEST_F(TestInt32Writer, NoWriteKeyValueMetadata) { + auto writer = this->BuildWriter(); + writer->Close(); + auto key_value_metadata = metadata_key_value_metadata(); + ASSERT_THAT(key_value_metadata, IsNull()); +} + +TEST_F(TestInt32Writer, WriteKeyValueMetadata) { + auto writer = this->BuildWriter(); + writer->AddKeyValueMetadata( + KeyValueMetadata::Make({"hello", "bye"}, {"world", "earth"})); + // overwrite the previous value + writer->AddKeyValueMetadata(KeyValueMetadata::Make({"bye"}, {"moon"})); + writer->Close(); + auto key_value_metadata = metadata_key_value_metadata(); + ASSERT_THAT(key_value_metadata, NotNull()); + ASSERT_EQ(2, key_value_metadata->size()); + ASSERT_OK_AND_ASSIGN(auto value, key_value_metadata->Get("hello")); + ASSERT_EQ("world", value); + ASSERT_OK_AND_ASSIGN(value, key_value_metadata->Get("bye")); + ASSERT_EQ("moon", value); +} + +TEST_F(TestInt32Writer, ResetKeyValueMetadata) { + auto writer = this->BuildWriter(); + writer->AddKeyValueMetadata(KeyValueMetadata::Make({"hello"}, {"world"})); + writer->ResetKeyValueMetadata(); + writer->Close(); + auto key_value_metadata = metadata_key_value_metadata(); + ASSERT_THAT(key_value_metadata, IsNull()); +} + +TEST_F(TestInt32Writer, WriteKeyValueMetadataEndToEnd) { + auto sink = CreateOutputStream(); + { + auto file_writer = ParquetFileWriter::Open( + sink, std::dynamic_pointer_cast(schema_.schema_root())); + auto rg_writer = file_writer->AppendRowGroup(); + auto col_writer = rg_writer->NextColumn(); + col_writer->AddKeyValueMetadata(KeyValueMetadata::Make({"foo"}, {"bar"})); + file_writer->Close(); + } + ASSERT_OK_AND_ASSIGN(auto buffer, sink->Finish()); + auto file_reader = + ParquetFileReader::Open(std::make_shared<::arrow::io::BufferReader>(buffer)); + auto key_value_metadata = + file_reader->metadata()->RowGroup(0)->ColumnChunk(0)->key_value_metadata(); + ASSERT_THAT(key_value_metadata, NotNull()); + ASSERT_EQ(1U, key_value_metadata->size()); + ASSERT_OK_AND_ASSIGN(auto value, key_value_metadata->Get("foo")); + ASSERT_EQ("bar", value); +} + } // namespace test } // namespace parquet diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index 10c8afaf375..4f2aa6e3732 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -135,6 +135,39 @@ std::shared_ptr MakeColumnStats(const format::ColumnMetaData& meta_d throw ParquetException("Can't decode page statistics for selected column type"); } +// Get KeyValueMetadata from parquet Thrift RowGroup or ColumnChunk metadata. +// +// Returns nullptr if the metadata is not set. +template +std::shared_ptr FromThriftKeyValueMetadata(const Metadata& source) { + std::shared_ptr metadata = nullptr; + if (source.__isset.key_value_metadata) { + std::vector keys; + std::vector values; + keys.reserve(source.key_value_metadata.size()); + values.reserve(source.key_value_metadata.size()); + for (const auto& it : source.key_value_metadata) { + keys.push_back(it.key); + values.push_back(it.value); + } + metadata = std::make_shared(std::move(keys), std::move(values)); + } + return metadata; +} + +template +void ToThriftKeyValueMetadata(const KeyValueMetadata& source, Metadata* metadata) { + std::vector key_value_metadata; + key_value_metadata.reserve(static_cast(source.size())); + for (int64_t i = 0; i < source.size(); ++i) { + format::KeyValue kv_pair; + kv_pair.__set_key(source.key(i)); + kv_pair.__set_value(source.value(i)); + key_value_metadata.emplace_back(std::move(kv_pair)); + } + metadata->__set_key_value_metadata(std::move(key_value_metadata)); +} + // MetaData Accessor // ColumnCryptoMetaData @@ -233,6 +266,7 @@ class ColumnChunkMetaData::ColumnChunkMetaDataImpl { encoding_stats.count}); } possible_stats_ = nullptr; + InitKeyValueMetadata(); } bool Equals(const ColumnChunkMetaDataImpl& other) const { @@ -343,7 +377,15 @@ class ColumnChunkMetaData::ColumnChunkMetaDataImpl { return std::nullopt; } + const std::shared_ptr& key_value_metadata() const { + return key_value_metadata_; + } + private: + void InitKeyValueMetadata() { + key_value_metadata_ = FromThriftKeyValueMetadata(*column_metadata_); + } + mutable std::shared_ptr possible_stats_; std::vector encodings_; std::vector encoding_stats_; @@ -353,6 +395,7 @@ class ColumnChunkMetaData::ColumnChunkMetaDataImpl { const ColumnDescriptor* descr_; const ReaderProperties properties_; const ApplicationVersion* writer_version_; + std::shared_ptr key_value_metadata_; }; std::unique_ptr ColumnChunkMetaData::Make( @@ -471,6 +514,11 @@ bool ColumnChunkMetaData::Equals(const ColumnChunkMetaData& other) const { return impl_->Equals(*other.impl_); } +const std::shared_ptr& ColumnChunkMetaData::key_value_metadata() + const { + return impl_->key_value_metadata(); +} + // row-group metadata class RowGroupMetaData::RowGroupMetaDataImpl { public: @@ -913,7 +961,7 @@ class FileMetaData::FileMetaDataImpl { std::vector column_orders; if (metadata_->__isset.column_orders) { column_orders.reserve(metadata_->column_orders.size()); - for (auto column_order : metadata_->column_orders) { + for (auto& column_order : metadata_->column_orders) { if (column_order.__isset.TYPE_ORDER) { column_orders.push_back(ColumnOrder::type_defined_); } else { @@ -928,14 +976,7 @@ class FileMetaData::FileMetaDataImpl { } void InitKeyValueMetadata() { - std::shared_ptr metadata = nullptr; - if (metadata_->__isset.key_value_metadata) { - metadata = std::make_shared(); - for (const auto& it : metadata_->key_value_metadata) { - metadata->Append(it.key, it.value); - } - } - key_value_metadata_ = std::move(metadata); + key_value_metadata_ = FromThriftKeyValueMetadata(*metadata_); } }; @@ -1590,6 +1631,10 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { column_chunk_->meta_data.__set_encodings(std::move(thrift_encodings)); column_chunk_->meta_data.__set_encoding_stats(std::move(thrift_encoding_stats)); + if (key_value_metadata_) { + ToThriftKeyValueMetadata(*key_value_metadata_, &column_chunk_->meta_data); + } + const auto& encrypt_md = properties_->column_encryption_properties(column_->path()->ToDotString()); // column is encrypted @@ -1656,6 +1701,10 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { return column_chunk_->meta_data.total_compressed_size; } + void SetKeyValueMetadata(std::shared_ptr key_value_metadata) { + key_value_metadata_ = std::move(key_value_metadata); + } + private: void Init(format::ColumnChunk* column_chunk) { column_chunk_ = column_chunk; @@ -1670,6 +1719,7 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { std::unique_ptr owned_column_chunk_; const std::shared_ptr properties_; const ColumnDescriptor* column_; + std::shared_ptr key_value_metadata_; }; std::unique_ptr ColumnChunkMetaDataBuilder::Make( @@ -1727,6 +1777,11 @@ void ColumnChunkMetaDataBuilder::SetStatistics(const EncodedStatistics& result) impl_->SetStatistics(result); } +void ColumnChunkMetaDataBuilder::SetKeyValueMetadata( + std::shared_ptr key_value_metadata) { + impl_->SetKeyValueMetadata(std::move(key_value_metadata)); +} + int64_t ColumnChunkMetaDataBuilder::total_compressed_size() const { return impl_->total_compressed_size(); } @@ -1925,16 +1980,7 @@ class FileMetaDataBuilder::FileMetaDataBuilderImpl { } else if (key_value_metadata) { key_value_metadata_ = key_value_metadata_->Merge(*key_value_metadata); } - metadata_->key_value_metadata.clear(); - metadata_->key_value_metadata.reserve( - static_cast(key_value_metadata_->size())); - for (int64_t i = 0; i < key_value_metadata_->size(); ++i) { - format::KeyValue kv_pair; - kv_pair.__set_key(key_value_metadata_->key(i)); - kv_pair.__set_value(key_value_metadata_->value(i)); - metadata_->key_value_metadata.push_back(std::move(kv_pair)); - } - metadata_->__isset.key_value_metadata = true; + ToThriftKeyValueMetadata(*key_value_metadata_, metadata_.get()); } int32_t file_version = 0; diff --git a/cpp/src/parquet/metadata.h b/cpp/src/parquet/metadata.h index e46297540ba..d1e2d1904a6 100644 --- a/cpp/src/parquet/metadata.h +++ b/cpp/src/parquet/metadata.h @@ -184,6 +184,7 @@ class PARQUET_EXPORT ColumnChunkMetaData { std::unique_ptr crypto_metadata() const; std::optional GetColumnIndexLocation() const; std::optional GetOffsetIndexLocation() const; + const std::shared_ptr& key_value_metadata() const; private: explicit ColumnChunkMetaData( @@ -466,8 +467,12 @@ class PARQUET_EXPORT ColumnChunkMetaDataBuilder { // column chunk // Used when a dataset is spread across multiple files void set_file_path(const std::string& path); + // column metadata void SetStatistics(const EncodedStatistics& stats); + + void SetKeyValueMetadata(std::shared_ptr key_value_metadata); + // get the column descriptor const ColumnDescriptor* descr() const; diff --git a/cpp/src/parquet/printer.cc b/cpp/src/parquet/printer.cc index 33df5925a1c..60adfc697f9 100644 --- a/cpp/src/parquet/printer.cc +++ b/cpp/src/parquet/printer.cc @@ -64,6 +64,25 @@ void PrintPageEncodingStats(std::ostream& stream, // the fixed initial size is just for an example #define COL_WIDTH 30 +void PutChars(std::ostream& stream, char c, int n) { + for (int i = 0; i < n; ++i) { + stream.put(c); + } +} + +void PrintKeyValueMetadata(std::ostream& stream, + const KeyValueMetadata& key_value_metadata, + int indent_level = 0, int indent_width = 1) { + const int64_t size_of_key_value_metadata = key_value_metadata.size(); + PutChars(stream, ' ', indent_level * indent_width); + stream << "Key Value Metadata: " << size_of_key_value_metadata << " entries\n"; + for (int64_t i = 0; i < size_of_key_value_metadata; i++) { + PutChars(stream, ' ', (indent_level + 1) * indent_width); + stream << "Key nr " << i << " " << key_value_metadata.key(i) << ": " + << key_value_metadata.value(i) << "\n"; + } +} + void ParquetFilePrinter::DebugPrint(std::ostream& stream, std::list selected_columns, bool print_values, bool format_dump, bool print_key_value_metadata, const char* filename) { @@ -76,12 +95,7 @@ void ParquetFilePrinter::DebugPrint(std::ostream& stream, std::list selecte if (print_key_value_metadata && file_metadata->key_value_metadata()) { auto key_value_metadata = file_metadata->key_value_metadata(); - int64_t size_of_key_value_metadata = key_value_metadata->size(); - stream << "Key Value File Metadata: " << size_of_key_value_metadata << " entries\n"; - for (int64_t i = 0; i < size_of_key_value_metadata; i++) { - stream << " Key nr " << i << " " << key_value_metadata->key(i) << ": " - << key_value_metadata->value(i) << "\n"; - } + PrintKeyValueMetadata(stream, *key_value_metadata); } stream << "Number of RowGroups: " << file_metadata->num_row_groups() << "\n"; @@ -136,7 +150,11 @@ void ParquetFilePrinter::DebugPrint(std::ostream& stream, std::list selecte std::shared_ptr stats = column_chunk->statistics(); const ColumnDescriptor* descr = file_metadata->schema()->Column(i); - stream << "Column " << i << std::endl << " Values: " << column_chunk->num_values(); + stream << "Column " << i << std::endl; + if (print_key_value_metadata && column_chunk->key_value_metadata()) { + PrintKeyValueMetadata(stream, *column_chunk->key_value_metadata(), 1, 2); + } + stream << " Values: " << column_chunk->num_values(); if (column_chunk->is_stats_set()) { std::string min = stats->EncodeMin(), max = stats->EncodeMax(); stream << ", Null Values: " << stats->null_count() diff --git a/python/pyarrow/_parquet.pxd b/python/pyarrow/_parquet.pxd index 35d15227ee5..d6aebd8284f 100644 --- a/python/pyarrow/_parquet.pxd +++ b/python/pyarrow/_parquet.pxd @@ -328,6 +328,7 @@ cdef extern from "parquet/api/reader.h" namespace "parquet" nogil: unique_ptr[CColumnCryptoMetaData] crypto_metadata() const optional[ParquetIndexLocation] GetColumnIndexLocation() const optional[ParquetIndexLocation] GetOffsetIndexLocation() const + shared_ptr[const CKeyValueMetadata] key_value_metadata() const struct CSortingColumn" parquet::SortingColumn": int column_idx diff --git a/python/pyarrow/_parquet.pyx b/python/pyarrow/_parquet.pyx index 41b15b633d3..254bfe3b09a 100644 --- a/python/pyarrow/_parquet.pyx +++ b/python/pyarrow/_parquet.pyx @@ -508,6 +508,19 @@ cdef class ColumnChunkMetaData(_Weakrefable): """Whether the column chunk has a column index""" return self.metadata.GetColumnIndexLocation().has_value() + @property + def metadata(self): + """Additional metadata as key value pairs (dict[bytes, bytes]).""" + cdef: + unordered_map[c_string, c_string] metadata + const CKeyValueMetadata* underlying_metadata + underlying_metadata = self.metadata.key_value_metadata().get() + if underlying_metadata != NULL: + underlying_metadata.ToUnorderedMap(&metadata) + return metadata + else: + return None + cdef class SortingColumn: """ diff --git a/python/pyarrow/tests/parquet/conftest.py b/python/pyarrow/tests/parquet/conftest.py index 767e7f6b69d..80605e973cd 100644 --- a/python/pyarrow/tests/parquet/conftest.py +++ b/python/pyarrow/tests/parquet/conftest.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. +import os +import pathlib + import pytest from pyarrow.util import guid @@ -25,6 +28,15 @@ def datadir(base_datadir): return base_datadir / 'parquet' +@pytest.fixture(scope='module') +def parquet_test_datadir(): + result = os.environ.get('PARQUET_TEST_DATA') + if not result: + raise RuntimeError('Please point the PARQUET_TEST_DATA environment ' + 'variable to the test data directory') + return pathlib.Path(result) + + @pytest.fixture def s3_bucket(s3_server): boto3 = pytest.importorskip('boto3') diff --git a/python/pyarrow/tests/parquet/test_metadata.py b/python/pyarrow/tests/parquet/test_metadata.py index 528cf0110dd..c29213ebc3d 100644 --- a/python/pyarrow/tests/parquet/test_metadata.py +++ b/python/pyarrow/tests/parquet/test_metadata.py @@ -782,3 +782,12 @@ def test_write_metadata_fs_file_combinations(tempdir, s3_example_s3fs): assert meta1.read_bytes() == meta2.read_bytes() \ == meta3.read_bytes() == meta4.read_bytes() \ == s3_fs.open(meta5).read() + + +def test_column_chunk_key_value_metadata(parquet_test_datadir): + metadata = pq.read_metadata(parquet_test_datadir / + 'column_chunk_key_value_metadata.parquet') + key_value_metadata1 = metadata.row_group(0).column(0).metadata + assert key_value_metadata1 == {b'foo': b'bar', b'thisiskeywithoutvalue': b''} + key_value_metadata2 = metadata.row_group(0).column(1).metadata + assert key_value_metadata2 is None From a50ad422cff112efb022d081e34344249ac83530 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Fri, 16 Aug 2024 02:06:08 +0200 Subject: [PATCH 021/157] MINOR: [CI] Fix ubuntu-lint to not install into system Python (#43710) ### Rationale for this change Currently, the `ubuntu-lint` Docker build would install its Python dependencies directly into the system Python, which can fail depending on existing system Python packages. See example here: https://github.com/apache/arrow/actions/runs/10400929007/job/28802420047?pr=43539 where pip's dependency resolution fails with the following error message: ``` packaging.version.InvalidVersion: Invalid version: '2013-02-16' ``` ### What changes are included in this PR? This PR switches to use a virtual environment, guaranteeing that we're not interfering with the system Python and that we're not bound by already installed Python packages. ### Are these changes tested? By CI. ### Are there any user-facing changes? No. Authored-by: Antoine Pitrou Signed-off-by: Sutou Kouhei --- docker-compose.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker-compose.yml b/docker-compose.yml index daa5c74bcb9..14eeeeee6e5 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1889,6 +1889,9 @@ services: command: > /bin/bash -c " git config --global --add safe.directory /arrow && + python3 -m venv /build/pyvenv && + source /build/pyvenv/bin/activate && + pip install -U pip setuptools && pip install arrow/dev/archery[lint] && archery lint --all --no-clang-tidy --no-iwyu --no-numpydoc --src /arrow" From a970fd72b3debbaf4ef797025e06efa45ba588f8 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Fri, 16 Aug 2024 04:32:25 +0200 Subject: [PATCH 022/157] GH-43688: [C++] Prevent Snappy from disabling RTTI when bundled (#43706) ### Rationale for this change Snappy's CMakeLists.txt unconditionally disables RTTI. This is incompatible with some other options, such as activating UBSAN for a fuzzing build: https://github.com/google/snappy/issues/189 ### What changes are included in this PR? Add `-frtti` at the end of compiler options when compiling a bundled Snappy build. ### Are these changes tested? On CI; also manually checked that this allows enabling Snappy on OSS-Fuzz builds. ### Are there any user-facing changes? No. * GitHub Issue: #43688 Lead-authored-by: Antoine Pitrou Co-authored-by: Antoine Pitrou Co-authored-by: Sutou Kouhei Co-authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 22 ++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 495aa704836..bc3a3a2249d 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1355,16 +1355,24 @@ macro(build_snappy) "-DCMAKE_INSTALL_PREFIX=${SNAPPY_PREFIX}") # Snappy unconditionally enables -Werror when building with clang this can lead # to build failures by way of new compiler warnings. This adds a flag to disable - # Werror to the very end of the invocation to override the snappy internal setting. + # -Werror to the very end of the invocation to override the snappy internal setting. + set(SNAPPY_ADDITIONAL_CXX_FLAGS "") if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") - foreach(CONFIG DEBUG MINSIZEREL RELEASE RELWITHDEBINFO) - list(APPEND - SNAPPY_CMAKE_ARGS - "-DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${EP_CXX_FLAGS_${CONFIG}} -Wno-error" - ) - endforeach() + string(APPEND SNAPPY_ADDITIONAL_CXX_FLAGS " -Wno-error") + endif() + # Snappy unconditionally disables RTTI, which is incompatible with some other + # build settings (https://github.com/apache/arrow/issues/43688). + if(NOT MSVC) + string(APPEND SNAPPY_ADDITIONAL_CXX_FLAGS " -frtti") endif() + foreach(CONFIG DEBUG MINSIZEREL RELEASE RELWITHDEBINFO) + list(APPEND + SNAPPY_CMAKE_ARGS + "-DCMAKE_CXX_FLAGS_${CONFIG}=${EP_CXX_FLAGS_${CONFIG}} ${SNAPPY_ADDITIONAL_CXX_FLAGS}" + ) + endforeach() + if(APPLE AND CMAKE_HOST_SYSTEM_VERSION VERSION_LESS 20) # On macOS 10.13 we need to explicitly add to avoid a missing include error # This can be removed once CRAN no longer checks on macOS 10.13 From e9767c1a268f543536077cf80f49b097739f308c Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Fri, 16 Aug 2024 11:43:32 +0900 Subject: [PATCH 023/157] GH-41396: [Ruby] Add workaround for re2.pc on Ubuntu 20.04 (#43721) ### Rationale for this change Old re2.pc add "-std=c++11" but it causes a build error. Because Apache Arrow C++ requires C++17. ### What changes are included in this PR? Remove "-std=c++11" as workaround. We can remove this workaround when we drop support for Ubuntu 20.04. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * GitHub Issue: #41396 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ruby/red-arrow/ext/arrow/extconf.rb | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ruby/red-arrow/ext/arrow/extconf.rb b/ruby/red-arrow/ext/arrow/extconf.rb index 7ef3c6c8343..28ccd0b2d59 100644 --- a/ruby/red-arrow/ext/arrow/extconf.rb +++ b/ruby/red-arrow/ext/arrow/extconf.rb @@ -66,6 +66,13 @@ exit(false) end +# Old re2.pc (e.g. re2.pc on Ubuntu 20.04) may add -std=c++11. It +# causes a build error because Apache Arrow C++ requires C++17 or +# later. +# +# We can remove this when we drop support for Ubuntu 20.04. +$CXXFLAGS.gsub!("-std=c++11", "") + [ ["glib2", "ext/glib2"], ].each do |name, relative_source_dir| From b80a51a65c8031bbd2d1d2e5645c541bd7076b5b Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Fri, 16 Aug 2024 14:23:05 +0900 Subject: [PATCH 024/157] GH-43594: [C++] Remove std::optional from arrow::ArrayStatistics::is_{min,max}_exact (#43595) ### Rationale for this change We don't need "unknown" state. If they aren't set, we can process they are not exact. ### What changes are included in this PR? Remove `std::optional` from `arrow::ArrayStatistics::is_{min,max}_exact`. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * GitHub Issue: #43594 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- cpp/src/arrow/array/statistics.h | 8 ++++---- cpp/src/arrow/array/statistics_test.cc | 14 ++++++-------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/array/statistics.h b/cpp/src/arrow/array/statistics.h index 816d68e7776..523f877bbe4 100644 --- a/cpp/src/arrow/array/statistics.h +++ b/cpp/src/arrow/array/statistics.h @@ -43,14 +43,14 @@ struct ARROW_EXPORT ArrayStatistics { /// \brief The minimum value, may not be set std::optional min = std::nullopt; - /// \brief Whether the minimum value is exact or not, may not be set - std::optional is_min_exact = std::nullopt; + /// \brief Whether the minimum value is exact or not + bool is_min_exact = false; /// \brief The maximum value, may not be set std::optional max = std::nullopt; - /// \brief Whether the maximum value is exact or not, may not be set - std::optional is_max_exact = std::nullopt; + /// \brief Whether the maximum value is exact or not + bool is_max_exact = false; /// \brief Check two statistics for equality bool Equals(const ArrayStatistics& other) const { diff --git a/cpp/src/arrow/array/statistics_test.cc b/cpp/src/arrow/array/statistics_test.cc index f4f4f500151..cf15a5d3829 100644 --- a/cpp/src/arrow/array/statistics_test.cc +++ b/cpp/src/arrow/array/statistics_test.cc @@ -40,27 +40,25 @@ TEST(ArrayStatisticsTest, TestDistinctCount) { TEST(ArrayStatisticsTest, TestMin) { ArrayStatistics statistics; ASSERT_FALSE(statistics.min.has_value()); - ASSERT_FALSE(statistics.is_min_exact.has_value()); + ASSERT_FALSE(statistics.is_min_exact); statistics.min = static_cast(29); statistics.is_min_exact = true; ASSERT_TRUE(statistics.min.has_value()); ASSERT_TRUE(std::holds_alternative(statistics.min.value())); ASSERT_EQ(29, std::get(statistics.min.value())); - ASSERT_TRUE(statistics.is_min_exact.has_value()); - ASSERT_TRUE(statistics.is_min_exact.value()); + ASSERT_TRUE(statistics.is_min_exact); } TEST(ArrayStatisticsTest, TestMax) { ArrayStatistics statistics; ASSERT_FALSE(statistics.max.has_value()); - ASSERT_FALSE(statistics.is_max_exact.has_value()); + ASSERT_FALSE(statistics.is_max_exact); statistics.max = std::string("hello"); statistics.is_max_exact = false; ASSERT_TRUE(statistics.max.has_value()); ASSERT_TRUE(std::holds_alternative(statistics.max.value())); ASSERT_EQ("hello", std::get(statistics.max.value())); - ASSERT_TRUE(statistics.is_max_exact.has_value()); - ASSERT_FALSE(statistics.is_max_exact.value()); + ASSERT_FALSE(statistics.is_max_exact); } TEST(ArrayStatisticsTest, TestEquality) { @@ -84,9 +82,9 @@ TEST(ArrayStatisticsTest, TestEquality) { statistics2.min = std::string("world"); ASSERT_EQ(statistics1, statistics2); - statistics1.is_min_exact = false; + statistics1.is_min_exact = true; ASSERT_NE(statistics1, statistics2); - statistics2.is_min_exact = false; + statistics2.is_min_exact = true; ASSERT_EQ(statistics1, statistics2); statistics1.max = static_cast(-29); From bee2fc8021f3b5dabff0315fe20290f316a44ce4 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 16 Aug 2024 09:59:30 +0200 Subject: [PATCH 025/157] MINOR: [Docs][Python] Add LargeListType to Data Types docs (#43597) ### Rationale for this change The `LargeListType` is missing in the Data Types docs: https://arrow.apache.org/docs/python/api/datatypes.html#type-classes ### What changes are included in this PR? This PR adds the `LargeListType` to the Data Types docs. ### Are these changes tested? The change only affects the docs. I have generated the docs locally and they appear as expected. See comment below with screenshot: https://github.com/apache/arrow/pull/43597#issuecomment-2273139016 ### Are there any user-facing changes? The change is indeed an update in the docs. Authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Signed-off-by: Sutou Kouhei --- docs/source/python/api/datatypes.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/python/api/datatypes.rst b/docs/source/python/api/datatypes.rst index a43c5299eae..86c29296873 100644 --- a/docs/source/python/api/datatypes.rst +++ b/docs/source/python/api/datatypes.rst @@ -96,6 +96,7 @@ functions above. DataType DictionaryType ListType + LargeListType MapType StructType UnionType From d801daeddead7ceaca83424874ea006245430bc3 Mon Sep 17 00:00:00 2001 From: Xin Hao Date: Fri, 16 Aug 2024 18:16:49 +0800 Subject: [PATCH 026/157] MINOR: [Go][Doc] fix code format in the readme (#43725) ### Rationale for this change ### What changes are included in this PR? ### Are these changes tested? ### Are there any user-facing changes? Authored-by: Xin Hao Signed-off-by: mwish --- go/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/README.md b/go/README.md index 51ac06c87f1..ec824229729 100644 --- a/go/README.md +++ b/go/README.md @@ -40,7 +40,7 @@ import ( ) func main() { - dsn := "uri=grpc://localhost:12345;username=mickeymouse;password=p@55w0RD" + dsn := "uri=grpc://localhost:12345;username=mickeymouse;password=p@55w0RD" db, err := sql.Open("flightsql", dsn) ... } From 801301ee22ce802fd000f9f4b919abb47ae1d6c3 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Fri, 16 Aug 2024 14:40:56 -0700 Subject: [PATCH 027/157] GH-43633: [R] Add tests for packages that might be tricky to roundtrip data to Tables + Parquet files (#43634) ### Rationale for this change Add coverage for objects that might have issues roundtripping to Arrow Tables or Parquet files ### What changes are included in this PR? A new test file + a crossbow job that ensures these other packages are installed so the tests run. ### Are these changes tested? The changes are tests ### Are there any user-facing changes? No * GitHub Issue: #43633 Authored-by: Jonathan Keane Signed-off-by: Jonathan Keane --- dev/tasks/r/github.linux.extra.packages.yml | 53 +++++++++ dev/tasks/tasks.yml | 4 + .../testthat/test-extra-package-roundtrip.R | 105 ++++++++++++++++++ 3 files changed, 162 insertions(+) create mode 100644 dev/tasks/r/github.linux.extra.packages.yml create mode 100644 r/tests/testthat/test-extra-package-roundtrip.R diff --git a/dev/tasks/r/github.linux.extra.packages.yml b/dev/tasks/r/github.linux.extra.packages.yml new file mode 100644 index 00000000000..bb486c72a06 --- /dev/null +++ b/dev/tasks/r/github.linux.extra.packages.yml @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +{% import 'macros.jinja' as macros with context %} + +{{ macros.github_header() }} + +jobs: + extra-packages: + name: "extra package roundtrip tests" + runs-on: ubuntu-latest + strategy: + fail-fast: false + env: + ARROW_R_DEV: "FALSE" + ARROW_R_FORCE_EXTRA_PACKAGE_TESTS: TRUE + steps: + {{ macros.github_checkout_arrow()|indent }} + + - uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + - uses: r-lib/actions/setup-pandoc@v2 + - uses: r-lib/actions/setup-r-dependencies@v2 + with: + working-directory: 'arrow/r' + extra-packages: | + any::data.table + any::rcmdcheck + any::readr + any::units + - name: Build arrow package + run: | + R CMD build --no-build-vignettes arrow/r + R CMD INSTALL --install-tests --no-test-load --no-byte-compile arrow_*.tar.gz + - name: run tests + run: | + testthat::test_package("arrow", filter = "extra-package-roundtrip") + shell: Rscript {0} diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 6e1f7609a98..a9da7eb2889 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1309,6 +1309,10 @@ tasks: ci: github template: r/github.linux.rchk.yml + test-r-extra-packages: + ci: github + template: r/github.linux.extra.packages.yml + test-r-linux-as-cran: ci: github template: r/github.linux.cran.yml diff --git a/r/tests/testthat/test-extra-package-roundtrip.R b/r/tests/testthat/test-extra-package-roundtrip.R new file mode 100644 index 00000000000..09a87ef19d5 --- /dev/null +++ b/r/tests/testthat/test-extra-package-roundtrip.R @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +skip_on_cran() + +# Any additional package that we test here that is not already in DESCRIPTION should be +# added to dev/tasks/r/github.linux.extra.packages.yml in the r-lib/actions/setup-r-dependencies@v2 +# step so that they are installed + available in that CI job. + +# So that we can force these in CI +load_or_skip <- function(pkg) { + if (identical(tolower(Sys.getenv("ARROW_R_FORCE_EXTRA_PACKAGE_TESTS")), "true")) { + # because of this indirection on the package name we also avoid a CHECK note and + # we don't otherwise need to Suggest this + requireNamespace(pkg, quietly = TRUE) + } else { + skip_if(!requireNamespace(pkg, quietly = TRUE)) + } + attachNamespace(pkg) +} + +library(dplyr) + +test_that("readr read csvs roundtrip", { + load_or_skip("readr") + + tbl <- example_data[, c("dbl", "lgl", "false", "chr")] + + tf <- tempfile() + on.exit(unlink(tf)) + write.csv(tbl, tf, row.names = FALSE) + + # we should still be able to turn this into a table + new_df <- read_csv(tf, show_col_types = FALSE) + expect_equal(new_df, as_tibble(arrow_table(new_df))) + + # we should still be able to turn this into a table + new_df <- read_csv(tf, show_col_types = FALSE, lazy = TRUE) + expect_equal(new_df, as_tibble(arrow_table(new_df))) + + # and can roundtrip to a parquet file + pq_tmp_file <- tempfile() + write_parquet(new_df, pq_tmp_file) + new_df_read <- read_parquet(pq_tmp_file) + + # we should still be able to turn this into a table + expect_equal(new_df, new_df_read) +}) + +test_that("data.table objects roundtrip", { + load_or_skip("data.table") + + # https://github.com/Rdatatable/data.table/blob/83fd2c05ce2d8555ceb8ba417833956b1b574f7e/R/cedta.R#L25-L27 + .datatable.aware=TRUE + + DT <- as.data.table(example_data) + + # Table -> collect which is what writing + reading to parquet uses under the hood to roundtrip + tab <- as_arrow_table(DT) + DT_read <- collect(tab) + + # we should still be able to turn this into a table + # the .internal.selfref attribute is automatically ignored by testthat/waldo + expect_equal(DT, DT_read) + + # and we can set keys + indices + create new columns + setkey(DT, chr) + setindex(DT, dbl) + DT[, dblshift := data.table::shift(dbl, 1)] + + # Table -> collect + tab <- as_arrow_table(DT) + DT_read <- collect(tab) + + # we should still be able to turn this into a table + expect_equal(DT, DT_read) +}) + +test_that("units roundtrip", { + load_or_skip("units") + + tbl <- example_data + units(tbl$dbl) <- "s" + + # Table -> collect which is what writing + reading to parquet uses under the hood to roundtrip + tab <- as_arrow_table(tbl) + tbl_read <- collect(tab) + + # we should still be able to turn this into a table + expect_equal(tbl, tbl_read) +}) From 8836535785ba3dd4ba335818a34e0479929b70e6 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 17 Aug 2024 11:20:16 +0900 Subject: [PATCH 028/157] GH-43702: [C++][FS][Azure] Use the latest Azurite and update the bundled Azure SDK for C++ to azure-identity_1.9.0 (#43723) ### Rationale for this change Some our CI jobs (such as conda based jobs) use recent Azure SDK for C++ and they require latest Azurite. We need to update Azurite for these jobs. I wanted to use the latest Azurite on all environments but I didn't. Because I want to keep using `apt install nodejs` on old Ubuntu for easy to maintain. ### What changes are included in this PR? * Use the latest Azurite if possible * Use `--skipApiVersionCheck` for old Azurite * Update the bundled Azure SDK for C++ * This is not required. It's for detecting this problem in many CI jobs. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * GitHub Issue: fix #41505 * GitHub Issue: #43702 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/scripts/install_azurite.sh | 24 ++++++++++++++++++------ cpp/src/arrow/filesystem/azurefs_test.cc | 5 ++++- cpp/thirdparty/versions.txt | 4 ++-- python/pyarrow/tests/conftest.py | 3 +++ 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/ci/scripts/install_azurite.sh b/ci/scripts/install_azurite.sh index dda5e99405b..b8b1618bed3 100755 --- a/ci/scripts/install_azurite.sh +++ b/ci/scripts/install_azurite.sh @@ -19,20 +19,32 @@ set -e -# Pin azurite to 3.29.0 due to https://github.com/apache/arrow/issues/41505 +node_version="$(node --version)" +echo "node version = ${node_version}" + +case "${node_version}" in + v12*) + # Pin azurite to 3.29.0 due to https://github.com/apache/arrow/issues/41505 + azurite_version=v3.29.0 + ;; + *) + azurite_version=latest + ;; +esac + case "$(uname)" in Darwin) - npm install -g azurite@v3.29.0 + npm install -g azurite@${azurite_version} which azurite ;; MINGW*) choco install nodejs.install - npm install -g azurite@v3.29.0 + npm install -g azurite@${azurite_version} ;; Linux) - npm install -g azurite@v3.29.0 + npm install -g azurite@${azurite_version} which azurite ;; esac -echo "node version = $(node --version)" -echo "azurite version = $(azurite --version)" \ No newline at end of file + +echo "azurite version = $(azurite --version)" diff --git a/cpp/src/arrow/filesystem/azurefs_test.cc b/cpp/src/arrow/filesystem/azurefs_test.cc index 36646f417cb..5ff241b17ff 100644 --- a/cpp/src/arrow/filesystem/azurefs_test.cc +++ b/cpp/src/arrow/filesystem/azurefs_test.cc @@ -198,7 +198,10 @@ class AzuriteEnv : public AzureEnvImpl { self->temp_dir_->path().Join("debug.log")); auto server_process = bp::child( boost::this_process::environment(), exe_path, "--silent", "--location", - self->temp_dir_->path().ToString(), "--debug", self->debug_log_path_.ToString()); + self->temp_dir_->path().ToString(), "--debug", self->debug_log_path_.ToString(), + // For old Azurite. We can't install the latest Azurite with + // old Node.js on old Ubuntu. + "--skipApiVersionCheck"); if (!server_process.valid() || !server_process.running()) { server_process.terminate(); server_process.wait(); diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index 16689c17fba..30fa24a2094 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -54,8 +54,8 @@ ARROW_AWS_LC_BUILD_SHA256_CHECKSUM=ae96a3567161552744fc0cae8b4d68ed88b1ec0f3d3c9 ARROW_AWSSDK_BUILD_VERSION=1.10.55 ARROW_AWSSDK_BUILD_SHA256_CHECKSUM=2d552fb1a84bef4a9b65e34aa7031851ed2aef5319e02cc6e4cb735c48aa30de # Despite the confusing version name this is still the whole Azure SDK for C++ including core, keyvault, storage-common, etc. -ARROW_AZURE_SDK_BUILD_VERSION=azure-core_1.10.3 -ARROW_AZURE_SDK_BUILD_SHA256_CHECKSUM=dd624c2f86adf474d2d0a23066be6e27af9cbd7e3f8d9d8fd7bf981e884b7b48 +ARROW_AZURE_SDK_BUILD_VERSION=azure-identity_1.9.0 +ARROW_AZURE_SDK_BUILD_SHA256_CHECKSUM=97065bfc971ac8df450853ce805f820f52b59457bd7556510186a1569502e4a1 ARROW_BOOST_BUILD_VERSION=1.81.0 ARROW_BOOST_BUILD_SHA256_CHECKSUM=9e0ffae35528c35f90468997bc8d99500bf179cbae355415a89a600c38e13574 ARROW_BROTLI_BUILD_VERSION=v1.0.9 diff --git a/python/pyarrow/tests/conftest.py b/python/pyarrow/tests/conftest.py index 343b602995d..e1919497b51 100644 --- a/python/pyarrow/tests/conftest.py +++ b/python/pyarrow/tests/conftest.py @@ -263,6 +263,9 @@ def azure_server(tmpdir_factory): tmpdir = tmpdir_factory.getbasetemp() # We only need blob service emulator, not queue or table. args = ['azurite-blob', "--location", tmpdir, "--blobPort", str(port)] + # For old Azurite. We can't install the latest Azurite with old + # Node.js on old Ubuntu. + args += ["--skipApiVersionCheck"] proc = None try: proc = subprocess.Popen(args, env=env) From 49be60f5c424cca40bbc5a6d1948ad7e800afaab Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 17 Aug 2024 11:50:46 +0900 Subject: [PATCH 029/157] GH-43175: [C++] Skip not Emscripten ready tests in CSV tests (#43724) ### Rationale for this change We can't use thread nor `%z` on Emacripten. Some CSV tests use them. ### What changes are included in this PR? Skip CSV tests that use thread or `%z`. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * GitHub Issue: #43175 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/scripts/cpp_test.sh | 2 +- cpp/src/arrow/csv/column_decoder_test.cc | 11 +++++++++++ cpp/src/arrow/csv/converter_test.cc | 5 +++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ci/scripts/cpp_test.sh b/ci/scripts/cpp_test.sh index 2c640f2c1fb..7912bf23e49 100755 --- a/ci/scripts/cpp_test.sh +++ b/ci/scripts/cpp_test.sh @@ -80,7 +80,7 @@ case "$(uname)" in ;; esac -if [ "${ARROW_EMSCRIPTEN:-OFF}" = "ON" ]; then +if [ "${ARROW_EMSCRIPTEN:-OFF}" = "ON" ]; then n_jobs=1 # avoid spurious fails on emscripten due to loading too many big executables fi diff --git a/cpp/src/arrow/csv/column_decoder_test.cc b/cpp/src/arrow/csv/column_decoder_test.cc index ebac7a3da2f..56773264717 100644 --- a/cpp/src/arrow/csv/column_decoder_test.cc +++ b/cpp/src/arrow/csv/column_decoder_test.cc @@ -175,6 +175,9 @@ class NullColumnDecoderTest : public ColumnDecoderTest { } void TestThreaded() { +#ifndef ARROW_ENABLE_THREADING + GTEST_SKIP() << "Test requires threading support"; +#endif constexpr int NITERS = 10; auto type = int32(); MakeDecoder(type); @@ -257,6 +260,10 @@ class TypedColumnDecoderTest : public ColumnDecoderTest { } void TestThreaded() { +#ifndef ARROW_ENABLE_THREADING + GTEST_SKIP() << "Test requires threading support"; +#endif + constexpr int NITERS = 10; auto type = uint32(); MakeDecoder(type, default_options); @@ -305,6 +312,10 @@ class InferringColumnDecoderTest : public ColumnDecoderTest { } void TestThreaded() { +#ifndef ARROW_ENABLE_THREADING + GTEST_SKIP() << "Test requires threading support"; +#endif + constexpr int NITERS = 10; auto type = float64(); MakeDecoder(default_options); diff --git a/cpp/src/arrow/csv/converter_test.cc b/cpp/src/arrow/csv/converter_test.cc index ea4e171d57e..657e8d813ca 100644 --- a/cpp/src/arrow/csv/converter_test.cc +++ b/cpp/src/arrow/csv/converter_test.cc @@ -625,6 +625,11 @@ TEST(TimestampConversion, UserDefinedParsers) { } TEST(TimestampConversion, UserDefinedParsersWithZone) { +#ifdef __EMSCRIPTEN__ + GTEST_SKIP() << "Test temporarily disabled due to emscripten bug " + "https://github.com/emscripten-core/emscripten/issues/20467"; +#endif + auto options = ConvertOptions::Defaults(); auto type = timestamp(TimeUnit::SECOND, "America/Phoenix"); From fbac12c353cb6ead58a5ee765b37bd1bc46cd672 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Sat, 17 Aug 2024 17:16:39 -0500 Subject: [PATCH 030/157] MINOR: [R] Fix a package namespace warning (#43737) Oops, I should have caught this in #43633 Removes `data.table::` since the namespace is loaded. Also fix some linting errors and free up space on the force tests run. Authored-by: Jonathan Keane Signed-off-by: Jonathan Keane --- .github/workflows/r.yml | 3 +++ r/tests/testthat/test-extra-package-roundtrip.R | 16 ++++++++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index c4899ddcc49..bf7eb99e7e9 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -133,6 +133,9 @@ jobs: with: fetch-depth: 0 submodules: recursive + - name: Free up disk space + run: | + ci/scripts/util_free_space.sh - name: Cache Docker Volumes uses: actions/cache@13aacd865c20de90d75de3b17ebe84f7a17d57d2 # v4.0.0 with: diff --git a/r/tests/testthat/test-extra-package-roundtrip.R b/r/tests/testthat/test-extra-package-roundtrip.R index 09a87ef19d5..092288dffb9 100644 --- a/r/tests/testthat/test-extra-package-roundtrip.R +++ b/r/tests/testthat/test-extra-package-roundtrip.R @@ -24,7 +24,7 @@ skip_on_cran() # So that we can force these in CI load_or_skip <- function(pkg) { if (identical(tolower(Sys.getenv("ARROW_R_FORCE_EXTRA_PACKAGE_TESTS")), "true")) { - # because of this indirection on the package name we also avoid a CHECK note and + # because of this indirection on the package name we also avoid a CHECK note and # we don't otherwise need to Suggest this requireNamespace(pkg, quietly = TRUE) } else { @@ -46,11 +46,11 @@ test_that("readr read csvs roundtrip", { # we should still be able to turn this into a table new_df <- read_csv(tf, show_col_types = FALSE) - expect_equal(new_df, as_tibble(arrow_table(new_df))) + expect_equal(new_df, as_tibble(arrow_table(new_df))) # we should still be able to turn this into a table new_df <- read_csv(tf, show_col_types = FALSE, lazy = TRUE) - expect_equal(new_df, as_tibble(arrow_table(new_df))) + expect_equal(new_df, as_tibble(arrow_table(new_df))) # and can roundtrip to a parquet file pq_tmp_file <- tempfile() @@ -65,11 +65,11 @@ test_that("data.table objects roundtrip", { load_or_skip("data.table") # https://github.com/Rdatatable/data.table/blob/83fd2c05ce2d8555ceb8ba417833956b1b574f7e/R/cedta.R#L25-L27 - .datatable.aware=TRUE + .datatable.aware <- TRUE DT <- as.data.table(example_data) - # Table -> collect which is what writing + reading to parquet uses under the hood to roundtrip + # Table to collect which is what writing + reading to parquet uses under the hood to roundtrip tab <- as_arrow_table(DT) DT_read <- collect(tab) @@ -80,9 +80,9 @@ test_that("data.table objects roundtrip", { # and we can set keys + indices + create new columns setkey(DT, chr) setindex(DT, dbl) - DT[, dblshift := data.table::shift(dbl, 1)] + DT[, dblshift := shift(dbl, 1)] - # Table -> collect + # Table to collect tab <- as_arrow_table(DT) DT_read <- collect(tab) @@ -96,7 +96,7 @@ test_that("units roundtrip", { tbl <- example_data units(tbl$dbl) <- "s" - # Table -> collect which is what writing + reading to parquet uses under the hood to roundtrip + # Table to collect which is what writing + reading to parquet uses under the hood to roundtrip tab <- as_arrow_table(tbl) tbl_read <- collect(tab) From b7e618f088540a45e2ddab39696ce3d543821763 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 18 Aug 2024 10:42:53 +0900 Subject: [PATCH 031/157] GH-43738: [GLib] Add `GArrowAzureFileSytem` (#43739) ### Rationale for this change The bindings for `arrow::fs::AzureFileSytem` is missing. ### What changes are included in this PR? Add the bindings for `arrow::fs::AzureFileSytem`. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. * GitHub Issue: #43738 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- c_glib/arrow-glib/file-system.cpp | 16 ++++++++++++++++ c_glib/arrow-glib/file-system.h | 12 ++++++++++++ 2 files changed, 28 insertions(+) diff --git a/c_glib/arrow-glib/file-system.cpp b/c_glib/arrow-glib/file-system.cpp index b6efa2b8726..9ba494e4059 100644 --- a/c_glib/arrow-glib/file-system.cpp +++ b/c_glib/arrow-glib/file-system.cpp @@ -56,6 +56,8 @@ G_BEGIN_DECLS * #GArrowS3FileSystem is a class for S3-backed file system. * * #GArrowGCSFileSystem is a class for GCS-backed file system. + * + * #GArrowAzureFileSystem is a class for Azure-backed file system. */ /* arrow::fs::FileInfo */ @@ -1561,6 +1563,18 @@ garrow_gcs_file_system_class_init(GArrowGCSFileSystemClass *klass) { } +G_DEFINE_TYPE(GArrowAzureFileSystem, garrow_azure_file_system, GARROW_TYPE_FILE_SYSTEM) + +static void +garrow_azure_file_system_init(GArrowAzureFileSystem *file_system) +{ +} + +static void +garrow_azure_file_system_class_init(GArrowAzureFileSystemClass *klass) +{ +} + G_END_DECLS GArrowFileInfo * @@ -1592,6 +1606,8 @@ garrow_file_system_new_raw(std::shared_ptr *arrow_file_sy file_system_type = GARROW_TYPE_S3_FILE_SYSTEM; } else if (type_name == "gcs") { file_system_type = GARROW_TYPE_GCS_FILE_SYSTEM; + } else if (type_name == "abfs") { + file_system_type = GARROW_TYPE_AZURE_FILE_SYSTEM; } else if (type_name == "mock") { file_system_type = GARROW_TYPE_MOCK_FILE_SYSTEM; } diff --git a/c_glib/arrow-glib/file-system.h b/c_glib/arrow-glib/file-system.h index 2e500672e14..9a903c6af68 100644 --- a/c_glib/arrow-glib/file-system.h +++ b/c_glib/arrow-glib/file-system.h @@ -337,4 +337,16 @@ struct _GArrowGCSFileSystemClass GArrowFileSystemClass parent_class; }; +#define GARROW_TYPE_AZURE_FILE_SYSTEM (garrow_azure_file_system_get_type()) +GARROW_AVAILABLE_IN_18_0 +G_DECLARE_DERIVABLE_TYPE(GArrowAzureFileSystem, + garrow_azure_file_system, + GARROW, + AZURE_FILE_SYSTEM, + GArrowFileSystem) +struct _GArrowAzureFileSystemClass +{ + GArrowFileSystemClass parent_class; +}; + G_END_DECLS From 5ef7e01053c526389acefddd6f961bf1fd9d274b Mon Sep 17 00:00:00 2001 From: Jin Chengcheng Date: Sun, 18 Aug 2024 15:28:52 +0800 Subject: [PATCH 032/157] GH-43506: [Java] Fix TestFragmentScanOptions result not match (#43639) ### Rationale for this change JNI test was not tested in CI. So the test failed but passed the CI. The parseChar function should return char but return bool, a typo error. ### What changes are included in this PR? ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #43506 Authored-by: Chengcheng Jin Signed-off-by: David Li --- java/dataset/src/main/cpp/jni_wrapper.cc | 2 +- .../dataset/TestFragmentScanOptions.java | 80 ++++++++++++------- 2 files changed, 52 insertions(+), 30 deletions(-) diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index 63b8dd73f47..49cc85251c8 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -368,7 +368,7 @@ std::shared_ptr LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec inline bool ParseBool(const std::string& value) { return value == "true" ? true : false; } -inline bool ParseChar(const std::string& key, const std::string& value) { +inline char ParseChar(const std::string& key, const std::string& value) { if (value.size() != 1) { JniThrow("Option " + key + " should be a char, but is " + value); } diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java index d5981905288..ed6344f0f9c 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java @@ -51,6 +51,16 @@ public class TestFragmentScanOptions { + private CsvFragmentScanOptions create( + ArrowSchema cSchema, + Map convertOptionsMap, + Map readOptions, + Map parseOptions) { + CsvConvertOptions convertOptions = new CsvConvertOptions(convertOptionsMap); + convertOptions.setArrowSchema(cSchema); + return new CsvFragmentScanOptions(convertOptions, readOptions, parseOptions); + } + @Test public void testCsvConvertOptions() throws Exception { final Schema schema = @@ -63,24 +73,29 @@ public void testCsvConvertOptions() throws Exception { String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); try (ArrowSchema cSchema = ArrowSchema.allocateNew(allocator); + ArrowSchema cSchema2 = ArrowSchema.allocateNew(allocator); CDataDictionaryProvider provider = new CDataDictionaryProvider()) { Data.exportSchema(allocator, schema, provider, cSchema); - CsvConvertOptions convertOptions = new CsvConvertOptions(ImmutableMap.of("delimiter", ";")); - convertOptions.setArrowSchema(cSchema); - CsvFragmentScanOptions fragmentScanOptions = - new CsvFragmentScanOptions(convertOptions, ImmutableMap.of(), ImmutableMap.of()); + Data.exportSchema(allocator, schema, provider, cSchema2); + CsvFragmentScanOptions fragmentScanOptions1 = + create(cSchema, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of("delimiter", ";")); + CsvFragmentScanOptions fragmentScanOptions2 = + create(cSchema2, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of("delimiter", ";")); ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .fragmentScanOptions(fragmentScanOptions) + .fragmentScanOptions(fragmentScanOptions1) .build(); try (DatasetFactory datasetFactory = new FileSystemDatasetFactory( - allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path); + allocator, + NativeMemoryPool.getDefault(), + FileFormat.CSV, + path, + Optional.of(fragmentScanOptions2)); Dataset dataset = datasetFactory.finish(); Scanner scanner = dataset.newScan(options); ArrowReader reader = scanner.scanBatches()) { - assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowCount = 0; while (reader.loadNextBatch()) { @@ -106,30 +121,38 @@ public void testCsvConvertOptionsDelimiterNotSet() throws Exception { String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); try (ArrowSchema cSchema = ArrowSchema.allocateNew(allocator); + ArrowSchema cSchema2 = ArrowSchema.allocateNew(allocator); CDataDictionaryProvider provider = new CDataDictionaryProvider()) { Data.exportSchema(allocator, schema, provider, cSchema); - CsvConvertOptions convertOptions = new CsvConvertOptions(ImmutableMap.of()); - convertOptions.setArrowSchema(cSchema); - CsvFragmentScanOptions fragmentScanOptions = - new CsvFragmentScanOptions(convertOptions, ImmutableMap.of(), ImmutableMap.of()); + Data.exportSchema(allocator, schema, provider, cSchema2); + CsvFragmentScanOptions fragmentScanOptions1 = + create(cSchema, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of()); + CsvFragmentScanOptions fragmentScanOptions2 = + create(cSchema2, ImmutableMap.of(), ImmutableMap.of(), ImmutableMap.of()); ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) .columns(Optional.empty()) - .fragmentScanOptions(fragmentScanOptions) + .fragmentScanOptions(fragmentScanOptions1) .build(); try (DatasetFactory datasetFactory = new FileSystemDatasetFactory( - allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path); + allocator, + NativeMemoryPool.getDefault(), + FileFormat.CSV, + path, + Optional.of(fragmentScanOptions2)); Dataset dataset = datasetFactory.finish(); Scanner scanner = dataset.newScan(options); ArrowReader reader = scanner.scanBatches()) { - - assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowCount = 0; while (reader.loadNextBatch()) { - final ValueIterableVector idVector = - (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id"); - assertThat(idVector.getValueIterable(), IsIterableContainingInOrder.contains(1, 2, 3)); + final ValueIterableVector idVector = + (ValueIterableVector) + reader.getVectorSchemaRoot().getVector("Id;Name;Language"); + assertThat( + idVector.getValueIterable(), + IsIterableContainingInOrder.contains( + new Text("1;Juno;Java"), new Text("2;Peter;Python"), new Text("3;Celin;C++"))); rowCount += reader.getVectorSchemaRoot().getRowCount(); } assertEquals(3, rowCount); @@ -157,13 +180,12 @@ public void testCsvConvertOptionsNoOption() throws Exception { assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowCount = 0; while (reader.loadNextBatch()) { - final ValueIterableVector idVector = - (ValueIterableVector) - reader.getVectorSchemaRoot().getVector("Id;Name;Language"); + final ValueIterableVector idVector = + (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id;Name;Language"); assertThat( idVector.getValueIterable(), IsIterableContainingInOrder.contains( - "1;Juno;Java\n" + "2;Peter;Python\n" + "3;Celin;C++")); + new Text("1;Juno;Java"), new Text("2;Peter;Python"), new Text("3;Celin;C++"))); rowCount += reader.getVectorSchemaRoot().getRowCount(); } assertEquals(3, rowCount); @@ -174,7 +196,10 @@ public void testCsvConvertOptionsNoOption() throws Exception { public void testCsvReadParseAndReadOptions() throws Exception { final Schema schema = new Schema( - Collections.singletonList(Field.nullable("Id;Name;Language", new ArrowType.Utf8())), + Arrays.asList( + Field.nullable("Id", new ArrowType.Int(64, true)), + Field.nullable("Name", new ArrowType.Utf8()), + Field.nullable("Language", new ArrowType.Utf8())), null); String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); @@ -202,12 +227,9 @@ public void testCsvReadParseAndReadOptions() throws Exception { assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); int rowCount = 0; while (reader.loadNextBatch()) { - final ValueIterableVector idVector = - (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id;Name;Language"); - assertThat( - idVector.getValueIterable(), - IsIterableContainingInOrder.contains( - new Text("2;Peter;Python"), new Text("3;Celin;C++"))); + final ValueIterableVector idVector = + (ValueIterableVector) reader.getVectorSchemaRoot().getVector("Id"); + assertThat(idVector.getValueIterable(), IsIterableContainingInOrder.contains(2L, 3L)); rowCount += reader.getVectorSchemaRoot().getRowCount(); } assertEquals(2, rowCount); From 1ae38d0d42c1ae5800e42b613f22593673b7370c Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Sun, 18 Aug 2024 08:48:55 -0500 Subject: [PATCH 033/157] GH-43735: [R] AWS SDK fails to build on one of CRAN's M1 builders (#43736) Trying to replicate the issue's on CRAN's M1 machine so that we can fix them. * GitHub Issue: #43735 Lead-authored-by: Jonathan Keane Co-authored-by: Sutou Kouhei Signed-off-by: Jonathan Keane --- cpp/cmake_modules/ThirdpartyToolchain.cmake | 12 +++ dev/tasks/r/github.macos.cran.yml | 82 +++++++++++++++++++++ dev/tasks/tasks.yml | 4 + 3 files changed, 98 insertions(+) create mode 100644 dev/tasks/r/github.macos.cran.yml diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index bc3a3a2249d..63e2c036c9a 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -4965,8 +4965,20 @@ macro(build_awssdk) set(AWSSDK_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/awssdk_ep-install") set(AWSSDK_INCLUDE_DIR "${AWSSDK_PREFIX}/include") + # The AWS SDK has a few warnings around shortening lengths + set(AWS_C_FLAGS "${EP_C_FLAGS}") + set(AWS_CXX_FLAGS "${EP_CXX_FLAGS}") + if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL + "Clang") + # Negate warnings that AWS SDK cannot build under + string(APPEND AWS_C_FLAGS " -Wno-error=shorten-64-to-32") + string(APPEND AWS_CXX_FLAGS " -Wno-error=shorten-64-to-32") + endif() + set(AWSSDK_COMMON_CMAKE_ARGS ${EP_COMMON_CMAKE_ARGS} + -DCMAKE_C_FLAGS=${AWS_C_FLAGS} + -DCMAKE_CXX_FLAGS=${AWS_CXX_FLAGS} -DCPP_STANDARD=${CMAKE_CXX_STANDARD} -DCMAKE_INSTALL_PREFIX=${AWSSDK_PREFIX} -DCMAKE_PREFIX_PATH=${AWSSDK_PREFIX} diff --git a/dev/tasks/r/github.macos.cran.yml b/dev/tasks/r/github.macos.cran.yml new file mode 100644 index 00000000000..33965988e21 --- /dev/null +++ b/dev/tasks/r/github.macos.cran.yml @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +{% import 'macros.jinja' as macros with context %} + +{{ macros.github_header() }} + +jobs: + macos-cran: + name: "macOS similar to CRAN" + runs-on: macOS-latest + strategy: + fail-fast: false + + steps: + {{ macros.github_checkout_arrow()|indent }} + + - name: Configure dependencies (macos) + run: | + brew install openssl + # disable sccache on macos as it times out for unknown reasons + # see GH-33721 + # brew install sccache + # remove cmake so that we can test our cmake downloading abilities + brew uninstall cmake + - uses: r-lib/actions/setup-r@v2 + with: + use-public-rspm: true + # CRAN builders have the entire bin here added to the path. This sometimes + # includes things like GNU libtool which name-collide with what we expect + - name: Add R.framework/Resources/bin to the path + run: echo "/Library/Frameworks/R.framework/Resources/bin" >> $GITHUB_PATH + - name : Check whether libtool in R is used + run: | + if [ "$(which libtool)" != "/Library/Frameworks/R.framework/Resources/bin/libtool" ]; then + echo "libtool provided by R isn't found: $(which libtool)" + exit 1 + fi + - name: Install dependencies + uses: r-lib/actions/setup-r-dependencies@v2 + with: + cache: false # cache does not work on across branches + working-directory: arrow/r + extra-packages: | + any::rcmdcheck + any::sys + - name: Install + env: + _R_CHECK_CRAN_INCOMING_: false + CXX: "clang++ -mmacos-version-min=14.6" + CFLAGS: "-falign-functions=8 -g -O2 -Wall -pedantic -Wconversion -Wno-sign-conversion -Wstrict-prototypes" + CXXFLAGS: "-g -O2 -Wall -pedantic -Wconversion -Wno-sign-conversion" + NOT_CRAN: false + run: | + sccache --start-server || echo 'sccache not found' + cd arrow/r + R CMD INSTALL . --install-tests + - name: Run the tests + run: R -e 'if(tools::testInstalledPackage("arrow") != 0L) stop("There was a test failure.")' + - name: Dump test logs + run: cat arrow-tests/testthat.Rout* + if: failure() + - name: Save the test output + uses: actions/upload-artifact@v2 + with: + name: test-output + path: arrow-tests/testthat.Rout* + if: always() diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index a9da7eb2889..fe02fe9ce68 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1319,6 +1319,10 @@ tasks: params: MATRIX: {{ "${{ matrix.r_image }}" }} + test-r-macos-as-cran: + ci: github + template: r/github.macos.cran.yml + test-r-arrow-backwards-compatibility: ci: github template: r/github.linux.arrow.version.back.compat.yml From 5e68513d62b0d216e916de6a1ad2db04f5d1a7bf Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Mon, 19 Aug 2024 18:39:05 +0800 Subject: [PATCH 034/157] GH-43495: [C++][Compute] Widen the row offset of the row table to 64-bit (#43389) ### Rationale for this change The row table uses `uint32_t` as the row offset within the row data buffer, effectively limiting the row data from growing beyond 4GB. This is quite restrictive, and the impact is described in more detail in #43495. This PR proposes to widen the row offset from 32-bit to 64-bit to address this limitation. #### Benefits Currently, the row table has three major limitations: 1. The overall data size cannot exceed 4GB. 2. The size of a single row cannot exceed 4GB. 3. The number of rows cannot exceed 2^32. This enhancement will eliminate the first limitation. Meanwhile, the second and third limitations are less likely to occur. Thus, this change will enable a significant range of use cases that are currently unsupported. #### Overhead Of course, this will introduce some overhead: 1. An extra 4 bytes of memory consumption for each row due to the offset size difference from 32-bit to 64-bit. 2. A wider offset type requires a few more SIMD instructions in each 8-row processing iteration. In my opinion, this overhead is justified by the benefits listed above. ### What changes are included in this PR? Change the row offset of the row table from 32-bit to 64-bit. Relative code in row comparison/encoding and swiss join has been updated accordingly. ### Are these changes tested? Test included. ### Are there any user-facing changes? Users could potentially see higher memory consumption when using acero's hash join and hash aggregation. However, on the other hand, certain use cases used to fail are now able to complete. * GitHub Issue: #43495 Authored-by: Ruoxi Sun Signed-off-by: Antoine Pitrou --- cpp/src/arrow/acero/hash_join_node_test.cc | 192 ++++++++++ cpp/src/arrow/acero/swiss_join.cc | 26 +- cpp/src/arrow/acero/swiss_join_avx2.cc | 126 +++++-- cpp/src/arrow/compute/row/compare_internal.cc | 39 +- cpp/src/arrow/compute/row/compare_internal.h | 27 +- .../compute/row/compare_internal_avx2.cc | 172 ++++----- cpp/src/arrow/compute/row/compare_test.cc | 333 +++++++++++++----- cpp/src/arrow/compute/row/encode_internal.cc | 47 ++- cpp/src/arrow/compute/row/encode_internal.h | 7 +- .../arrow/compute/row/encode_internal_avx2.cc | 10 +- cpp/src/arrow/compute/row/row_internal.cc | 38 +- cpp/src/arrow/compute/row/row_internal.h | 37 +- cpp/src/arrow/compute/row/row_test.cc | 66 ++-- cpp/src/arrow/testing/random.cc | 19 +- cpp/src/arrow/testing/random.h | 6 + 15 files changed, 802 insertions(+), 343 deletions(-) diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index f7b442cc3c6..88f9a9e71b7 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -30,6 +30,7 @@ #include "arrow/compute/kernels/test_util.h" #include "arrow/compute/light_array_internal.h" #include "arrow/testing/extension_type.h" +#include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" @@ -40,6 +41,10 @@ using testing::UnorderedElementsAreArray; namespace arrow { +using arrow::gen::Constant; +using arrow::random::kSeedMax; +using arrow::random::RandomArrayGenerator; +using compute::and_; using compute::call; using compute::default_exec_context; using compute::ExecBatchBuilder; @@ -3253,5 +3258,192 @@ TEST(HashJoin, ManyJoins) { ASSERT_OK_AND_ASSIGN(std::ignore, DeclarationToTable(std::move(root))); } +namespace { + +void AssertRowCountEq(Declaration source, int64_t expected) { + Declaration count{"aggregate", + {std::move(source)}, + AggregateNodeOptions{/*aggregates=*/{{"count_all", "count(*)"}}}}; + ASSERT_OK_AND_ASSIGN(auto batches, DeclarationToExecBatches(std::move(count))); + ASSERT_EQ(batches.batches.size(), 1); + ASSERT_EQ(batches.batches[0].values.size(), 1); + ASSERT_TRUE(batches.batches[0].values[0].is_scalar()); + ASSERT_EQ(batches.batches[0].values[0].scalar()->type->id(), Type::INT64); + ASSERT_TRUE(batches.batches[0].values[0].scalar_as().is_valid); + ASSERT_EQ(batches.batches[0].values[0].scalar_as().value, expected); +} + +} // namespace + +// GH-43495: Test that both the key and the payload of the right side (the build side) are +// fixed length and larger than 4GB, and the 64-bit offset in the hash table can handle it +// correctly. +TEST(HashJoin, LARGE_MEMORY_TEST(BuildSideOver4GBFixedLength)) { + constexpr int64_t k5GB = 5ll * 1024 * 1024 * 1024; + constexpr int fixed_length = 128; + const auto type = fixed_size_binary(fixed_length); + constexpr uint8_t byte_no_match_min = static_cast('A'); + constexpr uint8_t byte_no_match_max = static_cast('y'); + constexpr uint8_t byte_match = static_cast('z'); + const auto value_match = + std::make_shared(std::string(fixed_length, byte_match)); + constexpr int16_t num_rows_per_batch_left = 128; + constexpr int16_t num_rows_per_batch_right = 4096; + const int64_t num_batches_left = 8; + const int64_t num_batches_right = + k5GB / (num_rows_per_batch_right * type->byte_width()); + + // Left side composed of num_batches_left identical batches of num_rows_per_batch_left + // rows of value_match-es. + BatchesWithSchema batches_left; + { + // A column with num_rows_per_batch_left value_match-es. + ASSERT_OK_AND_ASSIGN(auto column, + Constant(value_match)->Generate(num_rows_per_batch_left)); + + // Use the column as both the key and the payload. + ExecBatch batch({column, column}, num_rows_per_batch_left); + batches_left = + BatchesWithSchema{std::vector(num_batches_left, std::move(batch)), + schema({field("l_key", type), field("l_payload", type)})}; + } + + // Right side composed of num_batches_right identical batches of + // num_rows_per_batch_right rows containing only 1 value_match. + BatchesWithSchema batches_right; + { + // A column with (num_rows_per_batch_right - 1) non-value_match-es (possibly null) and + // 1 value_match. + auto non_matches = RandomArrayGenerator(kSeedMax).FixedSizeBinary( + num_rows_per_batch_right - 1, fixed_length, + /*null_probability =*/0.01, /*min_byte=*/byte_no_match_min, + /*max_byte=*/byte_no_match_max); + ASSERT_OK_AND_ASSIGN(auto match, Constant(value_match)->Generate(1)); + ASSERT_OK_AND_ASSIGN(auto column, Concatenate({non_matches, match})); + + // Use the column as both the key and the payload. + ExecBatch batch({column, column}, num_rows_per_batch_right); + batches_right = + BatchesWithSchema{std::vector(num_batches_right, std::move(batch)), + schema({field("r_key", type), field("r_payload", type)})}; + } + + Declaration left{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_left.schema), + std::move(batches_left.batches))}; + + Declaration right{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_right.schema), + std::move(batches_right.batches))}; + + HashJoinNodeOptions join_opts(JoinType::INNER, /*left_keys=*/{"l_key"}, + /*right_keys=*/{"r_key"}); + Declaration join{"hashjoin", {std::move(left), std::move(right)}, join_opts}; + + ASSERT_OK_AND_ASSIGN(auto batches_result, DeclarationToExecBatches(std::move(join))); + Declaration result{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_result.schema), + std::move(batches_result.batches))}; + + // The row count of hash join should be (number of value_match-es in left side) * + // (number of value_match-es in right side). + AssertRowCountEq(result, + num_batches_left * num_rows_per_batch_left * num_batches_right); + + // All rows should be value_match-es. + auto predicate = and_({equal(field_ref("l_key"), literal(value_match)), + equal(field_ref("l_payload"), literal(value_match)), + equal(field_ref("r_key"), literal(value_match)), + equal(field_ref("r_payload"), literal(value_match))}); + Declaration filter{"filter", {result}, FilterNodeOptions{std::move(predicate)}}; + AssertRowCountEq(std::move(filter), + num_batches_left * num_rows_per_batch_left * num_batches_right); +} + +// GH-43495: Test that both the key and the payload of the right side (the build side) are +// var length and larger than 4GB, and the 64-bit offset in the hash table can handle it +// correctly. +TEST(HashJoin, LARGE_MEMORY_TEST(BuildSideOver4GBVarLength)) { + constexpr int64_t k5GB = 5ll * 1024 * 1024 * 1024; + const auto type = utf8(); + constexpr int value_no_match_length_min = 128; + constexpr int value_no_match_length_max = 129; + constexpr int value_match_length = 130; + const auto value_match = + std::make_shared(std::string(value_match_length, 'X')); + constexpr int16_t num_rows_per_batch_left = 128; + constexpr int16_t num_rows_per_batch_right = 4096; + const int64_t num_batches_left = 8; + const int64_t num_batches_right = + k5GB / (num_rows_per_batch_right * value_no_match_length_min); + + // Left side composed of num_batches_left identical batches of num_rows_per_batch_left + // rows of value_match-es. + BatchesWithSchema batches_left; + { + // A column with num_rows_per_batch_left value_match-es. + ASSERT_OK_AND_ASSIGN(auto column, + Constant(value_match)->Generate(num_rows_per_batch_left)); + + // Use the column as both the key and the payload. + ExecBatch batch({column, column}, num_rows_per_batch_left); + batches_left = + BatchesWithSchema{std::vector(num_batches_left, std::move(batch)), + schema({field("l_key", type), field("l_payload", type)})}; + } + + // Right side composed of num_batches_right identical batches of + // num_rows_per_batch_right rows containing only 1 value_match. + BatchesWithSchema batches_right; + { + // A column with (num_rows_per_batch_right - 1) non-value_match-es (possibly null) and + // 1 value_match. + auto non_matches = + RandomArrayGenerator(kSeedMax).String(num_rows_per_batch_right - 1, + /*min_length=*/value_no_match_length_min, + /*max_length=*/value_no_match_length_max, + /*null_probability =*/0.01); + ASSERT_OK_AND_ASSIGN(auto match, Constant(value_match)->Generate(1)); + ASSERT_OK_AND_ASSIGN(auto column, Concatenate({non_matches, match})); + + // Use the column as both the key and the payload. + ExecBatch batch({column, column}, num_rows_per_batch_right); + batches_right = + BatchesWithSchema{std::vector(num_batches_right, std::move(batch)), + schema({field("r_key", type), field("r_payload", type)})}; + } + + Declaration left{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_left.schema), + std::move(batches_left.batches))}; + + Declaration right{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_right.schema), + std::move(batches_right.batches))}; + + HashJoinNodeOptions join_opts(JoinType::INNER, /*left_keys=*/{"l_key"}, + /*right_keys=*/{"r_key"}); + Declaration join{"hashjoin", {std::move(left), std::move(right)}, join_opts}; + + ASSERT_OK_AND_ASSIGN(auto batches_result, DeclarationToExecBatches(std::move(join))); + Declaration result{"exec_batch_source", + ExecBatchSourceNodeOptions(std::move(batches_result.schema), + std::move(batches_result.batches))}; + + // The row count of hash join should be (number of value_match-es in left side) * + // (number of value_match-es in right side). + AssertRowCountEq(result, + num_batches_left * num_rows_per_batch_left * num_batches_right); + + // All rows should be value_match-es. + auto predicate = and_({equal(field_ref("l_key"), literal(value_match)), + equal(field_ref("l_payload"), literal(value_match)), + equal(field_ref("r_key"), literal(value_match)), + equal(field_ref("r_payload"), literal(value_match))}); + Declaration filter{"filter", {result}, FilterNodeOptions{std::move(predicate)}}; + AssertRowCountEq(std::move(filter), + num_batches_left * num_rows_per_batch_left * num_batches_right); +} + } // namespace acero } // namespace arrow diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 732deb72861..40a4b5886e4 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -122,7 +122,7 @@ void RowArrayAccessor::Visit(const RowTableImpl& rows, int column_id, int num_ro if (!is_fixed_length_column) { int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); const uint8_t* row_ptr_base = rows.data(2); - const uint32_t* row_offsets = rows.offsets(); + const RowTableImpl::offset_type* row_offsets = rows.offsets(); uint32_t field_offset_within_row, field_length; if (varbinary_column_id == 0) { @@ -173,7 +173,7 @@ void RowArrayAccessor::Visit(const RowTableImpl& rows, int column_id, int num_ro // Case 4: This is a fixed length column in a varying length row // const uint8_t* row_ptr_base = rows.data(2) + field_offset_within_row; - const uint32_t* row_offsets = rows.offsets(); + const RowTableImpl::offset_type* row_offsets = rows.offsets(); for (int i = 0; i < num_rows; ++i) { uint32_t row_id = row_ids[i]; const uint8_t* row_ptr = row_ptr_base + row_offsets[row_id]; @@ -473,17 +473,10 @@ Status RowArrayMerge::PrepareForMerge(RowArray* target, (*first_target_row_id)[sources.size()] = num_rows; } - if (num_bytes > std::numeric_limits::max()) { - return Status::Invalid( - "There are more than 2^32 bytes of key data. Acero cannot " - "process a join of this magnitude"); - } - // Allocate target memory // target->rows_.Clean(); - RETURN_NOT_OK(target->rows_.AppendEmpty(static_cast(num_rows), - static_cast(num_bytes))); + RETURN_NOT_OK(target->rows_.AppendEmpty(static_cast(num_rows), num_bytes)); // In case of varying length rows, // initialize the first row offset for each range of rows corresponding to a @@ -565,15 +558,15 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl& int64_t first_target_row_offset, const int64_t* source_rows_permutation) { int64_t num_source_rows = source.length(); - uint32_t* target_offsets = target->mutable_offsets(); - const uint32_t* source_offsets = source.offsets(); + RowTableImpl::offset_type* target_offsets = target->mutable_offsets(); + const RowTableImpl::offset_type* source_offsets = source.offsets(); // Permutation of source rows is optional. // if (!source_rows_permutation) { int64_t target_row_offset = first_target_row_offset; for (int64_t i = 0; i < num_source_rows; ++i) { - target_offsets[first_target_row_id + i] = static_cast(target_row_offset); + target_offsets[first_target_row_id + i] = target_row_offset; target_row_offset += source_offsets[i + 1] - source_offsets[i]; } // We purposefully skip outputting of N+1 offset, to allow concurrent @@ -593,7 +586,10 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl& int64_t source_row_id = source_rows_permutation[i]; const uint64_t* source_row_ptr = reinterpret_cast( source.data(2) + source_offsets[source_row_id]); - uint32_t length = source_offsets[source_row_id + 1] - source_offsets[source_row_id]; + int64_t length = source_offsets[source_row_id + 1] - source_offsets[source_row_id]; + // Though the row offset is 64-bit, the length of a single row must be 32-bit as + // required by current row table implementation. + DCHECK_LE(length, std::numeric_limits::max()); // Rows should be 64-bit aligned. // In that case we can copy them using a sequence of 64-bit read/writes. @@ -604,7 +600,7 @@ void RowArrayMerge::CopyVaryingLength(RowTableImpl* target, const RowTableImpl& *target_row_ptr++ = *source_row_ptr++; } - target_offsets[first_target_row_id + i] = static_cast(target_row_offset); + target_offsets[first_target_row_id + i] = target_row_offset; target_row_offset += length; } } diff --git a/cpp/src/arrow/acero/swiss_join_avx2.cc b/cpp/src/arrow/acero/swiss_join_avx2.cc index 0888dd89384..e42b0b40445 100644 --- a/cpp/src/arrow/acero/swiss_join_avx2.cc +++ b/cpp/src/arrow/acero/swiss_join_avx2.cc @@ -23,6 +23,9 @@ namespace arrow { namespace acero { +// TODO(GH-43693): The functions in this file are not wired anywhere. We may consider +// actually utilizing them or removing them. + template int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int num_rows, const uint32_t* row_ids, @@ -45,48 +48,78 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu if (!is_fixed_length_column) { int varbinary_column_id = VarbinaryColumnId(rows.metadata(), column_id); const uint8_t* row_ptr_base = rows.data(2); - const uint32_t* row_offsets = rows.offsets(); + const RowTableImpl::offset_type* row_offsets = rows.offsets(); + static_assert( + sizeof(RowTableImpl::offset_type) == sizeof(int64_t), + "RowArrayAccessor::Visit_avx2 only supports 64-bit RowTableImpl::offset_type"); if (varbinary_column_id == 0) { // Case 1: This is the first varbinary column // __m256i field_offset_within_row = _mm256_set1_epi32(rows.metadata().fixed_length); __m256i varbinary_end_array_offset = - _mm256_set1_epi32(rows.metadata().varbinary_end_array_offset); + _mm256_set1_epi64x(rows.metadata().varbinary_end_array_offset); for (int i = 0; i < num_rows / unroll; ++i) { + // Load 8 32-bit row ids. __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); - __m256i row_offset = _mm256_i32gather_epi32( - reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); + // Gather the lower/higher 4 64-bit row offsets based on the lower/higher 4 32-bit + // row ids. + __m256i row_offset_lo = + _mm256_i32gather_epi64(row_offsets, _mm256_castsi256_si128(row_id), + sizeof(RowTableImpl::offset_type)); + __m256i row_offset_hi = + _mm256_i32gather_epi64(row_offsets, _mm256_extracti128_si256(row_id, 1), + sizeof(RowTableImpl::offset_type)); + // Gather the lower/higher 4 32-bit field lengths based on the lower/higher 4 + // 64-bit row offsets. + __m128i field_length_lo = _mm256_i64gather_epi32( + reinterpret_cast(row_ptr_base), + _mm256_add_epi64(row_offset_lo, varbinary_end_array_offset), 1); + __m128i field_length_hi = _mm256_i64gather_epi32( + reinterpret_cast(row_ptr_base), + _mm256_add_epi64(row_offset_hi, varbinary_end_array_offset), 1); + // The final 8 32-bit field lengths, subtracting the field offset within row. __m256i field_length = _mm256_sub_epi32( - _mm256_i32gather_epi32( - reinterpret_cast(row_ptr_base), - _mm256_add_epi32(row_offset, varbinary_end_array_offset), 1), - field_offset_within_row); + _mm256_set_m128i(field_length_hi, field_length_lo), field_offset_within_row); process_8_values_fn(i * unroll, row_ptr_base, - _mm256_add_epi32(row_offset, field_offset_within_row), + _mm256_add_epi64(row_offset_lo, field_offset_within_row), + _mm256_add_epi64(row_offset_hi, field_offset_within_row), field_length); } } else { // Case 2: This is second or later varbinary column // __m256i varbinary_end_array_offset = - _mm256_set1_epi32(rows.metadata().varbinary_end_array_offset + - sizeof(uint32_t) * (varbinary_column_id - 1)); + _mm256_set1_epi64x(rows.metadata().varbinary_end_array_offset + + sizeof(uint32_t) * (varbinary_column_id - 1)); auto row_ptr_base_i64 = reinterpret_cast(row_ptr_base); for (int i = 0; i < num_rows / unroll; ++i) { + // Load 8 32-bit row ids. __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); - __m256i row_offset = _mm256_i32gather_epi32( - reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); - __m256i end_array_offset = - _mm256_add_epi32(row_offset, varbinary_end_array_offset); - - __m256i field_offset_within_row_A = _mm256_i32gather_epi64( - row_ptr_base_i64, _mm256_castsi256_si128(end_array_offset), 1); - __m256i field_offset_within_row_B = _mm256_i32gather_epi64( - row_ptr_base_i64, _mm256_extracti128_si256(end_array_offset, 1), 1); + // Gather the lower/higher 4 64-bit row offsets based on the lower/higher 4 32-bit + // row ids. + __m256i row_offset_lo = + _mm256_i32gather_epi64(row_offsets, _mm256_castsi256_si128(row_id), + sizeof(RowTableImpl::offset_type)); + // Gather the lower/higher 4 32-bit field lengths based on the lower/higher 4 + // 64-bit row offsets. + __m256i row_offset_hi = + _mm256_i32gather_epi64(row_offsets, _mm256_extracti128_si256(row_id, 1), + sizeof(RowTableImpl::offset_type)); + // Prepare the lower/higher 4 64-bit end array offsets based on the lower/higher 4 + // 64-bit row offsets. + __m256i end_array_offset_lo = + _mm256_add_epi64(row_offset_lo, varbinary_end_array_offset); + __m256i end_array_offset_hi = + _mm256_add_epi64(row_offset_hi, varbinary_end_array_offset); + + __m256i field_offset_within_row_A = + _mm256_i64gather_epi64(row_ptr_base_i64, end_array_offset_lo, 1); + __m256i field_offset_within_row_B = + _mm256_i64gather_epi64(row_ptr_base_i64, end_array_offset_hi, 1); field_offset_within_row_A = _mm256_permutevar8x32_epi32( field_offset_within_row_A, _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7)); field_offset_within_row_B = _mm256_permutevar8x32_epi32( @@ -110,8 +143,14 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu 0x4e); // Swapping low and high 128-bits field_length = _mm256_sub_epi32(field_length, field_offset_within_row); + field_offset_within_row_A = + _mm256_add_epi32(field_offset_within_row_A, alignment_padding); + field_offset_within_row_B = + _mm256_add_epi32(field_offset_within_row_B, alignment_padding); + process_8_values_fn(i * unroll, row_ptr_base, - _mm256_add_epi32(row_offset, field_offset_within_row), + _mm256_add_epi64(row_offset_lo, field_offset_within_row_A), + _mm256_add_epi64(row_offset_hi, field_offset_within_row_B), field_length); } } @@ -119,7 +158,7 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu if (is_fixed_length_column) { __m256i field_offset_within_row = - _mm256_set1_epi32(rows.metadata().encoded_field_offset( + _mm256_set1_epi64x(rows.metadata().encoded_field_offset( rows.metadata().pos_after_encoding(column_id))); __m256i field_length = _mm256_set1_epi32(rows.metadata().column_metadatas[column_id].fixed_length); @@ -130,24 +169,51 @@ int RowArrayAccessor::Visit_avx2(const RowTableImpl& rows, int column_id, int nu // const uint8_t* row_ptr_base = rows.data(1); for (int i = 0; i < num_rows / unroll; ++i) { + // Load 8 32-bit row ids. __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); - __m256i row_offset = _mm256_mullo_epi32(row_id, field_length); - __m256i field_offset = _mm256_add_epi32(row_offset, field_offset_within_row); - process_8_values_fn(i * unroll, row_ptr_base, field_offset, field_length); + // Widen the 32-bit row ids to 64-bit and store the lower/higher 4 of them into 2 + // 256-bit registers. + __m256i row_id_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(row_id)); + __m256i row_id_hi = _mm256_cvtepi32_epi64(_mm256_extracti128_si256(row_id, 1)); + // Calculate the lower/higher 4 64-bit row offsets based on the lower/higher 4 + // 64-bit row ids and the fixed field length. + __m256i row_offset_lo = _mm256_mul_epi32(row_id_lo, field_length); + __m256i row_offset_hi = _mm256_mul_epi32(row_id_hi, field_length); + // Calculate the lower/higher 4 64-bit field offsets based on the lower/higher 4 + // 64-bit row offsets and field offset within row. + __m256i field_offset_lo = + _mm256_add_epi64(row_offset_lo, field_offset_within_row); + __m256i field_offset_hi = + _mm256_add_epi64(row_offset_hi, field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, field_offset_lo, field_offset_hi, + field_length); } } else { // Case 4: This is a fixed length column in varying length row // const uint8_t* row_ptr_base = rows.data(2); - const uint32_t* row_offsets = rows.offsets(); + const RowTableImpl::offset_type* row_offsets = rows.offsets(); for (int i = 0; i < num_rows / unroll; ++i) { + // Load 8 32-bit row ids. __m256i row_id = _mm256_loadu_si256(reinterpret_cast(row_ids) + i); - __m256i row_offset = _mm256_i32gather_epi32( - reinterpret_cast(row_offsets), row_id, sizeof(uint32_t)); - __m256i field_offset = _mm256_add_epi32(row_offset, field_offset_within_row); - process_8_values_fn(i * unroll, row_ptr_base, field_offset, field_length); + // Gather the lower/higher 4 64-bit row offsets based on the lower/higher 4 32-bit + // row ids. + __m256i row_offset_lo = + _mm256_i32gather_epi64(row_offsets, _mm256_castsi256_si128(row_id), + sizeof(RowTableImpl::offset_type)); + __m256i row_offset_hi = + _mm256_i32gather_epi64(row_offsets, _mm256_extracti128_si256(row_id, 1), + sizeof(RowTableImpl::offset_type)); + // Calculate the lower/higher 4 64-bit field offsets based on the lower/higher 4 + // 64-bit row offsets and field offset within row. + __m256i field_offset_lo = + _mm256_add_epi64(row_offset_lo, field_offset_within_row); + __m256i field_offset_hi = + _mm256_add_epi64(row_offset_hi, field_offset_within_row); + process_8_values_fn(i * unroll, row_ptr_base, field_offset_lo, field_offset_hi, + field_length); } } } diff --git a/cpp/src/arrow/compute/row/compare_internal.cc b/cpp/src/arrow/compute/row/compare_internal.cc index 98aea901126..5e1a87b7952 100644 --- a/cpp/src/arrow/compute/row/compare_internal.cc +++ b/cpp/src/arrow/compute/row/compare_internal.cc @@ -104,18 +104,21 @@ void KeyCompare::CompareBinaryColumnToRowHelper( const uint8_t* rows_right = rows.data(1); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; - uint32_t irow_right = left_to_right_map[irow_left]; - uint32_t offset_right = irow_right * fixed_length + offset_within_row; + // irow_right is used to index into row data so promote to the row offset type. + RowTableImpl::offset_type irow_right = left_to_right_map[irow_left]; + RowTableImpl::offset_type offset_right = + irow_right * fixed_length + offset_within_row; match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, offset_right); } } else { const uint8_t* rows_left = col.data(1); - const uint32_t* offsets_right = rows.offsets(); + const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_right = rows.data(2); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { uint32_t irow_left = use_selection ? sel_left_maybe_null[i] : i; uint32_t irow_right = left_to_right_map[irow_left]; - uint32_t offset_right = offsets_right[irow_right] + offset_within_row; + RowTableImpl::offset_type offset_right = + offsets_right[irow_right] + offset_within_row; match_bytevector[i] = compare_fn(rows_left, rows_right, irow_left, offset_right); } } @@ -145,7 +148,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [bit_offset](const uint8_t* left_base, const uint8_t* right_base, - uint32_t irow_left, uint32_t offset_right) { + uint32_t irow_left, RowTableImpl::offset_type offset_right) { uint8_t left = bit_util::GetBit(left_base, irow_left + bit_offset) ? 0xff : 0x00; uint8_t right = right_base[offset_right]; @@ -156,7 +159,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { uint8_t left = left_base[irow_left]; uint8_t right = right_base[offset_right]; return left == right ? 0xff : 0; @@ -166,7 +169,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { util::CheckAlignment(left_base); util::CheckAlignment(right_base + offset_right); uint16_t left = reinterpret_cast(left_base)[irow_left]; @@ -178,7 +181,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { util::CheckAlignment(left_base); util::CheckAlignment(right_base + offset_right); uint32_t left = reinterpret_cast(left_base)[irow_left]; @@ -190,7 +193,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { util::CheckAlignment(left_base); util::CheckAlignment(right_base + offset_right); uint64_t left = reinterpret_cast(left_base)[irow_left]; @@ -202,7 +205,7 @@ void KeyCompare::CompareBinaryColumnToRow(uint32_t offset_within_row, offset_within_row, num_processed, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [&col](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left, - uint32_t offset_right) { + RowTableImpl::offset_type offset_right) { uint32_t length = col.metadata().fixed_length; // Non-zero length guarantees no underflow @@ -241,7 +244,7 @@ void KeyCompare::CompareVarBinaryColumnToRowHelper( const uint32_t* left_to_right_map, LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, uint8_t* match_bytevector) { const uint32_t* offsets_left = col.offsets(); - const uint32_t* offsets_right = rows.offsets(); + const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_left = col.data(2); const uint8_t* rows_right = rows.data(2); for (uint32_t i = first_row_to_compare; i < num_rows_to_compare; ++i) { @@ -249,7 +252,7 @@ void KeyCompare::CompareVarBinaryColumnToRowHelper( uint32_t irow_right = left_to_right_map[irow_left]; uint32_t begin_left = offsets_left[irow_left]; uint32_t length_left = offsets_left[irow_left + 1] - begin_left; - uint32_t begin_right = offsets_right[irow_right]; + RowTableImpl::offset_type begin_right = offsets_right[irow_right]; uint32_t length_right; uint32_t offset_within_row; if (!is_first_varbinary_col) { @@ -334,7 +337,13 @@ void KeyCompare::CompareColumnsToRows( const RowTableImpl& rows, bool are_cols_in_encoding_order, uint8_t* out_match_bitvector_maybe_null) { if (num_rows_to_compare == 0) { - *out_num_rows = 0; + if (out_match_bitvector_maybe_null) { + DCHECK_EQ(out_num_rows, nullptr); + DCHECK_EQ(out_sel_left_maybe_same, nullptr); + bit_util::ClearBitmap(out_match_bitvector_maybe_null, 0, num_rows_to_compare); + } else { + *out_num_rows = 0; + } return; } @@ -440,8 +449,8 @@ void KeyCompare::CompareColumnsToRows( match_bytevector_A, match_bitvector); if (out_match_bitvector_maybe_null) { - ARROW_DCHECK(out_num_rows == nullptr); - ARROW_DCHECK(out_sel_left_maybe_same == nullptr); + DCHECK_EQ(out_num_rows, nullptr); + DCHECK_EQ(out_sel_left_maybe_same, nullptr); memcpy(out_match_bitvector_maybe_null, match_bitvector, bit_util::BytesForBits(num_rows_to_compare)); } else { diff --git a/cpp/src/arrow/compute/row/compare_internal.h b/cpp/src/arrow/compute/row/compare_internal.h index a5a109b0b51..29d7f859e59 100644 --- a/cpp/src/arrow/compute/row/compare_internal.h +++ b/cpp/src/arrow/compute/row/compare_internal.h @@ -42,9 +42,30 @@ class ARROW_EXPORT KeyCompare { /*extra=*/util::MiniBatch::kMiniBatchLength; } - // Returns a single 16-bit selection vector of rows that failed comparison. - // If there is input selection on the left, the resulting selection is a filtered image - // of input selection. + /// \brief Compare a batch of rows in columnar format to the specified rows in row + /// format. + /// + /// The comparison result is populated in either a 16-bit selection vector of rows that + /// failed comparison, or a match bitvector with 1 for matched rows and 0 otherwise. + /// + /// @param num_rows_to_compare The number of rows to compare. + /// @param sel_left_maybe_null Optional input selection vector on the left, the + /// comparison is only performed on the selected rows. Null if all rows in + /// `left_to_right_map` are to be compared. + /// @param left_to_right_map The mapping from the left to the right rows. Left row `i` + /// in `cols` is compared to right row `left_to_right_map[i]` in `row`. + /// @param ctx The light context needed for the comparison. + /// @param out_num_rows The number of rows that failed comparison. Must be null if + /// `out_match_bitvector_maybe_null` is not null. + /// @param out_sel_left_maybe_same The selection vector of rows that failed comparison. + /// Can be the same as `sel_left_maybe_null` for in-place update. Must be null if + /// `out_match_bitvector_maybe_null` is not null. + /// @param cols The left rows in columnar format to compare. + /// @param rows The right rows in row format to compare. + /// @param are_cols_in_encoding_order Whether the columns are in encoding order. + /// @param out_match_bitvector_maybe_null The optional output match bitvector, 1 for + /// matched rows and 0 otherwise. Won't be populated if `out_num_rows` and + /// `out_sel_left_maybe_same` are not null. static void CompareColumnsToRows( uint32_t num_rows_to_compare, const uint16_t* sel_left_maybe_null, const uint32_t* left_to_right_map, LightContext* ctx, uint32_t* out_num_rows, diff --git a/cpp/src/arrow/compute/row/compare_internal_avx2.cc b/cpp/src/arrow/compute/row/compare_internal_avx2.cc index 23238a3691c..96eed6fc03a 100644 --- a/cpp/src/arrow/compute/row/compare_internal_avx2.cc +++ b/cpp/src/arrow/compute/row/compare_internal_avx2.cc @@ -180,40 +180,6 @@ uint32_t KeyCompare::NullUpdateColumnToRowImp_avx2( } } -namespace { - -// Intrinsics `_mm256_i32gather_epi32/64` treat the `vindex` as signed integer, and we -// are using `uint32_t` to represent the offset, in range of [0, 4G), within the row -// table. When the offset is larger than `0x80000000` (2GB), those intrinsics will treat -// it as negative offset and gather the data from undesired address. To avoid this issue, -// we normalize the addresses by translating `base` `0x80000000` higher, and `offset` -// `0x80000000` lower. This way, the offset is always in range of [-2G, 2G) and those -// intrinsics are safe. - -constexpr uint64_t kTwoGB = 0x80000000ull; - -template -inline __m256i UnsignedOffsetSafeGather32(int const* base, __m256i offset) { - int const* normalized_base = base + kTwoGB / sizeof(int); - __m256i normalized_offset = - _mm256_sub_epi32(offset, _mm256_set1_epi32(static_cast(kTwoGB / kScale))); - return _mm256_i32gather_epi32(normalized_base, normalized_offset, - static_cast(kScale)); -} - -template -inline __m256i UnsignedOffsetSafeGather64(arrow::util::int64_for_gather_t const* base, - __m128i offset) { - arrow::util::int64_for_gather_t const* normalized_base = - base + kTwoGB / sizeof(arrow::util::int64_for_gather_t); - __m128i normalized_offset = - _mm_sub_epi32(offset, _mm_set1_epi32(static_cast(kTwoGB / kScale))); - return _mm256_i32gather_epi64(normalized_base, normalized_offset, - static_cast(kScale)); -} - -} // namespace - template uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( uint32_t offset_within_row, uint32_t num_rows_to_compare, @@ -240,12 +206,26 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( _mm256_loadu_si256(reinterpret_cast(left_to_right_map) + i); } - __m256i offset_right = - _mm256_mullo_epi32(irow_right, _mm256_set1_epi32(fixed_length)); - offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row)); - - reinterpret_cast(match_bytevector)[i] = - compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right); + // Widen the 32-bit row ids to 64-bit and store the first/last 4 of them into 2 + // 256-bit registers. + __m256i irow_right_lo = _mm256_cvtepi32_epi64(_mm256_castsi256_si128(irow_right)); + __m256i irow_right_hi = + _mm256_cvtepi32_epi64(_mm256_extracti128_si256(irow_right, 1)); + // Calculate the lower/higher 4 64-bit row offsets based on the lower/higher 4 + // 64-bit row ids and the fixed length. + __m256i offset_right_lo = + _mm256_mul_epi32(irow_right_lo, _mm256_set1_epi64x(fixed_length)); + __m256i offset_right_hi = + _mm256_mul_epi32(irow_right_hi, _mm256_set1_epi64x(fixed_length)); + // Calculate the lower/higher 4 64-bit field offsets based on the lower/higher 4 + // 64-bit row offsets and field offset within row. + offset_right_lo = + _mm256_add_epi64(offset_right_lo, _mm256_set1_epi64x(offset_within_row)); + offset_right_hi = + _mm256_add_epi64(offset_right_hi, _mm256_set1_epi64x(offset_within_row)); + + reinterpret_cast(match_bytevector)[i] = compare8_fn( + rows_left, rows_right, i * unroll, irow_left, offset_right_lo, offset_right_hi); if (!use_selection) { irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8)); @@ -254,7 +234,7 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( return num_rows_to_compare - (num_rows_to_compare % unroll); } else { const uint8_t* rows_left = col.data(1); - const uint32_t* offsets_right = rows.offsets(); + const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_right = rows.data(2); constexpr uint32_t unroll = 8; __m256i irow_left = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); @@ -270,12 +250,29 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( irow_right = _mm256_loadu_si256(reinterpret_cast(left_to_right_map) + i); } - __m256i offset_right = - UnsignedOffsetSafeGather32<4>((int const*)offsets_right, irow_right); - offset_right = _mm256_add_epi32(offset_right, _mm256_set1_epi32(offset_within_row)); - reinterpret_cast(match_bytevector)[i] = - compare8_fn(rows_left, rows_right, i * unroll, irow_left, offset_right); + static_assert(sizeof(RowTableImpl::offset_type) == sizeof(int64_t), + "KeyCompare::CompareBinaryColumnToRowHelper_avx2 only supports " + "64-bit RowTableImpl::offset_type"); + auto offsets_right_i64 = + reinterpret_cast(offsets_right); + // Gather the lower/higher 4 64-bit row offsets based on the lower/higher 4 32-bit + // row ids. + __m256i offset_right_lo = + _mm256_i32gather_epi64(offsets_right_i64, _mm256_castsi256_si128(irow_right), + sizeof(RowTableImpl::offset_type)); + __m256i offset_right_hi = _mm256_i32gather_epi64( + offsets_right_i64, _mm256_extracti128_si256(irow_right, 1), + sizeof(RowTableImpl::offset_type)); + // Calculate the lower/higher 4 64-bit field offsets based on the lower/higher 4 + // 64-bit row offsets and field offset within row. + offset_right_lo = + _mm256_add_epi64(offset_right_lo, _mm256_set1_epi64x(offset_within_row)); + offset_right_hi = + _mm256_add_epi64(offset_right_hi, _mm256_set1_epi64x(offset_within_row)); + + reinterpret_cast(match_bytevector)[i] = compare8_fn( + rows_left, rows_right, i * unroll, irow_left, offset_right_lo, offset_right_hi); if (!use_selection) { irow_left = _mm256_add_epi32(irow_left, _mm256_set1_epi32(8)); @@ -287,8 +284,8 @@ uint32_t KeyCompare::CompareBinaryColumnToRowHelper_avx2( template inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* right_base, - __m256i irow_left, __m256i offset_right, - int bit_offset = 0) { + __m256i irow_left, __m256i offset_right_lo, + __m256i offset_right_hi, int bit_offset = 0) { __m256i left; switch (column_width) { case 0: { @@ -315,7 +312,9 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r ARROW_DCHECK(false); } - __m256i right = UnsignedOffsetSafeGather32<1>((int const*)right_base, offset_right); + __m128i right_lo = _mm256_i64gather_epi32((int const*)right_base, offset_right_lo, 1); + __m128i right_hi = _mm256_i64gather_epi32((int const*)right_base, offset_right_hi, 1); + __m256i right = _mm256_set_m128i(right_hi, right_lo); if (column_width != sizeof(uint32_t)) { constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff; right = _mm256_and_si256(right, _mm256_set1_epi32(mask)); @@ -333,8 +332,8 @@ inline uint64_t CompareSelected8_avx2(const uint8_t* left_base, const uint8_t* r template inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_base, - uint32_t irow_left_first, __m256i offset_right, - int bit_offset = 0) { + uint32_t irow_left_first, __m256i offset_right_lo, + __m256i offset_right_hi, int bit_offset = 0) { __m256i left; switch (column_width) { case 0: { @@ -364,7 +363,9 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas ARROW_DCHECK(false); } - __m256i right = UnsignedOffsetSafeGather32<1>((int const*)right_base, offset_right); + __m128i right_lo = _mm256_i64gather_epi32((int const*)right_base, offset_right_lo, 1); + __m128i right_hi = _mm256_i64gather_epi32((int const*)right_base, offset_right_hi, 1); + __m256i right = _mm256_set_m128i(right_hi, right_lo); if (column_width != sizeof(uint32_t)) { constexpr uint32_t mask = column_width == 0 || column_width == 1 ? 0xff : 0xffff; right = _mm256_and_si256(right, _mm256_set1_epi32(mask)); @@ -383,7 +384,7 @@ inline uint64_t Compare8_avx2(const uint8_t* left_base, const uint8_t* right_bas template inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* right_base, __m256i irow_left, uint32_t irow_left_first, - __m256i offset_right) { + __m256i offset_right_lo, __m256i offset_right_hi) { auto left_base_i64 = reinterpret_cast(left_base); __m256i left_lo, left_hi; @@ -400,10 +401,8 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig } auto right_base_i64 = reinterpret_cast(right_base); - __m256i right_lo = - UnsignedOffsetSafeGather64<1>(right_base_i64, _mm256_castsi256_si128(offset_right)); - __m256i right_hi = UnsignedOffsetSafeGather64<1>( - right_base_i64, _mm256_extracti128_si256(offset_right, 1)); + __m256i right_lo = _mm256_i64gather_epi64(right_base_i64, offset_right_lo, 1); + __m256i right_hi = _mm256_i64gather_epi64(right_base_i64, offset_right_hi, 1); uint32_t result_lo = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_lo, right_lo)); uint32_t result_hi = _mm256_movemask_epi8(_mm256_cmpeq_epi64(left_hi, right_hi)); return result_lo | (static_cast(result_hi) << 32); @@ -412,13 +411,19 @@ inline uint64_t Compare8_64bit_avx2(const uint8_t* left_base, const uint8_t* rig template inline uint64_t Compare8_Binary_avx2(uint32_t length, const uint8_t* left_base, const uint8_t* right_base, __m256i irow_left, - uint32_t irow_left_first, __m256i offset_right) { + uint32_t irow_left_first, __m256i offset_right_lo, + __m256i offset_right_hi) { uint32_t irow_left_array[8]; - uint32_t offset_right_array[8]; + RowTableImpl::offset_type offset_right_array[8]; if (use_selection) { _mm256_storeu_si256(reinterpret_cast<__m256i*>(irow_left_array), irow_left); } - _mm256_storeu_si256(reinterpret_cast<__m256i*>(offset_right_array), offset_right); + static_assert( + sizeof(RowTableImpl::offset_type) * 4 == sizeof(__m256i), + "Unexpected RowTableImpl::offset_type size in KeyCompare::Compare8_Binary_avx2"); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(offset_right_array), offset_right_lo); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(&offset_right_array[4]), + offset_right_hi); // Non-zero length guarantees no underflow int32_t num_loops_less_one = (static_cast(length) + 31) / 32 - 1; @@ -463,13 +468,14 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [bit_offset](const uint8_t* left_base, const uint8_t* right_base, - uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) { + uint32_t irow_left_base, __m256i irow_left, __m256i offset_right_lo, + __m256i offset_right_hi) { if (use_selection) { return CompareSelected8_avx2<0>(left_base, right_base, irow_left, - offset_right, bit_offset); + offset_right_lo, offset_right_hi, bit_offset); } else { - return Compare8_avx2<0>(left_base, right_base, irow_left_base, offset_right, - bit_offset); + return Compare8_avx2<0>(left_base, right_base, irow_left_base, + offset_right_lo, offset_right_hi, bit_offset); } }); } else if (col_width == 1) { @@ -477,12 +483,13 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base, - __m256i irow_left, __m256i offset_right) { + __m256i irow_left, __m256i offset_right_lo, __m256i offset_right_hi) { if (use_selection) { return CompareSelected8_avx2<1>(left_base, right_base, irow_left, - offset_right); + offset_right_lo, offset_right_hi); } else { - return Compare8_avx2<1>(left_base, right_base, irow_left_base, offset_right); + return Compare8_avx2<1>(left_base, right_base, irow_left_base, + offset_right_lo, offset_right_hi); } }); } else if (col_width == 2) { @@ -490,12 +497,13 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base, - __m256i irow_left, __m256i offset_right) { + __m256i irow_left, __m256i offset_right_lo, __m256i offset_right_hi) { if (use_selection) { return CompareSelected8_avx2<2>(left_base, right_base, irow_left, - offset_right); + offset_right_lo, offset_right_hi); } else { - return Compare8_avx2<2>(left_base, right_base, irow_left_base, offset_right); + return Compare8_avx2<2>(left_base, right_base, irow_left_base, + offset_right_lo, offset_right_hi); } }); } else if (col_width == 4) { @@ -503,12 +511,13 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base, - __m256i irow_left, __m256i offset_right) { + __m256i irow_left, __m256i offset_right_lo, __m256i offset_right_hi) { if (use_selection) { return CompareSelected8_avx2<4>(left_base, right_base, irow_left, - offset_right); + offset_right_lo, offset_right_hi); } else { - return Compare8_avx2<4>(left_base, right_base, irow_left_base, offset_right); + return Compare8_avx2<4>(left_base, right_base, irow_left_base, + offset_right_lo, offset_right_hi); } }); } else if (col_width == 8) { @@ -516,19 +525,22 @@ uint32_t KeyCompare::CompareBinaryColumnToRowImp_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [](const uint8_t* left_base, const uint8_t* right_base, uint32_t irow_left_base, - __m256i irow_left, __m256i offset_right) { + __m256i irow_left, __m256i offset_right_lo, __m256i offset_right_hi) { return Compare8_64bit_avx2(left_base, right_base, irow_left, - irow_left_base, offset_right); + irow_left_base, offset_right_lo, + offset_right_hi); }); } else { return CompareBinaryColumnToRowHelper_avx2( offset_within_row, num_rows_to_compare, sel_left_maybe_null, left_to_right_map, ctx, col, rows, match_bytevector, [&col](const uint8_t* left_base, const uint8_t* right_base, - uint32_t irow_left_base, __m256i irow_left, __m256i offset_right) { + uint32_t irow_left_base, __m256i irow_left, __m256i offset_right_lo, + __m256i offset_right_hi) { uint32_t length = col.metadata().fixed_length; - return Compare8_Binary_avx2( - length, left_base, right_base, irow_left, irow_left_base, offset_right); + return Compare8_Binary_avx2(length, left_base, right_base, + irow_left, irow_left_base, + offset_right_lo, offset_right_hi); }); } } @@ -541,7 +553,7 @@ void KeyCompare::CompareVarBinaryColumnToRowImp_avx2( LightContext* ctx, const KeyColumnArray& col, const RowTableImpl& rows, uint8_t* match_bytevector) { const uint32_t* offsets_left = col.offsets(); - const uint32_t* offsets_right = rows.offsets(); + const RowTableImpl::offset_type* offsets_right = rows.offsets(); const uint8_t* rows_left = col.data(2); const uint8_t* rows_right = rows.data(2); for (uint32_t i = 0; i < num_rows_to_compare; ++i) { @@ -549,7 +561,7 @@ void KeyCompare::CompareVarBinaryColumnToRowImp_avx2( uint32_t irow_right = left_to_right_map[irow_left]; uint32_t begin_left = offsets_left[irow_left]; uint32_t length_left = offsets_left[irow_left + 1] - begin_left; - uint32_t begin_right = offsets_right[irow_right]; + RowTableImpl::offset_type begin_right = offsets_right[irow_right]; uint32_t length_right; uint32_t offset_within_row; if (!is_first_varbinary_col) { diff --git a/cpp/src/arrow/compute/row/compare_test.cc b/cpp/src/arrow/compute/row/compare_test.cc index 22af7e067d8..5e8ee7c58a7 100644 --- a/cpp/src/arrow/compute/row/compare_test.cc +++ b/cpp/src/arrow/compute/row/compare_test.cc @@ -27,7 +27,12 @@ namespace arrow { namespace compute { using arrow::bit_util::BytesForBits; +using arrow::bit_util::GetBit; +using arrow::gen::Constant; +using arrow::gen::Random; +using arrow::internal::CountSetBits; using arrow::internal::CpuInfo; +using arrow::random::kSeedMax; using arrow::random::RandomArrayGenerator; using arrow::util::MiniBatch; using arrow::util::TempVectorStack; @@ -106,7 +111,7 @@ TEST(KeyCompare, CompareColumnsToRowsCuriousFSB) { true, match_bitvector.data()); for (int i = 0; i < num_rows; ++i) { SCOPED_TRACE(i); - ASSERT_EQ(arrow::bit_util::GetBit(match_bitvector.data(), i), i != 6); + ASSERT_EQ(GetBit(match_bitvector.data(), i), i != 6); } } } @@ -166,9 +171,111 @@ TEST(KeyCompare, CompareColumnsToRowsTempStackUsage) { } } +namespace { + +Result MakeRowTableFromExecBatch(const ExecBatch& batch) { + RowTableImpl row_table; + + std::vector column_metadatas; + RETURN_NOT_OK(ColumnMetadatasFromExecBatch(batch, &column_metadatas)); + RowTableMetadata table_metadata; + table_metadata.FromColumnMetadataVector(column_metadatas, sizeof(uint64_t), + sizeof(uint64_t)); + RETURN_NOT_OK(row_table.Init(default_memory_pool(), table_metadata)); + std::vector row_ids(batch.length); + std::iota(row_ids.begin(), row_ids.end(), 0); + RowTableEncoder row_encoder; + row_encoder.Init(column_metadatas, sizeof(uint64_t), sizeof(uint64_t)); + std::vector column_arrays; + RETURN_NOT_OK(ColumnArraysFromExecBatch(batch, &column_arrays)); + row_encoder.PrepareEncodeSelected(0, batch.length, column_arrays); + RETURN_NOT_OK(row_encoder.EncodeSelected( + &row_table, static_cast(batch.length), row_ids.data())); + + return row_table; +} + +Result RepeatRowTableUntil(const RowTableImpl& seed, int64_t num_rows) { + RowTableImpl row_table; + + RETURN_NOT_OK(row_table.Init(default_memory_pool(), seed.metadata())); + // Append the seed row table repeatedly to grow the row table to big enough. + while (row_table.length() < num_rows) { + RETURN_NOT_OK(row_table.AppendSelectionFrom(seed, + static_cast(seed.length()), + /*source_row_ids=*/NULLPTR)); + } + + return row_table; +} + +void AssertCompareColumnsToRowsAllMatch(const std::vector& columns, + const RowTableImpl& row_table, + const std::vector& row_ids_to_compare) { + uint32_t num_rows_to_compare = static_cast(row_ids_to_compare.size()); + + TempVectorStack stack; + ASSERT_OK( + stack.Init(default_memory_pool(), + KeyCompare::CompareColumnsToRowsTempStackUsage(num_rows_to_compare))); + LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack}; + + { + // No selection, output no match row ids. + uint32_t num_rows_no_match; + std::vector row_ids_out(num_rows_to_compare); + KeyCompare::CompareColumnsToRows(num_rows_to_compare, /*sel_left_maybe_null=*/NULLPTR, + row_ids_to_compare.data(), &ctx, &num_rows_no_match, + row_ids_out.data(), columns, row_table, + /*are_cols_in_encoding_order=*/true, + /*out_match_bitvector_maybe_null=*/NULLPTR); + ASSERT_EQ(num_rows_no_match, 0); + } + + { + // No selection, output match bit vector. + std::vector match_bitvector(BytesForBits(num_rows_to_compare)); + KeyCompare::CompareColumnsToRows( + num_rows_to_compare, /*sel_left_maybe_null=*/NULLPTR, row_ids_to_compare.data(), + &ctx, + /*out_num_rows=*/NULLPTR, /*out_sel_left_maybe_same=*/NULLPTR, columns, row_table, + /*are_cols_in_encoding_order=*/true, match_bitvector.data()); + ASSERT_EQ(CountSetBits(match_bitvector.data(), 0, num_rows_to_compare), + num_rows_to_compare); + } + + std::vector selection_left(num_rows_to_compare); + std::iota(selection_left.begin(), selection_left.end(), 0); + + { + // With selection, output no match row ids. + uint32_t num_rows_no_match; + std::vector row_ids_out(num_rows_to_compare); + KeyCompare::CompareColumnsToRows(num_rows_to_compare, selection_left.data(), + row_ids_to_compare.data(), &ctx, &num_rows_no_match, + row_ids_out.data(), columns, row_table, + /*are_cols_in_encoding_order=*/true, + /*out_match_bitvector_maybe_null=*/NULLPTR); + ASSERT_EQ(num_rows_no_match, 0); + } + + { + // With selection, output match bit vector. + std::vector match_bitvector(BytesForBits(num_rows_to_compare)); + KeyCompare::CompareColumnsToRows( + num_rows_to_compare, selection_left.data(), row_ids_to_compare.data(), &ctx, + /*out_num_rows=*/NULLPTR, /*out_sel_left_maybe_same=*/NULLPTR, columns, row_table, + /*are_cols_in_encoding_order=*/true, match_bitvector.data()); + ASSERT_EQ(CountSetBits(match_bitvector.data(), 0, num_rows_to_compare), + num_rows_to_compare); + } +} + +} // namespace + // Compare columns to rows at offsets over 2GB within a row table. // Certain AVX2 instructions may behave unexpectedly causing troubles like GH-41813. -TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsLarge)) { +TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver2GB)) { if constexpr (sizeof(void*) == 4) { GTEST_SKIP() << "Test only works on 64-bit platforms"; } @@ -176,128 +283,194 @@ TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsLarge)) { // The idea of this case is to create a row table using several fixed length columns and // one var length column (so the row is hence var length and has offset buffer), with // the overall data size exceeding 2GB. Then compare each row with itself. - constexpr int64_t two_gb = 2ll * 1024ll * 1024ll * 1024ll; + constexpr int64_t k2GB = 2ll * 1024ll * 1024ll * 1024ll; // The compare function requires the row id of the left column to be uint16_t, hence the // number of rows. constexpr int64_t num_rows = std::numeric_limits::max() + 1; const std::vector> fixed_length_types{uint64(), uint32()}; // The var length column should be a little smaller than 2GB to workaround the capacity // limitation in the var length builder. - constexpr int32_t var_length = two_gb / num_rows - 1; + constexpr int32_t var_length = k2GB / num_rows - 1; auto row_size = std::accumulate(fixed_length_types.begin(), fixed_length_types.end(), static_cast(var_length), [](int64_t acc, const std::shared_ptr& type) { return acc + type->byte_width(); }); // The overall size should be larger than 2GB. - ASSERT_GT(row_size * num_rows, two_gb); - - MemoryPool* pool = default_memory_pool(); + ASSERT_GT(row_size * num_rows, k2GB); - // The left side columns. - std::vector columns_left; + // The left side batch. ExecBatch batch_left; { std::vector values; // Several fixed length arrays containing random content. for (const auto& type : fixed_length_types) { - ASSERT_OK_AND_ASSIGN(auto value, ::arrow::gen::Random(type)->Generate(num_rows)); + ASSERT_OK_AND_ASSIGN(auto value, Random(type)->Generate(num_rows)); values.push_back(std::move(value)); } // A var length array containing 'X' repeated var_length times. - ASSERT_OK_AND_ASSIGN(auto value_var_length, - ::arrow::gen::Constant( - std::make_shared(std::string(var_length, 'X'))) - ->Generate(num_rows)); + ASSERT_OK_AND_ASSIGN( + auto value_var_length, + Constant(std::make_shared(std::string(var_length, 'X'))) + ->Generate(num_rows)); values.push_back(std::move(value_var_length)); batch_left = ExecBatch(std::move(values), num_rows); - ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &columns_left)); } + // The left side columns. + std::vector columns_left; + ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &columns_left)); + // The right side row table. - RowTableImpl row_table_right; - { - // Encode the row table with the left columns. - std::vector column_metadatas; - ASSERT_OK(ColumnMetadatasFromExecBatch(batch_left, &column_metadatas)); - RowTableMetadata table_metadata; - table_metadata.FromColumnMetadataVector(column_metadatas, sizeof(uint64_t), - sizeof(uint64_t)); - ASSERT_OK(row_table_right.Init(pool, table_metadata)); - std::vector row_ids(num_rows); - std::iota(row_ids.begin(), row_ids.end(), 0); - RowTableEncoder row_encoder; - row_encoder.Init(column_metadatas, sizeof(uint64_t), sizeof(uint64_t)); - row_encoder.PrepareEncodeSelected(0, num_rows, columns_left); - ASSERT_OK(row_encoder.EncodeSelected( - &row_table_right, static_cast(num_rows), row_ids.data())); - - // The row table must contain an offset buffer. - ASSERT_NE(row_table_right.offsets(), NULLPTR); - // The whole point of this test. - ASSERT_GT(row_table_right.offsets()[num_rows - 1], two_gb); - } + ASSERT_OK_AND_ASSIGN(RowTableImpl row_table_right, + MakeRowTableFromExecBatch(batch_left)); + // The row table must contain an offset buffer. + ASSERT_NE(row_table_right.data(2), NULLPTR); + // The whole point of this test. + ASSERT_GT(row_table_right.offsets()[num_rows - 1], k2GB); // The rows to compare. std::vector row_ids_to_compare(num_rows); std::iota(row_ids_to_compare.begin(), row_ids_to_compare.end(), 0); - TempVectorStack stack; - ASSERT_OK(stack.Init(pool, KeyCompare::CompareColumnsToRowsTempStackUsage(num_rows))); - LightContext ctx{CpuInfo::GetInstance()->hardware_flags(), &stack}; + AssertCompareColumnsToRowsAllMatch(columns_left, row_table_right, row_ids_to_compare); +} - { - // No selection, output no match row ids. - uint32_t num_rows_no_match; - std::vector row_ids_out(num_rows); - KeyCompare::CompareColumnsToRows(num_rows, /*sel_left_maybe_null=*/NULLPTR, - row_ids_to_compare.data(), &ctx, &num_rows_no_match, - row_ids_out.data(), columns_left, row_table_right, - /*are_cols_in_encoding_order=*/true, - /*out_match_bitvector_maybe_null=*/NULLPTR); - ASSERT_EQ(num_rows_no_match, 0); +// GH-43495: Compare fixed length columns to rows over 4GB within a row table. +TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver4GBFixedLength)) { + if constexpr (sizeof(void*) == 4) { + GTEST_SKIP() << "Test only works on 64-bit platforms"; } + // The idea of this case is to create a row table using one fixed length column (so the + // row is hence fixed length), with more than 4GB data. Then compare the rows located at + // over 4GB. + + // A small batch to append to the row table repeatedly to grow the row table to big + // enough. + constexpr int64_t num_rows_batch = std::numeric_limits::max(); + constexpr int fixed_length = 256; + + // The size of the row table is one batch larger than 4GB, and we'll compare the last + // num_rows_batch rows. + constexpr int64_t k4GB = 4ll * 1024 * 1024 * 1024; + constexpr int64_t num_rows_row_table = + (k4GB / (fixed_length * num_rows_batch) + 1) * num_rows_batch; + static_assert(num_rows_row_table < std::numeric_limits::max(), + "row table length must be less than uint32 max"); + static_assert(num_rows_row_table * fixed_length > k4GB, + "row table size must be greater than 4GB"); + + // The left side batch with num_rows_batch rows. + ExecBatch batch_left; { - // No selection, output match bit vector. - std::vector match_bitvector(BytesForBits(num_rows)); - KeyCompare::CompareColumnsToRows( - num_rows, /*sel_left_maybe_null=*/NULLPTR, row_ids_to_compare.data(), &ctx, - /*out_num_rows=*/NULLPTR, /*out_sel_left_maybe_same=*/NULLPTR, columns_left, - row_table_right, - /*are_cols_in_encoding_order=*/true, match_bitvector.data()); - ASSERT_EQ(arrow::internal::CountSetBits(match_bitvector.data(), 0, num_rows), - num_rows); + std::vector values; + + // A fixed length array containing random values. + ASSERT_OK_AND_ASSIGN( + auto value_fixed_length, + Random(fixed_size_binary(fixed_length))->Generate(num_rows_batch)); + values.push_back(std::move(value_fixed_length)); + + batch_left = ExecBatch(std::move(values), num_rows_batch); } - std::vector selection_left(num_rows); - std::iota(selection_left.begin(), selection_left.end(), 0); + // The left side columns with num_rows_batch rows. + std::vector columns_left; + ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &columns_left)); + + // The right side row table with num_rows_row_table rows. + ASSERT_OK_AND_ASSIGN( + RowTableImpl row_table_right, + RepeatRowTableUntil(MakeRowTableFromExecBatch(batch_left).ValueUnsafe(), + num_rows_row_table)); + // The row table must not contain a third buffer. + ASSERT_EQ(row_table_right.data(2), NULLPTR); + // The row data must be greater than 4GB. + ASSERT_GT(row_table_right.buffer_size(1), k4GB); + + // The rows to compare: the last num_rows_batch rows in the row table VS. the whole + // batch. + std::vector row_ids_to_compare(num_rows_batch); + std::iota(row_ids_to_compare.begin(), row_ids_to_compare.end(), + static_cast(num_rows_row_table - num_rows_batch)); + + AssertCompareColumnsToRowsAllMatch(columns_left, row_table_right, row_ids_to_compare); +} - { - // With selection, output no match row ids. - uint32_t num_rows_no_match; - std::vector row_ids_out(num_rows); - KeyCompare::CompareColumnsToRows(num_rows, selection_left.data(), - row_ids_to_compare.data(), &ctx, &num_rows_no_match, - row_ids_out.data(), columns_left, row_table_right, - /*are_cols_in_encoding_order=*/true, - /*out_match_bitvector_maybe_null=*/NULLPTR); - ASSERT_EQ(num_rows_no_match, 0); +// GH-43495: Compare var length columns to rows at offset over 4GB within a row table. +TEST(KeyCompare, LARGE_MEMORY_TEST(CompareColumnsToRowsOver4GBVarLength)) { + if constexpr (sizeof(void*) == 4) { + GTEST_SKIP() << "Test only works on 64-bit platforms"; } + // The idea of this case is to create a row table using one fixed length column and one + // var length column (so the row is hence var length and has offset buffer), with more + // than 4GB data. Then compare the rows located at over 4GB. + + // A small batch to append to the row table repeatedly to grow the row table to big + // enough. + constexpr int64_t num_rows_batch = std::numeric_limits::max(); + constexpr int fixed_length = 128; + // Involve some small randomness in the var length column. + constexpr int var_length_min = 128; + constexpr int var_length_max = 129; + constexpr double null_probability = 0.01; + + // The size of the row table is one batch larger than 4GB, and we'll compare the last + // num_rows_batch rows. + constexpr int64_t k4GB = 4ll * 1024 * 1024 * 1024; + constexpr int64_t size_row_min = fixed_length + var_length_min; + constexpr int64_t num_rows_row_table = + (k4GB / (size_row_min * num_rows_batch) + 1) * num_rows_batch; + static_assert(num_rows_row_table < std::numeric_limits::max(), + "row table length must be less than uint32 max"); + static_assert(num_rows_row_table * size_row_min > k4GB, + "row table size must be greater than 4GB"); + + // The left side batch with num_rows_batch rows. + ExecBatch batch_left; { - // With selection, output match bit vector. - std::vector match_bitvector(BytesForBits(num_rows)); - KeyCompare::CompareColumnsToRows( - num_rows, selection_left.data(), row_ids_to_compare.data(), &ctx, - /*out_num_rows=*/NULLPTR, /*out_sel_left_maybe_same=*/NULLPTR, columns_left, - row_table_right, - /*are_cols_in_encoding_order=*/true, match_bitvector.data()); - ASSERT_EQ(arrow::internal::CountSetBits(match_bitvector.data(), 0, num_rows), - num_rows); + std::vector values; + + // A fixed length array containing random values. + ASSERT_OK_AND_ASSIGN( + auto value_fixed_length, + Random(fixed_size_binary(fixed_length))->Generate(num_rows_batch)); + values.push_back(std::move(value_fixed_length)); + + // A var length array containing random binary of 128 or 129 bytes with small portion + // of nulls. + auto value_var_length = RandomArrayGenerator(kSeedMax).String( + num_rows_batch, var_length_min, var_length_max, null_probability); + values.push_back(std::move(value_var_length)); + + batch_left = ExecBatch(std::move(values), num_rows_batch); } + + // The left side columns with num_rows_batch rows. + std::vector columns_left; + ASSERT_OK(ColumnArraysFromExecBatch(batch_left, &columns_left)); + + // The right side row table with num_rows_row_table rows. + ASSERT_OK_AND_ASSIGN( + RowTableImpl row_table_right, + RepeatRowTableUntil(MakeRowTableFromExecBatch(batch_left).ValueUnsafe(), + num_rows_row_table)); + // The row table must contain an offset buffer. + ASSERT_NE(row_table_right.data(2), NULLPTR); + // At least the last row should be located at over 4GB. + ASSERT_GT(row_table_right.offsets()[num_rows_row_table - 1], k4GB); + + // The rows to compare: the last num_rows_batch rows in the row table VS. the whole + // batch. + std::vector row_ids_to_compare(num_rows_batch); + std::iota(row_ids_to_compare.begin(), row_ids_to_compare.end(), + static_cast(num_rows_row_table - num_rows_batch)); + + AssertCompareColumnsToRowsAllMatch(columns_left, row_table_right, row_ids_to_compare); } } // namespace compute diff --git a/cpp/src/arrow/compute/row/encode_internal.cc b/cpp/src/arrow/compute/row/encode_internal.cc index 658e0dffcac..127d43021d6 100644 --- a/cpp/src/arrow/compute/row/encode_internal.cc +++ b/cpp/src/arrow/compute/row/encode_internal.cc @@ -17,7 +17,6 @@ #include "arrow/compute/row/encode_internal.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/int_util_overflow.h" namespace arrow { namespace compute { @@ -265,7 +264,8 @@ void EncoderInteger::Decode(uint32_t start_row, uint32_t num_rows, num_rows * row_size); } else if (rows.metadata().is_fixed_length) { uint32_t row_size = rows.metadata().fixed_length; - const uint8_t* row_base = rows.data(1) + start_row * row_size; + const uint8_t* row_base = + rows.data(1) + static_cast(start_row) * row_size; row_base += offset_within_row; uint8_t* col_base = col_prep.mutable_data(1); switch (col_prep.metadata().fixed_length) { @@ -296,7 +296,7 @@ void EncoderInteger::Decode(uint32_t start_row, uint32_t num_rows, DCHECK(false); } } else { - const uint32_t* row_offsets = rows.offsets() + start_row; + const RowTableImpl::offset_type* row_offsets = rows.offsets() + start_row; const uint8_t* row_base = rows.data(2); row_base += offset_within_row; uint8_t* col_base = col_prep.mutable_data(1); @@ -362,14 +362,14 @@ void EncoderBinary::EncodeSelectedImp(uint32_t offset_within_row, RowTableImpl* } else { const uint8_t* src_base = col.data(1); uint8_t* dst = rows->mutable_data(2) + offset_within_row; - const uint32_t* offsets = rows->offsets(); + const RowTableImpl::offset_type* offsets = rows->offsets(); for (uint32_t i = 0; i < num_selected; ++i) { copy_fn(dst + offsets[i], src_base, selection[i]); } if (col.data(0)) { const uint8_t* non_null_bits = col.data(0); uint8_t* dst = rows->mutable_data(2) + offset_within_row; - const uint32_t* offsets = rows->offsets(); + const RowTableImpl::offset_type* offsets = rows->offsets(); for (uint32_t i = 0; i < num_selected; ++i) { bool is_null = !bit_util::GetBit(non_null_bits, selection[i] + col.bit_offset(0)); if (is_null) { @@ -585,10 +585,12 @@ void EncoderBinaryPair::DecodeImp(uint32_t num_rows_to_skip, uint32_t start_row, uint8_t* dst_B = col2->mutable_data(1); uint32_t fixed_length = rows.metadata().fixed_length; - const uint32_t* offsets; + const RowTableImpl::offset_type* offsets; const uint8_t* src_base; if (is_row_fixed_length) { - src_base = rows.data(1) + fixed_length * start_row + offset_within_row; + src_base = rows.data(1) + + static_cast(start_row) * fixed_length + + offset_within_row; offsets = nullptr; } else { src_base = rows.data(2) + offset_within_row; @@ -640,7 +642,7 @@ void EncoderOffsets::Decode(uint32_t start_row, uint32_t num_rows, // The Nth element is the sum of all the lengths of varbinary columns data in // that row, up to and including Nth varbinary column. - const uint32_t* row_offsets = rows.offsets() + start_row; + const RowTableImpl::offset_type* row_offsets = rows.offsets() + start_row; // Set the base offset for each column for (size_t col = 0; col < varbinary_cols->size(); ++col) { @@ -658,8 +660,8 @@ void EncoderOffsets::Decode(uint32_t start_row, uint32_t num_rows, // Update the offset of each column uint32_t offset_within_row = rows.metadata().fixed_length; for (size_t col = 0; col < varbinary_cols->size(); ++col) { - offset_within_row += - RowTableMetadata::padding_for_alignment(offset_within_row, string_alignment); + offset_within_row += RowTableMetadata::padding_for_alignment_within_row( + offset_within_row, string_alignment); uint32_t length = varbinary_ends[col] - offset_within_row; offset_within_row = varbinary_ends[col]; uint32_t* col_offsets = (*varbinary_cols)[col].mutable_offsets(); @@ -676,7 +678,7 @@ Status EncoderOffsets::GetRowOffsetsSelected(RowTableImpl* rows, return Status::OK(); } - uint32_t* row_offsets = rows->mutable_offsets(); + RowTableImpl::offset_type* row_offsets = rows->mutable_offsets(); for (uint32_t i = 0; i < num_selected; ++i) { row_offsets[i] = rows->metadata().fixed_length; } @@ -688,7 +690,7 @@ Status EncoderOffsets::GetRowOffsetsSelected(RowTableImpl* rows, for (uint32_t i = 0; i < num_selected; ++i) { uint32_t irow = selection[i]; uint32_t length = col_offsets[irow + 1] - col_offsets[irow]; - row_offsets[i] += RowTableMetadata::padding_for_alignment( + row_offsets[i] += RowTableMetadata::padding_for_alignment_row( row_offsets[i], rows->metadata().string_alignment); row_offsets[i] += length; } @@ -708,20 +710,13 @@ Status EncoderOffsets::GetRowOffsetsSelected(RowTableImpl* rows, } } - uint32_t sum = 0; + int64_t sum = 0; int row_alignment = rows->metadata().row_alignment; for (uint32_t i = 0; i < num_selected; ++i) { - uint32_t length = row_offsets[i]; - length += RowTableMetadata::padding_for_alignment(length, row_alignment); + RowTableImpl::offset_type length = row_offsets[i]; + length += RowTableMetadata::padding_for_alignment_row(length, row_alignment); row_offsets[i] = sum; - uint32_t sum_maybe_overflow = 0; - if (ARROW_PREDICT_FALSE( - arrow::internal::AddWithOverflow(sum, length, &sum_maybe_overflow))) { - return Status::Invalid( - "Offset overflow detected in EncoderOffsets::GetRowOffsetsSelected for row ", i, - " of length ", length, " bytes, current length in total is ", sum, " bytes"); - } - sum = sum_maybe_overflow; + sum += length; } row_offsets[num_selected] = sum; @@ -732,7 +727,7 @@ template void EncoderOffsets::EncodeSelectedImp(uint32_t ivarbinary, RowTableImpl* rows, const std::vector& cols, uint32_t num_selected, const uint16_t* selection) { - const uint32_t* row_offsets = rows->offsets(); + const RowTableImpl::offset_type* row_offsets = rows->offsets(); uint8_t* row_base = rows->mutable_data(2) + rows->metadata().varbinary_end_array_offset + ivarbinary * sizeof(uint32_t); @@ -753,7 +748,7 @@ void EncoderOffsets::EncodeSelectedImp(uint32_t ivarbinary, RowTableImpl* rows, row[0] = rows->metadata().fixed_length + length; } else { row[0] = row[-1] + - RowTableMetadata::padding_for_alignment( + RowTableMetadata::padding_for_alignment_within_row( row[-1], rows->metadata().string_alignment) + length; } @@ -857,7 +852,7 @@ void EncoderNulls::Decode(uint32_t start_row, uint32_t num_rows, const RowTableI void EncoderVarBinary::EncodeSelected(uint32_t ivarbinary, RowTableImpl* rows, const KeyColumnArray& cols, uint32_t num_selected, const uint16_t* selection) { - const uint32_t* row_offsets = rows->offsets(); + const RowTableImpl::offset_type* row_offsets = rows->offsets(); uint8_t* row_base = rows->mutable_data(2); const uint32_t* col_offsets = cols.offsets(); const uint8_t* col_base = cols.data(2); diff --git a/cpp/src/arrow/compute/row/encode_internal.h b/cpp/src/arrow/compute/row/encode_internal.h index 0618ddd8e4b..37538fcc4b8 100644 --- a/cpp/src/arrow/compute/row/encode_internal.h +++ b/cpp/src/arrow/compute/row/encode_internal.h @@ -173,7 +173,7 @@ class EncoderBinary { copy_fn(dst, src, col_width); } } else { - const uint32_t* row_offsets = rows_const->offsets(); + const RowTableImpl::offset_type* row_offsets = rows_const->offsets(); for (uint32_t i = 0; i < num_rows; ++i) { const uint8_t* src; uint8_t* dst; @@ -267,7 +267,8 @@ class EncoderVarBinary { ARROW_DCHECK(!rows_const->metadata().is_fixed_length && !col_const->metadata().is_fixed_length); - const uint32_t* row_offsets_for_batch = rows_const->offsets() + start_row; + const RowTableImpl::offset_type* row_offsets_for_batch = + rows_const->offsets() + start_row; const uint32_t* col_offsets = col_const->offsets(); uint32_t col_offset_next = col_offsets[0]; @@ -275,7 +276,7 @@ class EncoderVarBinary { uint32_t col_offset = col_offset_next; col_offset_next = col_offsets[i + 1]; - uint32_t row_offset = row_offsets_for_batch[i]; + RowTableImpl::offset_type row_offset = row_offsets_for_batch[i]; const uint8_t* row = rows_const->data(2) + row_offset; uint32_t offset_within_row; diff --git a/cpp/src/arrow/compute/row/encode_internal_avx2.cc b/cpp/src/arrow/compute/row/encode_internal_avx2.cc index 50969c7bd60..26f8e3a63de 100644 --- a/cpp/src/arrow/compute/row/encode_internal_avx2.cc +++ b/cpp/src/arrow/compute/row/encode_internal_avx2.cc @@ -75,10 +75,12 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows uint8_t* col_vals_B = col2->mutable_data(1); uint32_t fixed_length = rows.metadata().fixed_length; - const uint32_t* offsets; + const RowTableImpl::offset_type* offsets; const uint8_t* src_base; if (is_row_fixed_length) { - src_base = rows.data(1) + fixed_length * start_row + offset_within_row; + src_base = rows.data(1) + + static_cast(fixed_length) * start_row + + offset_within_row; offsets = nullptr; } else { src_base = rows.data(2) + offset_within_row; @@ -99,7 +101,7 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows src2 = reinterpret_cast(src + fixed_length * 2); src3 = reinterpret_cast(src + fixed_length * 3); } else { - const uint32_t* row_offsets = offsets + i * unroll; + const RowTableImpl::offset_type* row_offsets = offsets + i * unroll; const uint8_t* src = src_base; src0 = reinterpret_cast(src + row_offsets[0]); src1 = reinterpret_cast(src + row_offsets[1]); @@ -140,7 +142,7 @@ uint32_t EncoderBinaryPair::DecodeImp_avx2(uint32_t start_row, uint32_t num_rows } } } else { - const uint32_t* row_offsets = offsets + i * unroll; + const RowTableImpl::offset_type* row_offsets = offsets + i * unroll; const uint8_t* src = src_base; for (int j = 0; j < unroll; ++j) { if (col_width == 1) { diff --git a/cpp/src/arrow/compute/row/row_internal.cc b/cpp/src/arrow/compute/row/row_internal.cc index 746ed950ffa..aa7e62add45 100644 --- a/cpp/src/arrow/compute/row/row_internal.cc +++ b/cpp/src/arrow/compute/row/row_internal.cc @@ -18,7 +18,6 @@ #include "arrow/compute/row/row_internal.h" #include "arrow/compute/util.h" -#include "arrow/util/int_util_overflow.h" namespace arrow { namespace compute { @@ -128,8 +127,8 @@ void RowTableMetadata::FromColumnMetadataVector( const KeyColumnMetadata& col = cols[column_order[i]]; if (col.is_fixed_length && col.fixed_length != 0 && ARROW_POPCOUNT64(col.fixed_length) != 1) { - offset_within_row += RowTableMetadata::padding_for_alignment(offset_within_row, - string_alignment, col); + offset_within_row += RowTableMetadata::padding_for_alignment_within_row( + offset_within_row, string_alignment, col); } column_offsets[i] = offset_within_row; if (!col.is_fixed_length) { @@ -155,7 +154,7 @@ void RowTableMetadata::FromColumnMetadataVector( is_fixed_length = (num_varbinary_cols == 0); fixed_length = offset_within_row + - RowTableMetadata::padding_for_alignment( + RowTableMetadata::padding_for_alignment_within_row( offset_within_row, num_varbinary_cols == 0 ? row_alignment : string_alignment); // We set the number of bytes per row storing null masks of individual key columns @@ -191,7 +190,7 @@ Status RowTableImpl::Init(MemoryPool* pool, const RowTableMetadata& metadata) { auto offsets, AllocateResizableBuffer(size_offsets(kInitialRowsCapacity), pool_)); offsets_ = std::move(offsets); memset(offsets_->mutable_data(), 0, size_offsets(kInitialRowsCapacity)); - reinterpret_cast(offsets_->mutable_data())[0] = 0; + reinterpret_cast(offsets_->mutable_data())[0] = 0; ARROW_ASSIGN_OR_RAISE( auto rows, @@ -226,7 +225,7 @@ void RowTableImpl::Clean() { has_any_nulls_ = false; if (!metadata_.is_fixed_length) { - reinterpret_cast(offsets_->mutable_data())[0] = 0; + reinterpret_cast(offsets_->mutable_data())[0] = 0; } } @@ -235,7 +234,7 @@ int64_t RowTableImpl::size_null_masks(int64_t num_rows) const { } int64_t RowTableImpl::size_offsets(int64_t num_rows) const { - return (num_rows + 1) * sizeof(uint32_t) + kPaddingForVectors; + return (num_rows + 1) * sizeof(offset_type) + kPaddingForVectors; } int64_t RowTableImpl::size_rows_fixed_length(int64_t num_rows) const { @@ -326,23 +325,15 @@ Status RowTableImpl::AppendSelectionFrom(const RowTableImpl& from, if (!metadata_.is_fixed_length) { // Varying-length rows - auto from_offsets = reinterpret_cast(from.offsets_->data()); - auto to_offsets = reinterpret_cast(offsets_->mutable_data()); - uint32_t total_length = to_offsets[num_rows_]; - uint32_t total_length_to_append = 0; + auto from_offsets = reinterpret_cast(from.offsets_->data()); + auto to_offsets = reinterpret_cast(offsets_->mutable_data()); + offset_type total_length = to_offsets[num_rows_]; + int64_t total_length_to_append = 0; for (uint32_t i = 0; i < num_rows_to_append; ++i) { uint16_t row_id = source_row_ids ? source_row_ids[i] : i; - uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id]; + int64_t length = from_offsets[row_id + 1] - from_offsets[row_id]; total_length_to_append += length; - uint32_t to_offset_maybe_overflow = 0; - if (ARROW_PREDICT_FALSE(arrow::internal::AddWithOverflow( - total_length, total_length_to_append, &to_offset_maybe_overflow))) { - return Status::Invalid( - "Offset overflow detected in RowTableImpl::AppendSelectionFrom for row ", - num_rows_ + i, " of length ", length, " bytes, current length in total is ", - to_offsets[num_rows_ + i], " bytes"); - } - to_offsets[num_rows_ + i + 1] = to_offset_maybe_overflow; + to_offsets[num_rows_ + i + 1] = total_length + total_length_to_append; } RETURN_NOT_OK(ResizeOptionalVaryingLengthBuffer(total_length_to_append)); @@ -351,7 +342,8 @@ Status RowTableImpl::AppendSelectionFrom(const RowTableImpl& from, uint8_t* dst = rows_->mutable_data() + total_length; for (uint32_t i = 0; i < num_rows_to_append; ++i) { uint16_t row_id = source_row_ids ? source_row_ids[i] : i; - uint32_t length = from_offsets[row_id + 1] - from_offsets[row_id]; + int64_t length = from_offsets[row_id + 1] - from_offsets[row_id]; + DCHECK_LE(length, std::numeric_limits::max()); auto src64 = reinterpret_cast(src + from_offsets[row_id]); auto dst64 = reinterpret_cast(dst); for (uint32_t j = 0; j < bit_util::CeilDiv(length, 8); ++j) { @@ -397,7 +389,7 @@ Status RowTableImpl::AppendSelectionFrom(const RowTableImpl& from, } Status RowTableImpl::AppendEmpty(uint32_t num_rows_to_append, - uint32_t num_extra_bytes_to_append) { + int64_t num_extra_bytes_to_append) { RETURN_NOT_OK(ResizeFixedLengthBuffers(num_rows_to_append)); if (!metadata_.is_fixed_length) { RETURN_NOT_OK(ResizeOptionalVaryingLengthBuffer(num_extra_bytes_to_append)); diff --git a/cpp/src/arrow/compute/row/row_internal.h b/cpp/src/arrow/compute/row/row_internal.h index 93818fb14d6..094a9c31efe 100644 --- a/cpp/src/arrow/compute/row/row_internal.h +++ b/cpp/src/arrow/compute/row/row_internal.h @@ -30,6 +30,8 @@ namespace compute { /// Description of the data stored in a RowTable struct ARROW_EXPORT RowTableMetadata { + using offset_type = int64_t; + /// \brief True if there are no variable length columns in the table bool is_fixed_length; @@ -78,26 +80,35 @@ struct ARROW_EXPORT RowTableMetadata { /// Offsets within a row to fields in their encoding order. std::vector column_offsets; - /// Rounding up offset to the nearest multiple of alignment value. + /// Rounding up offset within row to the nearest multiple of alignment value. /// Alignment must be a power of 2. - static inline uint32_t padding_for_alignment(uint32_t offset, int required_alignment) { + static inline uint32_t padding_for_alignment_within_row(uint32_t offset, + int required_alignment) { ARROW_DCHECK(ARROW_POPCOUNT64(required_alignment) == 1); return static_cast((-static_cast(offset)) & (required_alignment - 1)); } - /// Rounding up offset to the beginning of next column, + /// Rounding up offset within row to the beginning of next column, /// choosing required alignment based on the data type of that column. - static inline uint32_t padding_for_alignment(uint32_t offset, int string_alignment, - const KeyColumnMetadata& col_metadata) { + static inline uint32_t padding_for_alignment_within_row( + uint32_t offset, int string_alignment, const KeyColumnMetadata& col_metadata) { if (!col_metadata.is_fixed_length || ARROW_POPCOUNT64(col_metadata.fixed_length) <= 1) { return 0; } else { - return padding_for_alignment(offset, string_alignment); + return padding_for_alignment_within_row(offset, string_alignment); } } + /// Rounding up row offset to the nearest multiple of alignment value. + /// Alignment must be a power of 2. + static inline offset_type padding_for_alignment_row(offset_type row_offset, + int required_alignment) { + ARROW_DCHECK(ARROW_POPCOUNT64(required_alignment) == 1); + return (-row_offset) & (required_alignment - 1); + } + /// Returns an array of offsets within a row of ends of varbinary fields. inline const uint32_t* varbinary_end_array(const uint8_t* row) const { ARROW_DCHECK(!is_fixed_length); @@ -127,7 +138,7 @@ struct ARROW_EXPORT RowTableMetadata { ARROW_DCHECK(varbinary_id > 0); const uint32_t* varbinary_end = varbinary_end_array(row); uint32_t offset = varbinary_end[varbinary_id - 1]; - offset += padding_for_alignment(offset, string_alignment); + offset += padding_for_alignment_within_row(offset, string_alignment); *out_offset = offset; *out_length = varbinary_end[varbinary_id] - offset; } @@ -161,6 +172,8 @@ struct ARROW_EXPORT RowTableMetadata { /// The row table is not safe class ARROW_EXPORT RowTableImpl { public: + using offset_type = RowTableMetadata::offset_type; + RowTableImpl(); /// \brief Initialize a row array for use /// @@ -175,7 +188,7 @@ class ARROW_EXPORT RowTableImpl { /// \param num_extra_bytes_to_append For tables storing variable-length data this /// should be a guess of how many data bytes will be needed to populate the /// data. This is ignored if there are no variable-length columns - Status AppendEmpty(uint32_t num_rows_to_append, uint32_t num_extra_bytes_to_append); + Status AppendEmpty(uint32_t num_rows_to_append, int64_t num_extra_bytes_to_append); /// \brief Append rows from a source table /// \param from The table to append from /// \param num_rows_to_append The number of rows to append @@ -201,8 +214,12 @@ class ARROW_EXPORT RowTableImpl { } return NULLPTR; } - const uint32_t* offsets() const { return reinterpret_cast(data(1)); } - uint32_t* mutable_offsets() { return reinterpret_cast(mutable_data(1)); } + const offset_type* offsets() const { + return reinterpret_cast(data(1)); + } + offset_type* mutable_offsets() { + return reinterpret_cast(mutable_data(1)); + } const uint8_t* null_masks() const { return null_masks_->data(); } uint8_t* null_masks() { return null_masks_->mutable_data(); } diff --git a/cpp/src/arrow/compute/row/row_test.cc b/cpp/src/arrow/compute/row/row_test.cc index 75f981fb128..6aed9e43278 100644 --- a/cpp/src/arrow/compute/row/row_test.cc +++ b/cpp/src/arrow/compute/row/row_test.cc @@ -123,7 +123,7 @@ TEST(RowTableMemoryConsumption, Encode) { ASSERT_GT(actual_null_mask_size * 2, row_table.buffer_size(0) - padding_for_vectors); - int64_t actual_offset_size = num_rows * sizeof(uint32_t); + int64_t actual_offset_size = num_rows * sizeof(RowTableImpl::offset_type); ASSERT_LE(actual_offset_size, row_table.buffer_size(1) - padding_for_vectors); ASSERT_GT(actual_offset_size * 2, row_table.buffer_size(1) - padding_for_vectors); @@ -134,15 +134,14 @@ TEST(RowTableMemoryConsumption, Encode) { } } -// GH-43202: Ensure that when offset overflow happens in encoding the row table, an -// explicit error is raised instead of a silent wrong result. -TEST(RowTableOffsetOverflow, LARGE_MEMORY_TEST(Encode)) { +// GH-43495: Ensure that we can build a row table with more than 4GB row data. +TEST(RowTableLarge, LARGE_MEMORY_TEST(Encode)) { if constexpr (sizeof(void*) == 4) { GTEST_SKIP() << "Test only works on 64-bit platforms"; } - // Use 8 512MB var-length rows (occupies 4GB+) to overflow the offset in the row table. - constexpr int64_t num_rows = 8; + // Use 9 512MB var-length rows to occupy more than 4GB memory. + constexpr int64_t num_rows = 9; constexpr int64_t length_per_binary = 512 * 1024 * 1024; constexpr int64_t row_alignment = sizeof(uint32_t); constexpr int64_t var_length_alignment = sizeof(uint32_t); @@ -174,39 +173,24 @@ TEST(RowTableOffsetOverflow, LARGE_MEMORY_TEST(Encode)) { // The rows to encode. std::vector row_ids(num_rows, 0); - // Encoding 7 rows should be fine. - { - row_encoder.PrepareEncodeSelected(0, num_rows - 1, columns); - ASSERT_OK(row_encoder.EncodeSelected(&row_table, static_cast(num_rows - 1), - row_ids.data())); - } + // Encode num_rows rows. + row_encoder.PrepareEncodeSelected(0, num_rows, columns); + ASSERT_OK(row_encoder.EncodeSelected(&row_table, static_cast(num_rows), + row_ids.data())); - // Encoding 8 rows should overflow. - { - int64_t length_per_row = table_metadata.fixed_length + length_per_binary; - std::stringstream expected_error_message; - expected_error_message << "Invalid: Offset overflow detected in " - "EncoderOffsets::GetRowOffsetsSelected for row " - << num_rows - 1 << " of length " << length_per_row - << " bytes, current length in total is " - << length_per_row * (num_rows - 1) << " bytes"; - row_encoder.PrepareEncodeSelected(0, num_rows, columns); - ASSERT_RAISES_WITH_MESSAGE( - Invalid, expected_error_message.str(), - row_encoder.EncodeSelected(&row_table, static_cast(num_rows), - row_ids.data())); - } + auto encoded_row_length = table_metadata.fixed_length + length_per_binary; + ASSERT_EQ(row_table.offsets()[num_rows - 1], encoded_row_length * (num_rows - 1)); + ASSERT_EQ(row_table.offsets()[num_rows], encoded_row_length * num_rows); } -// GH-43202: Ensure that when offset overflow happens in appending to the row table, an -// explicit error is raised instead of a silent wrong result. -TEST(RowTableOffsetOverflow, LARGE_MEMORY_TEST(AppendFrom)) { +// GH-43495: Ensure that we can build a row table with more than 4GB row data. +TEST(RowTableLarge, LARGE_MEMORY_TEST(AppendFrom)) { if constexpr (sizeof(void*) == 4) { GTEST_SKIP() << "Test only works on 64-bit platforms"; } - // Use 8 512MB var-length rows (occupies 4GB+) to overflow the offset in the row table. - constexpr int64_t num_rows = 8; + // Use 9 512MB var-length rows to occupy more than 4GB memory. + constexpr int64_t num_rows = 9; constexpr int64_t length_per_binary = 512 * 1024 * 1024; constexpr int64_t num_rows_seed = 1; constexpr int64_t row_alignment = sizeof(uint32_t); @@ -244,23 +228,15 @@ TEST(RowTableOffsetOverflow, LARGE_MEMORY_TEST(AppendFrom)) { RowTableImpl row_table; ASSERT_OK(row_table.Init(pool, table_metadata)); - // Appending the seed 7 times should be fine. - for (int i = 0; i < num_rows - 1; ++i) { + // Append seed num_rows times. + for (int i = 0; i < num_rows; ++i) { ASSERT_OK(row_table.AppendSelectionFrom(row_table_seed, num_rows_seed, /*source_row_ids=*/NULLPTR)); } - // Appending the seed the 8-th time should overflow. - int64_t length_per_row = table_metadata.fixed_length + length_per_binary; - std::stringstream expected_error_message; - expected_error_message - << "Invalid: Offset overflow detected in RowTableImpl::AppendSelectionFrom for row " - << num_rows - 1 << " of length " << length_per_row - << " bytes, current length in total is " << length_per_row * (num_rows - 1) - << " bytes"; - ASSERT_RAISES_WITH_MESSAGE(Invalid, expected_error_message.str(), - row_table.AppendSelectionFrom(row_table_seed, num_rows_seed, - /*source_row_ids=*/NULLPTR)); + auto encoded_row_length = table_metadata.fixed_length + length_per_binary; + ASSERT_EQ(row_table.offsets()[num_rows - 1], encoded_row_length * (num_rows - 1)); + ASSERT_EQ(row_table.offsets()[num_rows], encoded_row_length * num_rows); } } // namespace compute diff --git a/cpp/src/arrow/testing/random.cc b/cpp/src/arrow/testing/random.cc index c317fe7aef4..59de09fff83 100644 --- a/cpp/src/arrow/testing/random.cc +++ b/cpp/src/arrow/testing/random.cc @@ -473,19 +473,16 @@ std::shared_ptr RandomArrayGenerator::StringWithRepeats( return result; } -std::shared_ptr RandomArrayGenerator::FixedSizeBinary(int64_t size, - int32_t byte_width, - double null_probability, - int64_t alignment, - MemoryPool* memory_pool) { +std::shared_ptr RandomArrayGenerator::FixedSizeBinary( + int64_t size, int32_t byte_width, double null_probability, uint8_t min_byte, + uint8_t max_byte, int64_t alignment, MemoryPool* memory_pool) { if (null_probability < 0 || null_probability > 1) { ABORT_NOT_OK(Status::Invalid("null_probability must be between 0 and 1")); } // Visual Studio does not implement uniform_int_distribution for char types. using GenOpt = GenerateOptions>; - GenOpt options(seed(), static_cast('A'), static_cast('z'), - null_probability); + GenOpt options(seed(), min_byte, max_byte, null_probability); int64_t null_count = 0; auto null_bitmap = *AllocateEmptyBitmap(size, alignment, memory_pool); @@ -1087,7 +1084,9 @@ std::shared_ptr RandomArrayGenerator::ArrayOf(const Field& field, int64_t case Type::type::FIXED_SIZE_BINARY: { auto byte_width = internal::checked_pointer_cast(field.type())->byte_width(); - return *FixedSizeBinary(length, byte_width, null_probability, alignment, + return *FixedSizeBinary(length, byte_width, null_probability, + /*min_byte=*/static_cast('A'), + /*min_byte=*/static_cast('z'), alignment, memory_pool) ->View(field.type()); } @@ -1143,7 +1142,9 @@ std::shared_ptr RandomArrayGenerator::ArrayOf(const Field& field, int64_t // type means it's not a (useful) composition of other generators GENERATE_INTEGRAL_CASE_VIEW(Int64Type, DayTimeIntervalType); case Type::type::INTERVAL_MONTH_DAY_NANO: { - return *FixedSizeBinary(length, /*byte_width=*/16, null_probability, alignment, + return *FixedSizeBinary(length, /*byte_width=*/16, null_probability, + /*min_byte=*/static_cast('A'), + /*min_byte=*/static_cast('z'), alignment, memory_pool) ->View(month_day_nano_interval()); } diff --git a/cpp/src/arrow/testing/random.h b/cpp/src/arrow/testing/random.h index 1d97a3ada72..9c0c5baae0f 100644 --- a/cpp/src/arrow/testing/random.h +++ b/cpp/src/arrow/testing/random.h @@ -434,12 +434,18 @@ class ARROW_TESTING_EXPORT RandomArrayGenerator { /// \param[in] size the size of the array to generate /// \param[in] byte_width the byte width of fixed-size binary items /// \param[in] null_probability the probability of a value being null + /// \param[in] min_byte the lower bound of each byte in the binary determined by the + /// uniform distribution + /// \param[in] max_byte the upper bound of each byte in the binary determined by the + /// uniform distribution /// \param[in] alignment alignment for memory allocations (in bytes) /// \param[in] memory_pool memory pool to allocate memory from /// /// \return a generated Array std::shared_ptr FixedSizeBinary(int64_t size, int32_t byte_width, double null_probability = 0, + uint8_t min_byte = static_cast('A'), + uint8_t max_byte = static_cast('z'), int64_t alignment = kDefaultBufferAlignment, MemoryPool* memory_pool = default_memory_pool()); From c599fa0064a627d3b58d4eff821a34391120bcf6 Mon Sep 17 00:00:00 2001 From: Tom Scott-Coombes <62209801+tscottcoombes1@users.noreply.github.com> Date: Mon, 19 Aug 2024 16:13:35 +0100 Subject: [PATCH 035/157] GH-43554: [Go] Handle excluded fields (#43555) ### Rationale for this change We want to be able to handle excluded fields. ### What changes are included in this PR? * we no longer use the value of the field when getting the element type of a list (as the values are invalid for excluded fields) * similarly for map, key value pairs, we don't use the value is there is none * add some tests ### Are these changes tested? yes ### Are there any user-facing changes? no * GitHub Issue: #43554 Lead-authored-by: Tom Scott-Coombes Co-authored-by: Tom Scott-Coombes <62209801+tscottcoombes1@users.noreply.github.com> Co-authored-by: Matt Topol Co-authored-by: tscottcoombes1 <62209801+tscottcoombes1@users.noreply.github.com> Co-authored-by: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Signed-off-by: Joel Lubinitsky --- go/arrow/util/messages/types.proto | 46 ++ go/arrow/util/protobuf_reflect.go | 31 +- go/arrow/util/protobuf_reflect_test.go | 421 +++++++++++----- go/arrow/util/util_message/types.pb.go | 654 +++++++++++++++++++++++-- 4 files changed, 996 insertions(+), 156 deletions(-) diff --git a/go/arrow/util/messages/types.proto b/go/arrow/util/messages/types.proto index c085273ca35..79b922a22a3 100644 --- a/go/arrow/util/messages/types.proto +++ b/go/arrow/util/messages/types.proto @@ -54,3 +54,49 @@ message AllTheTypes { OPTION_1 = 1; } } + +message AllTheTypesNoAny { + string str = 1; + int32 int32 = 2; + int64 int64 = 3; + sint32 sint32 = 4; + sint64 sin64 = 5; + uint32 uint32 = 6; + uint64 uint64 = 7; + fixed32 fixed32 = 8; + fixed64 fixed64 = 9; + sfixed32 sfixed32 = 10; + bool bool = 11; + bytes bytes = 12; + double double = 13; + ExampleEnum enum = 14; + ExampleMessage message = 15; + oneof oneof { + string oneofstring = 16; + ExampleMessage oneofmessage = 17; + } + map simple_map = 19; + map complex_map = 20; + repeated string simple_list = 21; + repeated ExampleMessage complex_list = 22; + + enum ExampleEnum { + OPTION_0 = 0; + OPTION_1 = 1; + } +} + +message SimpleNested { + repeated ExampleMessage simple_a = 1; + repeated ExampleMessage simple_b = 2; +} + +message ComplexNested { + repeated AllTheTypesNoAny all_the_types_no_any_a = 1; + repeated AllTheTypesNoAny all_the_types_no_any_b = 2; +} + +message DeepNested { + ComplexNested complex_nested = 1; + SimpleNested simple_nested = 2; +} diff --git a/go/arrow/util/protobuf_reflect.go b/go/arrow/util/protobuf_reflect.go index 03153563b8c..c8cda96acf9 100644 --- a/go/arrow/util/protobuf_reflect.go +++ b/go/arrow/util/protobuf_reflect.go @@ -60,6 +60,7 @@ type ProtobufFieldReflection struct { rValue reflect.Value schemaOptions arrow.Field + isListItem bool } func (pfr *ProtobufFieldReflection) isNull() bool { @@ -170,7 +171,7 @@ func (pfr *ProtobufFieldReflection) isEnum() bool { } func (pfr *ProtobufFieldReflection) isStruct() bool { - return pfr.descriptor.Kind() == protoreflect.MessageKind && !pfr.descriptor.IsMap() && pfr.rValue.Kind() != reflect.Slice + return pfr.descriptor.Kind() == protoreflect.MessageKind && !pfr.descriptor.IsMap() && !pfr.isList() } func (pfr *ProtobufFieldReflection) isMap() bool { @@ -178,7 +179,7 @@ func (pfr *ProtobufFieldReflection) isMap() bool { } func (pfr *ProtobufFieldReflection) isList() bool { - return pfr.descriptor.IsList() && pfr.rValue.Kind() == reflect.Slice + return pfr.descriptor.IsList() && !pfr.isListItem } // ProtobufMessageReflection represents the metadata and values of a protobuf message @@ -218,11 +219,7 @@ func (psr ProtobufMessageReflection) getArrowFields() []arrow.Field { var fields []arrow.Field for pfr := range psr.generateStructFields() { - fields = append(fields, arrow.Field{ - Name: pfr.name(), - Type: pfr.getDataType(), - Nullable: true, - }) + fields = append(fields, pfr.arrowField()) } return fields @@ -237,12 +234,10 @@ func (pfr *ProtobufFieldReflection) asList() protobufListReflection { } func (plr protobufListReflection) getDataType() arrow.DataType { - for li := range plr.generateListItems() { - return arrow.ListOf(li.getDataType()) - } pfr := ProtobufFieldReflection{ descriptor: plr.descriptor, schemaOptions: plr.schemaOptions, + isListItem: true, } return arrow.ListOf(pfr.getDataType()) } @@ -401,6 +396,22 @@ func (pmr protobufMapReflection) generateKeyValuePairs() chan protobufMapKeyValu go func() { defer close(out) + if !pmr.rValue.IsValid() { + kvp := protobufMapKeyValuePairReflection{ + k: ProtobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapKey(), + schemaOptions: pmr.schemaOptions, + }, + v: ProtobufFieldReflection{ + parent: pmr.parent, + descriptor: pmr.descriptor.MapValue(), + schemaOptions: pmr.schemaOptions, + }, + } + out <- kvp + return + } for _, k := range pmr.rValue.MapKeys() { kvp := protobufMapKeyValuePairReflection{ k: ProtobufFieldReflection{ diff --git a/go/arrow/util/protobuf_reflect_test.go b/go/arrow/util/protobuf_reflect_test.go index 220552df8d8..7420aa72633 100644 --- a/go/arrow/util/protobuf_reflect_test.go +++ b/go/arrow/util/protobuf_reflect_test.go @@ -17,9 +17,12 @@ package util import ( - "strings" + "encoding/json" + "fmt" "testing" + "google.golang.org/protobuf/proto" + "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/memory" @@ -30,14 +33,52 @@ import ( "google.golang.org/protobuf/types/known/anypb" ) -func SetupTest() util_message.AllTheTypes { - msg := util_message.ExampleMessage{ - Field1: "Example", +type Fixture struct { + msg proto.Message + schema string + jsonStr string +} + +type J map[string]any + +func AllTheTypesFixture() Fixture { + e := J{"field1": "Example"} + + m := J{ + "str": "Hello", + "int32": 10, + "int64": 100, + "sint32": -10, + "sin64": -100, + "uint32": 10, + "uint64": 100, + "fixed32": 10, + "fixed64": 1000, + "sfixed32": 10, + "bool": false, + "bytes": "SGVsbG8sIHdvcmxkIQ==", + "double": 1.1, + "enum": "OPTION_1", + "message": e, + "oneof": []any{0, "World"}, + "any": J{"field1": "Example"}, + "simple_map": []J{{"key": 99, "value": "Hello"}}, + "complex_map": []J{{"key": "complex", "value": e}}, + "simple_list": []any{"Hello", "World"}, + "complex_list": []J{e}, } + jm, err := json.Marshal(m) + if err != nil { + panic(err) + } + jsonString := string(jm) - anyMsg, _ := anypb.New(&msg) + exampleMsg := util_message.ExampleMessage{ + Field1: "Example", + } + anyMsg, _ := anypb.New(&exampleMsg) - return util_message.AllTheTypes{ + msg := util_message.AllTheTypes{ Str: "Hello", Int32: 10, Int64: 100, @@ -52,23 +93,80 @@ func SetupTest() util_message.AllTheTypes { Bytes: []byte("Hello, world!"), Double: 1.1, Enum: util_message.AllTheTypes_OPTION_1, - Message: &msg, + Message: &exampleMsg, Oneof: &util_message.AllTheTypes_Oneofstring{Oneofstring: "World"}, Any: anyMsg, //Breaks the test as the Golang maps have a non-deterministic order //SimpleMap: map[int32]string{99: "Hello", 100: "World", 98: "How", 101: "Are", 1: "You"}, SimpleMap: map[int32]string{99: "Hello"}, - ComplexMap: map[string]*util_message.ExampleMessage{"complex": &msg}, + ComplexMap: map[string]*util_message.ExampleMessage{"complex": &exampleMsg}, SimpleList: []string{"Hello", "World"}, - ComplexList: []*util_message.ExampleMessage{&msg}, + ComplexList: []*util_message.ExampleMessage{&exampleMsg}, + } + + schema := `schema: + fields: 22 + - str: type=utf8, nullable + - int32: type=int32, nullable + - int64: type=int64, nullable + - sint32: type=int32, nullable + - sin64: type=int64, nullable + - uint32: type=uint32, nullable + - uint64: type=uint64, nullable + - fixed32: type=uint32, nullable + - fixed64: type=uint64, nullable + - sfixed32: type=int32, nullable + - bool: type=bool, nullable + - bytes: type=binary, nullable + - double: type=float64, nullable + - enum: type=dictionary, nullable + - message: type=struct, nullable + - oneofstring: type=utf8, nullable + - oneofmessage: type=struct, nullable + - any: type=struct, nullable + - simple_map: type=map, nullable + - complex_map: type=map, items_nullable>, nullable + - simple_list: type=list, nullable + - complex_list: type=list, nullable>, nullable` + + return Fixture{ + msg: &msg, + schema: schema, + jsonStr: jsonString, } } -func TestGetSchema(t *testing.T) { - msg := SetupTest() +func AllTheTypesNoAnyFixture() Fixture { + exampleMsg := util_message.ExampleMessage{ + Field1: "Example", + } - got := NewProtobufMessageReflection(&msg).Schema().String() - want := `schema: + msg := util_message.AllTheTypesNoAny{ + Str: "Hello", + Int32: 10, + Int64: 100, + Sint32: -10, + Sin64: -100, + Uint32: 10, + Uint64: 100, + Fixed32: 10, + Fixed64: 1000, + Sfixed32: 10, + Bool: false, + Bytes: []byte("Hello, world!"), + Double: 1.1, + Enum: util_message.AllTheTypesNoAny_OPTION_1, + Message: &exampleMsg, + Oneof: &util_message.AllTheTypesNoAny_Oneofstring{Oneofstring: "World"}, + //Breaks the test as the Golang maps have a non-deterministic order + //SimpleMap: map[int32]string{99: "Hello", 100: "World", 98: "How", 101: "Are", 1: "You"}, + SimpleMap: map[int32]string{99: "Hello"}, + ComplexMap: map[string]*util_message.ExampleMessage{"complex": &exampleMsg}, + SimpleList: []string{"Hello", "World"}, + ComplexList: []*util_message.ExampleMessage{&exampleMsg}, + } + + schema := `schema: fields: 22 - str: type=utf8, nullable - int32: type=int32, nullable @@ -87,16 +185,62 @@ func TestGetSchema(t *testing.T) { - message: type=struct, nullable - oneofstring: type=utf8, nullable - oneofmessage: type=struct, nullable - - any: type=struct, nullable - simple_map: type=map, nullable - complex_map: type=map, items_nullable>, nullable - simple_list: type=list, nullable - complex_list: type=list, nullable>, nullable` - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + jsonStr := `{ + "str":"Hello", + "int32":10, + "int64":100, + "sint32":-10, + "sin64":-100, + "uint32":10, + "uint64":100, + "fixed32":10, + "fixed64":1000, + "sfixed32":10, + "bool":false, + "bytes":"SGVsbG8sIHdvcmxkIQ==", + "double":1.1, + "enum":"OPTION_1", + "message":{"field1":"Example"}, + "oneofmessage": { "field1": null }, + "oneofstring": "World", + "simple_map":[{"key":99,"value":"Hello"}], + "complex_map":[{"key":"complex","value":{"field1":"Example"}}], + "simple_list":["Hello","World"], + "complex_list":[{"field1":"Example"}] + }` + + return Fixture{ + msg: &msg, + schema: schema, + jsonStr: jsonStr, + } +} - got = NewProtobufMessageReflection(&msg, WithOneOfHandler(OneOfDenseUnion)).Schema().String() - want = `schema: +func CheckSchema(t *testing.T, pmr *ProtobufMessageReflection, want string) { + got := pmr.Schema().String() + require.Equal(t, got, want, "got: %s\nwant: %s", got, want) +} + +func CheckRecord(t *testing.T, pmr *ProtobufMessageReflection, jsonStr string) { + rec := pmr.Record(nil) + got, err := json.Marshal(rec) + assert.NoError(t, err) + assert.JSONEq(t, jsonStr, string(got), "got: %s\nwant: %s", got, jsonStr) +} + +func TestGetSchema(t *testing.T) { + f := AllTheTypesFixture() + + pmr := NewProtobufMessageReflection(f.msg) + CheckSchema(t, pmr, f.schema) + + pmr = NewProtobufMessageReflection(f.msg, WithOneOfHandler(OneOfDenseUnion)) + want := `schema: fields: 21 - str: type=utf8, nullable - int32: type=int32, nullable @@ -119,14 +263,13 @@ func TestGetSchema(t *testing.T) { - complex_map: type=map, items_nullable>, nullable - simple_list: type=list, nullable - complex_list: type=list, nullable>, nullable` - - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + CheckSchema(t, pmr, want) excludeComplex := func(pfr *ProtobufFieldReflection) bool { return pfr.isMap() || pfr.isList() || pfr.isStruct() } - got = NewProtobufMessageReflection(&msg, WithExclusionPolicy(excludeComplex)).Schema().String() + pmr = NewProtobufMessageReflection(f.msg, WithExclusionPolicy(excludeComplex)) want = `schema: fields: 15 - str: type=utf8, nullable @@ -144,14 +287,13 @@ func TestGetSchema(t *testing.T) { - double: type=float64, nullable - enum: type=dictionary, nullable - oneofstring: type=utf8, nullable` + CheckSchema(t, pmr, want) - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) - - got = NewProtobufMessageReflection( - &msg, + pmr = NewProtobufMessageReflection( + f.msg, WithExclusionPolicy(excludeComplex), WithFieldNameFormatter(xstrings.ToCamelCase), - ).Schema().String() + ) want = `schema: fields: 15 - Str: type=utf8, nullable @@ -169,123 +311,168 @@ func TestGetSchema(t *testing.T) { - Double: type=float64, nullable - Enum: type=dictionary, nullable - Oneofstring: type=utf8, nullable` - - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + CheckSchema(t, pmr, want) onlyEnum := func(pfr *ProtobufFieldReflection) bool { return !pfr.isEnum() } - got = NewProtobufMessageReflection( - &msg, + pmr = NewProtobufMessageReflection( + f.msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumNumber), - ).Schema().String() + ) want = `schema: fields: 1 - enum: type=int32, nullable` + CheckSchema(t, pmr, want) - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) - - got = NewProtobufMessageReflection( - &msg, + pmr = NewProtobufMessageReflection( + f.msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumValue), - ).Schema().String() + ) want = `schema: fields: 1 - enum: type=utf8, nullable` - - require.Equal(t, want, got, "got: %s\nwant: %s", got, want) + CheckSchema(t, pmr, want) } func TestRecordFromProtobuf(t *testing.T) { - msg := SetupTest() - - pmr := NewProtobufMessageReflection(&msg, WithOneOfHandler(OneOfDenseUnion)) - schema := pmr.Schema() - got := pmr.Record(nil) - jsonStr := `[ - { - "str":"Hello", - "int32":10, - "int64":100, - "sint32":-10, - "sin64":-100, - "uint32":10, - "uint64":100, - "fixed32":10, - "fixed64":1000, - "sfixed32":10, - "bool":false, - "bytes":"SGVsbG8sIHdvcmxkIQ==", - "double":1.1, - "enum":"OPTION_1", - "message":{"field1":"Example"}, - "oneof": [0, "World"], - "any":{"field1":"Example"}, - "simple_map":[{"key":99,"value":"Hello"}], - "complex_map":[{"key":"complex","value":{"field1":"Example"}}], - "simple_list":["Hello","World"], - "complex_list":[{"field1":"Example"}] - } - ]` - want, _, err := array.RecordFromJSON(memory.NewGoAllocator(), schema, strings.NewReader(jsonStr)) + f := AllTheTypesFixture() - require.NoError(t, err) - require.EqualExportedValues(t, got, want, "got: %s\nwant: %s", got, want) + pmr := NewProtobufMessageReflection(f.msg, WithOneOfHandler(OneOfDenseUnion)) + CheckRecord(t, pmr, fmt.Sprintf(`[%s]`, f.jsonStr)) onlyEnum := func(pfr *ProtobufFieldReflection) bool { return !pfr.isEnum() } - pmr = NewProtobufMessageReflection(&msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumValue)) - got = pmr.Record(nil) - jsonStr = `[ { "enum":"OPTION_1" } ]` - want, _, err = array.RecordFromJSON(memory.NewGoAllocator(), pmr.Schema(), strings.NewReader(jsonStr)) - require.NoError(t, err) - require.True(t, array.RecordEqual(got, want), "got: %s\nwant: %s", got, want) - - pmr = NewProtobufMessageReflection(&msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumNumber)) - got = pmr.Record(nil) - jsonStr = `[ { "enum":"1" } ]` - want, _, err = array.RecordFromJSON(memory.NewGoAllocator(), pmr.Schema(), strings.NewReader(jsonStr)) - require.NoError(t, err) - require.True(t, array.RecordEqual(got, want), "got: %s\nwant: %s", got, want) + pmr = NewProtobufMessageReflection(f.msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumValue)) + jsonStr := `[ { "enum":"OPTION_1" } ]` + CheckRecord(t, pmr, jsonStr) + + pmr = NewProtobufMessageReflection(f.msg, WithExclusionPolicy(onlyEnum), WithEnumHandler(EnumNumber)) + jsonStr = `[ { "enum":1 } ]` + CheckRecord(t, pmr, jsonStr) } func TestNullRecordFromProtobuf(t *testing.T) { pmr := NewProtobufMessageReflection(&util_message.AllTheTypes{}) - schema := pmr.Schema() - got := pmr.Record(nil) - _, _ = got.MarshalJSON() - jsonStr := `[ - { - "str":"", - "int32":0, - "int64":0, - "sint32":0, - "sin64":0, - "uint32":0, - "uint64":0, - "fixed32":0, - "fixed64":0, - "sfixed32":0, - "bool":false, - "bytes":"", - "double":0, - "enum":"OPTION_0", - "message":null, - "oneofmessage":{"field1":""}, - "oneofstring":"", - "any":null, - "simple_map":[], - "complex_map":[], - "simple_list":[], - "complex_list":[] - } - ]` - - want, _, err := array.RecordFromJSON(memory.NewGoAllocator(), schema, strings.NewReader(jsonStr)) - - require.NoError(t, err) - require.EqualExportedValues(t, got, want, "got: %s\nwant: %s", got, want) + CheckRecord(t, pmr, `[{ + "str":"", + "int32":0, + "int64":0, + "sint32":0, + "sin64":0, + "uint32":0, + "uint64":0, + "fixed32":0, + "fixed64":0, + "sfixed32":0, + "bool":false, + "bytes":null, + "double":0, + "enum":"OPTION_0", + "message":null, + "oneofmessage":{"field1":""}, + "oneofstring":"", + "any": null, + "simple_map":[], + "complex_map":[], + "simple_list":[], + "complex_list":[] + }]`) +} + +func TestExcludedNested(t *testing.T) { + msg := util_message.ExampleMessage{ + Field1: "Example", + } + schema := `schema: + fields: 2 + - simple_a: type=list, nullable>, nullable + - simple_b: type=list, nullable>, nullable` + + simpleNested := util_message.SimpleNested{ + SimpleA: []*util_message.ExampleMessage{&msg}, + SimpleB: []*util_message.ExampleMessage{&msg}, + } + pmr := NewProtobufMessageReflection(&simpleNested) + jsonStr := `[{ "simple_a":[{"field1":"Example"}], "simple_b":[{"field1":"Example"}] }]` + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + //exclude one value + simpleNested = util_message.SimpleNested{ + SimpleA: []*util_message.ExampleMessage{&msg}, + } + jsonStr = `[{ "simple_a":[{"field1":"Example"}], "simple_b":[]}]` + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + ////exclude both values + simpleNested = util_message.SimpleNested{} + jsonStr = `[{ "simple_a":[], "simple_b":[] }]` + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + f := AllTheTypesNoAnyFixture() + schema = `schema: + fields: 2 + - all_the_types_no_any_a: type=list, message: struct, oneofstring: utf8, oneofmessage: struct, simple_map: map, complex_map: map, items_nullable>, simple_list: list, complex_list: list, nullable>>, nullable>, nullable + - all_the_types_no_any_b: type=list, message: struct, oneofstring: utf8, oneofmessage: struct, simple_map: map, complex_map: map, items_nullable>, simple_list: list, complex_list: list, nullable>>, nullable>, nullable` + + complexNested := util_message.ComplexNested{ + AllTheTypesNoAnyA: []*util_message.AllTheTypesNoAny{f.msg.(*util_message.AllTheTypesNoAny)}, + AllTheTypesNoAnyB: []*util_message.AllTheTypesNoAny{f.msg.(*util_message.AllTheTypesNoAny)}, + } + jsonStr = fmt.Sprintf(`[{ "all_the_types_no_any_a": [%s], "all_the_types_no_any_b": [%s] }]`, f.jsonStr, f.jsonStr) + pmr = NewProtobufMessageReflection(&complexNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + // exclude one value + complexNested = util_message.ComplexNested{ + AllTheTypesNoAnyB: []*util_message.AllTheTypesNoAny{f.msg.(*util_message.AllTheTypesNoAny)}, + } + jsonStr = fmt.Sprintf(`[{ "all_the_types_no_any_a": [], "all_the_types_no_any_b": [%s] }]`, f.jsonStr) + pmr = NewProtobufMessageReflection(&complexNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + // exclude both values + complexNested = util_message.ComplexNested{} + jsonStr = `[{ "all_the_types_no_any_a": [], "all_the_types_no_any_b": [] }]` + pmr = NewProtobufMessageReflection(&complexNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + schema = `schema: + fields: 2 + - complex_nested: type=struct, message: struct, oneofstring: utf8, oneofmessage: struct, simple_map: map, complex_map: map, items_nullable>, simple_list: list, complex_list: list, nullable>>, nullable>, all_the_types_no_any_b: list, message: struct, oneofstring: utf8, oneofmessage: struct, simple_map: map, complex_map: map, items_nullable>, simple_list: list, complex_list: list, nullable>>, nullable>>, nullable + - simple_nested: type=struct, nullable>, simple_b: list, nullable>>, nullable` + + deepNested := util_message.DeepNested{ + ComplexNested: &complexNested, + SimpleNested: &simpleNested, + } + jsonStr = `[{ "simple_nested": {"simple_a":[], "simple_b":[]}, "complex_nested": {"all_the_types_no_any_a": [], "all_the_types_no_any_b": []} }]` + pmr = NewProtobufMessageReflection(&deepNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + // exclude one value + deepNested = util_message.DeepNested{ + ComplexNested: &complexNested, + } + jsonStr = `[{ "simple_nested": null, "complex_nested": {"all_the_types_no_any_a": [], "all_the_types_no_any_b": []} }]` + pmr = NewProtobufMessageReflection(&deepNested) + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) + + // exclude both values + deepNested = util_message.DeepNested{} + pmr = NewProtobufMessageReflection(&deepNested) + jsonStr = `[{ "simple_nested": null, "complex_nested": null }]` + CheckSchema(t, pmr, schema) + CheckRecord(t, pmr, jsonStr) } type testProtobufReflection struct { diff --git a/go/arrow/util/util_message/types.pb.go b/go/arrow/util/util_message/types.pb.go index 80e18847c19..6486b2cc87a 100644 --- a/go/arrow/util/util_message/types.pb.go +++ b/go/arrow/util/util_message/types.pb.go @@ -23,12 +23,11 @@ package util_message import ( - reflect "reflect" - sync "sync" - protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" anypb "google.golang.org/protobuf/types/known/anypb" + reflect "reflect" + sync "sync" ) const ( @@ -84,6 +83,52 @@ func (AllTheTypes_ExampleEnum) EnumDescriptor() ([]byte, []int) { return file_messages_types_proto_rawDescGZIP(), []int{1, 0} } +type AllTheTypesNoAny_ExampleEnum int32 + +const ( + AllTheTypesNoAny_OPTION_0 AllTheTypesNoAny_ExampleEnum = 0 + AllTheTypesNoAny_OPTION_1 AllTheTypesNoAny_ExampleEnum = 1 +) + +// Enum value maps for AllTheTypesNoAny_ExampleEnum. +var ( + AllTheTypesNoAny_ExampleEnum_name = map[int32]string{ + 0: "OPTION_0", + 1: "OPTION_1", + } + AllTheTypesNoAny_ExampleEnum_value = map[string]int32{ + "OPTION_0": 0, + "OPTION_1": 1, + } +) + +func (x AllTheTypesNoAny_ExampleEnum) Enum() *AllTheTypesNoAny_ExampleEnum { + p := new(AllTheTypesNoAny_ExampleEnum) + *p = x + return p +} + +func (x AllTheTypesNoAny_ExampleEnum) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (AllTheTypesNoAny_ExampleEnum) Descriptor() protoreflect.EnumDescriptor { + return file_messages_types_proto_enumTypes[1].Descriptor() +} + +func (AllTheTypesNoAny_ExampleEnum) Type() protoreflect.EnumType { + return &file_messages_types_proto_enumTypes[1] +} + +func (x AllTheTypesNoAny_ExampleEnum) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use AllTheTypesNoAny_ExampleEnum.Descriptor instead. +func (AllTheTypesNoAny_ExampleEnum) EnumDescriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{2, 0} +} + type ExampleMessage struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -372,6 +417,404 @@ func (*AllTheTypes_Oneofstring) isAllTheTypes_Oneof() {} func (*AllTheTypes_Oneofmessage) isAllTheTypes_Oneof() {} +type AllTheTypesNoAny struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Str string `protobuf:"bytes,1,opt,name=str,proto3" json:"str,omitempty"` + Int32 int32 `protobuf:"varint,2,opt,name=int32,proto3" json:"int32,omitempty"` + Int64 int64 `protobuf:"varint,3,opt,name=int64,proto3" json:"int64,omitempty"` + Sint32 int32 `protobuf:"zigzag32,4,opt,name=sint32,proto3" json:"sint32,omitempty"` + Sin64 int64 `protobuf:"zigzag64,5,opt,name=sin64,proto3" json:"sin64,omitempty"` + Uint32 uint32 `protobuf:"varint,6,opt,name=uint32,proto3" json:"uint32,omitempty"` + Uint64 uint64 `protobuf:"varint,7,opt,name=uint64,proto3" json:"uint64,omitempty"` + Fixed32 uint32 `protobuf:"fixed32,8,opt,name=fixed32,proto3" json:"fixed32,omitempty"` + Fixed64 uint64 `protobuf:"fixed64,9,opt,name=fixed64,proto3" json:"fixed64,omitempty"` + Sfixed32 int32 `protobuf:"fixed32,10,opt,name=sfixed32,proto3" json:"sfixed32,omitempty"` + Bool bool `protobuf:"varint,11,opt,name=bool,proto3" json:"bool,omitempty"` + Bytes []byte `protobuf:"bytes,12,opt,name=bytes,proto3" json:"bytes,omitempty"` + Double float64 `protobuf:"fixed64,13,opt,name=double,proto3" json:"double,omitempty"` + Enum AllTheTypesNoAny_ExampleEnum `protobuf:"varint,14,opt,name=enum,proto3,enum=AllTheTypesNoAny_ExampleEnum" json:"enum,omitempty"` + Message *ExampleMessage `protobuf:"bytes,15,opt,name=message,proto3" json:"message,omitempty"` + // Types that are assignable to Oneof: + // + // *AllTheTypesNoAny_Oneofstring + // *AllTheTypesNoAny_Oneofmessage + Oneof isAllTheTypesNoAny_Oneof `protobuf_oneof:"oneof"` + SimpleMap map[int32]string `protobuf:"bytes,19,rep,name=simple_map,json=simpleMap,proto3" json:"simple_map,omitempty" protobuf_key:"varint,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + ComplexMap map[string]*ExampleMessage `protobuf:"bytes,20,rep,name=complex_map,json=complexMap,proto3" json:"complex_map,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` + SimpleList []string `protobuf:"bytes,21,rep,name=simple_list,json=simpleList,proto3" json:"simple_list,omitempty"` + ComplexList []*ExampleMessage `protobuf:"bytes,22,rep,name=complex_list,json=complexList,proto3" json:"complex_list,omitempty"` +} + +func (x *AllTheTypesNoAny) Reset() { + *x = AllTheTypesNoAny{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_types_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *AllTheTypesNoAny) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AllTheTypesNoAny) ProtoMessage() {} + +func (x *AllTheTypesNoAny) ProtoReflect() protoreflect.Message { + mi := &file_messages_types_proto_msgTypes[2] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AllTheTypesNoAny.ProtoReflect.Descriptor instead. +func (*AllTheTypesNoAny) Descriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{2} +} + +func (x *AllTheTypesNoAny) GetStr() string { + if x != nil { + return x.Str + } + return "" +} + +func (x *AllTheTypesNoAny) GetInt32() int32 { + if x != nil { + return x.Int32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetInt64() int64 { + if x != nil { + return x.Int64 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetSint32() int32 { + if x != nil { + return x.Sint32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetSin64() int64 { + if x != nil { + return x.Sin64 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetUint32() uint32 { + if x != nil { + return x.Uint32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetUint64() uint64 { + if x != nil { + return x.Uint64 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetFixed32() uint32 { + if x != nil { + return x.Fixed32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetFixed64() uint64 { + if x != nil { + return x.Fixed64 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetSfixed32() int32 { + if x != nil { + return x.Sfixed32 + } + return 0 +} + +func (x *AllTheTypesNoAny) GetBool() bool { + if x != nil { + return x.Bool + } + return false +} + +func (x *AllTheTypesNoAny) GetBytes() []byte { + if x != nil { + return x.Bytes + } + return nil +} + +func (x *AllTheTypesNoAny) GetDouble() float64 { + if x != nil { + return x.Double + } + return 0 +} + +func (x *AllTheTypesNoAny) GetEnum() AllTheTypesNoAny_ExampleEnum { + if x != nil { + return x.Enum + } + return AllTheTypesNoAny_OPTION_0 +} + +func (x *AllTheTypesNoAny) GetMessage() *ExampleMessage { + if x != nil { + return x.Message + } + return nil +} + +func (m *AllTheTypesNoAny) GetOneof() isAllTheTypesNoAny_Oneof { + if m != nil { + return m.Oneof + } + return nil +} + +func (x *AllTheTypesNoAny) GetOneofstring() string { + if x, ok := x.GetOneof().(*AllTheTypesNoAny_Oneofstring); ok { + return x.Oneofstring + } + return "" +} + +func (x *AllTheTypesNoAny) GetOneofmessage() *ExampleMessage { + if x, ok := x.GetOneof().(*AllTheTypesNoAny_Oneofmessage); ok { + return x.Oneofmessage + } + return nil +} + +func (x *AllTheTypesNoAny) GetSimpleMap() map[int32]string { + if x != nil { + return x.SimpleMap + } + return nil +} + +func (x *AllTheTypesNoAny) GetComplexMap() map[string]*ExampleMessage { + if x != nil { + return x.ComplexMap + } + return nil +} + +func (x *AllTheTypesNoAny) GetSimpleList() []string { + if x != nil { + return x.SimpleList + } + return nil +} + +func (x *AllTheTypesNoAny) GetComplexList() []*ExampleMessage { + if x != nil { + return x.ComplexList + } + return nil +} + +type isAllTheTypesNoAny_Oneof interface { + isAllTheTypesNoAny_Oneof() +} + +type AllTheTypesNoAny_Oneofstring struct { + Oneofstring string `protobuf:"bytes,16,opt,name=oneofstring,proto3,oneof"` +} + +type AllTheTypesNoAny_Oneofmessage struct { + Oneofmessage *ExampleMessage `protobuf:"bytes,17,opt,name=oneofmessage,proto3,oneof"` +} + +func (*AllTheTypesNoAny_Oneofstring) isAllTheTypesNoAny_Oneof() {} + +func (*AllTheTypesNoAny_Oneofmessage) isAllTheTypesNoAny_Oneof() {} + +type SimpleNested struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SimpleA []*ExampleMessage `protobuf:"bytes,1,rep,name=simple_a,json=simpleA,proto3" json:"simple_a,omitempty"` + SimpleB []*ExampleMessage `protobuf:"bytes,2,rep,name=simple_b,json=simpleB,proto3" json:"simple_b,omitempty"` +} + +func (x *SimpleNested) Reset() { + *x = SimpleNested{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_types_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *SimpleNested) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SimpleNested) ProtoMessage() {} + +func (x *SimpleNested) ProtoReflect() protoreflect.Message { + mi := &file_messages_types_proto_msgTypes[3] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SimpleNested.ProtoReflect.Descriptor instead. +func (*SimpleNested) Descriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{3} +} + +func (x *SimpleNested) GetSimpleA() []*ExampleMessage { + if x != nil { + return x.SimpleA + } + return nil +} + +func (x *SimpleNested) GetSimpleB() []*ExampleMessage { + if x != nil { + return x.SimpleB + } + return nil +} + +type ComplexNested struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + AllTheTypesNoAnyA []*AllTheTypesNoAny `protobuf:"bytes,1,rep,name=all_the_types_no_any_a,json=allTheTypesNoAnyA,proto3" json:"all_the_types_no_any_a,omitempty"` + AllTheTypesNoAnyB []*AllTheTypesNoAny `protobuf:"bytes,2,rep,name=all_the_types_no_any_b,json=allTheTypesNoAnyB,proto3" json:"all_the_types_no_any_b,omitempty"` +} + +func (x *ComplexNested) Reset() { + *x = ComplexNested{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_types_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *ComplexNested) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ComplexNested) ProtoMessage() {} + +func (x *ComplexNested) ProtoReflect() protoreflect.Message { + mi := &file_messages_types_proto_msgTypes[4] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ComplexNested.ProtoReflect.Descriptor instead. +func (*ComplexNested) Descriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{4} +} + +func (x *ComplexNested) GetAllTheTypesNoAnyA() []*AllTheTypesNoAny { + if x != nil { + return x.AllTheTypesNoAnyA + } + return nil +} + +func (x *ComplexNested) GetAllTheTypesNoAnyB() []*AllTheTypesNoAny { + if x != nil { + return x.AllTheTypesNoAnyB + } + return nil +} + +type DeepNested struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + ComplexNested *ComplexNested `protobuf:"bytes,1,opt,name=complex_nested,json=complexNested,proto3" json:"complex_nested,omitempty"` + SimpleNested *SimpleNested `protobuf:"bytes,2,opt,name=simple_nested,json=simpleNested,proto3" json:"simple_nested,omitempty"` +} + +func (x *DeepNested) Reset() { + *x = DeepNested{} + if protoimpl.UnsafeEnabled { + mi := &file_messages_types_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *DeepNested) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeepNested) ProtoMessage() {} + +func (x *DeepNested) ProtoReflect() protoreflect.Message { + mi := &file_messages_types_proto_msgTypes[5] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeepNested.ProtoReflect.Descriptor instead. +func (*DeepNested) Descriptor() ([]byte, []int) { + return file_messages_types_proto_rawDescGZIP(), []int{5} +} + +func (x *DeepNested) GetComplexNested() *ComplexNested { + if x != nil { + return x.ComplexNested + } + return nil +} + +func (x *DeepNested) GetSimpleNested() *SimpleNested { + if x != nil { + return x.SimpleNested + } + return nil +} + var File_messages_types_proto protoreflect.FileDescriptor var file_messages_types_proto_rawDesc = []byte{ @@ -439,9 +882,90 @@ var file_messages_types_proto_rawDesc = []byte{ 0x02, 0x38, 0x01, 0x22, 0x29, 0x0a, 0x0b, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x45, 0x6e, 0x75, 0x6d, 0x12, 0x0c, 0x0a, 0x08, 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x30, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x4f, 0x50, 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x31, 0x10, 0x01, 0x42, 0x07, - 0x0a, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x42, 0x11, 0x5a, 0x0f, 0x2e, 0x2e, 0x2f, 0x75, 0x74, - 0x69, 0x6c, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, + 0x0a, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x22, 0x95, 0x07, 0x0a, 0x10, 0x41, 0x6c, 0x6c, 0x54, + 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x12, 0x10, 0x0a, 0x03, + 0x73, 0x74, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x73, 0x74, 0x72, 0x12, 0x14, + 0x0a, 0x05, 0x69, 0x6e, 0x74, 0x33, 0x32, 0x18, 0x02, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x69, + 0x6e, 0x74, 0x33, 0x32, 0x12, 0x14, 0x0a, 0x05, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x05, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x12, 0x16, 0x0a, 0x06, 0x73, 0x69, + 0x6e, 0x74, 0x33, 0x32, 0x18, 0x04, 0x20, 0x01, 0x28, 0x11, 0x52, 0x06, 0x73, 0x69, 0x6e, 0x74, + 0x33, 0x32, 0x12, 0x14, 0x0a, 0x05, 0x73, 0x69, 0x6e, 0x36, 0x34, 0x18, 0x05, 0x20, 0x01, 0x28, + 0x12, 0x52, 0x05, 0x73, 0x69, 0x6e, 0x36, 0x34, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x69, 0x6e, 0x74, + 0x33, 0x32, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x06, 0x75, 0x69, 0x6e, 0x74, 0x33, 0x32, + 0x12, 0x16, 0x0a, 0x06, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x18, 0x07, 0x20, 0x01, 0x28, 0x04, + 0x52, 0x06, 0x75, 0x69, 0x6e, 0x74, 0x36, 0x34, 0x12, 0x18, 0x0a, 0x07, 0x66, 0x69, 0x78, 0x65, + 0x64, 0x33, 0x32, 0x18, 0x08, 0x20, 0x01, 0x28, 0x07, 0x52, 0x07, 0x66, 0x69, 0x78, 0x65, 0x64, + 0x33, 0x32, 0x12, 0x18, 0x0a, 0x07, 0x66, 0x69, 0x78, 0x65, 0x64, 0x36, 0x34, 0x18, 0x09, 0x20, + 0x01, 0x28, 0x06, 0x52, 0x07, 0x66, 0x69, 0x78, 0x65, 0x64, 0x36, 0x34, 0x12, 0x1a, 0x0a, 0x08, + 0x73, 0x66, 0x69, 0x78, 0x65, 0x64, 0x33, 0x32, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0f, 0x52, 0x08, + 0x73, 0x66, 0x69, 0x78, 0x65, 0x64, 0x33, 0x32, 0x12, 0x12, 0x0a, 0x04, 0x62, 0x6f, 0x6f, 0x6c, + 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x62, 0x6f, 0x6f, 0x6c, 0x12, 0x14, 0x0a, 0x05, + 0x62, 0x79, 0x74, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x62, 0x79, 0x74, + 0x65, 0x73, 0x12, 0x16, 0x0a, 0x06, 0x64, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x18, 0x0d, 0x20, 0x01, + 0x28, 0x01, 0x52, 0x06, 0x64, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x12, 0x31, 0x0a, 0x04, 0x65, 0x6e, + 0x75, 0x6d, 0x18, 0x0e, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1d, 0x2e, 0x41, 0x6c, 0x6c, 0x54, 0x68, + 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x2e, 0x45, 0x78, 0x61, 0x6d, + 0x70, 0x6c, 0x65, 0x45, 0x6e, 0x75, 0x6d, 0x52, 0x04, 0x65, 0x6e, 0x75, 0x6d, 0x12, 0x29, 0x0a, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x0f, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0f, + 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, + 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x22, 0x0a, 0x0b, 0x6f, 0x6e, 0x65, 0x6f, + 0x66, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x18, 0x10, 0x20, 0x01, 0x28, 0x09, 0x48, 0x00, 0x52, + 0x0b, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x73, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x12, 0x35, 0x0a, 0x0c, + 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x11, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x0f, 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x48, 0x00, 0x52, 0x0c, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x6d, 0x65, 0x73, 0x73, + 0x61, 0x67, 0x65, 0x12, 0x3f, 0x0a, 0x0a, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x6d, 0x61, + 0x70, 0x18, 0x13, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x20, 0x2e, 0x41, 0x6c, 0x6c, 0x54, 0x68, 0x65, + 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, + 0x65, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x09, 0x73, 0x69, 0x6d, 0x70, 0x6c, + 0x65, 0x4d, 0x61, 0x70, 0x12, 0x42, 0x0a, 0x0b, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x5f, + 0x6d, 0x61, 0x70, 0x18, 0x14, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x21, 0x2e, 0x41, 0x6c, 0x6c, 0x54, + 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x2e, 0x43, 0x6f, 0x6d, + 0x70, 0x6c, 0x65, 0x78, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x0a, 0x63, 0x6f, + 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x4d, 0x61, 0x70, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x69, 0x6d, 0x70, + 0x6c, 0x65, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x15, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0a, 0x73, + 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x32, 0x0a, 0x0c, 0x63, 0x6f, 0x6d, + 0x70, 0x6c, 0x65, 0x78, 0x5f, 0x6c, 0x69, 0x73, 0x74, 0x18, 0x16, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x0f, 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x52, 0x0b, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x4c, 0x69, 0x73, 0x74, 0x1a, 0x3c, 0x0a, + 0x0e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, + 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x03, 0x6b, 0x65, + 0x79, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x1a, 0x4e, 0x0a, 0x0f, 0x43, + 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x4d, 0x61, 0x70, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, + 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, + 0x12, 0x25, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x0f, 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, + 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x29, 0x0a, 0x0b, 0x45, + 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x45, 0x6e, 0x75, 0x6d, 0x12, 0x0c, 0x0a, 0x08, 0x4f, 0x50, + 0x54, 0x49, 0x4f, 0x4e, 0x5f, 0x30, 0x10, 0x00, 0x12, 0x0c, 0x0a, 0x08, 0x4f, 0x50, 0x54, 0x49, + 0x4f, 0x4e, 0x5f, 0x31, 0x10, 0x01, 0x42, 0x07, 0x0a, 0x05, 0x6f, 0x6e, 0x65, 0x6f, 0x66, 0x22, + 0x66, 0x0a, 0x0c, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x12, + 0x2a, 0x0a, 0x08, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x61, 0x18, 0x01, 0x20, 0x03, 0x28, + 0x0b, 0x32, 0x0f, 0x2e, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x52, 0x07, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x41, 0x12, 0x2a, 0x0a, 0x08, 0x73, + 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x62, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x0f, 0x2e, + 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x52, 0x07, + 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x42, 0x22, 0x9b, 0x01, 0x0a, 0x0d, 0x43, 0x6f, 0x6d, 0x70, + 0x6c, 0x65, 0x78, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x12, 0x44, 0x0a, 0x16, 0x61, 0x6c, 0x6c, + 0x5f, 0x74, 0x68, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x73, 0x5f, 0x6e, 0x6f, 0x5f, 0x61, 0x6e, + 0x79, 0x5f, 0x61, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x41, 0x6c, 0x6c, 0x54, + 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x52, 0x11, 0x61, 0x6c, + 0x6c, 0x54, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, 0x6e, 0x79, 0x41, 0x12, + 0x44, 0x0a, 0x16, 0x61, 0x6c, 0x6c, 0x5f, 0x74, 0x68, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x73, + 0x5f, 0x6e, 0x6f, 0x5f, 0x61, 0x6e, 0x79, 0x5f, 0x62, 0x18, 0x02, 0x20, 0x03, 0x28, 0x0b, 0x32, + 0x11, 0x2e, 0x41, 0x6c, 0x6c, 0x54, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, 0x6f, 0x41, + 0x6e, 0x79, 0x52, 0x11, 0x61, 0x6c, 0x6c, 0x54, 0x68, 0x65, 0x54, 0x79, 0x70, 0x65, 0x73, 0x4e, + 0x6f, 0x41, 0x6e, 0x79, 0x42, 0x22, 0x77, 0x0a, 0x0a, 0x44, 0x65, 0x65, 0x70, 0x4e, 0x65, 0x73, + 0x74, 0x65, 0x64, 0x12, 0x35, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x5f, 0x6e, + 0x65, 0x73, 0x74, 0x65, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x43, 0x6f, + 0x6d, 0x70, 0x6c, 0x65, 0x78, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x52, 0x0d, 0x63, 0x6f, 0x6d, + 0x70, 0x6c, 0x65, 0x78, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x12, 0x32, 0x0a, 0x0d, 0x73, 0x69, + 0x6d, 0x70, 0x6c, 0x65, 0x5f, 0x6e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, + 0x0b, 0x32, 0x0d, 0x2e, 0x53, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, + 0x52, 0x0c, 0x73, 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x4e, 0x65, 0x73, 0x74, 0x65, 0x64, 0x42, 0x11, + 0x5a, 0x0f, 0x2e, 0x2e, 0x2f, 0x75, 0x74, 0x69, 0x6c, 0x5f, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -456,30 +980,50 @@ func file_messages_types_proto_rawDescGZIP() []byte { return file_messages_types_proto_rawDescData } -var file_messages_types_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_messages_types_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_messages_types_proto_enumTypes = make([]protoimpl.EnumInfo, 2) +var file_messages_types_proto_msgTypes = make([]protoimpl.MessageInfo, 10) var file_messages_types_proto_goTypes = []interface{}{ - (AllTheTypes_ExampleEnum)(0), // 0: AllTheTypes.ExampleEnum - (*ExampleMessage)(nil), // 1: ExampleMessage - (*AllTheTypes)(nil), // 2: AllTheTypes - nil, // 3: AllTheTypes.SimpleMapEntry - nil, // 4: AllTheTypes.ComplexMapEntry - (*anypb.Any)(nil), // 5: google.protobuf.Any + (AllTheTypes_ExampleEnum)(0), // 0: AllTheTypes.ExampleEnum + (AllTheTypesNoAny_ExampleEnum)(0), // 1: AllTheTypesNoAny.ExampleEnum + (*ExampleMessage)(nil), // 2: ExampleMessage + (*AllTheTypes)(nil), // 3: AllTheTypes + (*AllTheTypesNoAny)(nil), // 4: AllTheTypesNoAny + (*SimpleNested)(nil), // 5: SimpleNested + (*ComplexNested)(nil), // 6: ComplexNested + (*DeepNested)(nil), // 7: DeepNested + nil, // 8: AllTheTypes.SimpleMapEntry + nil, // 9: AllTheTypes.ComplexMapEntry + nil, // 10: AllTheTypesNoAny.SimpleMapEntry + nil, // 11: AllTheTypesNoAny.ComplexMapEntry + (*anypb.Any)(nil), // 12: google.protobuf.Any } var file_messages_types_proto_depIdxs = []int32{ - 0, // 0: AllTheTypes.enum:type_name -> AllTheTypes.ExampleEnum - 1, // 1: AllTheTypes.message:type_name -> ExampleMessage - 1, // 2: AllTheTypes.oneofmessage:type_name -> ExampleMessage - 5, // 3: AllTheTypes.any:type_name -> google.protobuf.Any - 3, // 4: AllTheTypes.simple_map:type_name -> AllTheTypes.SimpleMapEntry - 4, // 5: AllTheTypes.complex_map:type_name -> AllTheTypes.ComplexMapEntry - 1, // 6: AllTheTypes.complex_list:type_name -> ExampleMessage - 1, // 7: AllTheTypes.ComplexMapEntry.value:type_name -> ExampleMessage - 8, // [8:8] is the sub-list for method output_type - 8, // [8:8] is the sub-list for method input_type - 8, // [8:8] is the sub-list for extension type_name - 8, // [8:8] is the sub-list for extension extendee - 0, // [0:8] is the sub-list for field type_name + 0, // 0: AllTheTypes.enum:type_name -> AllTheTypes.ExampleEnum + 2, // 1: AllTheTypes.message:type_name -> ExampleMessage + 2, // 2: AllTheTypes.oneofmessage:type_name -> ExampleMessage + 12, // 3: AllTheTypes.any:type_name -> google.protobuf.Any + 8, // 4: AllTheTypes.simple_map:type_name -> AllTheTypes.SimpleMapEntry + 9, // 5: AllTheTypes.complex_map:type_name -> AllTheTypes.ComplexMapEntry + 2, // 6: AllTheTypes.complex_list:type_name -> ExampleMessage + 1, // 7: AllTheTypesNoAny.enum:type_name -> AllTheTypesNoAny.ExampleEnum + 2, // 8: AllTheTypesNoAny.message:type_name -> ExampleMessage + 2, // 9: AllTheTypesNoAny.oneofmessage:type_name -> ExampleMessage + 10, // 10: AllTheTypesNoAny.simple_map:type_name -> AllTheTypesNoAny.SimpleMapEntry + 11, // 11: AllTheTypesNoAny.complex_map:type_name -> AllTheTypesNoAny.ComplexMapEntry + 2, // 12: AllTheTypesNoAny.complex_list:type_name -> ExampleMessage + 2, // 13: SimpleNested.simple_a:type_name -> ExampleMessage + 2, // 14: SimpleNested.simple_b:type_name -> ExampleMessage + 4, // 15: ComplexNested.all_the_types_no_any_a:type_name -> AllTheTypesNoAny + 4, // 16: ComplexNested.all_the_types_no_any_b:type_name -> AllTheTypesNoAny + 6, // 17: DeepNested.complex_nested:type_name -> ComplexNested + 5, // 18: DeepNested.simple_nested:type_name -> SimpleNested + 2, // 19: AllTheTypes.ComplexMapEntry.value:type_name -> ExampleMessage + 2, // 20: AllTheTypesNoAny.ComplexMapEntry.value:type_name -> ExampleMessage + 21, // [21:21] is the sub-list for method output_type + 21, // [21:21] is the sub-list for method input_type + 21, // [21:21] is the sub-list for extension type_name + 21, // [21:21] is the sub-list for extension extendee + 0, // [0:21] is the sub-list for field type_name } func init() { file_messages_types_proto_init() } @@ -512,18 +1056,70 @@ func file_messages_types_proto_init() { return nil } } + file_messages_types_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*AllTheTypesNoAny); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_messages_types_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*SimpleNested); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_messages_types_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*ComplexNested); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_messages_types_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*DeepNested); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_messages_types_proto_msgTypes[1].OneofWrappers = []interface{}{ (*AllTheTypes_Oneofstring)(nil), (*AllTheTypes_Oneofmessage)(nil), } + file_messages_types_proto_msgTypes[2].OneofWrappers = []interface{}{ + (*AllTheTypesNoAny_Oneofstring)(nil), + (*AllTheTypesNoAny_Oneofmessage)(nil), + } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_messages_types_proto_rawDesc, - NumEnums: 1, - NumMessages: 4, + NumEnums: 2, + NumMessages: 10, NumExtensions: 0, NumServices: 0, }, From a380d695a6672f6981d1fe36cd1acc8d68ee9c3e Mon Sep 17 00:00:00 2001 From: mwish Date: Tue, 20 Aug 2024 01:27:45 +0800 Subject: [PATCH 036/157] GH-43733: [C++] Fix Scalar boolean handling in row encoder (#43734) ### Rationale for this change See https://github.com/apache/arrow/issues/43733 ### What changes are included in this PR? Separate Null and Valid handling when BooleanKeyEncoder::Encode meets a Null This patch also does a migration: * row_encoder.cc -> row_encoder_internal.cc * move row_encoder_internal{.cc|.h} from `compute/kernel` to `compute/row` ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #43733 Authored-by: mwish Signed-off-by: Antoine Pitrou --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/acero/asof_join_node_test.cc | 2 +- cpp/src/arrow/acero/hash_join.cc | 2 +- cpp/src/arrow/acero/hash_join_benchmark.cc | 2 +- cpp/src/arrow/acero/hash_join_dict.h | 2 +- cpp/src/arrow/acero/hash_join_node_test.cc | 2 +- cpp/src/arrow/acero/swiss_join.cc | 2 +- cpp/src/arrow/acero/swiss_join_internal.h | 2 +- cpp/src/arrow/acero/tpch_node_test.cc | 2 +- cpp/src/arrow/compute/CMakeLists.txt | 1 + .../arrow/compute/kernels/hash_aggregate.cc | 2 +- cpp/src/arrow/compute/row/grouper.cc | 2 +- .../row_encoder_internal.cc} | 41 ++++++----- .../{kernels => row}/row_encoder_internal.h | 14 ++-- .../compute/row/row_encoder_internal_test.cc | 68 +++++++++++++++++++ cpp/src/arrow/compute/row/row_test.cc | 2 +- 16 files changed, 111 insertions(+), 37 deletions(-) rename cpp/src/arrow/compute/{kernels/row_encoder.cc => row/row_encoder_internal.cc} (93%) rename cpp/src/arrow/compute/{kernels => row}/row_encoder_internal.h (96%) create mode 100644 cpp/src/arrow/compute/row/row_encoder_internal_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 67d2c19f98a..fb785e1e957 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -723,7 +723,6 @@ set(ARROW_COMPUTE_SRCS compute/ordering.cc compute/registry.cc compute/kernels/codegen_internal.cc - compute/kernels/row_encoder.cc compute/kernels/ree_util_internal.cc compute/kernels/scalar_cast_boolean.cc compute/kernels/scalar_cast_dictionary.cc @@ -742,6 +741,7 @@ set(ARROW_COMPUTE_SRCS compute/row/encode_internal.cc compute/row/compare_internal.cc compute/row/grouper.cc + compute/row/row_encoder_internal.cc compute/row/row_internal.cc compute/util.cc compute/util_internal.cc) diff --git a/cpp/src/arrow/acero/asof_join_node_test.cc b/cpp/src/arrow/acero/asof_join_node_test.cc index 051e280a4c5..555f580028f 100644 --- a/cpp/src/arrow/acero/asof_join_node_test.cc +++ b/cpp/src/arrow/acero/asof_join_node_test.cc @@ -41,8 +41,8 @@ #include "arrow/acero/util.h" #include "arrow/api.h" #include "arrow/compute/api_scalar.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/acero/hash_join.cc b/cpp/src/arrow/acero/hash_join.cc index 5aa70a23f7c..ddcd2a09957 100644 --- a/cpp/src/arrow/acero/hash_join.cc +++ b/cpp/src/arrow/acero/hash_join.cc @@ -27,8 +27,8 @@ #include "arrow/acero/hash_join_dict.h" #include "arrow/acero/task_util.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/row/encode_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/util/tracing_internal.h" namespace arrow { diff --git a/cpp/src/arrow/acero/hash_join_benchmark.cc b/cpp/src/arrow/acero/hash_join_benchmark.cc index 1f8e02e9f0f..470960b1c50 100644 --- a/cpp/src/arrow/acero/hash_join_benchmark.cc +++ b/cpp/src/arrow/acero/hash_join_benchmark.cc @@ -23,7 +23,7 @@ #include "arrow/acero/test_util_internal.h" #include "arrow/acero/util.h" #include "arrow/api.h" -#include "arrow/compute/kernels/row_encoder_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/testing/random.h" #include "arrow/util/thread_pool.h" diff --git a/cpp/src/arrow/acero/hash_join_dict.h b/cpp/src/arrow/acero/hash_join_dict.h index c7d8d785d07..02454a71462 100644 --- a/cpp/src/arrow/acero/hash_join_dict.h +++ b/cpp/src/arrow/acero/hash_join_dict.h @@ -22,7 +22,7 @@ #include "arrow/acero/schema_util.h" #include "arrow/compute/exec.h" -#include "arrow/compute/kernels/row_encoder_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 88f9a9e71b7..9065e286a22 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -26,9 +26,9 @@ #include "arrow/acero/test_util_internal.h" #include "arrow/acero/util.h" #include "arrow/api.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/compute/light_array_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/testing/extension_type.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" diff --git a/cpp/src/arrow/acero/swiss_join.cc b/cpp/src/arrow/acero/swiss_join.cc index 40a4b5886e4..4d0c8187ac6 100644 --- a/cpp/src/arrow/acero/swiss_join.cc +++ b/cpp/src/arrow/acero/swiss_join.cc @@ -24,10 +24,10 @@ #include "arrow/acero/swiss_join_internal.h" #include "arrow/acero/util.h" #include "arrow/array/util.h" // MakeArrayFromScalar -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/key_hash_internal.h" #include "arrow/compute/row/compare_internal.h" #include "arrow/compute/row/encode_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" #include "arrow/util/tracing_internal.h" diff --git a/cpp/src/arrow/acero/swiss_join_internal.h b/cpp/src/arrow/acero/swiss_join_internal.h index dceb74abe4f..4d749c1c529 100644 --- a/cpp/src/arrow/acero/swiss_join_internal.h +++ b/cpp/src/arrow/acero/swiss_join_internal.h @@ -22,10 +22,10 @@ #include "arrow/acero/partition_util.h" #include "arrow/acero/schema_util.h" #include "arrow/acero/task_util.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/key_map_internal.h" #include "arrow/compute/light_array_internal.h" #include "arrow/compute/row/encode_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" namespace arrow { diff --git a/cpp/src/arrow/acero/tpch_node_test.cc b/cpp/src/arrow/acero/tpch_node_test.cc index 076bcf634a6..17fb43452bc 100644 --- a/cpp/src/arrow/acero/tpch_node_test.cc +++ b/cpp/src/arrow/acero/tpch_node_test.cc @@ -27,8 +27,8 @@ #include "arrow/acero/test_util_internal.h" #include "arrow/acero/tpch_node.h" #include "arrow/acero/util.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/testing/random.h" diff --git a/cpp/src/arrow/compute/CMakeLists.txt b/cpp/src/arrow/compute/CMakeLists.txt index e20b45897db..aa2a2d4e9af 100644 --- a/cpp/src/arrow/compute/CMakeLists.txt +++ b/cpp/src/arrow/compute/CMakeLists.txt @@ -92,6 +92,7 @@ add_arrow_test(internals_test key_hash_test.cc row/compare_test.cc row/grouper_test.cc + row/row_encoder_internal_test.cc row/row_test.cc util_internal_test.cc) diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 54cd695421a..4bf6a6106df 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -33,9 +33,9 @@ #include "arrow/compute/kernels/aggregate_internal.h" #include "arrow/compute/kernels/aggregate_var_std_internal.h" #include "arrow/compute/kernels/common_internal.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/kernels/util_internal.h" #include "arrow/compute/row/grouper.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/record_batch.h" #include "arrow/stl_allocator.h" #include "arrow/type_traits.h" diff --git a/cpp/src/arrow/compute/row/grouper.cc b/cpp/src/arrow/compute/row/grouper.cc index 45b9ad5971e..5889f94d96c 100644 --- a/cpp/src/arrow/compute/row/grouper.cc +++ b/cpp/src/arrow/compute/row/grouper.cc @@ -25,12 +25,12 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/function.h" -#include "arrow/compute/kernels/row_encoder_internal.h" #include "arrow/compute/key_hash_internal.h" #include "arrow/compute/light_array_internal.h" #include "arrow/compute/registry.h" #include "arrow/compute/row/compare_internal.h" #include "arrow/compute/row/grouper_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/type.h" #include "arrow/type_traits.h" #include "arrow/util/bitmap_ops.h" diff --git a/cpp/src/arrow/compute/kernels/row_encoder.cc b/cpp/src/arrow/compute/row/row_encoder_internal.cc similarity index 93% rename from cpp/src/arrow/compute/kernels/row_encoder.cc rename to cpp/src/arrow/compute/row/row_encoder_internal.cc index 8224eaa6d63..414cc6793a5 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder.cc +++ b/cpp/src/arrow/compute/row/row_encoder_internal.cc @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/compute/kernels/row_encoder_internal.h" +#include "arrow/compute/row/row_encoder_internal.h" #include "arrow/util/bitmap_writer.h" #include "arrow/util/logging.h" @@ -75,26 +75,31 @@ void BooleanKeyEncoder::AddLengthNull(int32_t* length) { Status BooleanKeyEncoder::Encode(const ExecValue& data, int64_t batch_length, uint8_t** encoded_bytes) { + auto handle_next_valid_value = [&encoded_bytes](bool value) { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + *encoded_ptr++ = value; + }; + auto handle_next_null_value = [&encoded_bytes]() { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kNullByte; + *encoded_ptr++ = 0; + }; + if (data.is_array()) { - VisitArraySpanInline( - data.array, - [&](bool value) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - *encoded_ptr++ = value; - }, - [&] { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kNullByte; - *encoded_ptr++ = 0; - }); + VisitArraySpanInline(data.array, handle_next_valid_value, + handle_next_null_value); } else { const auto& scalar = data.scalar_as(); - bool value = scalar.is_valid && scalar.value; - for (int64_t i = 0; i < batch_length; i++) { - auto& encoded_ptr = *encoded_bytes++; - *encoded_ptr++ = kValidByte; - *encoded_ptr++ = value; + if (!scalar.is_valid) { + for (int64_t i = 0; i < batch_length; i++) { + handle_next_null_value(); + } + } else { + const bool value = scalar.value; + for (int64_t i = 0; i < batch_length; i++) { + handle_next_valid_value(value); + } } } return Status::OK(); diff --git a/cpp/src/arrow/compute/kernels/row_encoder_internal.h b/cpp/src/arrow/compute/row/row_encoder_internal.h similarity index 96% rename from cpp/src/arrow/compute/kernels/row_encoder_internal.h rename to cpp/src/arrow/compute/row/row_encoder_internal.h index 9bf7c1d1c4f..60eb14af504 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder_internal.h +++ b/cpp/src/arrow/compute/row/row_encoder_internal.h @@ -29,7 +29,7 @@ using internal::checked_cast; namespace compute { namespace internal { -struct KeyEncoder { +struct ARROW_EXPORT KeyEncoder { // the first byte of an encoded key is used to indicate nullity static constexpr bool kExtraByteForNull = true; @@ -60,7 +60,7 @@ struct KeyEncoder { } }; -struct BooleanKeyEncoder : KeyEncoder { +struct ARROW_EXPORT BooleanKeyEncoder : KeyEncoder { static constexpr int kByteWidth = 1; void AddLength(const ExecValue& data, int64_t batch_length, int32_t* lengths) override; @@ -76,7 +76,7 @@ struct BooleanKeyEncoder : KeyEncoder { MemoryPool* pool) override; }; -struct FixedWidthKeyEncoder : KeyEncoder { +struct ARROW_EXPORT FixedWidthKeyEncoder : KeyEncoder { explicit FixedWidthKeyEncoder(std::shared_ptr type) : type_(std::move(type)), byte_width_(checked_cast(*type_).bit_width() / 8) {} @@ -97,7 +97,7 @@ struct FixedWidthKeyEncoder : KeyEncoder { int byte_width_; }; -struct DictionaryKeyEncoder : FixedWidthKeyEncoder { +struct ARROW_EXPORT DictionaryKeyEncoder : FixedWidthKeyEncoder { DictionaryKeyEncoder(std::shared_ptr type, MemoryPool* pool) : FixedWidthKeyEncoder(std::move(type)), pool_(pool) {} @@ -112,7 +112,7 @@ struct DictionaryKeyEncoder : FixedWidthKeyEncoder { }; template -struct VarLengthKeyEncoder : KeyEncoder { +struct ARROW_EXPORT VarLengthKeyEncoder : KeyEncoder { using Offset = typename T::offset_type; void AddLength(const ExecValue& data, int64_t batch_length, int32_t* lengths) override { @@ -232,7 +232,7 @@ struct VarLengthKeyEncoder : KeyEncoder { std::shared_ptr type_; }; -struct NullKeyEncoder : KeyEncoder { +struct ARROW_EXPORT NullKeyEncoder : KeyEncoder { void AddLength(const ExecValue&, int64_t batch_length, int32_t* lengths) override {} void AddLengthNull(int32_t* length) override {} @@ -274,7 +274,7 @@ class ARROW_EXPORT RowEncoder { } private: - ExecContext* ctx_; + ExecContext* ctx_{nullptr}; std::vector> encoders_; std::vector offsets_; std::vector bytes_; diff --git a/cpp/src/arrow/compute/row/row_encoder_internal_test.cc b/cpp/src/arrow/compute/row/row_encoder_internal_test.cc new file mode 100644 index 00000000000..78839d1ead5 --- /dev/null +++ b/cpp/src/arrow/compute/row/row_encoder_internal_test.cc @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/compute/row/row_encoder_internal.h" + +#include "arrow/array/validate.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/type_fwd.h" + +namespace arrow::compute::internal { + +// GH-43733: Test that the key encoder can handle boolean scalar values well. +TEST(TestKeyEncoder, BooleanScalar) { + for (auto scalar : {BooleanScalar{}, BooleanScalar{true}, BooleanScalar{false}}) { + BooleanKeyEncoder key_encoder; + SCOPED_TRACE("scalar " + scalar.ToString()); + constexpr int64_t kBatchLength = 10; + std::array lengths{}; + key_encoder.AddLength(ExecValue{&scalar}, kBatchLength, lengths.data()); + // Check that the lengths are all 2. + constexpr int32_t kPayloadWidth = + BooleanKeyEncoder::kByteWidth + BooleanKeyEncoder::kExtraByteForNull; + for (int i = 0; i < kBatchLength; ++i) { + ASSERT_EQ(kPayloadWidth, lengths[i]); + } + std::array, kBatchLength> payloads{}; + std::array payload_ptrs{}; + // Reset the payload pointers to point to the beginning of each payload. + // This is necessary because the key encoder may have modified the pointers. + auto reset_payload_ptrs = [&payload_ptrs, &payloads]() { + std::transform(payloads.begin(), payloads.end(), payload_ptrs.begin(), + [](auto& payload) -> uint8_t* { return payload.data(); }); + }; + reset_payload_ptrs(); + ASSERT_OK(key_encoder.Encode(ExecValue{&scalar}, kBatchLength, payload_ptrs.data())); + reset_payload_ptrs(); + ASSERT_OK_AND_ASSIGN(auto array_data, + key_encoder.Decode(payload_ptrs.data(), kBatchLength, + ::arrow::default_memory_pool())); + ASSERT_EQ(kBatchLength, array_data->length); + auto boolean_array = std::make_shared(array_data); + ASSERT_OK(arrow::internal::ValidateArrayFull(*array_data)); + ASSERT_OK_AND_ASSIGN( + auto expected_array, + MakeArrayFromScalar(scalar, kBatchLength, ::arrow::default_memory_pool())); + AssertArraysEqual(*expected_array, *boolean_array); + } +} + +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/row/row_test.cc b/cpp/src/arrow/compute/row/row_test.cc index 6aed9e43278..5057ce91b5b 100644 --- a/cpp/src/arrow/compute/row/row_test.cc +++ b/cpp/src/arrow/compute/row/row_test.cc @@ -155,7 +155,7 @@ TEST(RowTableLarge, LARGE_MEMORY_TEST(Encode)) { auto value, ::arrow::gen::Constant( std::make_shared(std::string(length_per_binary, 'X'))) ->Generate(1)); - values.push_back(std::move(value)); + values.emplace_back(std::move(value)); ExecBatch batch = ExecBatch(std::move(values), 1); ASSERT_OK(ColumnArraysFromExecBatch(batch, &columns)); From 9d4dcc903e84732a6e14d61aece1f9a1d096f7c9 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Mon, 19 Aug 2024 14:01:57 -0500 Subject: [PATCH 037/157] GH-38847: [Documentation][C++] Explicitly note that compute is optional (#43629) ### Rationale for this change A user didn't know from reading just the compute documentation that compute is an optional feature. We can make that explicit ### What changes are included in this PR? Added a cross-reference to the optional features section ### Are these changes tested? No ### Are there any user-facing changes? No * GitHub Issue: #38847 Authored-by: Benjamin Kietzman Signed-off-by: Benjamin Kietzman --- docs/source/cpp/tutorials/compute_tutorial.rst | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/cpp/tutorials/compute_tutorial.rst b/docs/source/cpp/tutorials/compute_tutorial.rst index a650865d75c..72ebc35650d 100644 --- a/docs/source/cpp/tutorials/compute_tutorial.rst +++ b/docs/source/cpp/tutorials/compute_tutorial.rst @@ -39,7 +39,9 @@ Pre-requisites Before continuing, make sure you have: -1. An Arrow installation, which you can set up here: :doc:`/cpp/build_system` +1. An Arrow installation, which you can set up here: :doc:`/cpp/build_system`. + If you're compiling Arrow yourself, be sure you compile with the compute module + enabled (i.e., ``-DARROW_COMPUTE=ON``), see :ref:`cpp_build_optional_components`. 2. An understanding of basic Arrow data structures from :doc:`/cpp/tutorials/basic_arrow` @@ -50,7 +52,7 @@ Before running some computations, we need to fill in a couple gaps: 1. We need to include necessary headers. -2. ``A main()`` is needed to glue things together. +2. A ``main()`` is needed to glue things together. 3. We need data to play with. From 364e01441a1d437c4e833fc00ec76af3a6f342d7 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 09:48:06 +0900 Subject: [PATCH 038/157] MINOR: [Java] Bump org.apache.avro:avro from 1.11.3 to 1.12.0 in /java (#43564) Bumps org.apache.avro:avro from 1.11.3 to 1.12.0. [![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=org.apache.avro:avro&package-manager=maven&previous-version=1.11.3&new-version=1.12.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/dataset/pom.xml | 2 +- java/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index 74071a6c305..f3384fabbed 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -33,7 +33,7 @@ under the License. ../../../cpp/release-build/ 1.14.1 - 1.11.3 + 1.12.0 diff --git a/java/pom.xml b/java/pom.xml index 0466cad9237..45e9f07174b 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -102,7 +102,7 @@ under the License. 2.17.2 3.4.0 24.3.25 - 1.11.3 + 1.12.0 2 10.17.0 From b5726ea59e9e92dd99c687faf07e4c797b02ce7b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 09:51:27 +0900 Subject: [PATCH 039/157] MINOR: [Java] Bump org.apache.commons:commons-compress from 1.26.2 to 1.27.0 in /java (#43653) Bumps org.apache.commons:commons-compress from 1.26.2 to 1.27.0. [![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=org.apache.commons:commons-compress&package-manager=maven&previous-version=1.26.2&new-version=1.27.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/compression/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/compression/pom.xml b/java/compression/pom.xml index 8774f7cabde..a1f2bc861da 100644 --- a/java/compression/pom.xml +++ b/java/compression/pom.xml @@ -50,7 +50,7 @@ under the License. org.apache.commons commons-compress - 1.26.2 + 1.27.0 com.github.luben From 944d13660c952c05deb58a9a74f562947ef7ec16 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:18:10 +0900 Subject: [PATCH 040/157] MINOR: [Java] Bump error_prone_core.version from 2.29.2 to 2.30.0 in /java (#43656) Bumps `error_prone_core.version` from 2.29.2 to 2.30.0. Updates `com.google.errorprone:error_prone_annotations` from 2.29.2 to 2.30.0
Release notes

Sourced from com.google.errorprone:error_prone_annotations's releases.

Error Prone 2.30.0

New checks:

Closed issues: #632, #4487

Full changelog: https://github.com/google/error-prone/compare/v2.29.2...v2.30.0

Commits
  • 5ada179 Release Error Prone 2.30.0
  • af175b0 Don't fire the CanIgnoreReturnValueSuggester for `dagger.producers.Producti...
  • ba8f9a2 Do not update getters that override methods from a superclass.
  • a706e8d Add ability to suppress warning for the entire AutoValue class
  • 86df5cf Convert some simple blocks to return switches using yield
  • 474554a Remove // fall out comments, which are sometimes used to document an empty ...
  • ac7ebf5 Handle var in MustBeClosedChecker
  • ccd3ca6 Add handling of toBuilder()
  • d887307 Omit some unnecessary break statements when translating to -> switches
  • fe07236 Add Error Prone check for unnecessary boxed types in AutoValue classes.
  • Additional commits viewable in compare view

Updates `com.google.errorprone:error_prone_core` from 2.29.2 to 2.30.0
Release notes

Sourced from com.google.errorprone:error_prone_core's releases.

Error Prone 2.30.0

New checks:

Closed issues: #632, #4487

Full changelog: https://github.com/google/error-prone/compare/v2.29.2...v2.30.0

Commits
  • 5ada179 Release Error Prone 2.30.0
  • af175b0 Don't fire the CanIgnoreReturnValueSuggester for `dagger.producers.Producti...
  • ba8f9a2 Do not update getters that override methods from a superclass.
  • a706e8d Add ability to suppress warning for the entire AutoValue class
  • 86df5cf Convert some simple blocks to return switches using yield
  • 474554a Remove // fall out comments, which are sometimes used to document an empty ...
  • ac7ebf5 Handle var in MustBeClosedChecker
  • ccd3ca6 Add handling of toBuilder()
  • d887307 Omit some unnecessary break statements when translating to -> switches
  • fe07236 Add Error Prone check for unnecessary boxed types in AutoValue classes.
  • Additional commits viewable in compare view

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 45e9f07174b..1524dc32579 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -107,7 +107,7 @@ under the License. 2 10.17.0 true - 2.29.2 + 2.30.0 5.11.0 5.2.0 3.46.0 From 906934e4af9f2ec8a402cc87d15a11783fc99950 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 20 Aug 2024 12:19:07 +0900 Subject: [PATCH 041/157] MINOR: [Java] Bump com.h2database:h2 from 2.3.230 to 2.3.232 in /java (#43654) Bumps [com.h2database:h2](https://github.com/h2database/h2database) from 2.3.230 to 2.3.232.
Release notes

Sourced from com.h2database:h2's releases.

Version 2.3.232

Commits
  • 2e46a1c Merge remote-tracking branch 'h2database/master'
  • 5badbf9 in preparation for release
  • c0696ef Merge pull request #4113 from katzyn/uuid
  • 8f8e88c Don't cast to long and back
  • e0895be Fix building of documentation
  • 19d4428 Add optional version parameter to RANDOM_UUID function
  • bd9ac2f Merge pull request #4103 from katzyn/map_columns
  • 64f2fbe Pass mapped columns to table filters of subqueries
  • 74ed2b5 Merge pull request #4094 from andreitokar/issue_4075
  • 9d533f1 Merge pull request #4098 from katzyn/fixes
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=com.h2database:h2&package-manager=maven&previous-version=2.3.230&new-version=2.3.232)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/adapter/jdbc/pom.xml | 2 +- java/performance/pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/java/adapter/jdbc/pom.xml b/java/adapter/jdbc/pom.xml index 124cc535c25..099798a95cd 100644 --- a/java/adapter/jdbc/pom.xml +++ b/java/adapter/jdbc/pom.xml @@ -59,7 +59,7 @@ under the License. com.h2database h2 - 2.3.230 + 2.3.232 test diff --git a/java/performance/pom.xml b/java/performance/pom.xml index f6d3a26b4f3..9f4df1ff2e7 100644 --- a/java/performance/pom.xml +++ b/java/performance/pom.xml @@ -75,7 +75,7 @@ under the License. com.h2database h2 - 2.3.230 + 2.3.232 runtime From bd3953f01b2b443a2021027e9beb5e302f74f42d Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 20:20:59 -0700 Subject: [PATCH 042/157] MINOR: [C#] Bump Google.Protobuf from 3.27.0 to 3.27.3 in /csharp (#43754) Bumps Google.Protobuf from 3.27.0 to 3.27.3. [![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=Google.Protobuf&package-manager=nuget&previous-version=3.27.0&new-version=3.27.3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Curt Hagenlocher --- .../Apache.Arrow.Flight.TestWeb.csproj | 1 + 1 file changed, 1 insertion(+) diff --git a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj index e6c7e174fa3..14227e2c4eb 100644 --- a/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj +++ b/csharp/test/Apache.Arrow.Flight.TestWeb/Apache.Arrow.Flight.TestWeb.csproj @@ -5,6 +5,7 @@ + From 70a0189f30cbfe9484681f0d407aed5ca3f4467b Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 19 Aug 2024 20:23:45 -0700 Subject: [PATCH 043/157] MINOR: [C#] Bump System.Memory from 4.5.4 to 4.5.5 in /csharp (#43755) Bumps System.Memory from 4.5.4 to 4.5.5. [![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=System.Memory&package-manager=nuget&previous-version=4.5.4&new-version=4.5.5)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Curt Hagenlocher --- csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj | 1 + 1 file changed, 1 insertion(+) diff --git a/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj b/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj index a46f0d91935..9e1866f8416 100644 --- a/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj +++ b/csharp/src/Apache.Arrow.Flight/Apache.Arrow.Flight.csproj @@ -8,6 +8,7 @@ +
From b0317f2b2b62b3be9beb8d834aa51b776fb0179e Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 20 Aug 2024 17:04:33 +0900 Subject: [PATCH 044/157] GH-43707: [Python] Fix compilation on Cython<3 (#43765) ### Rationale for this change Fix compilation on Cython < 3 ### What changes are included in this PR? Add an explicit cast ### Are these changes tested? N/A ### Are there any user-facing changes? No * GitHub Issue: #43707 Authored-by: David Li Signed-off-by: Joris Van den Bossche --- python/pyarrow/types.pxi | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 93d68fb8478..dcd2b61c334 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -5328,8 +5328,9 @@ def opaque(DataType storage_type, str type_name not None, str vendor_name not No cdef: c_string c_type_name = tobytes(type_name) c_string c_vendor_name = tobytes(vendor_name) - shared_ptr[CDataType] c_type = make_shared[COpaqueType]( + shared_ptr[COpaqueType] c_opaque_type = make_shared[COpaqueType]( storage_type.sp_type, c_type_name, c_vendor_name) + shared_ptr[CDataType] c_type = static_pointer_cast[CDataType, COpaqueType](c_opaque_type) OpaqueType out = OpaqueType.__new__(OpaqueType) out.init(c_type) return out From cc3c868aea7317a58447658f1c165ad352cd4865 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue, 20 Aug 2024 16:57:57 +0200 Subject: [PATCH 045/157] MINOR: [Documentation] Add installation of ninja-build to Python Development docs (#43600) ### Rationale for this change Otherwise, you get a CMake error: ``` CMake Error: CMake was unable to find a build program corresponding to "Ninja". CMAKE_MAKE_PROGRAM is not set. You probably need to select a different build tool. ``` Authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Signed-off-by: Joris Van den Bossche --- docs/source/developers/python.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/developers/python.rst b/docs/source/developers/python.rst index 2f3e892ce8e..6beea55e66b 100644 --- a/docs/source/developers/python.rst +++ b/docs/source/developers/python.rst @@ -267,7 +267,7 @@ On Debian/Ubuntu, you need the following minimal set of dependencies: .. code-block:: - $ sudo apt-get install build-essential cmake python3-dev + $ sudo apt-get install build-essential ninja-build cmake python3-dev Now, let's create a Python virtual environment with all Python dependencies in the same folder as the repositories, and a target installation folder: From 525881987d0b9b4f464c3e3593a9a7b4e3c767d0 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Tue, 20 Aug 2024 20:25:19 -0400 Subject: [PATCH 046/157] GH-17682: [C++][Python] Bool8 Extension Type Implementation (#43488) ### Rationale for this change C++ and Python implementations of #43234 ### What changes are included in this PR? - Implement C++ `Bool8Type`, `Bool8Array`, `Bool8Scalar`, and tests - Implement Python bindings to C++, as well as zero-copy numpy conversion methods - TODO: docs waiting for rebase on #43458 ### Are these changes tested? Yes ### Are there any user-facing changes? Bool8 extension type will be available in C++ and Python libraries * GitHub Issue: #17682 Authored-by: Joel Lubinitsky Signed-off-by: Felipe Oliveira Carvalho --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/extension/CMakeLists.txt | 6 + cpp/src/arrow/extension/bool8.cc | 61 ++++++++ cpp/src/arrow/extension/bool8.h | 58 ++++++++ cpp/src/arrow/extension/bool8_test.cc | 91 ++++++++++++ cpp/src/arrow/extension_type.cc | 7 +- python/pyarrow/__init__.py | 7 +- python/pyarrow/array.pxi | 114 ++++++++++++++- python/pyarrow/includes/libarrow.pxd | 9 ++ python/pyarrow/lib.pxd | 3 + python/pyarrow/public-api.pxi | 2 + python/pyarrow/scalar.pxi | 23 ++- python/pyarrow/tests/test_extension_type.py | 152 ++++++++++++++++++++ python/pyarrow/tests/test_misc.py | 3 + python/pyarrow/types.pxi | 74 ++++++++++ 15 files changed, 604 insertions(+), 7 deletions(-) create mode 100644 cpp/src/arrow/extension/bool8.cc create mode 100644 cpp/src/arrow/extension/bool8.h create mode 100644 cpp/src/arrow/extension/bool8_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index fb785e1e957..fb7253b6fd6 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -906,6 +906,7 @@ endif() if(ARROW_JSON) arrow_add_object_library(ARROW_JSON + extension/bool8.cc extension/fixed_shape_tensor.cc extension/opaque.cc json/options.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index 6741ab602f5..fcd5fa529ab 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,6 +15,12 @@ # specific language governing permissions and limitations # under the License. +add_arrow_test(test + SOURCES + bool8_test.cc + PREFIX + "arrow-extension-bool8") + add_arrow_test(test SOURCES fixed_shape_tensor_test.cc diff --git a/cpp/src/arrow/extension/bool8.cc b/cpp/src/arrow/extension/bool8.cc new file mode 100644 index 00000000000..c081f0c2b28 --- /dev/null +++ b/cpp/src/arrow/extension/bool8.cc @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/extension/bool8.h" +#include "arrow/util/logging.h" + +namespace arrow::extension { + +bool Bool8Type::ExtensionEquals(const ExtensionType& other) const { + return extension_name() == other.extension_name(); +} + +std::string Bool8Type::ToString(bool show_metadata) const { + std::stringstream ss; + ss << "extension<" << this->extension_name() << ">"; + return ss.str(); +} + +std::string Bool8Type::Serialize() const { return ""; } + +Result> Bool8Type::Deserialize( + std::shared_ptr storage_type, const std::string& serialized_data) const { + if (storage_type->id() != Type::INT8) { + return Status::Invalid("Expected INT8 storage type, got ", storage_type->ToString()); + } + if (serialized_data != "") { + return Status::Invalid("Serialize data must be empty, got ", serialized_data); + } + return bool8(); +} + +std::shared_ptr Bool8Type::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.bool8", + internal::checked_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> Bool8Type::Make() { + return std::make_shared(); +} + +std::shared_ptr bool8() { return std::make_shared(); } + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/bool8.h b/cpp/src/arrow/extension/bool8.h new file mode 100644 index 00000000000..02e629b28a8 --- /dev/null +++ b/cpp/src/arrow/extension/bool8.h @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension_type.h" + +namespace arrow::extension { + +/// \brief Bool8 is an alternate representation for boolean +/// arrays using 8 bits instead of 1 bit per value. The underlying +/// storage type is int8. +class ARROW_EXPORT Bool8Array : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief Bool8 is an alternate representation for boolean +/// arrays using 8 bits instead of 1 bit per value. The underlying +/// storage type is int8. +class ARROW_EXPORT Bool8Type : public ExtensionType { + public: + /// \brief Construct a Bool8Type. + Bool8Type() : ExtensionType(int8()) {} + + std::string extension_name() const override { return "arrow.bool8"; } + std::string ToString(bool show_metadata = false) const override; + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::string Serialize() const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized_data) const override; + + /// Create a Bool8Array from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + static Result> Make(); +}; + +/// \brief Return a Bool8Type instance. +ARROW_EXPORT std::shared_ptr bool8(); + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/bool8_test.cc b/cpp/src/arrow/extension/bool8_test.cc new file mode 100644 index 00000000000..eabcfcf62d3 --- /dev/null +++ b/cpp/src/arrow/extension/bool8_test.cc @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/bool8.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" +#include "arrow/testing/extension_type.h" +#include "arrow/testing/gtest_util.h" + +namespace arrow { + +TEST(Bool8Type, Basics) { + auto type = internal::checked_pointer_cast(extension::bool8()); + auto type2 = internal::checked_pointer_cast(extension::bool8()); + ASSERT_EQ("arrow.bool8", type->extension_name()); + ASSERT_EQ(*type, *type); + ASSERT_NE(*arrow::null(), *type); + ASSERT_EQ(*type, *type2); + ASSERT_EQ(*arrow::int8(), *type->storage_type()); + ASSERT_EQ("", type->Serialize()); + ASSERT_EQ("extension", type->ToString(false)); +} + +TEST(Bool8Type, CreateFromArray) { + auto type = internal::checked_pointer_cast(extension::bool8()); + auto storage = ArrayFromJSON(int8(), "[-1,0,1,2,null]"); + auto array = ExtensionType::WrapArray(type, storage); + ASSERT_EQ(5, array->length()); + ASSERT_EQ(1, array->null_count()); +} + +TEST(Bool8Type, Deserialize) { + auto type = internal::checked_pointer_cast(extension::bool8()); + ASSERT_OK_AND_ASSIGN(auto deserialized, type->Deserialize(type->storage_type(), "")); + ASSERT_EQ(*type, *deserialized); + ASSERT_NOT_OK(type->Deserialize(type->storage_type(), "must be empty")); + ASSERT_EQ(*type, *deserialized); + ASSERT_NOT_OK(type->Deserialize(uint8(), "")); + ASSERT_EQ(*type, *deserialized); +} + +TEST(Bool8Type, MetadataRoundTrip) { + auto type = internal::checked_pointer_cast(extension::bool8()); + std::string serialized = type->Serialize(); + ASSERT_OK_AND_ASSIGN(auto deserialized, + type->Deserialize(type->storage_type(), serialized)); + ASSERT_EQ(*type, *deserialized); +} + +TEST(Bool8Type, BatchRoundTrip) { + auto type = internal::checked_pointer_cast(extension::bool8()); + + auto storage = ArrayFromJSON(int8(), "[-1,0,1,2,null]"); + auto array = ExtensionType::WrapArray(type, storage); + auto batch = + RecordBatch::Make(schema({field("field", type)}), array->length(), {array}); + + std::shared_ptr written; + { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(&written)); + } + + ASSERT_EQ(*batch->schema(), *written->schema()); + ASSERT_BATCHES_EQUAL(*batch, *written); +} + +} // namespace arrow diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index cf8dda7a85d..685018f7de7 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -28,6 +28,7 @@ #include "arrow/chunked_array.h" #include "arrow/config.h" #ifdef ARROW_JSON +#include "arrow/extension/bool8.h" #include "arrow/extension/fixed_shape_tensor.h" #endif #include "arrow/status.h" @@ -146,10 +147,12 @@ static void CreateGlobalRegistry() { #ifdef ARROW_JSON // Register canonical extension types - auto ext_type = + auto fst_ext_type = checked_pointer_cast(extension::fixed_shape_tensor(int64(), {})); + ARROW_CHECK_OK(g_registry->RegisterType(fst_ext_type)); - ARROW_CHECK_OK(g_registry->RegisterType(ext_type)); + auto bool8_ext_type = checked_pointer_cast(extension::bool8()); + ARROW_CHECK_OK(g_registry->RegisterType(bool8_ext_type)); #endif } diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index aa7bab9f97e..807bcdc3150 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -174,6 +174,7 @@ def print_entry(label, value): run_end_encoded, fixed_shape_tensor, opaque, + bool8, field, type_for_alias, DataType, DictionaryType, StructType, @@ -184,7 +185,7 @@ def print_entry(label, value): FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, RunEndEncodedType, FixedShapeTensorType, OpaqueType, - PyExtensionType, UnknownExtensionType, + Bool8Type, PyExtensionType, UnknownExtensionType, register_extension_type, unregister_extension_type, DictionaryMemo, KeyValueMetadata, @@ -218,7 +219,7 @@ def print_entry(label, value): MonthDayNanoIntervalArray, Decimal128Array, Decimal256Array, StructArray, ExtensionArray, RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray, - scalar, NA, _NULL as NULL, Scalar, + Bool8Array, scalar, NA, _NULL as NULL, Scalar, NullScalar, BooleanScalar, Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar, @@ -235,7 +236,7 @@ def print_entry(label, value): FixedSizeBinaryScalar, DictionaryScalar, MapScalar, StructScalar, UnionScalar, RunEndEncodedScalar, ExtensionScalar, - FixedShapeTensorScalar, OpaqueScalar) + FixedShapeTensorScalar, OpaqueScalar, Bool8Scalar) # Buffers, allocation from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 6c40a21db96..4c3eb932326 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1581,7 +1581,7 @@ cdef class Array(_PandasConvertible): def to_numpy(self, zero_copy_only=True, writable=False): """ - Return a NumPy view or copy of this array (experimental). + Return a NumPy view or copy of this array. By default, tries to return a view of this array. This is only supported for primitive arrays with the same memory layout as NumPy @@ -4476,6 +4476,118 @@ cdef class OpaqueArray(ExtensionArray): """ +cdef class Bool8Array(ExtensionArray): + """ + Concrete class for bool8 extension arrays. + + Examples + -------- + Define the extension type for an bool8 array + + >>> import pyarrow as pa + >>> bool8_type = pa.bool8() + + Create an extension array + + >>> arr = [-1, 0, 1, 2, None] + >>> storage = pa.array(arr, pa.int8()) + >>> pa.ExtensionArray.from_storage(bool8_type, storage) + + [ + -1, + 0, + 1, + 2, + null + ] + """ + + def to_numpy(self, zero_copy_only=True, writable=False): + """ + Return a NumPy bool view or copy of this array. + + By default, tries to return a view of this array. This is only + supported for arrays without any nulls. + + Parameters + ---------- + zero_copy_only : bool, default True + If True, an exception will be raised if the conversion to a numpy + array would require copying the underlying data (e.g. in presence + of nulls). + writable : bool, default False + For numpy arrays created with zero copy (view on the Arrow data), + the resulting array is not writable (Arrow data is immutable). + By setting this to True, a copy of the array is made to ensure + it is writable. + + Returns + ------- + array : numpy.ndarray + """ + if not writable: + try: + return self.storage.to_numpy().view(np.bool_) + except ArrowInvalid as e: + if zero_copy_only: + raise e + + return _pc().not_equal(self.storage, 0).to_numpy(zero_copy_only=zero_copy_only, writable=writable) + + @staticmethod + def from_storage(Int8Array storage): + """ + Construct Bool8Array from Int8Array storage. + + Parameters + ---------- + storage : Int8Array + The underlying storage for the result array. + + Returns + ------- + bool8_array : Bool8Array + """ + return ExtensionArray.from_storage(bool8(), storage) + + @staticmethod + def from_numpy(obj): + """ + Convert numpy array to a bool8 extension array without making a copy. + The input array must be 1-dimensional, with either bool_ or int8 dtype. + + Parameters + ---------- + obj : numpy.ndarray + + Returns + ------- + bool8_array : Bool8Array + + Examples + -------- + >>> import pyarrow as pa + >>> import numpy as np + >>> arr = np.array([True, False, True], dtype=np.bool_) + >>> pa.Bool8Array.from_numpy(arr) + + [ + 1, + 0, + 1 + ] + """ + + if obj.ndim != 1: + raise ValueError(f"Cannot convert {obj.ndim}-D array to bool8 array") + + if obj.dtype not in [np.bool_, np.int8]: + raise TypeError(f"Array dtype {obj.dtype} incompatible with bool8 storage") + + storage_arr = array(obj.view(np.int8), type=int8()) + return Bool8Array.from_storage(storage_arr) + + cdef dict _array_classes = { _Type_NA: NullArray, _Type_BOOL: BooleanArray, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 9b008d150f1..a54a1db292f 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2895,6 +2895,15 @@ cdef extern from "arrow/extension/opaque.h" namespace "arrow::extension" nogil: pass +cdef extern from "arrow/extension/bool8.h" namespace "arrow::extension" nogil: + cdef cppclass CBool8Type" arrow::extension::Bool8Type"(CExtensionType): + + @staticmethod + CResult[shared_ptr[CDataType]] Make() + + cdef cppclass CBool8Array" arrow::extension::Bool8Array"(CExtensionArray): + pass + cdef extern from "arrow/util/compression.h" namespace "arrow" nogil: cdef enum CCompressionType" arrow::Compression::type": CCompressionType_UNCOMPRESSED" arrow::Compression::UNCOMPRESSED" diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 2cb302d20a8..e3625c18152 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -214,6 +214,9 @@ cdef class FixedShapeTensorType(BaseExtensionType): cdef: const CFixedShapeTensorType* tensor_ext_type +cdef class Bool8Type(BaseExtensionType): + cdef: + const CBool8Type* bool8_ext_type cdef class OpaqueType(BaseExtensionType): cdef: diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 2f9fc1c5542..19a26bd6c68 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -126,6 +126,8 @@ cdef api object pyarrow_wrap_data_type( out = FixedShapeTensorType.__new__(FixedShapeTensorType) elif ext_type.extension_name() == b"arrow.opaque": out = OpaqueType.__new__(OpaqueType) + elif ext_type.extension_name() == b"arrow.bool8": + out = Bool8Type.__new__(Bool8Type) else: out = BaseExtensionType.__new__(BaseExtensionType) else: diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 12a99c2aece..72ae2aee5f8 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -1091,6 +1091,18 @@ cdef class OpaqueScalar(ExtensionScalar): """ +cdef class Bool8Scalar(ExtensionScalar): + """ + Concrete class for bool8 extension scalar. + """ + + def as_py(self): + """ + Return this scalar as a Python object. + """ + py_val = super().as_py() + return None if py_val is None else py_val != 0 + cdef dict _scalar_classes = { _Type_BOOL: BooleanScalar, _Type_UINT8: UInt8Scalar, @@ -1199,6 +1211,11 @@ def scalar(value, type=None, *, from_pandas=None, MemoryPool memory_pool=None): type = ensure_type(type, allow_none=True) pool = maybe_unbox_memory_pool(memory_pool) + extension_type = None + if type is not None and type.id == _Type_EXTENSION: + extension_type = type + type = type.storage_type + if _is_array_like(value): value = get_values(value, &is_pandas_object) @@ -1223,4 +1240,8 @@ def scalar(value, type=None, *, from_pandas=None, MemoryPool memory_pool=None): # retrieve the scalar from the first position scalar = GetResultValue(array.get().GetScalar(0)) - return Scalar.wrap(scalar) + result = Scalar.wrap(scalar) + + if extension_type is not None: + result = ExtensionScalar.from_storage(extension_type, result) + return result diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 58c54189f22..b04ee85ec99 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1707,3 +1707,155 @@ def test_opaque_type(pickle_module, storage_type, storage): # cast extension type -> storage type inner = arr.cast(storage_type) assert inner == storage + + +def test_bool8_type(pickle_module): + bool8_type = pa.bool8() + storage_type = pa.int8() + assert bool8_type.extension_name == "arrow.bool8" + assert bool8_type.storage_type == storage_type + assert str(bool8_type) == "extension" + + assert bool8_type == bool8_type + assert bool8_type == pa.bool8() + assert bool8_type != storage_type + + # Pickle roundtrip + result = pickle_module.loads(pickle_module.dumps(bool8_type)) + assert result == bool8_type + + # IPC roundtrip + storage = pa.array([-1, 0, 1, 2, None], storage_type) + arr = pa.ExtensionArray.from_storage(bool8_type, storage) + assert isinstance(arr, pa.Bool8Array) + + # extension is registered by default + buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) + batch = ipc_read_batch(buf) + + assert batch.column(0).type.extension_name == "arrow.bool8" + assert isinstance(batch.column(0), pa.Bool8Array) + + # cast storage -> extension type + result = storage.cast(bool8_type) + assert result == arr + + # cast extension type -> storage type + inner = arr.cast(storage_type) + assert inner == storage + + +def test_bool8_to_bool_conversion(): + bool_arr = pa.array([True, False, True, True, None], pa.bool_()) + bool8_arr = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([-1, 0, 1, 2, None], pa.int8()), + ) + + # cast extension type -> arrow boolean type + assert bool8_arr.cast(pa.bool_()) == bool_arr + + # cast arrow boolean type -> extension type, expecting canonical values + canonical_storage = pa.array([1, 0, 1, 1, None], pa.int8()) + canonical_bool8_arr = pa.ExtensionArray.from_storage(pa.bool8(), canonical_storage) + assert bool_arr.cast(pa.bool8()) == canonical_bool8_arr + + +def test_bool8_to_numpy_conversion(): + arr = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([-1, 0, 1, 2, None], pa.int8()), + ) + + # cannot zero-copy with nulls + with pytest.raises( + pa.ArrowInvalid, + match="Needed to copy 1 chunks with 1 nulls, but zero_copy_only was True", + ): + arr.to_numpy() + + # nullable conversion possible with a copy, but dest dtype is object + assert np.array_equal( + arr.to_numpy(zero_copy_only=False), + np.array([True, False, True, True, None], dtype=np.object_), + ) + + # zero-copy possible with non-null array + np_arr_no_nulls = np.array([True, False, True, True], dtype=np.bool_) + arr_no_nulls = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([-1, 0, 1, 2], pa.int8()), + ) + + arr_to_np = arr_no_nulls.to_numpy() + assert np.array_equal(arr_to_np, np_arr_no_nulls) + + # same underlying buffer + assert arr_to_np.ctypes.data == arr_no_nulls.buffers()[1].address + + # if the user requests a writable array, a copy should be performed + arr_to_np_writable = arr_no_nulls.to_numpy(zero_copy_only=False, writable=True) + assert np.array_equal(arr_to_np_writable, np_arr_no_nulls) + + # different underlying buffer + assert arr_to_np_writable.ctypes.data != arr_no_nulls.buffers()[1].address + + +def test_bool8_from_numpy_conversion(): + np_arr_no_nulls = np.array([True, False, True, True], dtype=np.bool_) + canonical_bool8_arr_no_nulls = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([1, 0, 1, 1], pa.int8()), + ) + + arr_from_np = pa.Bool8Array.from_numpy(np_arr_no_nulls) + assert arr_from_np == canonical_bool8_arr_no_nulls + + # same underlying buffer + assert arr_from_np.buffers()[1].address == np_arr_no_nulls.ctypes.data + + # conversion only valid for 1-D arrays + with pytest.raises( + ValueError, + match="Cannot convert 2-D array to bool8 array", + ): + pa.Bool8Array.from_numpy( + np.array([[True, False], [False, True]], dtype=np.bool_), + ) + + with pytest.raises( + ValueError, + match="Cannot convert 0-D array to bool8 array", + ): + pa.Bool8Array.from_numpy(np.bool_()) + + # must use compatible storage type + with pytest.raises( + TypeError, + match="Array dtype float64 incompatible with bool8 storage", + ): + pa.Bool8Array.from_numpy(np.array([1, 2, 3], dtype=np.float64)) + + +def test_bool8_scalar(): + assert pa.ExtensionScalar.from_storage(pa.bool8(), -1).as_py() is True + assert pa.ExtensionScalar.from_storage(pa.bool8(), 0).as_py() is False + assert pa.ExtensionScalar.from_storage(pa.bool8(), 1).as_py() is True + assert pa.ExtensionScalar.from_storage(pa.bool8(), 2).as_py() is True + assert pa.ExtensionScalar.from_storage(pa.bool8(), None).as_py() is None + + arr = pa.ExtensionArray.from_storage( + pa.bool8(), + pa.array([-1, 0, 1, 2, None], pa.int8()), + ) + assert arr[0].as_py() is True + assert arr[1].as_py() is False + assert arr[2].as_py() is True + assert arr[3].as_py() is True + assert arr[4].as_py() is None + + assert pa.scalar(-1, type=pa.bool8()).as_py() is True + assert pa.scalar(0, type=pa.bool8()).as_py() is False + assert pa.scalar(1, type=pa.bool8()).as_py() is True + assert pa.scalar(2, type=pa.bool8()).as_py() is True + assert pa.scalar(None, type=pa.bool8()).as_py() is None diff --git a/python/pyarrow/tests/test_misc.py b/python/pyarrow/tests/test_misc.py index 9a55a38177f..5d3471c7c35 100644 --- a/python/pyarrow/tests/test_misc.py +++ b/python/pyarrow/tests/test_misc.py @@ -250,6 +250,9 @@ def test_set_timezone_db_path_non_windows(): pa.OpaqueArray, pa.OpaqueScalar, pa.OpaqueType, + pa.Bool8Array, + pa.Bool8Scalar, + pa.Bool8Type, ]) def test_extension_type_constructor_errors(klass): # ARROW-2638: prevent calling extension class constructors directly diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index dcd2b61c334..563782f0c26 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1837,6 +1837,37 @@ cdef class FixedShapeTensorType(BaseExtensionType): return FixedShapeTensorScalar +cdef class Bool8Type(BaseExtensionType): + """ + Concrete class for bool8 extension type. + + Bool8 is an alternate representation for boolean + arrays using 8 bits instead of 1 bit per value. The underlying + storage type is int8. + + Examples + -------- + Create an instance of bool8 extension type: + + >>> import pyarrow as pa + >>> pa.bool8() + Bool8Type(extension) + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.bool8_ext_type = type.get() + + def __arrow_ext_class__(self): + return Bool8Array + + def __reduce__(self): + return bool8, () + + def __arrow_ext_scalar_class__(self): + return Bool8Scalar + + cdef class OpaqueType(BaseExtensionType): """ Concrete class for opaque extension type. @@ -5278,6 +5309,49 @@ def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=N return out +def bool8(): + """ + Create instance of bool8 extension type. + + Examples + -------- + Create an instance of bool8 extension type: + + >>> import pyarrow as pa + >>> type = pa.bool8() + >>> type + Bool8Type(extension) + + Inspect the data type: + + >>> type.storage_type + DataType(int8) + + Create a table with a bool8 array: + + >>> arr = [-1, 0, 1, 2, None] + >>> storage = pa.array(arr, pa.int8()) + >>> other = pa.ExtensionArray.from_storage(type, storage) + >>> pa.table([other], names=["unknown_col"]) + pyarrow.Table + unknown_col: extension + ---- + unknown_col: [[-1,0,1,2,null]] + + Returns + ------- + type : Bool8Type + """ + + cdef Bool8Type out = Bool8Type.__new__(Bool8Type) + + c_type = GetResultValue(CBool8Type.Make()) + + out.init(c_type) + + return out + + def opaque(DataType storage_type, str type_name not None, str vendor_name not None): """ Create instance of opaque extension type. From 27c22389579dd773d9701f5d3c743bbfca3bdb8e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 14:38:12 +0900 Subject: [PATCH 047/157] MINOR: [Java] Bump org.codehaus.mojo:exec-maven-plugin from 3.3.0 to 3.4.1 in /java (#43692) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [org.codehaus.mojo:exec-maven-plugin](https://github.com/mojohaus/exec-maven-plugin) from 3.3.0 to 3.4.1.
Release notes

Sourced from org.codehaus.mojo:exec-maven-plugin's releases.

3.4.1

🐛 Bug Fixes

📦 Dependency updates

👻 Maintenance

🔧 Build

3.4.0

🚀 New features and improvements

  • Allow <includePluginDependencies> to be specified for the exec:exec goal (#432) @​sebthom

🐛 Bug Fixes

📦 Dependency updates

👻 Maintenance

🔧 Build

Commits
  • 7b0be2c [maven-release-plugin] prepare release 3.4.1
  • 5ac4f80 Environment variable Path should be used as case-insensitive
  • cfb3a9f Use Maven4 enabled with GH Action
  • d0ded48 Use shared release drafter GH Action
  • 4c22954 Bump org.codehaus.mojo:mojo-parent from 84 to 85
  • a8c4f94 fix: NPE because declared MavenSession field hides field of superclass
  • a2b735f Remove redundant spotless configuration
  • 8e0e83c [maven-release-plugin] prepare for next development iteration
  • 6c4996f [maven-release-plugin] prepare release 3.4.0
  • c7ad671 Remove Log4j 1.2.x from ITs
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=org.codehaus.mojo:exec-maven-plugin&package-manager=maven&previous-version=3.3.0&new-version=3.4.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 1524dc32579..0f3e5760f2b 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -504,7 +504,7 @@ under the License. org.codehaus.mojo exec-maven-plugin - 3.3.0 + 3.4.1 org.codehaus.mojo From 4af1e491df7ac22217656668b65c3e8d55f5b5ab Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Aug 2024 14:56:44 +0900 Subject: [PATCH 048/157] MINOR: [Java] Bump io.grpc:grpc-bom from 1.65.0 to 1.66.0 in /java (#43657) Bumps [io.grpc:grpc-bom](https://github.com/grpc/grpc-java) from 1.65.0 to 1.66.0.
Release notes

Sourced from io.grpc:grpc-bom's releases.

v1.65.1

What's Changed

  • netty: Restore old behavior of NettyAdaptiveCumulator, but avoid using that class if Netty is on version 4.1.111 or later
Commits
  • cf78406 Bump version to 1.66.0
  • 33af0a7 Update README etc to reference 1.66.0
  • 19c9b99 xds: XdsClient should unsubscribe on last resource (#11264)
  • 752a045 Revert "Start 1.67.0 development cycle (#11416)" (#11428)
  • ef09d94 Revert "Introduce onResult2 in NameResolver Listener2 that returns Status (#1...
  • c37fb18 Start 1.67.0 development cycle
  • 9ba2f9d Introduce onResult2 in NameResolver Listener2 that returns Status (#11313)
  • 786523d xds: WRR rr_fallback should trigger with one endpoint weight
  • b108ed3 api: Give instruments a toString() including their name
  • eb4cdf7 Update MAINTAINERS.md (#11241)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=io.grpc:grpc-bom&package-manager=maven&previous-version=1.65.0&new-version=1.66.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 0f3e5760f2b..a73453df68f 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -97,7 +97,7 @@ under the License. 2.0.13 33.2.1-jre 4.1.112.Final - 1.65.0 + 1.66.0 3.25.4 2.17.2 3.4.0 From 9fc03015463a8f1cb616b088342b104fbc767a0c Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 21 Aug 2024 09:22:53 +0200 Subject: [PATCH 049/157] GH-43069: [Python] Use Py_IsFinalizing from pythoncapi_compat.h (#43767) ### Rationale for this change https://github.com/apache/arrow/pull/43540 already vendored `pythoncapi_compat.h`, so closing https://github.com/apache/arrow/issues/43069 by using this as well for `Py_IsFinalizing` (which was added in https://github.com/apache/arrow/pull/42034, and for which we opened that follow-up issue to use `pythoncapi_compat.h` instead) Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- python/pyarrow/src/arrow/python/udf.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index 2c1e97c3ea0..74f16899c47 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -24,14 +24,11 @@ #include "arrow/compute/kernel.h" #include "arrow/compute/row/grouper.h" #include "arrow/python/common.h" +#include "arrow/python/vendored/pythoncapi_compat.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" -// Py_IsFinalizing added in Python 3.13.0a4 -#if PY_VERSION_HEX < 0x030D00A4 -#define Py_IsFinalizing() _Py_IsFinalizing() -#endif namespace arrow { using compute::ExecSpan; using compute::Grouper; From e1e7c501019ac26c896d61fa0c129eee83da9b55 Mon Sep 17 00:00:00 2001 From: Oliver Layer Date: Wed, 21 Aug 2024 13:22:57 +0200 Subject: [PATCH 050/157] GH-40036: [C++] Azure file system write buffering & async writes (#43096) ### Rationale for this change See #40036. ### What changes are included in this PR? Write buffering and async writes (similar to what the S3 file system does) in the `ObjectAppendStream` for the Azure file system. With write buffering and async writes, the input scenario creation runtime in the tests (which uses the `ObjectAppendStream` against Azurite) decreased from ~25s (see [here](https://github.com/apache/arrow/issues/40036)) to ~800ms: ``` [ RUN ] TestAzuriteFileSystem.OpenInputFileMixedReadVsReadAt [ OK ] TestAzuriteFileSystem.OpenInputFileMixedReadVsReadAt (787 ms) ``` ### Are these changes tested? Added some tests with background writes enabled and disabled (some were taken from the S3 tests). Everything changed should be covered. ### Are there any user-facing changes? `AzureOptions` now allows for `background_writes` to be set (default: true). No breaking changes. ### Notes - The code in `DoWrite` is very similar to [the code in the S3 FS](https://github.com/apache/arrow/blob/edfa343eeca008513f0300924380e1b187cc976b/cpp/src/arrow/filesystem/s3fs.cc#L1753). Maybe this could be unified? I didn't see this in the scope of the PR though. * GitHub Issue: #40036 Lead-authored-by: Oliver Layer Co-authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/filesystem/azurefs.cc | 276 ++++++++++++++++++++--- cpp/src/arrow/filesystem/azurefs.h | 3 + cpp/src/arrow/filesystem/azurefs_test.cc | 264 ++++++++++++++++++---- 3 files changed, 471 insertions(+), 72 deletions(-) diff --git a/cpp/src/arrow/filesystem/azurefs.cc b/cpp/src/arrow/filesystem/azurefs.cc index 9b3c0c0c1d7..0bad8563397 100644 --- a/cpp/src/arrow/filesystem/azurefs.cc +++ b/cpp/src/arrow/filesystem/azurefs.cc @@ -22,6 +22,7 @@ #include "arrow/filesystem/azurefs.h" #include "arrow/filesystem/azurefs_internal.h" +#include "arrow/io/memory.h" // idenfity.hpp triggers -Wattributes warnings cause -Werror builds to fail, // so disable it for this file with pragmas. @@ -144,6 +145,9 @@ Status AzureOptions::ExtractFromUriQuery(const Uri& uri) { blob_storage_scheme = "http"; dfs_storage_scheme = "http"; } + } else if (kv.first == "background_writes") { + ARROW_ASSIGN_OR_RAISE(background_writes, + ::arrow::internal::ParseBoolean(kv.second)); } else { return Status::Invalid( "Unexpected query parameter in Azure Blob File System URI: '", kv.first, "'"); @@ -937,8 +941,8 @@ Status CommitBlockList(std::shared_ptr block_bl const std::vector& block_ids, const Blobs::CommitBlockListOptions& options) { try { - // CommitBlockList puts all block_ids in the latest element. That means in the case of - // overlapping block_ids the newly staged block ids will always replace the + // CommitBlockList puts all block_ids in the latest element. That means in the case + // of overlapping block_ids the newly staged block ids will always replace the // previously committed blocks. // https://learn.microsoft.com/en-us/rest/api/storageservices/put-block-list?tabs=microsoft-entra-id#request-body block_blob_client->CommitBlockList(block_ids, options); @@ -950,7 +954,34 @@ Status CommitBlockList(std::shared_ptr block_bl return Status::OK(); } +Status StageBlock(Blobs::BlockBlobClient* block_blob_client, const std::string& id, + Core::IO::MemoryBodyStream& content) { + try { + block_blob_client->StageBlock(id, content); + } catch (const Storage::StorageException& exception) { + return ExceptionToStatus( + exception, "StageBlock failed for '", block_blob_client->GetUrl(), + "' new_block_id: '", id, + "'. Staging new blocks is fundamental to streaming writes to blob storage."); + } + + return Status::OK(); +} + +/// Writes will be buffered up to this size (in bytes) before actually uploading them. +static constexpr int64_t kBlockUploadSizeBytes = 10 * 1024 * 1024; +/// The maximum size of a block in Azure Blob (as per docs). +static constexpr int64_t kMaxBlockSizeBytes = 4UL * 1024 * 1024 * 1024; + +/// This output stream, similar to other arrow OutputStreams, is not thread-safe. class ObjectAppendStream final : public io::OutputStream { + private: + struct UploadState; + + std::shared_ptr Self() { + return std::dynamic_pointer_cast(shared_from_this()); + } + public: ObjectAppendStream(std::shared_ptr block_blob_client, const io::IOContext& io_context, const AzureLocation& location, @@ -958,7 +989,8 @@ class ObjectAppendStream final : public io::OutputStream { const AzureOptions& options) : block_blob_client_(std::move(block_blob_client)), io_context_(io_context), - location_(location) { + location_(location), + background_writes_(options.background_writes) { if (metadata && metadata->size() != 0) { ArrowMetadataToCommitBlockListOptions(metadata, commit_block_list_options_); } else if (options.default_metadata && options.default_metadata->size() != 0) { @@ -1008,10 +1040,13 @@ class ObjectAppendStream final : public io::OutputStream { content_length_ = 0; } } + + upload_state_ = std::make_shared(); + if (content_length_ > 0) { ARROW_ASSIGN_OR_RAISE(auto block_list, GetBlockList(block_blob_client_)); for (auto block : block_list.CommittedBlocks) { - block_ids_.push_back(block.Name); + upload_state_->block_ids.push_back(block.Name); } } initialised_ = true; @@ -1031,12 +1066,34 @@ class ObjectAppendStream final : public io::OutputStream { if (closed_) { return Status::OK(); } + + if (current_block_) { + // Upload remaining buffer + RETURN_NOT_OK(AppendCurrentBlock()); + } + RETURN_NOT_OK(Flush()); block_blob_client_ = nullptr; closed_ = true; return Status::OK(); } + Future<> CloseAsync() override { + if (closed_) { + return Status::OK(); + } + + if (current_block_) { + // Upload remaining buffer + RETURN_NOT_OK(AppendCurrentBlock()); + } + + return FlushAsync().Then([self = Self()]() { + self->block_blob_client_ = nullptr; + self->closed_ = true; + }); + } + bool closed() const override { return closed_; } Status CheckClosed(const char* action) const { @@ -1052,11 +1109,11 @@ class ObjectAppendStream final : public io::OutputStream { } Status Write(const std::shared_ptr& buffer) override { - return DoAppend(buffer->data(), buffer->size(), buffer); + return DoWrite(buffer->data(), buffer->size(), buffer); } Status Write(const void* data, int64_t nbytes) override { - return DoAppend(data, nbytes); + return DoWrite(data, nbytes); } Status Flush() override { @@ -1066,20 +1123,111 @@ class ObjectAppendStream final : public io::OutputStream { // flush. This also avoids some unhandled errors when flushing in the destructor. return Status::OK(); } - return CommitBlockList(block_blob_client_, block_ids_, commit_block_list_options_); + + Future<> pending_blocks_completed; + { + std::unique_lock lock(upload_state_->mutex); + pending_blocks_completed = upload_state_->pending_blocks_completed; + } + + RETURN_NOT_OK(pending_blocks_completed.status()); + std::unique_lock lock(upload_state_->mutex); + return CommitBlockList(block_blob_client_, upload_state_->block_ids, + commit_block_list_options_); } - private: - Status DoAppend(const void* data, int64_t nbytes, - std::shared_ptr owned_buffer = nullptr) { - RETURN_NOT_OK(CheckClosed("append")); - auto append_data = reinterpret_cast(data); - Core::IO::MemoryBodyStream block_content(append_data, nbytes); - if (block_content.Length() == 0) { + Future<> FlushAsync() { + RETURN_NOT_OK(CheckClosed("flush async")); + if (!initialised_) { + // If the stream has not been successfully initialized then there is nothing to + // flush. This also avoids some unhandled errors when flushing in the destructor. return Status::OK(); } - const auto n_block_ids = block_ids_.size(); + Future<> pending_blocks_completed; + { + std::unique_lock lock(upload_state_->mutex); + pending_blocks_completed = upload_state_->pending_blocks_completed; + } + + return pending_blocks_completed.Then([self = Self()] { + std::unique_lock lock(self->upload_state_->mutex); + return CommitBlockList(self->block_blob_client_, self->upload_state_->block_ids, + self->commit_block_list_options_); + }); + } + + private: + Status AppendCurrentBlock() { + ARROW_ASSIGN_OR_RAISE(auto buf, current_block_->Finish()); + current_block_.reset(); + current_block_size_ = 0; + return AppendBlock(buf); + } + + Status DoWrite(const void* data, int64_t nbytes, + std::shared_ptr owned_buffer = nullptr) { + if (closed_) { + return Status::Invalid("Operation on closed stream"); + } + + const auto* data_ptr = reinterpret_cast(data); + auto advance_ptr = [this, &data_ptr, &nbytes](const int64_t offset) { + data_ptr += offset; + nbytes -= offset; + pos_ += offset; + content_length_ += offset; + }; + + // Handle case where we have some bytes buffered from prior calls. + if (current_block_size_ > 0) { + // Try to fill current buffer + const int64_t to_copy = + std::min(nbytes, kBlockUploadSizeBytes - current_block_size_); + RETURN_NOT_OK(current_block_->Write(data_ptr, to_copy)); + current_block_size_ += to_copy; + advance_ptr(to_copy); + + // If buffer isn't full, break + if (current_block_size_ < kBlockUploadSizeBytes) { + return Status::OK(); + } + + // Upload current buffer + RETURN_NOT_OK(AppendCurrentBlock()); + } + + // We can upload chunks without copying them into a buffer + while (nbytes >= kBlockUploadSizeBytes) { + const auto upload_size = std::min(nbytes, kMaxBlockSizeBytes); + RETURN_NOT_OK(AppendBlock(data_ptr, upload_size)); + advance_ptr(upload_size); + } + + // Buffer remaining bytes + if (nbytes > 0) { + current_block_size_ = nbytes; + + if (current_block_ == nullptr) { + ARROW_ASSIGN_OR_RAISE( + current_block_, + io::BufferOutputStream::Create(kBlockUploadSizeBytes, io_context_.pool())); + } else { + // Re-use the allocation from before. + RETURN_NOT_OK(current_block_->Reset(kBlockUploadSizeBytes, io_context_.pool())); + } + + RETURN_NOT_OK(current_block_->Write(data_ptr, current_block_size_)); + pos_ += current_block_size_; + content_length_ += current_block_size_; + } + + return Status::OK(); + } + + std::string CreateBlock() { + std::unique_lock lock(upload_state_->mutex); + const auto n_block_ids = upload_state_->block_ids.size(); // New block ID must always be distinct from the existing block IDs. Otherwise we // will accidentally replace the content of existing blocks, causing corruption. @@ -1093,36 +1241,106 @@ class ObjectAppendStream final : public io::OutputStream { new_block_id.insert(0, required_padding_digits, '0'); // There is a small risk when appending to a blob created by another client that // `new_block_id` may overlapping with an existing block id. Adding the `-arrow` - // suffix significantly reduces the risk, but does not 100% eliminate it. For example - // if the blob was previously created with one block, with id `00001-arrow` then the - // next block we append will conflict with that, and cause corruption. + // suffix significantly reduces the risk, but does not 100% eliminate it. For + // example if the blob was previously created with one block, with id `00001-arrow` + // then the next block we append will conflict with that, and cause corruption. new_block_id += "-arrow"; new_block_id = Core::Convert::Base64Encode( std::vector(new_block_id.begin(), new_block_id.end())); - try { - block_blob_client_->StageBlock(new_block_id, block_content); - } catch (const Storage::StorageException& exception) { - return ExceptionToStatus( - exception, "StageBlock failed for '", block_blob_client_->GetUrl(), - "' new_block_id: '", new_block_id, - "'. Staging new blocks is fundamental to streaming writes to blob storage."); + upload_state_->block_ids.push_back(new_block_id); + + // We only use the future if we have background writes enabled. Without background + // writes the future is initialized as finished and not mutated any more. + if (background_writes_ && upload_state_->blocks_in_progress++ == 0) { + upload_state_->pending_blocks_completed = Future<>::Make(); } - block_ids_.push_back(new_block_id); - pos_ += nbytes; - content_length_ += nbytes; + + return new_block_id; + } + + Status AppendBlock(const void* data, int64_t nbytes, + std::shared_ptr owned_buffer = nullptr) { + RETURN_NOT_OK(CheckClosed("append")); + + if (nbytes == 0) { + return Status::OK(); + } + + const auto block_id = CreateBlock(); + + if (background_writes_) { + if (owned_buffer == nullptr) { + ARROW_ASSIGN_OR_RAISE(owned_buffer, AllocateBuffer(nbytes, io_context_.pool())); + memcpy(owned_buffer->mutable_data(), data, nbytes); + } else { + DCHECK_EQ(data, owned_buffer->data()); + DCHECK_EQ(nbytes, owned_buffer->size()); + } + + // The closure keeps the buffer and the upload state alive + auto deferred = [owned_buffer, block_id, block_blob_client = block_blob_client_, + state = upload_state_]() mutable -> Status { + Core::IO::MemoryBodyStream block_content(owned_buffer->data(), + owned_buffer->size()); + + auto status = StageBlock(block_blob_client.get(), block_id, block_content); + HandleUploadOutcome(state, status); + return Status::OK(); + }; + RETURN_NOT_OK(io::internal::SubmitIO(io_context_, std::move(deferred))); + } else { + auto append_data = reinterpret_cast(data); + Core::IO::MemoryBodyStream block_content(append_data, nbytes); + + RETURN_NOT_OK(StageBlock(block_blob_client_.get(), block_id, block_content)); + } + return Status::OK(); } + Status AppendBlock(std::shared_ptr buffer) { + return AppendBlock(buffer->data(), buffer->size(), buffer); + } + + static void HandleUploadOutcome(const std::shared_ptr& state, + const Status& status) { + std::unique_lock lock(state->mutex); + if (!status.ok()) { + state->status &= status; + } + // Notify completion + if (--state->blocks_in_progress == 0) { + auto fut = state->pending_blocks_completed; + lock.unlock(); + fut.MarkFinished(state->status); + } + } + std::shared_ptr block_blob_client_; const io::IOContext io_context_; const AzureLocation location_; + const bool background_writes_; int64_t content_length_ = kNoSize; + std::shared_ptr current_block_; + int64_t current_block_size_ = 0; + bool closed_ = false; bool initialised_ = false; int64_t pos_ = 0; - std::vector block_ids_; + + // This struct is kept alive through background writes to avoid problems + // in the completion handler. + struct UploadState { + std::mutex mutex; + std::vector block_ids; + int64_t blocks_in_progress = 0; + Status status; + Future<> pending_blocks_completed = Future<>::MakeFinished(Status::OK()); + }; + std::shared_ptr upload_state_; + Blobs::CommitBlockListOptions commit_block_list_options_; }; diff --git a/cpp/src/arrow/filesystem/azurefs.h b/cpp/src/arrow/filesystem/azurefs.h index 072b061eeb2..ebbe00c4ee7 100644 --- a/cpp/src/arrow/filesystem/azurefs.h +++ b/cpp/src/arrow/filesystem/azurefs.h @@ -112,6 +112,9 @@ struct ARROW_EXPORT AzureOptions { /// This will be ignored if non-empty metadata is passed to OpenOutputStream. std::shared_ptr default_metadata; + /// Whether OutputStream writes will be issued in the background, without blocking. + bool background_writes = true; + private: enum class CredentialKind { kDefault, diff --git a/cpp/src/arrow/filesystem/azurefs_test.cc b/cpp/src/arrow/filesystem/azurefs_test.cc index 5ff241b17ff..9d437d1f83a 100644 --- a/cpp/src/arrow/filesystem/azurefs_test.cc +++ b/cpp/src/arrow/filesystem/azurefs_test.cc @@ -39,6 +39,7 @@ #include #include #include +#include #include #include @@ -53,6 +54,7 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" +#include "arrow/util/future.h" #include "arrow/util/io_util.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" @@ -566,6 +568,7 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, default_options.dfs_storage_scheme); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kDefault); ASSERT_EQ(path, "container/dir/blob"); + ASSERT_EQ(options.background_writes, true); } void TestFromUriDfsStorage() { @@ -582,6 +585,7 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, default_options.dfs_storage_scheme); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kDefault); ASSERT_EQ(path, "file_system/dir/file"); + ASSERT_EQ(options.background_writes, true); } void TestFromUriAbfs() { @@ -597,6 +601,7 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, "https"); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kStorageSharedKey); ASSERT_EQ(path, "container/dir/blob"); + ASSERT_EQ(options.background_writes, true); } void TestFromUriAbfss() { @@ -612,6 +617,7 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, "https"); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kStorageSharedKey); ASSERT_EQ(path, "container/dir/blob"); + ASSERT_EQ(options.background_writes, true); } void TestFromUriEnableTls() { @@ -628,6 +634,17 @@ class TestAzureOptions : public ::testing::Test { ASSERT_EQ(options.dfs_storage_scheme, "http"); ASSERT_EQ(options.credential_kind_, AzureOptions::CredentialKind::kStorageSharedKey); ASSERT_EQ(path, "container/dir/blob"); + ASSERT_EQ(options.background_writes, true); + } + + void TestFromUriDisableBackgroundWrites() { + std::string path; + ASSERT_OK_AND_ASSIGN(auto options, + AzureOptions::FromUri( + "abfs://account:password@127.0.0.1:10000/container/dir/blob?" + "background_writes=false", + &path)); + ASSERT_EQ(options.background_writes, false); } void TestFromUriCredentialDefault() { @@ -773,6 +790,9 @@ TEST_F(TestAzureOptions, FromUriDfsStorage) { TestFromUriDfsStorage(); } TEST_F(TestAzureOptions, FromUriAbfs) { TestFromUriAbfs(); } TEST_F(TestAzureOptions, FromUriAbfss) { TestFromUriAbfss(); } TEST_F(TestAzureOptions, FromUriEnableTls) { TestFromUriEnableTls(); } +TEST_F(TestAzureOptions, FromUriDisableBackgroundWrites) { + TestFromUriDisableBackgroundWrites(); +} TEST_F(TestAzureOptions, FromUriCredentialDefault) { TestFromUriCredentialDefault(); } TEST_F(TestAzureOptions, FromUriCredentialAnonymous) { TestFromUriCredentialAnonymous(); } TEST_F(TestAzureOptions, FromUriCredentialStorageSharedKey) { @@ -929,8 +949,9 @@ class TestAzureFileSystem : public ::testing::Test { void UploadLines(const std::vector& lines, const std::string& path, int total_size) { ASSERT_OK_AND_ASSIGN(auto output, fs()->OpenOutputStream(path, {})); - const auto all_lines = std::accumulate(lines.begin(), lines.end(), std::string("")); - ASSERT_OK(output->Write(all_lines)); + for (auto const& line : lines) { + ASSERT_OK(output->Write(line.data(), line.size())); + } ASSERT_OK(output->Close()); } @@ -1474,6 +1495,162 @@ class TestAzureFileSystem : public ::testing::Test { arrow::fs::AssertFileInfo(fs(), data.Path("dir/file0"), FileType::File); } + void AssertObjectContents(AzureFileSystem* fs, std::string_view path, + std::string_view expected) { + ASSERT_OK_AND_ASSIGN(auto input, fs->OpenInputStream(std::string{path})); + std::string contents; + std::shared_ptr buffer; + do { + ASSERT_OK_AND_ASSIGN(buffer, input->Read(128 * 1024)); + contents.append(buffer->ToString()); + } while (buffer->size() != 0); + + EXPECT_EQ(expected, contents); + } + + void TestOpenOutputStreamSmall() { + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + + auto data = SetUpPreexistingData(); + const auto path = data.ContainerPath("test-write-object"); + ASSERT_OK_AND_ASSIGN(auto output, fs->OpenOutputStream(path, {})); + const std::string_view expected(PreexistingData::kLoremIpsum); + ASSERT_OK(output->Write(expected)); + ASSERT_OK(output->Close()); + + // Verify we can read the object back. + AssertObjectContents(fs.get(), path, expected); + } + + void TestOpenOutputStreamLarge() { + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + + auto data = SetUpPreexistingData(); + const auto path = data.ContainerPath("test-write-object"); + ASSERT_OK_AND_ASSIGN(auto output, fs->OpenOutputStream(path, {})); + + // Upload 5 MB, 4 MB und 2 MB and a very small write to test varying sizes + std::vector sizes{5 * 1024 * 1024, 4 * 1024 * 1024, 2 * 1024 * 1024, + 2000}; + + std::vector buffers{}; + char current_char = 'A'; + for (const auto size : sizes) { + buffers.emplace_back(size, current_char++); + } + + auto expected_size = std::int64_t{0}; + for (size_t i = 0; i < buffers.size(); ++i) { + ASSERT_OK(output->Write(buffers[i])); + expected_size += sizes[i]; + ASSERT_EQ(expected_size, output->Tell()); + } + ASSERT_OK(output->Close()); + + AssertObjectContents(fs.get(), path, + buffers[0] + buffers[1] + buffers[2] + buffers[3]); + } + + void TestOpenOutputStreamLargeSingleWrite() { + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + + auto data = SetUpPreexistingData(); + const auto path = data.ContainerPath("test-write-object"); + ASSERT_OK_AND_ASSIGN(auto output, fs->OpenOutputStream(path, {})); + + constexpr std::int64_t size{12 * 1024 * 1024}; + const std::string large_string(size, 'X'); + + ASSERT_OK(output->Write(large_string)); + ASSERT_EQ(size, output->Tell()); + ASSERT_OK(output->Close()); + + AssertObjectContents(fs.get(), path, large_string); + } + + void TestOpenOutputStreamCloseAsync() { +#if defined(ADDRESS_SANITIZER) || defined(ARROW_VALGRIND) + // This false positive leak is similar to the one pinpointed in the + // have_false_positive_memory_leak_with_generator() comments above, + // though the stack trace is different. It happens when a block list + // is committed from a background thread. + // + // clang-format off + // Direct leak of 968 byte(s) in 1 object(s) allocated from: + // #0 calloc + // #1 (/lib/x86_64-linux-gnu/libxml2.so.2+0xe25a4) + // #2 __xmlDefaultBufferSize + // #3 xmlBufferCreate + // #4 Azure::Storage::_internal::XmlWriter::XmlWriter() + // #5 Azure::Storage::Blobs::_detail::BlockBlobClient::CommitBlockList + // #6 Azure::Storage::Blobs::BlockBlobClient::CommitBlockList + // #7 arrow::fs::(anonymous namespace)::CommitBlockList + // #8 arrow::fs::(anonymous namespace)::ObjectAppendStream::FlushAsync()::'lambda' + // clang-format on + // + // TODO perhaps remove this skip once we can rely on + // https://github.com/Azure/azure-sdk-for-cpp/pull/5767 + // + // Also note that ClickHouse has a workaround for a similar issue: + // https://github.com/ClickHouse/ClickHouse/pull/45796 + if (options_.background_writes) { + GTEST_SKIP() << "False positive memory leak in libxml2 with CloseAsync"; + } +#endif + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + auto data = SetUpPreexistingData(); + const std::string path = data.ContainerPath("test-write-object"); + constexpr auto payload = PreexistingData::kLoremIpsum; + + ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenOutputStream(path)); + ASSERT_OK(stream->Write(payload)); + auto close_fut = stream->CloseAsync(); + + ASSERT_OK(close_fut.MoveResult()); + + AssertObjectContents(fs.get(), path, payload); + } + + void TestOpenOutputStreamCloseAsyncDestructor() { +#if defined(ADDRESS_SANITIZER) || defined(ARROW_VALGRIND) + // See above. + if (options_.background_writes) { + GTEST_SKIP() << "False positive memory leak in libxml2 with CloseAsync"; + } +#endif + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + auto data = SetUpPreexistingData(); + const std::string path = data.ContainerPath("test-write-object"); + constexpr auto payload = PreexistingData::kLoremIpsum; + + ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenOutputStream(path)); + ASSERT_OK(stream->Write(payload)); + // Destructor implicitly closes stream and completes the upload. + // Testing it doesn't matter whether flush is triggered asynchronously + // after CloseAsync or synchronously after stream.reset() since we're just + // checking that the future keeps the stream alive until completion + // rather than segfaulting on a dangling stream. + auto close_fut = stream->CloseAsync(); + stream.reset(); + ASSERT_OK(close_fut.MoveResult()); + + AssertObjectContents(fs.get(), path, payload); + } + + void TestOpenOutputStreamDestructor() { + ASSERT_OK_AND_ASSIGN(auto fs, AzureFileSystem::Make(options_)); + constexpr auto* payload = "new data"; + auto data = SetUpPreexistingData(); + const std::string path = data.ContainerPath("test-write-object"); + + ASSERT_OK_AND_ASSIGN(auto stream, fs->OpenOutputStream(path)); + ASSERT_OK(stream->Write(payload)); + // Destructor implicitly closes stream and completes the multipart upload. + stream.reset(); + + AssertObjectContents(fs.get(), path, payload); + } + private: using StringMatcher = ::testing::PolymorphicMatcher<::testing::internal::HasSubstrMatcher>; @@ -2704,53 +2881,27 @@ TEST_F(TestAzuriteFileSystem, WriteMetadataHttpHeaders) { ASSERT_EQ("text/plain", content_type); } -TEST_F(TestAzuriteFileSystem, OpenOutputStreamSmall) { - auto data = SetUpPreexistingData(); - const auto path = data.ContainerPath("test-write-object"); - ASSERT_OK_AND_ASSIGN(auto output, fs()->OpenOutputStream(path, {})); - const std::string_view expected(PreexistingData::kLoremIpsum); - ASSERT_OK(output->Write(expected)); - ASSERT_OK(output->Close()); - - // Verify we can read the object back. - ASSERT_OK_AND_ASSIGN(auto input, fs()->OpenInputStream(path)); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamSmallNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamSmall(); +} - std::array inbuf{}; - ASSERT_OK_AND_ASSIGN(auto size, input->Read(inbuf.size(), inbuf.data())); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamSmall) { TestOpenOutputStreamSmall(); } - EXPECT_EQ(expected, std::string_view(inbuf.data(), size)); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamLargeNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamLarge(); } -TEST_F(TestAzuriteFileSystem, OpenOutputStreamLarge) { - auto data = SetUpPreexistingData(); - const auto path = data.ContainerPath("test-write-object"); - ASSERT_OK_AND_ASSIGN(auto output, fs()->OpenOutputStream(path, {})); - std::array sizes{257 * 1024, 258 * 1024, 259 * 1024}; - std::array buffers{ - std::string(sizes[0], 'A'), - std::string(sizes[1], 'B'), - std::string(sizes[2], 'C'), - }; - auto expected = std::int64_t{0}; - for (auto i = 0; i != 3; ++i) { - ASSERT_OK(output->Write(buffers[i])); - expected += sizes[i]; - ASSERT_EQ(expected, output->Tell()); - } - ASSERT_OK(output->Close()); - - // Verify we can read the object back. - ASSERT_OK_AND_ASSIGN(auto input, fs()->OpenInputStream(path)); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamLarge) { TestOpenOutputStreamLarge(); } - std::string contents; - std::shared_ptr buffer; - do { - ASSERT_OK_AND_ASSIGN(buffer, input->Read(128 * 1024)); - ASSERT_TRUE(buffer); - contents.append(buffer->ToString()); - } while (buffer->size() != 0); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamLargeSingleWriteNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamLargeSingleWrite(); +} - EXPECT_EQ(contents, buffers[0] + buffers[1] + buffers[2]); +TEST_F(TestAzuriteFileSystem, OpenOutputStreamLargeSingleWrite) { + TestOpenOutputStreamLargeSingleWrite(); } TEST_F(TestAzuriteFileSystem, OpenOutputStreamTruncatesExistingFile) { @@ -2820,6 +2971,33 @@ TEST_F(TestAzuriteFileSystem, OpenOutputStreamClosed) { ASSERT_RAISES(Invalid, output->Tell()); } +TEST_F(TestAzuriteFileSystem, OpenOutputStreamCloseAsync) { + TestOpenOutputStreamCloseAsync(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamCloseAsyncNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamCloseAsync(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamAsyncDestructor) { + TestOpenOutputStreamCloseAsyncDestructor(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamAsyncDestructorNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamCloseAsyncDestructor(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamDestructor) { + TestOpenOutputStreamDestructor(); +} + +TEST_F(TestAzuriteFileSystem, OpenOutputStreamDestructorNoBackgroundWrites) { + options_.background_writes = false; + TestOpenOutputStreamDestructor(); +} + TEST_F(TestAzuriteFileSystem, OpenOutputStreamUri) { auto data = SetUpPreexistingData(); const auto path = data.ContainerPath("open-output-stream-uri.txt"); From ffee537d88ab6d26614e2a1e85d4d18152695020 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 21 Aug 2024 14:18:45 +0200 Subject: [PATCH 051/157] GH-42222: [Python] Add bindings for CopyTo on RecordBatch and Array classes (#42223) ### Rationale for this change We have added bindings for the Device and MemoryManager classes (https://github.com/apache/arrow/issues/41126), and as a next step we can expose the functionality to copy a full Array or RecordBatch to a specific memory manager. ### What changes are included in this PR? This adds a `copy_to` method on pyarrow Array and RecordBatch. ### Are these changes tested? Yes * GitHub Issue: #42222 Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- python/pyarrow/array.pxi | 36 ++++++++++++ python/pyarrow/device.pxi | 6 ++ python/pyarrow/includes/libarrow.pxd | 4 ++ python/pyarrow/lib.pxd | 4 ++ python/pyarrow/table.pxi | 35 ++++++++++++ python/pyarrow/tests/test_cuda.py | 82 +++++++++++----------------- python/pyarrow/tests/test_device.py | 26 +++++++++ 7 files changed, 143 insertions(+), 50 deletions(-) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 4c3eb932326..77d6c9c06d2 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1702,6 +1702,42 @@ cdef class Array(_PandasConvertible): _append_array_buffers(self.sp_array.get().data().get(), res) return res + def copy_to(self, destination): + """ + Construct a copy of the array with all buffers on destination + device. + + This method recursively copies the array's buffers and those of its + children onto the destination MemoryManager device and returns the + new Array. + + Parameters + ---------- + destination : pyarrow.MemoryManager or pyarrow.Device + The destination device to copy the array to. + + Returns + ------- + Array + """ + cdef: + shared_ptr[CArray] c_array + shared_ptr[CMemoryManager] c_memory_manager + + if isinstance(destination, Device): + c_memory_manager = (destination).unwrap().get().default_memory_manager() + elif isinstance(destination, MemoryManager): + c_memory_manager = (destination).unwrap() + else: + raise TypeError( + "Argument 'destination' has incorrect type (expected a " + f"pyarrow Device or MemoryManager, got {type(destination)})" + ) + + with nogil: + c_array = GetResultValue(self.ap.CopyTo(c_memory_manager)) + return pyarrow_wrap_array(c_array) + def _export_to_c(self, out_ptr, out_schema_ptr=0): """ Export to a C ArrowArray struct, given its pointer. diff --git a/python/pyarrow/device.pxi b/python/pyarrow/device.pxi index 6e603475208..26256de6209 100644 --- a/python/pyarrow/device.pxi +++ b/python/pyarrow/device.pxi @@ -64,6 +64,9 @@ cdef class Device(_Weakrefable): self.init(device) return self + cdef inline shared_ptr[CDevice] unwrap(self) nogil: + return self.device + def __eq__(self, other): if not isinstance(other, Device): return False @@ -130,6 +133,9 @@ cdef class MemoryManager(_Weakrefable): self.init(mm) return self + cdef inline shared_ptr[CMemoryManager] unwrap(self) nogil: + return self.memory_manager + def __repr__(self): return "".format( frombytes(self.memory_manager.get().device().get().ToString()) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index a54a1db292f..6f510cfc0c0 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -234,7 +234,9 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: CStatus Validate() const CStatus ValidateFull() const CResult[shared_ptr[CArray]] View(const shared_ptr[CDataType]& type) + CDeviceAllocationType device_type() + CResult[shared_ptr[CArray]] CopyTo(const shared_ptr[CMemoryManager]& to) const shared_ptr[CArray] MakeArray(const shared_ptr[CArrayData]& data) CResult[shared_ptr[CArray]] MakeArrayOfNull( @@ -1027,6 +1029,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: shared_ptr[CRecordBatch] Slice(int64_t offset) shared_ptr[CRecordBatch] Slice(int64_t offset, int64_t length) + CResult[shared_ptr[CRecordBatch]] CopyTo(const shared_ptr[CMemoryManager]& to) const + CResult[shared_ptr[CTensor]] ToTensor(c_bool null_to_nan, c_bool row_major, CMemoryPool* pool) const diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index e3625c18152..a7c3b496a00 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -542,6 +542,8 @@ cdef class Device(_Weakrefable): @staticmethod cdef wrap(const shared_ptr[CDevice]& device) + cdef inline shared_ptr[CDevice] unwrap(self) nogil + cdef class MemoryManager(_Weakrefable): cdef: @@ -552,6 +554,8 @@ cdef class MemoryManager(_Weakrefable): @staticmethod cdef wrap(const shared_ptr[CMemoryManager]& mm) + cdef inline shared_ptr[CMemoryManager] unwrap(self) nogil + cdef class Buffer(_Weakrefable): cdef: diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 8f7c44e55dc..6d34c71c9df 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3569,6 +3569,41 @@ cdef class RecordBatch(_Tabular): row_major, pool)) return pyarrow_wrap_tensor(c_tensor) + def copy_to(self, destination): + """ + Copy the entire RecordBatch to destination device. + + This copies each column of the record batch to create + a new record batch where all underlying buffers for the columns have + been copied to the destination MemoryManager. + + Parameters + ---------- + destination : pyarrow.MemoryManager or pyarrow.Device + The destination device to copy the array to. + + Returns + ------- + RecordBatch + """ + cdef: + shared_ptr[CRecordBatch] c_batch + shared_ptr[CMemoryManager] c_memory_manager + + if isinstance(destination, Device): + c_memory_manager = (destination).unwrap().get().default_memory_manager() + elif isinstance(destination, MemoryManager): + c_memory_manager = (destination).unwrap() + else: + raise TypeError( + "Argument 'destination' has incorrect type (expected a " + f"pyarrow Device or MemoryManager, got {type(destination)})" + ) + + with nogil: + c_batch = GetResultValue(self.batch.CopyTo(c_memory_manager)) + return pyarrow_wrap_batch(c_batch) + def _export_to_c(self, out_ptr, out_schema_ptr=0): """ Export to a C ArrowArray struct, given its pointer. diff --git a/python/pyarrow/tests/test_cuda.py b/python/pyarrow/tests/test_cuda.py index 36b97a62064..d55be651b15 100644 --- a/python/pyarrow/tests/test_cuda.py +++ b/python/pyarrow/tests/test_cuda.py @@ -827,21 +827,29 @@ def test_IPC(size): assert p.exitcode == 0 -def _arr_copy_to_host(carr): - # TODO replace below with copy to device when exposed in python - buffers = [] - for cbuf in carr.buffers(): - if cbuf is None: - buffers.append(None) - else: - buf = global_context.foreign_buffer( - cbuf.address, cbuf.size, cbuf - ).copy_to_host() - buffers.append(buf) - - child = pa.Array.from_buffers(carr.type.value_type, 3, buffers[2:]) - new = pa.Array.from_buffers(carr.type, 2, buffers[:2], children=[child]) - return new +def test_copy_to(): + _, buf = make_random_buffer(size=10, target='device') + mm_cuda = buf.memory_manager + + for dest in [mm_cuda, mm_cuda.device]: + arr = pa.array([0, 1, 2]) + arr_cuda = arr.copy_to(dest) + assert not arr_cuda.buffers()[1].is_cpu + assert arr_cuda.buffers()[1].device_type == pa.DeviceAllocationType.CUDA + assert arr_cuda.buffers()[1].device == mm_cuda.device + + arr_roundtrip = arr_cuda.copy_to(pa.default_cpu_memory_manager()) + assert arr_roundtrip.equals(arr) + + batch = pa.record_batch({"col": arr}) + batch_cuda = batch.copy_to(dest) + buf_cuda = batch_cuda["col"].buffers()[1] + assert not buf_cuda.is_cpu + assert buf_cuda.device_type == pa.DeviceAllocationType.CUDA + assert buf_cuda.device == mm_cuda.device + + batch_roundtrip = batch_cuda.copy_to(pa.default_cpu_memory_manager()) + assert batch_roundtrip.equals(batch) def test_device_interface_array(): @@ -856,19 +864,10 @@ def test_device_interface_array(): typ = pa.list_(pa.int32()) arr = pa.array([[1], [2, 42]], type=typ) - # TODO replace below with copy to device when exposed in python - cbuffers = [] - for buf in arr.buffers(): - if buf is None: - cbuffers.append(None) - else: - cbuf = global_context.new_buffer(buf.size) - cbuf.copy_from_host(buf, position=0, nbytes=buf.size) - cbuffers.append(cbuf) - - carr = pa.Array.from_buffers(typ, 2, cbuffers[:2], children=[ - pa.Array.from_buffers(typ.value_type, 3, cbuffers[2:]) - ]) + # copy to device + _, buf = make_random_buffer(size=10, target='device') + mm_cuda = buf.memory_manager + carr = arr.copy_to(mm_cuda) # Type is known up front carr._export_to_c_device(ptr_array) @@ -882,7 +881,7 @@ def test_device_interface_array(): del carr carr_new = pa.Array._import_from_c_device(ptr_array, typ) assert carr_new.type == pa.list_(pa.int32()) - arr_new = _arr_copy_to_host(carr_new) + arr_new = carr_new.copy_to(pa.default_cpu_memory_manager()) assert arr_new.equals(arr) del carr_new @@ -891,15 +890,13 @@ def test_device_interface_array(): pa.Array._import_from_c_device(ptr_array, typ) # Schema is exported and imported at the same time - carr = pa.Array.from_buffers(typ, 2, cbuffers[:2], children=[ - pa.Array.from_buffers(typ.value_type, 3, cbuffers[2:]) - ]) + carr = arr.copy_to(mm_cuda) carr._export_to_c_device(ptr_array, ptr_schema) # Delete and recreate C++ objects from exported pointers del carr carr_new = pa.Array._import_from_c_device(ptr_array, ptr_schema) assert carr_new.type == pa.list_(pa.int32()) - arr_new = _arr_copy_to_host(carr_new) + arr_new = carr_new.copy_to(pa.default_cpu_memory_manager()) assert arr_new.equals(arr) del carr_new @@ -908,21 +905,6 @@ def test_device_interface_array(): pa.Array._import_from_c_device(ptr_array, ptr_schema) -def _batch_copy_to_host(cbatch): - # TODO replace below with copy to device when exposed in python - arrs = [] - for col in cbatch.columns: - buffers = [ - global_context.foreign_buffer(buf.address, buf.size, buf).copy_to_host() - if buf is not None else None - for buf in col.buffers() - ] - new = pa.Array.from_buffers(col.type, len(col), buffers) - arrs.append(new) - - return pa.RecordBatch.from_arrays(arrs, schema=cbatch.schema) - - def test_device_interface_batch_array(): cffi = pytest.importorskip("pyarrow.cffi") ffi = cffi.ffi @@ -949,7 +931,7 @@ def test_device_interface_batch_array(): del cbatch cbatch_new = pa.RecordBatch._import_from_c_device(ptr_array, schema) assert cbatch_new.schema == schema - batch_new = _batch_copy_to_host(cbatch_new) + batch_new = cbatch_new.copy_to(pa.default_cpu_memory_manager()) assert batch_new.equals(batch) del cbatch_new @@ -964,7 +946,7 @@ def test_device_interface_batch_array(): del cbatch cbatch_new = pa.RecordBatch._import_from_c_device(ptr_array, ptr_schema) assert cbatch_new.schema == schema - batch_new = _batch_copy_to_host(cbatch_new) + batch_new = cbatch_new.copy_to(pa.default_cpu_memory_manager()) assert batch_new.equals(batch) del cbatch_new diff --git a/python/pyarrow/tests/test_device.py b/python/pyarrow/tests/test_device.py index 6bdb015be1a..dc1a51e6d00 100644 --- a/python/pyarrow/tests/test_device.py +++ b/python/pyarrow/tests/test_device.py @@ -17,6 +17,8 @@ import pyarrow as pa +import pytest + def test_device_memory_manager(): mm = pa.default_cpu_memory_manager() @@ -41,3 +43,27 @@ def test_buffer_device(): assert buf.device.is_cpu assert buf.device == pa.default_cpu_memory_manager().device assert buf.memory_manager.is_cpu + + +def test_copy_to(): + mm = pa.default_cpu_memory_manager() + + arr = pa.array([0, 1, 2]) + batch = pa.record_batch({"col": arr}) + + for dest in [mm, mm.device]: + arr_copied = arr.copy_to(dest) + assert arr_copied.equals(arr) + assert arr_copied.buffers()[1].device == mm.device + assert arr_copied.buffers()[1].address != arr.buffers()[1].address + + batch_copied = batch.copy_to(dest) + assert batch_copied.equals(batch) + assert batch_copied["col"].buffers()[1].device == mm.device + assert batch_copied["col"].buffers()[1].address != arr.buffers()[1].address + + with pytest.raises(TypeError, match="Argument 'destination' has incorrect type"): + arr.copy_to(mm.device.device_type) + + with pytest.raises(TypeError, match="Argument 'destination' has incorrect type"): + batch.copy_to(mm.device.device_type) From f9911ee2ffc62fa946b2e1198bcdd13a757181fe Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 21 Aug 2024 14:37:47 +0200 Subject: [PATCH 052/157] GH-43776: [C++] Add chunked Take benchmarks with a small selection factor (#43772) This should help exercise the performance of chunked Take implementation on more use cases. * GitHub Issue: #43776 Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- .../kernels/vector_selection_benchmark.cc | 91 ++++++++++++++++--- 1 file changed, 80 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc index c2a27dfe434..75affd32560 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_benchmark.cc @@ -17,6 +17,7 @@ #include "benchmark/benchmark.h" +#include #include #include @@ -42,6 +43,9 @@ struct FilterParams { const double filter_null_proportion; }; +constexpr double kDefaultTakeSelectionFactor = 1.0; +constexpr double kSmallTakeSelectionFactor = 0.05; + std::vector g_data_sizes = {kL2Size}; // The benchmark state parameter references this vector of cases. Test high and @@ -104,14 +108,21 @@ struct TakeBenchmark { benchmark::State& state; RegressionArgs args; random::RandomArrayGenerator rand; + double selection_factor; bool indices_have_nulls; bool monotonic_indices = false; TakeBenchmark(benchmark::State& state, bool indices_have_nulls, bool monotonic_indices = false) + : TakeBenchmark(state, /*selection_factor=*/kDefaultTakeSelectionFactor, + indices_have_nulls, monotonic_indices) {} + + TakeBenchmark(benchmark::State& state, double selection_factor, bool indices_have_nulls, + bool monotonic_indices = false) : state(state), args(state, /*size_is_bytes=*/false), rand(kSeed), + selection_factor(selection_factor), indices_have_nulls(indices_have_nulls), monotonic_indices(monotonic_indices) {} @@ -185,10 +196,10 @@ struct TakeBenchmark { } void Bench(const std::shared_ptr& values) { - double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0; - auto indices = - rand.Int32(values->length(), 0, static_cast(values->length() - 1), - indices_null_proportion); + const double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0; + const int64_t num_indices = static_cast(selection_factor * values->length()); + auto indices = rand.Int32(num_indices, 0, static_cast(values->length() - 1), + indices_null_proportion); if (monotonic_indices) { auto arg_sorter = *SortIndices(*indices); @@ -198,14 +209,15 @@ struct TakeBenchmark { for (auto _ : state) { ABORT_NOT_OK(Take(values, indices).status()); } - state.SetItemsProcessed(state.iterations() * values->length()); + state.SetItemsProcessed(state.iterations() * num_indices); + state.counters["selection_factor"] = selection_factor; } void BenchChunked(const std::shared_ptr& values, bool chunk_indices_too) { double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0; - auto indices = - rand.Int32(values->length(), 0, static_cast(values->length() - 1), - indices_null_proportion); + const int64_t num_indices = static_cast(selection_factor * values->length()); + auto indices = rand.Int32(num_indices, 0, static_cast(values->length() - 1), + indices_null_proportion); if (monotonic_indices) { auto arg_sorter = *SortIndices(*indices); @@ -213,14 +225,26 @@ struct TakeBenchmark { } std::shared_ptr chunked_indices; if (chunk_indices_too) { + // Here we choose for indices chunks to have roughly the same length + // as values chunks, but there may be less of them if selection_factor < 1.0. + // The alternative is to have the same number of chunks, but with a potentially + // much smaller (and irrealistic) length. std::vector> indices_chunks; + // Make sure there are at least two chunks of indices + const auto max_chunk_length = indices->length() / 2 + 1; int64_t offset = 0; for (int i = 0; i < values->num_chunks(); ++i) { - auto chunk = indices->Slice(offset, values->chunk(i)->length()); + const auto chunk_length = std::min(max_chunk_length, values->chunk(i)->length()); + auto chunk = indices->Slice(offset, chunk_length); indices_chunks.push_back(std::move(chunk)); - offset += values->chunk(i)->length(); + offset += chunk_length; + if (offset >= indices->length()) { + break; + } } chunked_indices = std::make_shared(std::move(indices_chunks)); + ARROW_CHECK_EQ(chunked_indices->length(), num_indices); + ARROW_CHECK_GT(chunked_indices->num_chunks(), 1); } if (chunk_indices_too) { @@ -232,7 +256,8 @@ struct TakeBenchmark { ABORT_NOT_OK(Take(values, indices).status()); } } - state.SetItemsProcessed(state.iterations() * values->length()); + state.SetItemsProcessed(state.iterations() * num_indices); + state.counters["selection_factor"] = selection_factor; } }; @@ -432,12 +457,25 @@ static void TakeChunkedChunkedInt64RandomIndicesWithNulls(benchmark::State& stat .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/true); } +static void TakeChunkedChunkedInt64FewRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/true) + .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + static void TakeChunkedChunkedInt64MonotonicIndices(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) .ChunkedInt64( /*num_chunks=*/100, /*chunk_indices_too=*/true); } +static void TakeChunkedChunkedInt64FewMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedInt64( + /*num_chunks=*/100, /*chunk_indices_too=*/true); +} + static void TakeChunkedChunkedFSBRandomIndicesNoNulls(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false) .ChunkedFSB(/*num_chunks=*/100, /*chunk_indices_too=*/true); @@ -463,11 +501,23 @@ static void TakeChunkedChunkedStringRandomIndicesWithNulls(benchmark::State& sta .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); } +static void TakeChunkedChunkedStringFewRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/true) + .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + static void TakeChunkedChunkedStringMonotonicIndices(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); } +static void TakeChunkedChunkedStringFewMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + static void TakeChunkedFlatInt64RandomIndicesNoNulls(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false) .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/false); @@ -478,12 +528,25 @@ static void TakeChunkedFlatInt64RandomIndicesWithNulls(benchmark::State& state) .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/false); } +static void TakeChunkedFlatInt64FewRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/true) + .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/false); +} + static void TakeChunkedFlatInt64MonotonicIndices(benchmark::State& state) { TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) .ChunkedInt64( /*num_chunks=*/100, /*chunk_indices_too=*/false); } +static void TakeChunkedFlatInt64FewMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*selection_factor=*/kSmallTakeSelectionFactor, + /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedInt64( + /*num_chunks=*/100, /*chunk_indices_too=*/false); +} + void FilterSetArgs(benchmark::internal::Benchmark* bench) { for (int64_t size : g_data_sizes) { for (int i = 0; i < static_cast(g_filter_params.size()); ++i) { @@ -560,18 +623,24 @@ BENCHMARK(TakeStringMonotonicIndices)->Apply(TakeSetArgs); // Chunked values x Chunked indices BENCHMARK(TakeChunkedChunkedInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedInt64FewRandomIndicesWithNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedInt64MonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedInt64FewMonotonicIndices)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedFSBRandomIndicesNoNulls)->Apply(TakeFSBSetArgs); BENCHMARK(TakeChunkedChunkedFSBRandomIndicesWithNulls)->Apply(TakeFSBSetArgs); BENCHMARK(TakeChunkedChunkedFSBMonotonicIndices)->Apply(TakeFSBSetArgs); BENCHMARK(TakeChunkedChunkedStringRandomIndicesNoNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedStringRandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedStringFewRandomIndicesWithNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedChunkedStringMonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedStringFewMonotonicIndices)->Apply(TakeSetArgs); // Chunked values x Flat indices BENCHMARK(TakeChunkedFlatInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedFlatInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedFlatInt64FewRandomIndicesWithNulls)->Apply(TakeSetArgs); BENCHMARK(TakeChunkedFlatInt64MonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedFlatInt64FewMonotonicIndices)->Apply(TakeSetArgs); } // namespace compute } // namespace arrow From f078942ce2df68de8f48c3b4233132133601ca53 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 22 Aug 2024 02:59:04 +1200 Subject: [PATCH 053/157] GH-43141: [C++][Parquet] Replace use of int with int32_t in the internal Parquet encryption APIs (#43413) ### Rationale for this change See #43141 ### What changes are included in this PR? * Changes uses of int to int32_t in the Encryptor and Decryptor APIs, except where interfacing with OpenSSL. * Also change RandBytes to use size_t instead of int and check for overflow. * Check the return code from OpenSSL's Rand_bytes in case there is a failure generating random bytes ### Are these changes tested? Yes, this doesn't change behaviour and is covered by existing tests. ### Are there any user-facing changes? No * GitHub Issue: #43141 Authored-by: Adam Reeve Signed-off-by: Antoine Pitrou --- cpp/src/parquet/column_reader.cc | 4 +- cpp/src/parquet/encryption/crypto_factory.cc | 6 +- .../parquet/encryption/encryption_internal.cc | 251 ++++++++++-------- .../parquet/encryption/encryption_internal.h | 46 ++-- .../encryption/encryption_internal_nossl.cc | 47 ++-- .../encryption/encryption_internal_test.cc | 22 +- .../parquet/encryption/file_key_wrapper.cc | 4 +- .../encryption/internal_file_decryptor.cc | 12 +- .../encryption/internal_file_decryptor.h | 8 +- .../encryption/internal_file_encryptor.cc | 10 +- .../encryption/internal_file_encryptor.h | 6 +- .../encryption/key_toolkit_internal.cc | 2 +- cpp/src/parquet/metadata.cc | 6 +- cpp/src/parquet/thrift_internal.h | 2 +- 14 files changed, 233 insertions(+), 193 deletions(-) diff --git a/cpp/src/parquet/column_reader.cc b/cpp/src/parquet/column_reader.cc index 05ee6a16c54..60a8a2176b0 100644 --- a/cpp/src/parquet/column_reader.cc +++ b/cpp/src/parquet/column_reader.cc @@ -468,8 +468,8 @@ std::shared_ptr SerializedPageReader::NextPage() { // Advance the stream offset PARQUET_THROW_NOT_OK(stream_->Advance(header_size)); - int compressed_len = current_page_header_.compressed_page_size; - int uncompressed_len = current_page_header_.uncompressed_page_size; + int32_t compressed_len = current_page_header_.compressed_page_size; + int32_t uncompressed_len = current_page_header_.uncompressed_page_size; if (compressed_len < 0 || uncompressed_len < 0) { throw ParquetException("Invalid page header"); } diff --git a/cpp/src/parquet/encryption/crypto_factory.cc b/cpp/src/parquet/encryption/crypto_factory.cc index 72506bdc014..56069d55977 100644 --- a/cpp/src/parquet/encryption/crypto_factory.cc +++ b/cpp/src/parquet/encryption/crypto_factory.cc @@ -72,8 +72,7 @@ std::shared_ptr CryptoFactory::GetFileEncryptionProper int dek_length = dek_length_bits / 8; std::string footer_key(dek_length, '\0'); - RandBytes(reinterpret_cast(&footer_key[0]), - static_cast(footer_key.size())); + RandBytes(reinterpret_cast(footer_key.data()), footer_key.size()); std::string footer_key_metadata = key_wrapper.GetEncryptionKeyMetadata(footer_key, footer_key_id, true); @@ -148,8 +147,7 @@ ColumnPathToEncryptionPropertiesMap CryptoFactory::GetColumnEncryptionProperties } std::string column_key(dek_length, '\0'); - RandBytes(reinterpret_cast(&column_key[0]), - static_cast(column_key.size())); + RandBytes(reinterpret_cast(column_key.data()), column_key.size()); std::string column_key_key_metadata = key_wrapper->GetEncryptionKeyMetadata(column_key, column_key_id, false); diff --git a/cpp/src/parquet/encryption/encryption_internal.cc b/cpp/src/parquet/encryption/encryption_internal.cc index 99d1707f4a8..a0d9367b619 100644 --- a/cpp/src/parquet/encryption/encryption_internal.cc +++ b/cpp/src/parquet/encryption/encryption_internal.cc @@ -18,6 +18,7 @@ #include "parquet/encryption/encryption_internal.h" #include +#include #include #include @@ -36,10 +37,10 @@ using parquet::ParquetException; namespace parquet::encryption { -constexpr int kGcmMode = 0; -constexpr int kCtrMode = 1; -constexpr int kCtrIvLength = 16; -constexpr int kBufferSizeLength = 4; +constexpr int32_t kGcmMode = 0; +constexpr int32_t kCtrMode = 1; +constexpr int32_t kCtrIvLength = 16; +constexpr int32_t kBufferSizeLength = 4; #define ENCRYPT_INIT(CTX, ALG) \ if (1 != EVP_EncryptInit_ex(CTX, ALG, nullptr, nullptr, nullptr)) { \ @@ -53,17 +54,17 @@ constexpr int kBufferSizeLength = 4; class AesEncryptor::AesEncryptorImpl { public: - explicit AesEncryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata, + explicit AesEncryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length); ~AesEncryptorImpl() { WipeOut(); } - int Encrypt(span plaintext, span key, - span aad, span ciphertext); + int32_t Encrypt(span plaintext, span key, + span aad, span ciphertext); - int SignedFooterEncrypt(span footer, span key, - span aad, span nonce, - span encrypted_footer); + int32_t SignedFooterEncrypt(span footer, span key, + span aad, span nonce, + span encrypted_footer); void WipeOut() { if (nullptr != ctx_) { EVP_CIPHER_CTX_free(ctx_); @@ -89,21 +90,22 @@ class AesEncryptor::AesEncryptorImpl { private: EVP_CIPHER_CTX* ctx_; - int aes_mode_; - int key_length_; - int ciphertext_size_delta_; - int length_buffer_length_; + int32_t aes_mode_; + int32_t key_length_; + int32_t ciphertext_size_delta_; + int32_t length_buffer_length_; - int GcmEncrypt(span plaintext, span key, - span nonce, span aad, - span ciphertext); + int32_t GcmEncrypt(span plaintext, span key, + span nonce, span aad, + span ciphertext); - int CtrEncrypt(span plaintext, span key, - span nonce, span ciphertext); + int32_t CtrEncrypt(span plaintext, span key, + span nonce, span ciphertext); }; -AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int key_len, - bool metadata, bool write_length) { +AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, + int32_t key_len, bool metadata, + bool write_length) { openssl::EnsureInitialized(); ctx_ = nullptr; @@ -151,11 +153,9 @@ AesEncryptor::AesEncryptorImpl::AesEncryptorImpl(ParquetCipher::type alg_id, int } } -int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt(span footer, - span key, - span aad, - span nonce, - span encrypted_footer) { +int32_t AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt( + span footer, span key, span aad, + span nonce, span encrypted_footer) { if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -176,10 +176,10 @@ int AesEncryptor::AesEncryptorImpl::SignedFooterEncrypt(span foot return GcmEncrypt(footer, key, nonce, aad, encrypted_footer); } -int AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, - span key, - span aad, - span ciphertext) { +int32_t AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, + span key, + span aad, + span ciphertext) { if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -205,13 +205,13 @@ int AesEncryptor::AesEncryptorImpl::Encrypt(span plaintext, return CtrEncrypt(plaintext, key, nonce, ciphertext); } -int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, - span key, - span nonce, - span aad, - span ciphertext) { +int32_t AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, + span key, + span nonce, + span aad, + span ciphertext) { int len; - int ciphertext_len; + int32_t ciphertext_len; std::array tag{}; @@ -227,12 +227,22 @@ int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, } // Setting additional authenticated data + if (aad.size() > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "AAD size " << aad.size() << " overflows int"; + throw ParquetException(ss.str()); + } if ((!aad.empty()) && (1 != EVP_EncryptUpdate(ctx_, nullptr, &len, aad.data(), static_cast(aad.size())))) { throw ParquetException("Couldn't set AAD"); } // Encryption + if (plaintext.size() > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "Plaintext size " << plaintext.size() << " overflows int"; + throw ParquetException(ss.str()); + } if (1 != EVP_EncryptUpdate(ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength, &len, plaintext.data(), static_cast(plaintext.size()))) { @@ -256,7 +266,7 @@ int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, } // Copying the buffer size, nonce and tag to ciphertext - int buffer_size = kNonceLength + ciphertext_len + kGcmTagLength; + int32_t buffer_size = kNonceLength + ciphertext_len + kGcmTagLength; if (length_buffer_length_ > 0) { ciphertext[3] = static_cast(0xff & (buffer_size >> 24)); ciphertext[2] = static_cast(0xff & (buffer_size >> 16)); @@ -271,12 +281,12 @@ int AesEncryptor::AesEncryptorImpl::GcmEncrypt(span plaintext, return length_buffer_length_ + buffer_size; } -int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, - span key, - span nonce, - span ciphertext) { +int32_t AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, + span key, + span nonce, + span ciphertext) { int len; - int ciphertext_len; + int32_t ciphertext_len; if (nonce.size() != static_cast(kNonceLength)) { std::stringstream ss; @@ -298,6 +308,11 @@ int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, } // Encryption + if (plaintext.size() > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "Plaintext size " << plaintext.size() << " overflows int"; + throw ParquetException(ss.str()); + } if (1 != EVP_EncryptUpdate(ctx_, ciphertext.data() + length_buffer_length_ + kNonceLength, &len, plaintext.data(), static_cast(plaintext.size()))) { @@ -316,7 +331,7 @@ int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, ciphertext_len += len; // Copying the buffer size and nonce to ciphertext - int buffer_size = kNonceLength + ciphertext_len; + int32_t buffer_size = kNonceLength + ciphertext_len; if (length_buffer_length_ > 0) { ciphertext[3] = static_cast(0xff & (buffer_size >> 24)); ciphertext[2] = static_cast(0xff & (buffer_size >> 16)); @@ -331,9 +346,11 @@ int AesEncryptor::AesEncryptorImpl::CtrEncrypt(span plaintext, AesEncryptor::~AesEncryptor() {} -int AesEncryptor::SignedFooterEncrypt(span footer, span key, - span aad, span nonce, - span encrypted_footer) { +int32_t AesEncryptor::SignedFooterEncrypt(span footer, + span key, + span aad, + span nonce, + span encrypted_footer) { return impl_->SignedFooterEncrypt(footer, key, aad, nonce, encrypted_footer); } @@ -343,25 +360,25 @@ int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const { return impl_->CiphertextLength(plaintext_len); } -int AesEncryptor::Encrypt(span plaintext, span key, - span aad, span ciphertext) { +int32_t AesEncryptor::Encrypt(span plaintext, span key, + span aad, span ciphertext) { return impl_->Encrypt(plaintext, key, aad, ciphertext); } -AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata, +AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length) : impl_{std::unique_ptr( new AesEncryptorImpl(alg_id, key_len, metadata, write_length))} {} class AesDecryptor::AesDecryptorImpl { public: - explicit AesDecryptorImpl(ParquetCipher::type alg_id, int key_len, bool metadata, + explicit AesDecryptorImpl(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length); ~AesDecryptorImpl() { WipeOut(); } - int Decrypt(span ciphertext, span key, - span aad, span plaintext); + int32_t Decrypt(span ciphertext, span key, + span aad, span plaintext); void WipeOut() { if (nullptr != ctx_) { @@ -370,7 +387,7 @@ class AesDecryptor::AesDecryptorImpl { } } - [[nodiscard]] int PlaintextLength(int ciphertext_len) const { + [[nodiscard]] int32_t PlaintextLength(int32_t ciphertext_len) const { if (ciphertext_len < ciphertext_size_delta_) { std::stringstream ss; ss << "Ciphertext length " << ciphertext_len << " is invalid, expected at least " @@ -380,12 +397,13 @@ class AesDecryptor::AesDecryptorImpl { return ciphertext_len - ciphertext_size_delta_; } - [[nodiscard]] int CiphertextLength(int plaintext_len) const { + [[nodiscard]] int32_t CiphertextLength(int32_t plaintext_len) const { if (plaintext_len < 0) { std::stringstream ss; ss << "Negative plaintext length " << plaintext_len; throw ParquetException(ss.str()); - } else if (plaintext_len > std::numeric_limits::max() - ciphertext_size_delta_) { + } else if (plaintext_len > + std::numeric_limits::max() - ciphertext_size_delta_) { std::stringstream ss; ss << "Plaintext length " << plaintext_len << " plus ciphertext size delta " << ciphertext_size_delta_ << " overflows int32"; @@ -396,24 +414,24 @@ class AesDecryptor::AesDecryptorImpl { private: EVP_CIPHER_CTX* ctx_; - int aes_mode_; - int key_length_; - int ciphertext_size_delta_; - int length_buffer_length_; + int32_t aes_mode_; + int32_t key_length_; + int32_t ciphertext_size_delta_; + int32_t length_buffer_length_; /// Get the actual ciphertext length, inclusive of the length buffer length, /// and validate that the provided buffer size is large enough. - [[nodiscard]] int GetCiphertextLength(span ciphertext) const; + [[nodiscard]] int32_t GetCiphertextLength(span ciphertext) const; - int GcmDecrypt(span ciphertext, span key, - span aad, span plaintext); + int32_t GcmDecrypt(span ciphertext, span key, + span aad, span plaintext); - int CtrDecrypt(span ciphertext, span key, - span plaintext); + int32_t CtrDecrypt(span ciphertext, span key, + span plaintext); }; -int AesDecryptor::Decrypt(span ciphertext, span key, - span aad, span plaintext) { +int32_t AesDecryptor::Decrypt(span ciphertext, span key, + span aad, span plaintext) { return impl_->Decrypt(ciphertext, key, aad, plaintext); } @@ -421,8 +439,9 @@ void AesDecryptor::WipeOut() { impl_->WipeOut(); } AesDecryptor::~AesDecryptor() {} -AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, int key_len, - bool metadata, bool contains_length) { +AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, + int32_t key_len, bool metadata, + bool contains_length) { openssl::EnsureInitialized(); ctx_ = nullptr; @@ -469,13 +488,14 @@ AesDecryptor::AesDecryptorImpl::AesDecryptorImpl(ParquetCipher::type alg_id, int } } -std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, - bool metadata) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, + int32_t key_len, bool metadata) { return Make(alg_id, key_len, metadata, true /*write_length*/); } -std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, - bool metadata, bool write_length) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, + int32_t key_len, bool metadata, + bool write_length) { if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) { std::stringstream ss; ss << "Crypto algorithm " << alg_id << " is not supported"; @@ -485,13 +505,13 @@ std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int return std::make_unique(alg_id, key_len, metadata, write_length); } -AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata, +AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length) : impl_{std::unique_ptr( new AesDecryptorImpl(alg_id, key_len, metadata, contains_length))} {} std::shared_ptr AesDecryptor::Make( - ParquetCipher::type alg_id, int key_len, bool metadata, + ParquetCipher::type alg_id, int32_t key_len, bool metadata, std::vector>* all_decryptors) { if (ParquetCipher::AES_GCM_V1 != alg_id && ParquetCipher::AES_GCM_CTR_V1 != alg_id) { std::stringstream ss; @@ -506,15 +526,15 @@ std::shared_ptr AesDecryptor::Make( return decryptor; } -int AesDecryptor::PlaintextLength(int ciphertext_len) const { +int32_t AesDecryptor::PlaintextLength(int32_t ciphertext_len) const { return impl_->PlaintextLength(ciphertext_len); } -int AesDecryptor::CiphertextLength(int plaintext_len) const { +int32_t AesDecryptor::CiphertextLength(int32_t plaintext_len) const { return impl_->CiphertextLength(plaintext_len); } -int AesDecryptor::AesDecryptorImpl::GetCiphertextLength( +int32_t AesDecryptor::AesDecryptorImpl::GetCiphertextLength( span ciphertext) const { if (length_buffer_length_ > 0) { // Note: length_buffer_length_ must be either 0 or kBufferSizeLength @@ -533,10 +553,11 @@ int AesDecryptor::AesDecryptorImpl::GetCiphertextLength( (static_cast(ciphertext[0])); if (written_ciphertext_len > - static_cast(std::numeric_limits::max() - length_buffer_length_)) { + static_cast(std::numeric_limits::max() - + length_buffer_length_)) { std::stringstream ss; ss << "Written ciphertext length " << written_ciphertext_len - << " plus length buffer length " << length_buffer_length_ << " overflows int"; + << " plus length buffer length " << length_buffer_length_ << " overflows int32"; throw ParquetException(ss.str()); } else if (ciphertext.size() < static_cast(written_ciphertext_len) + length_buffer_length_) { @@ -548,28 +569,28 @@ int AesDecryptor::AesDecryptorImpl::GetCiphertextLength( throw ParquetException(ss.str()); } - return static_cast(written_ciphertext_len) + length_buffer_length_; + return static_cast(written_ciphertext_len) + length_buffer_length_; } else { - if (ciphertext.size() > static_cast(std::numeric_limits::max())) { + if (ciphertext.size() > static_cast(std::numeric_limits::max())) { std::stringstream ss; - ss << "Ciphertext buffer length " << ciphertext.size() << " overflows int"; + ss << "Ciphertext buffer length " << ciphertext.size() << " overflows int32"; throw ParquetException(ss.str()); } - return static_cast(ciphertext.size()); + return static_cast(ciphertext.size()); } } -int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, - span key, - span aad, - span plaintext) { +int32_t AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, + span key, + span aad, + span plaintext) { int len; - int plaintext_len; + int32_t plaintext_len; std::array tag{}; std::array nonce{}; - int ciphertext_len = GetCiphertextLength(ciphertext); + int32_t ciphertext_len = GetCiphertextLength(ciphertext); if (plaintext.size() < static_cast(ciphertext_len) - ciphertext_size_delta_) { std::stringstream ss; @@ -597,16 +618,22 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, } // Setting additional authenticated data + if (aad.size() > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "AAD size " << aad.size() << " overflows int"; + throw ParquetException(ss.str()); + } if ((!aad.empty()) && (1 != EVP_DecryptUpdate(ctx_, nullptr, &len, aad.data(), static_cast(aad.size())))) { throw ParquetException("Couldn't set AAD"); } // Decryption - if (!EVP_DecryptUpdate( - ctx_, plaintext.data(), &len, - ciphertext.data() + length_buffer_length_ + kNonceLength, - ciphertext_len - length_buffer_length_ - kNonceLength - kGcmTagLength)) { + int decryption_length = + ciphertext_len - length_buffer_length_ - kNonceLength - kGcmTagLength; + if (!EVP_DecryptUpdate(ctx_, plaintext.data(), &len, + ciphertext.data() + length_buffer_length_ + kNonceLength, + decryption_length)) { throw ParquetException("Failed decryption update"); } @@ -626,15 +653,15 @@ int AesDecryptor::AesDecryptorImpl::GcmDecrypt(span ciphertext, return plaintext_len; } -int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, - span key, - span plaintext) { +int32_t AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, + span key, + span plaintext) { int len; - int plaintext_len; + int32_t plaintext_len; std::array iv{}; - int ciphertext_len = GetCiphertextLength(ciphertext); + int32_t ciphertext_len = GetCiphertextLength(ciphertext); if (plaintext.size() < static_cast(ciphertext_len) - ciphertext_size_delta_) { std::stringstream ss; @@ -665,9 +692,10 @@ int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, } // Decryption + int decryption_length = ciphertext_len - length_buffer_length_ - kNonceLength; if (!EVP_DecryptUpdate(ctx_, plaintext.data(), &len, ciphertext.data() + length_buffer_length_ + kNonceLength, - ciphertext_len - length_buffer_length_ - kNonceLength)) { + decryption_length)) { throw ParquetException("Failed decryption update"); } @@ -682,10 +710,10 @@ int AesDecryptor::AesDecryptorImpl::CtrDecrypt(span ciphertext, return plaintext_len; } -int AesDecryptor::AesDecryptorImpl::Decrypt(span ciphertext, - span key, - span aad, - span plaintext) { +int32_t AesDecryptor::AesDecryptorImpl::Decrypt(span ciphertext, + span key, + span aad, + span plaintext) { if (static_cast(key_length_) != key.size()) { std::stringstream ss; ss << "Wrong key length " << key.size() << ". Should be " << key_length_; @@ -758,9 +786,22 @@ void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD) { std::memcpy(AAD->data() + AAD->length() - 2, page_ordinal_bytes.data(), 2); } -void RandBytes(unsigned char* buf, int num) { +void RandBytes(unsigned char* buf, size_t num) { + if (num > static_cast(std::numeric_limits::max())) { + std::stringstream ss; + ss << "Length " << num << " for RandBytes overflows int"; + throw ParquetException(ss.str()); + } openssl::EnsureInitialized(); - RAND_bytes(buf, num); + int status = RAND_bytes(buf, static_cast(num)); + if (status != 1) { + const auto error_code = ERR_get_error(); + char buffer[256]; + ERR_error_string_n(error_code, buffer, sizeof(buffer)); + std::stringstream ss; + ss << "Failed to generate random bytes: " << buffer; + throw ParquetException(ss.str()); + } } void EnsureBackendInitialized() { openssl::EnsureInitialized(); } diff --git a/cpp/src/parquet/encryption/encryption_internal.h b/cpp/src/parquet/encryption/encryption_internal.h index c874b137ad1..d79ff56ad49 100644 --- a/cpp/src/parquet/encryption/encryption_internal.h +++ b/cpp/src/parquet/encryption/encryption_internal.h @@ -29,8 +29,8 @@ using parquet::ParquetCipher; namespace parquet::encryption { -constexpr int kGcmTagLength = 16; -constexpr int kNonceLength = 12; +constexpr int32_t kGcmTagLength = 16; +constexpr int32_t kNonceLength = 12; // Module types constexpr int8_t kFooter = 0; @@ -49,13 +49,13 @@ class PARQUET_EXPORT AesEncryptor { public: /// Can serve one key length only. Possible values: 16, 24, 32 bytes. /// If write_length is true, prepend ciphertext length to the ciphertext - explicit AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata, + explicit AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length = true); - static std::unique_ptr Make(ParquetCipher::type alg_id, int key_len, + static std::unique_ptr Make(ParquetCipher::type alg_id, int32_t key_len, bool metadata); - static std::unique_ptr Make(ParquetCipher::type alg_id, int key_len, + static std::unique_ptr Make(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length); ~AesEncryptor(); @@ -65,17 +65,17 @@ class PARQUET_EXPORT AesEncryptor { /// Encrypts plaintext with the key and aad. Key length is passed only for validation. /// If different from value in constructor, exception will be thrown. - int Encrypt(::arrow::util::span plaintext, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span ciphertext); + int32_t Encrypt(::arrow::util::span plaintext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span ciphertext); /// Encrypts plaintext footer, in order to compute footer signature (tag). - int SignedFooterEncrypt(::arrow::util::span footer, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span nonce, - ::arrow::util::span encrypted_footer); + int32_t SignedFooterEncrypt(::arrow::util::span footer, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span nonce, + ::arrow::util::span encrypted_footer); void WipeOut(); @@ -90,7 +90,7 @@ class PARQUET_EXPORT AesDecryptor { public: /// Can serve one key length only. Possible values: 16, 24, 32 bytes. /// If contains_length is true, expect ciphertext length prepended to the ciphertext - explicit AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata, + explicit AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length = true); /// \brief Factory function to create an AesDecryptor @@ -102,26 +102,26 @@ class PARQUET_EXPORT AesDecryptor { /// out when decryption is finished /// \return shared pointer to a new AesDecryptor static std::shared_ptr Make( - ParquetCipher::type alg_id, int key_len, bool metadata, + ParquetCipher::type alg_id, int32_t key_len, bool metadata, std::vector>* all_decryptors); ~AesDecryptor(); void WipeOut(); /// The size of the plaintext, for this cipher and the specified ciphertext length. - [[nodiscard]] int PlaintextLength(int ciphertext_len) const; + [[nodiscard]] int32_t PlaintextLength(int32_t ciphertext_len) const; /// The size of the ciphertext, for this cipher and the specified plaintext length. - [[nodiscard]] int CiphertextLength(int plaintext_len) const; + [[nodiscard]] int32_t CiphertextLength(int32_t plaintext_len) const; /// Decrypts ciphertext with the key and aad. Key length is passed only for /// validation. If different from value in constructor, exception will be thrown. /// The caller is responsible for ensuring that the plaintext buffer is at least as /// large as PlaintextLength(ciphertext_len). - int Decrypt(::arrow::util::span ciphertext, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span plaintext); + int32_t Decrypt(::arrow::util::span ciphertext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span plaintext); private: // PIMPL Idiom @@ -139,7 +139,7 @@ std::string CreateFooterAad(const std::string& aad_prefix_bytes); void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD); // Wraps OpenSSL RAND_bytes function -void RandBytes(unsigned char* buf, int num); +void RandBytes(unsigned char* buf, size_t num); // Ensure OpenSSL is initialized. // diff --git a/cpp/src/parquet/encryption/encryption_internal_nossl.cc b/cpp/src/parquet/encryption/encryption_internal_nossl.cc index 2cce83915d7..2a8162ed396 100644 --- a/cpp/src/parquet/encryption/encryption_internal_nossl.cc +++ b/cpp/src/parquet/encryption/encryption_internal_nossl.cc @@ -29,11 +29,11 @@ class AesEncryptor::AesEncryptorImpl {}; AesEncryptor::~AesEncryptor() {} -int AesEncryptor::SignedFooterEncrypt(::arrow::util::span footer, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span nonce, - ::arrow::util::span encrypted_footer) { +int32_t AesEncryptor::SignedFooterEncrypt(::arrow::util::span footer, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span nonce, + ::arrow::util::span encrypted_footer) { ThrowOpenSSLRequiredException(); return -1; } @@ -45,25 +45,25 @@ int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const { return -1; } -int AesEncryptor::Encrypt(::arrow::util::span plaintext, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span ciphertext) { +int32_t AesEncryptor::Encrypt(::arrow::util::span plaintext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span ciphertext) { ThrowOpenSSLRequiredException(); return -1; } -AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata, +AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool write_length) { ThrowOpenSSLRequiredException(); } class AesDecryptor::AesDecryptorImpl {}; -int AesDecryptor::Decrypt(::arrow::util::span ciphertext, - ::arrow::util::span key, - ::arrow::util::span aad, - ::arrow::util::span plaintext) { +int32_t AesDecryptor::Decrypt(::arrow::util::span ciphertext, + ::arrow::util::span key, + ::arrow::util::span aad, + ::arrow::util::span plaintext) { ThrowOpenSSLRequiredException(); return -1; } @@ -72,36 +72,37 @@ void AesDecryptor::WipeOut() { ThrowOpenSSLRequiredException(); } AesDecryptor::~AesDecryptor() {} -std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, - bool metadata) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, + int32_t key_len, bool metadata) { ThrowOpenSSLRequiredException(); return NULLPTR; } -std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, int key_len, - bool metadata, bool write_length) { +std::unique_ptr AesEncryptor::Make(ParquetCipher::type alg_id, + int32_t key_len, bool metadata, + bool write_length) { ThrowOpenSSLRequiredException(); return NULLPTR; } -AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata, +AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata, bool contains_length) { ThrowOpenSSLRequiredException(); } std::shared_ptr AesDecryptor::Make( - ParquetCipher::type alg_id, int key_len, bool metadata, + ParquetCipher::type alg_id, int32_t key_len, bool metadata, std::vector>* all_decryptors) { ThrowOpenSSLRequiredException(); return NULLPTR; } -int AesDecryptor::PlaintextLength(int ciphertext_len) const { +int32_t AesDecryptor::PlaintextLength(int32_t ciphertext_len) const { ThrowOpenSSLRequiredException(); return -1; } -int AesDecryptor::CiphertextLength(int plaintext_len) const { +int32_t AesDecryptor::CiphertextLength(int32_t plaintext_len) const { ThrowOpenSSLRequiredException(); return -1; } @@ -122,7 +123,7 @@ void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD) { ThrowOpenSSLRequiredException(); } -void RandBytes(unsigned char* buf, int num) { ThrowOpenSSLRequiredException(); } +void RandBytes(unsigned char* buf, size_t num) { ThrowOpenSSLRequiredException(); } void EnsureBackendInitialized() {} diff --git a/cpp/src/parquet/encryption/encryption_internal_test.cc b/cpp/src/parquet/encryption/encryption_internal_test.cc index 22e14663ea8..bf6607e3287 100644 --- a/cpp/src/parquet/encryption/encryption_internal_test.cc +++ b/cpp/src/parquet/encryption/encryption_internal_test.cc @@ -41,22 +41,22 @@ class TestAesEncryption : public ::testing::Test { encryptor.CiphertextLength(static_cast(plain_text_.size())); std::vector ciphertext(expected_ciphertext_len, '\0'); - int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), - str2span(aad_), ciphertext); + int32_t ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), + str2span(aad_), ciphertext); ASSERT_EQ(ciphertext_length, expected_ciphertext_len); AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length); - int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); + int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); std::vector decrypted_text(expected_plaintext_length, '\0'); - int plaintext_length = + int32_t plaintext_length = decryptor.Decrypt(ciphertext, str2span(key_), str2span(aad_), decrypted_text); std::string decrypted_text_str(decrypted_text.begin(), decrypted_text.end()); - ASSERT_EQ(plaintext_length, static_cast(plain_text_.size())); + ASSERT_EQ(plaintext_length, static_cast(plain_text_.size())); ASSERT_EQ(plaintext_length, expected_plaintext_length); ASSERT_EQ(decrypted_text_str, plain_text_); } @@ -68,10 +68,10 @@ class TestAesEncryption : public ::testing::Test { AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length); // Create ciphertext of all zeros, so the ciphertext length will be read as zero - const int ciphertext_length = 100; + constexpr int32_t ciphertext_length = 100; std::vector ciphertext(ciphertext_length, '\0'); - int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); + int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); std::vector decrypted_text(expected_plaintext_length, '\0'); EXPECT_THROW( @@ -89,12 +89,12 @@ class TestAesEncryption : public ::testing::Test { encryptor.CiphertextLength(static_cast(plain_text_.size())); std::vector ciphertext(expected_ciphertext_len, '\0'); - int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), - str2span(aad_), ciphertext); + int32_t ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_), + str2span(aad_), ciphertext); AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length); - int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); + int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length); std::vector decrypted_text(expected_plaintext_length, '\0'); ::arrow::util::span truncated_ciphertext(ciphertext.data(), @@ -105,7 +105,7 @@ class TestAesEncryption : public ::testing::Test { } private: - int key_length_ = 0; + int32_t key_length_ = 0; std::string key_; std::string aad_; std::string plain_text_; diff --git a/cpp/src/parquet/encryption/file_key_wrapper.cc b/cpp/src/parquet/encryption/file_key_wrapper.cc index 032ae45821a..8ce563e60d7 100644 --- a/cpp/src/parquet/encryption/file_key_wrapper.cc +++ b/cpp/src/parquet/encryption/file_key_wrapper.cc @@ -112,10 +112,10 @@ std::string FileKeyWrapper::GetEncryptionKeyMetadata(const std::string& data_key KeyEncryptionKey FileKeyWrapper::CreateKeyEncryptionKey( const std::string& master_key_id) { std::string kek_bytes(kKeyEncryptionKeyLength, '\0'); - RandBytes(reinterpret_cast(&kek_bytes[0]), kKeyEncryptionKeyLength); + RandBytes(reinterpret_cast(kek_bytes.data()), kKeyEncryptionKeyLength); std::string kek_id(kKeyEncryptionKeyIdLength, '\0'); - RandBytes(reinterpret_cast(&kek_id[0]), kKeyEncryptionKeyIdLength); + RandBytes(reinterpret_cast(kek_id.data()), kKeyEncryptionKeyIdLength); // Encrypt KEK with Master key std::string encoded_wrapped_kek = kms_client_->WrapKey(kek_bytes, master_key_id); diff --git a/cpp/src/parquet/encryption/internal_file_decryptor.cc b/cpp/src/parquet/encryption/internal_file_decryptor.cc index fae5ce1f7a8..53a2f8c0216 100644 --- a/cpp/src/parquet/encryption/internal_file_decryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_decryptor.cc @@ -33,16 +33,16 @@ Decryptor::Decryptor(std::shared_ptr aes_decryptor, aad_(aad), pool_(pool) {} -int Decryptor::PlaintextLength(int ciphertext_len) const { +int32_t Decryptor::PlaintextLength(int32_t ciphertext_len) const { return aes_decryptor_->PlaintextLength(ciphertext_len); } -int Decryptor::CiphertextLength(int plaintext_len) const { +int32_t Decryptor::CiphertextLength(int32_t plaintext_len) const { return aes_decryptor_->CiphertextLength(plaintext_len); } -int Decryptor::Decrypt(::arrow::util::span ciphertext, - ::arrow::util::span plaintext) { +int32_t Decryptor::Decrypt(::arrow::util::span ciphertext, + ::arrow::util::span plaintext) { return aes_decryptor_->Decrypt(ciphertext, str2span(key_), str2span(aad_), plaintext); } @@ -143,7 +143,7 @@ std::shared_ptr InternalFileDecryptor::GetFooterDecryptor( // Create both data and metadata decryptors to avoid redundant retrieval of key // from the key_retriever. - int key_len = static_cast(footer_key.size()); + auto key_len = static_cast(footer_key.size()); std::shared_ptr aes_metadata_decryptor; std::shared_ptr aes_data_decryptor; @@ -197,7 +197,7 @@ std::shared_ptr InternalFileDecryptor::GetColumnDecryptor( throw HiddenColumnException("HiddenColumnException, path=" + column_path); } - int key_len = static_cast(column_key.size()); + auto key_len = static_cast(column_key.size()); std::lock_guard lock(mutex_); auto aes_decryptor = encryption::AesDecryptor::Make(algorithm_, key_len, metadata, &all_decryptors_); diff --git a/cpp/src/parquet/encryption/internal_file_decryptor.h b/cpp/src/parquet/encryption/internal_file_decryptor.h index 8af3587acf8..08423de7fe9 100644 --- a/cpp/src/parquet/encryption/internal_file_decryptor.h +++ b/cpp/src/parquet/encryption/internal_file_decryptor.h @@ -45,10 +45,10 @@ class PARQUET_EXPORT Decryptor { void UpdateAad(const std::string& aad) { aad_ = aad; } ::arrow::MemoryPool* pool() { return pool_; } - [[nodiscard]] int PlaintextLength(int ciphertext_len) const; - [[nodiscard]] int CiphertextLength(int plaintext_len) const; - int Decrypt(::arrow::util::span ciphertext, - ::arrow::util::span plaintext); + [[nodiscard]] int32_t PlaintextLength(int32_t ciphertext_len) const; + [[nodiscard]] int32_t CiphertextLength(int32_t plaintext_len) const; + int32_t Decrypt(::arrow::util::span ciphertext, + ::arrow::util::span plaintext); private: std::shared_ptr aes_decryptor_; diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.cc b/cpp/src/parquet/encryption/internal_file_encryptor.cc index 285c2100be8..94094e6aca2 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.cc +++ b/cpp/src/parquet/encryption/internal_file_encryptor.cc @@ -35,8 +35,8 @@ int32_t Encryptor::CiphertextLength(int64_t plaintext_len) const { return aes_encryptor_->CiphertextLength(plaintext_len); } -int Encryptor::Encrypt(::arrow::util::span plaintext, - ::arrow::util::span ciphertext) { +int32_t Encryptor::Encrypt(::arrow::util::span plaintext, + ::arrow::util::span ciphertext) { return aes_encryptor_->Encrypt(plaintext, str2span(key_), str2span(aad_), ciphertext); } @@ -143,7 +143,7 @@ InternalFileEncryptor::InternalFileEncryptor::GetColumnEncryptor( return encryptor; } -int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) const { +int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int32_t key_len) const { if (key_len == 16) return 0; else if (key_len == 24) @@ -155,7 +155,7 @@ int InternalFileEncryptor::MapKeyLenToEncryptorArrayIndex(int key_len) const { encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor( ParquetCipher::type algorithm, size_t key_size) { - int key_len = static_cast(key_size); + auto key_len = static_cast(key_size); int index = MapKeyLenToEncryptorArrayIndex(key_len); if (meta_encryptor_[index] == nullptr) { meta_encryptor_[index] = encryption::AesEncryptor::Make(algorithm, key_len, true); @@ -165,7 +165,7 @@ encryption::AesEncryptor* InternalFileEncryptor::GetMetaAesEncryptor( encryption::AesEncryptor* InternalFileEncryptor::GetDataAesEncryptor( ParquetCipher::type algorithm, size_t key_size) { - int key_len = static_cast(key_size); + auto key_len = static_cast(key_size); int index = MapKeyLenToEncryptorArrayIndex(key_len); if (data_encryptor_[index] == nullptr) { data_encryptor_[index] = encryption::AesEncryptor::Make(algorithm, key_len, false); diff --git a/cpp/src/parquet/encryption/internal_file_encryptor.h b/cpp/src/parquet/encryption/internal_file_encryptor.h index 91b6e9fe5aa..5a3d743ce53 100644 --- a/cpp/src/parquet/encryption/internal_file_encryptor.h +++ b/cpp/src/parquet/encryption/internal_file_encryptor.h @@ -45,8 +45,8 @@ class PARQUET_EXPORT Encryptor { [[nodiscard]] int32_t CiphertextLength(int64_t plaintext_len) const; - int Encrypt(::arrow::util::span plaintext, - ::arrow::util::span ciphertext); + int32_t Encrypt(::arrow::util::span plaintext, + ::arrow::util::span ciphertext); bool EncryptColumnMetaData( bool encrypted_footer, @@ -103,7 +103,7 @@ class InternalFileEncryptor { encryption::AesEncryptor* GetDataAesEncryptor(ParquetCipher::type algorithm, size_t key_len); - int MapKeyLenToEncryptorArrayIndex(int key_len) const; + int MapKeyLenToEncryptorArrayIndex(int32_t key_len) const; }; } // namespace parquet diff --git a/cpp/src/parquet/encryption/key_toolkit_internal.cc b/cpp/src/parquet/encryption/key_toolkit_internal.cc index 5d7925aa031..89a52a2bcd6 100644 --- a/cpp/src/parquet/encryption/key_toolkit_internal.cc +++ b/cpp/src/parquet/encryption/key_toolkit_internal.cc @@ -53,7 +53,7 @@ std::string DecryptKeyLocally(const std::string& encoded_encrypted_key, static_cast(master_key.size()), false, false /*contains_length*/); - int decrypted_key_len = + int32_t decrypted_key_len = key_decryptor.PlaintextLength(static_cast(encrypted_key.size())); std::string decrypted_key(decrypted_key_len, '\0'); ::arrow::util::span decrypted_key_span( diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index 4f2aa6e3732..423154f8641 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -751,7 +751,7 @@ class FileMetaData::FileMetaDataImpl { std::shared_ptr encrypted_buffer = AllocateBuffer( file_decryptor_->pool(), aes_encryptor->CiphertextLength(serialized_len)); - uint32_t encrypted_len = aes_encryptor->SignedFooterEncrypt( + int32_t encrypted_len = aes_encryptor->SignedFooterEncrypt( serialized_data_span, str2span(key), str2span(aad), nonce, encrypted_buffer->mutable_span_as()); // Delete AES encryptor object. It was created only to verify the footer signature. @@ -799,7 +799,7 @@ class FileMetaData::FileMetaDataImpl { // encrypt the footer key std::vector encrypted_data(encryptor->CiphertextLength(serialized_len)); - int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); + int32_t encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); // write unencrypted footer PARQUET_THROW_NOT_OK(dst->Write(serialized_data, serialized_len)); @@ -1672,7 +1672,7 @@ class ColumnChunkMetaDataBuilder::ColumnChunkMetaDataBuilderImpl { serialized_len); std::vector encrypted_data(encryptor->CiphertextLength(serialized_len)); - int encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); + int32_t encrypted_len = encryptor->Encrypt(serialized_data_span, encrypted_data); const char* temp = const_cast(reinterpret_cast(encrypted_data.data())); diff --git a/cpp/src/parquet/thrift_internal.h b/cpp/src/parquet/thrift_internal.h index b21b0e07afb..e7bfd434c81 100644 --- a/cpp/src/parquet/thrift_internal.h +++ b/cpp/src/parquet/thrift_internal.h @@ -530,7 +530,7 @@ class ThriftSerializer { auto cipher_buffer = AllocateBuffer(encryptor->pool(), encryptor->CiphertextLength(out_length)); ::arrow::util::span out_span(out_buffer, out_length); - int cipher_buffer_len = + int32_t cipher_buffer_len = encryptor->Encrypt(out_span, cipher_buffer->mutable_span_as()); PARQUET_THROW_NOT_OK(out->Write(cipher_buffer->data(), cipher_buffer_len)); From 6a1d0520974355a749557c993841732d4fcf894c Mon Sep 17 00:00:00 2001 From: Devin Smith Date: Wed, 21 Aug 2024 18:12:45 -0700 Subject: [PATCH 054/157] GH-43717: [Java][FlightSQL] Add all ActionTypes to FlightSqlUtils.FLIGHT_SQL_ACTIONS (#43718) This adds all of the FlightSQL ActionTypes to FlightSqlUtils.FLIGHT_SQL_ACTIONS * GitHub Issue: #43717 Authored-by: Devin Smith Signed-off-by: David Li --- .../org/apache/arrow/flight/sql/FlightSqlUtils.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java index 9bb95047691..9e13e57d66c 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java @@ -82,7 +82,15 @@ public final class FlightSqlUtils { + "Response Message: N/A"); public static final List FLIGHT_SQL_ACTIONS = - ImmutableList.of(FLIGHT_SQL_CREATE_PREPARED_STATEMENT, FLIGHT_SQL_CLOSE_PREPARED_STATEMENT); + ImmutableList.of( + FLIGHT_SQL_BEGIN_SAVEPOINT, + FLIGHT_SQL_BEGIN_TRANSACTION, + FLIGHT_SQL_CREATE_PREPARED_STATEMENT, + FLIGHT_SQL_CLOSE_PREPARED_STATEMENT, + FLIGHT_SQL_CREATE_PREPARED_SUBSTRAIT_PLAN, + FLIGHT_SQL_CANCEL_QUERY, + FLIGHT_SQL_END_SAVEPOINT, + FLIGHT_SQL_END_TRANSACTION); /** * Helper to parse {@link com.google.protobuf.Any} objects to the specific protobuf object. From 2e83aa63d95a6fa380efdd5e5cb720a3154f9c93 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 22 Aug 2024 09:57:02 +0200 Subject: [PATCH 055/157] GH-43690: [Python][CI] Simplify python/requirements-wheel-test.txt file (#43691) ### Rationale for this change The current [requirements-wheel-test.txt](https://github.com/apache/arrow/blob/7c8909a144f2e8d593dc8ad363ac95b2865b04ca/python/requirements-wheel-test.txt) file has quite complex and detailed version pinning, varying per architecture. I think this can be simplified because we just want to test with some older version of numpy and pandas (and the exact version isn't that important). * GitHub Issue: #43690 Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- python/requirements-wheel-test.txt | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/python/requirements-wheel-test.txt b/python/requirements-wheel-test.txt index 46bedc13ba1..c7ff63e3395 100644 --- a/python/requirements-wheel-test.txt +++ b/python/requirements-wheel-test.txt @@ -5,22 +5,12 @@ pytest pytz tzdata; sys_platform == 'win32' -numpy==1.21.3; platform_system == "Linux" and platform_machine == "aarch64" and python_version < "3.11" -numpy==1.23.4; python_version == "3.11" -numpy==1.26.0; python_version >= "3.12" -numpy==1.19.5; platform_system == "Linux" and platform_machine != "aarch64" and python_version < "3.9" -numpy==1.21.3; platform_system == "Linux" and platform_machine != "aarch64" and python_version >= "3.9" and python_version < "3.11" -numpy==1.21.3; platform_system == "Darwin" and platform_machine == "arm64" and python_version < "3.11" -numpy==1.19.5; platform_system == "Darwin" and platform_machine != "arm64" and python_version < "3.9" -numpy==1.21.3; platform_system == "Darwin" and platform_machine != "arm64" and python_version >= "3.9" and python_version < "3.11" -numpy==1.19.5; platform_system == "Windows" and python_version < "3.9" -numpy==1.21.3; platform_system == "Windows" and python_version >= "3.9" and python_version < "3.11" +# We generally test with the oldest numpy version that supports a given Python +# version. However, there is no need to make this strictly the oldest version, +# so it can be broadened to have a single version specification across platforms. +# (`~=x.y.z` specifies a compatible release as `>=x.y.z, == x.y.*`) +numpy~=1.21.3; python_version < "3.11" +numpy~=1.23.2; python_version == "3.11" +numpy~=1.26.0; python_version == "3.12" -pandas<1.1.0; platform_system == "Linux" and platform_machine != "aarch64" and python_version < "3.8" -pandas; platform_system == "Linux" and platform_machine != "aarch64" and python_version >= "3.8" -pandas; platform_system == "Linux" and platform_machine == "aarch64" -pandas<1.1.0; platform_system == "Darwin" and platform_machine != "arm64" and python_version < "3.8" -pandas; platform_system == "Darwin" and platform_machine != "arm64" and python_version >= "3.8" -pandas; platform_system == "Darwin" and platform_machine == "arm64" -pandas<1.1.0; platform_system == "Windows" and python_version < "3.8" -pandas; platform_system == "Windows" and python_version >= "3.8" +pandas From fc54eadb72791288fc9681bbcc6c8a9d8d6fff1d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 22 Aug 2024 11:28:01 +0200 Subject: [PATCH 056/157] GH-43785: [Python][CI] Correct PARQUET_TEST_DATA path in wheel tests (#43786) ### Rationale for this change Starting with https://github.com/apache/arrow/pull/41580, the pyarrow tests now also rely on a file in the parquet-testing submodule. And the path to that directory is controlled by `PARQUET_TEST_DATA`, which appears to be set wrongly in the wheel test scripts, causing all wheel builds to fail at the moment. * GitHub Issue: #43785 Authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- ci/scripts/python_wheel_unix_test.sh | 2 +- ci/scripts/python_wheel_windows_test.bat | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ci/scripts/python_wheel_unix_test.sh b/ci/scripts/python_wheel_unix_test.sh index a25e5c51bdd..cf87a170567 100755 --- a/ci/scripts/python_wheel_unix_test.sh +++ b/ci/scripts/python_wheel_unix_test.sh @@ -54,7 +54,7 @@ export PYARROW_TEST_S3=${ARROW_S3} export PYARROW_TEST_TENSORFLOW=ON export ARROW_TEST_DATA=${source_dir}/testing/data -export PARQUET_TEST_DATA=${source_dir}/submodules/parquet-testing/data +export PARQUET_TEST_DATA=${source_dir}/cpp/submodules/parquet-testing/data if [ "${INSTALL_PYARROW}" == "ON" ]; then # Install the built wheels diff --git a/ci/scripts/python_wheel_windows_test.bat b/ci/scripts/python_wheel_windows_test.bat index a928c3571d0..87c0bb12520 100755 --- a/ci/scripts/python_wheel_windows_test.bat +++ b/ci/scripts/python_wheel_windows_test.bat @@ -35,7 +35,7 @@ set PYARROW_TEST_TENSORFLOW=ON @REM set PYARROW_TEST_PANDAS=ON set ARROW_TEST_DATA=C:\arrow\testing\data -set PARQUET_TEST_DATA=C:\arrow\submodules\parquet-testing\data +set PARQUET_TEST_DATA=C:\arrow\cpp\submodules\parquet-testing\data @REM Install testing dependencies pip install -r C:\arrow\python\requirements-wheel-test.txt || exit /B 1 From b4f7efe5bdc2218bb595b130b4f65237caecfa76 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Thu, 22 Aug 2024 14:45:00 +0200 Subject: [PATCH 057/157] GH-43787: [C++] Register the new Opaque extension type by default (#43788) This is to resolve #43787 > The Opaque extension type implementation for C++ (plus python bindings) was added in https://github.com/apache/arrow/pull/43458, but it was not registered by default, which we should do for canonical extension types (see https://github.com/apache/arrow/pull/43458#issuecomment-2302551404) Additionally, this adds `bool8` extension type builds with `ARROW_JSON=false` as discussed [here](https://github.com/apache/arrow/commit/525881987d0b9b4f464c3e3593a9a7b4e3c767d0#r145613657) ### Rationale for this change Canonical types should be registered by default if possible (except e.g. if they can't be compiled due to `ARROW_JSON=false`). ### What changes are included in this PR? This adds default registration for `opaque`, changes when `bool8` is built and moves all canonical tests under the same test target. ### Are these changes tested? Changes are tested by previously existing tests. ### Are there any user-facing changes? `opaue` will now be registered by default and `bool8` will be present in case `ARROW_JSON=false` at build time. * GitHub Issue: #43787 Authored-by: Rok Mihevc Signed-off-by: Rok Mihevc --- cpp/src/arrow/CMakeLists.txt | 2 +- cpp/src/arrow/extension/CMakeLists.txt | 18 ++++++----------- cpp/src/arrow/extension/bool8.h | 2 ++ cpp/src/arrow/extension/bool8_test.cc | 1 - cpp/src/arrow/extension/fixed_shape_tensor.h | 2 ++ cpp/src/arrow/extension/opaque.h | 2 ++ cpp/src/arrow/extension/opaque_test.cc | 2 -- cpp/src/arrow/extension_type.cc | 21 ++++++++++++-------- python/pyarrow/tests/test_extension_type.py | 5 ++--- 9 files changed, 28 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index fb7253b6fd6..89f28ee416e 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -374,6 +374,7 @@ set(ARROW_SRCS datum.cc device.cc extension_type.cc + extension/bool8.cc pretty_print.cc record_batch.cc result.cc @@ -906,7 +907,6 @@ endif() if(ARROW_JSON) arrow_add_object_library(ARROW_JSON - extension/bool8.cc extension/fixed_shape_tensor.cc extension/opaque.cc json/options.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index fcd5fa529ab..5cb4bc77af2 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,22 +15,16 @@ # specific language governing permissions and limitations # under the License. -add_arrow_test(test - SOURCES - bool8_test.cc - PREFIX - "arrow-extension-bool8") +set(CANONICAL_EXTENSION_TESTS bool8_test.cc) -add_arrow_test(test - SOURCES - fixed_shape_tensor_test.cc - PREFIX - "arrow-fixed-shape-tensor") +if(ARROW_JSON) + list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc) +endif() add_arrow_test(test SOURCES - opaque_test.cc + ${CANONICAL_EXTENSION_TESTS} PREFIX - "arrow-extension-opaque") + "arrow-canonical-extensions") arrow_install_all_headers("arrow/extension") diff --git a/cpp/src/arrow/extension/bool8.h b/cpp/src/arrow/extension/bool8.h index 02e629b28a8..fbb507639e2 100644 --- a/cpp/src/arrow/extension/bool8.h +++ b/cpp/src/arrow/extension/bool8.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include "arrow/extension_type.h" namespace arrow::extension { diff --git a/cpp/src/arrow/extension/bool8_test.cc b/cpp/src/arrow/extension/bool8_test.cc index eabcfcf62d3..ee77332bc32 100644 --- a/cpp/src/arrow/extension/bool8_test.cc +++ b/cpp/src/arrow/extension/bool8_test.cc @@ -19,7 +19,6 @@ #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" -#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" namespace arrow { diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h b/cpp/src/arrow/extension/fixed_shape_tensor.h index 20ec20a64c2..80a602021c6 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.h +++ b/cpp/src/arrow/extension/fixed_shape_tensor.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include "arrow/extension_type.h" namespace arrow { diff --git a/cpp/src/arrow/extension/opaque.h b/cpp/src/arrow/extension/opaque.h index 9814b391cba..5d3411798f8 100644 --- a/cpp/src/arrow/extension/opaque.h +++ b/cpp/src/arrow/extension/opaque.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include "arrow/extension_type.h" #include "arrow/type.h" diff --git a/cpp/src/arrow/extension/opaque_test.cc b/cpp/src/arrow/extension/opaque_test.cc index 1629cdb3965..16fcba3fa6b 100644 --- a/cpp/src/arrow/extension/opaque_test.cc +++ b/cpp/src/arrow/extension/opaque_test.cc @@ -25,7 +25,6 @@ #include "arrow/ipc/reader.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" -#include "arrow/testing/extension_type.h" #include "arrow/testing/gtest_util.h" #include "arrow/type_fwd.h" #include "arrow/util/checked_cast.h" @@ -169,7 +168,6 @@ TEST(OpaqueType, MetadataRoundTrip) { TEST(OpaqueType, BatchRoundTrip) { auto type = internal::checked_pointer_cast( extension::opaque(binary(), "geometry", "adbc.postgresql")); - ExtensionTypeGuard guard(type); auto storage = ArrayFromJSON(binary(), R"(["foobar", null])"); auto array = ExtensionType::WrapArray(type, storage); diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index 685018f7de7..83c7ebed4f3 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -27,9 +27,10 @@ #include "arrow/array/util.h" #include "arrow/chunked_array.h" #include "arrow/config.h" -#ifdef ARROW_JSON #include "arrow/extension/bool8.h" +#ifdef ARROW_JSON #include "arrow/extension/fixed_shape_tensor.h" +#include "arrow/extension/opaque.h" #endif #include "arrow/status.h" #include "arrow/type.h" @@ -143,17 +144,21 @@ static std::once_flag registry_initialized; namespace internal { static void CreateGlobalRegistry() { + // Register canonical extension types + g_registry = std::make_shared(); + std::vector> ext_types{extension::bool8()}; #ifdef ARROW_JSON - // Register canonical extension types - auto fst_ext_type = - checked_pointer_cast(extension::fixed_shape_tensor(int64(), {})); - ARROW_CHECK_OK(g_registry->RegisterType(fst_ext_type)); - - auto bool8_ext_type = checked_pointer_cast(extension::bool8()); - ARROW_CHECK_OK(g_registry->RegisterType(bool8_ext_type)); + ext_types.push_back(extension::fixed_shape_tensor(int64(), {})); + ext_types.push_back(extension::opaque(null(), "", "")); #endif + + // Register canonical extension types + for (const auto& ext_type : ext_types) { + ARROW_CHECK_OK( + g_registry->RegisterType(checked_pointer_cast(ext_type))); + } } } // namespace internal diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index b04ee85ec99..0d50c467e96 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -1693,9 +1693,8 @@ def test_opaque_type(pickle_module, storage_type, storage): arr = pa.ExtensionArray.from_storage(opaque_type, storage) assert isinstance(arr, opaque_arr_class) - with registered_extension_type(opaque_type): - buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) - batch = ipc_read_batch(buf) + buf = ipc_write_batch(pa.RecordBatch.from_arrays([arr], ["ext"])) + batch = ipc_read_batch(buf) assert batch.column(0).type.extension_name == "arrow.opaque" assert isinstance(batch.column(0), opaque_arr_class) From 3e9384bbf4162ea060e867a753bce464b31e5e1c Mon Sep 17 00:00:00 2001 From: Lysandros Nikolaou Date: Thu, 22 Aug 2024 15:27:40 +0200 Subject: [PATCH 058/157] GH-43519: [Python] Set up wheel building for Python 3.13 (#43539) ### Rationale for this change Like #43519 mentionies, now that the first `rc` is out, it's probably time to add CI coverage for Python 3.13 (and also start building wheels). ### What changes are included in this PR? I'm fairly new to the build/CI processes of the project, but I tried to follow the same template as #37901. I'll follow up afterwards with adding CI coverage for the free-threaded build as well. * GitHub Issue: #43519 Lead-authored-by: Lysandros Nikolaou Co-authored-by: Joris Van den Bossche Signed-off-by: Joris Van den Bossche --- .env | 2 +- ci/docker/python-wheel-manylinux-test.dockerfile | 7 ++++--- ci/docker/python-wheel-manylinux.dockerfile | 2 +- .../python-wheel-windows-test-vs2019.dockerfile | 7 ++++--- ci/docker/python-wheel-windows-vs2019.dockerfile | 7 ++++--- ci/scripts/install_gcs_testbench.sh | 10 +++++++--- ci/scripts/install_python.sh | 14 +++++++++++--- ci/scripts/python_wheel_macos_build.sh | 2 -- dev/release/verify-release-candidate.sh | 6 +++--- dev/tasks/python-wheels/github.linux.yml | 5 +++++ dev/tasks/python-wheels/github.osx.yml | 2 +- dev/tasks/tasks.yml | 3 ++- docker-compose.yml | 9 ++++++--- python/pyproject.toml | 1 + python/requirements-wheel-build.txt | 5 +++++ python/requirements-wheel-test.txt | 7 +++++++ 16 files changed, 62 insertions(+), 27 deletions(-) diff --git a/.env b/.env index 1358aafe824..21f904c3208 100644 --- a/.env +++ b/.env @@ -95,7 +95,7 @@ VCPKG="943c5ef1c8f6b5e6ced092b242c8299caae2ff01" # 2024.04.26 Release # ci/docker/python-wheel-windows-vs2019.dockerfile. # This is a workaround for our CI problem that "archery docker build" doesn't # use pulled built images in dev/tasks/python-wheels/github.windows.yml. -PYTHON_WHEEL_WINDOWS_IMAGE_REVISION=2024-06-18 +PYTHON_WHEEL_WINDOWS_IMAGE_REVISION=2024-08-06 # Use conanio/${CONAN_BASE}:{CONAN_VERSION} for "docker-compose run --rm conan". # See https://github.com/conan-io/conan-docker-tools#readme and diff --git a/ci/docker/python-wheel-manylinux-test.dockerfile b/ci/docker/python-wheel-manylinux-test.dockerfile index cdd0ae3ced7..443ff9c53cb 100644 --- a/ci/docker/python-wheel-manylinux-test.dockerfile +++ b/ci/docker/python-wheel-manylinux-test.dockerfile @@ -16,8 +16,8 @@ # under the License. ARG arch -ARG python -FROM ${arch}/python:${python} +ARG python_image_tag +FROM ${arch}/python:${python_image_tag} # RUN pip install --upgrade pip @@ -27,4 +27,5 @@ COPY python/requirements-wheel-test.txt /arrow/python/ RUN pip install -r /arrow/python/requirements-wheel-test.txt COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ -RUN PYTHON=python /arrow/ci/scripts/install_gcs_testbench.sh default +ARG python +RUN PYTHON_VERSION=${python} /arrow/ci/scripts/install_gcs_testbench.sh default diff --git a/ci/docker/python-wheel-manylinux.dockerfile b/ci/docker/python-wheel-manylinux.dockerfile index cb39667af1e..42f088fd8a2 100644 --- a/ci/docker/python-wheel-manylinux.dockerfile +++ b/ci/docker/python-wheel-manylinux.dockerfile @@ -103,7 +103,7 @@ RUN vcpkg install \ # Configure Python for applications running in the bash shell of this Dockerfile ARG python=3.8 ENV PYTHON_VERSION=${python} -RUN PYTHON_ROOT=$(find /opt/python -name cp${PYTHON_VERSION/./}-*) && \ +RUN PYTHON_ROOT=$(find /opt/python -name cp${PYTHON_VERSION/./}-cp${PYTHON_VERSION/./}) && \ echo "export PATH=$PYTHON_ROOT/bin:\$PATH" >> /etc/profile.d/python.sh SHELL ["/bin/bash", "-i", "-c"] diff --git a/ci/docker/python-wheel-windows-test-vs2019.dockerfile b/ci/docker/python-wheel-windows-test-vs2019.dockerfile index 32bbb55e826..5f488a4c285 100644 --- a/ci/docker/python-wheel-windows-test-vs2019.dockerfile +++ b/ci/docker/python-wheel-windows-test-vs2019.dockerfile @@ -40,10 +40,11 @@ ARG python=3.8 RUN (if "%python%"=="3.8" setx PYTHON_VERSION "3.8.10" && setx PATH "%PATH%;C:\Python38;C:\Python38\Scripts") & \ (if "%python%"=="3.9" setx PYTHON_VERSION "3.9.13" && setx PATH "%PATH%;C:\Python39;C:\Python39\Scripts") & \ (if "%python%"=="3.10" setx PYTHON_VERSION "3.10.11" && setx PATH "%PATH%;C:\Python310;C:\Python310\Scripts") & \ - (if "%python%"=="3.11" setx PYTHON_VERSION "3.11.5" && setx PATH "%PATH%;C:\Python311;C:\Python311\Scripts") & \ - (if "%python%"=="3.12" setx PYTHON_VERSION "3.12.0" && setx PATH "%PATH%;C:\Python312;C:\Python312\Scripts") + (if "%python%"=="3.11" setx PYTHON_VERSION "3.11.9" && setx PATH "%PATH%;C:\Python311;C:\Python311\Scripts") & \ + (if "%python%"=="3.12" setx PYTHON_VERSION "3.12.4" && setx PATH "%PATH%;C:\Python312;C:\Python312\Scripts") & \ + (if "%python%"=="3.13" setx PYTHON_VERSION "3.13.0-rc1" && setx PATH "%PATH%;C:\Python313;C:\Python313\Scripts") # Install archiver to extract xz archives -RUN choco install -r -y --no-progress python --version=%PYTHON_VERSION% & \ +RUN choco install -r -y --pre --no-progress python --version=%PYTHON_VERSION% & \ python -m pip install --no-cache-dir -U pip setuptools & \ choco install --no-progress -r -y archiver diff --git a/ci/docker/python-wheel-windows-vs2019.dockerfile b/ci/docker/python-wheel-windows-vs2019.dockerfile index ff42de939d9..5a17e3e4c52 100644 --- a/ci/docker/python-wheel-windows-vs2019.dockerfile +++ b/ci/docker/python-wheel-windows-vs2019.dockerfile @@ -83,9 +83,10 @@ ARG python=3.8 RUN (if "%python%"=="3.8" setx PYTHON_VERSION "3.8.10" && setx PATH "%PATH%;C:\Python38;C:\Python38\Scripts") & \ (if "%python%"=="3.9" setx PYTHON_VERSION "3.9.13" && setx PATH "%PATH%;C:\Python39;C:\Python39\Scripts") & \ (if "%python%"=="3.10" setx PYTHON_VERSION "3.10.11" && setx PATH "%PATH%;C:\Python310;C:\Python310\Scripts") & \ - (if "%python%"=="3.11" setx PYTHON_VERSION "3.11.5" && setx PATH "%PATH%;C:\Python311;C:\Python311\Scripts") & \ - (if "%python%"=="3.12" setx PYTHON_VERSION "3.12.0" && setx PATH "%PATH%;C:\Python312;C:\Python312\Scripts") -RUN choco install -r -y --no-progress python --version=%PYTHON_VERSION% + (if "%python%"=="3.11" setx PYTHON_VERSION "3.11.9" && setx PATH "%PATH%;C:\Python311;C:\Python311\Scripts") & \ + (if "%python%"=="3.12" setx PYTHON_VERSION "3.12.4" && setx PATH "%PATH%;C:\Python312;C:\Python312\Scripts") & \ + (if "%python%"=="3.13" setx PYTHON_VERSION "3.13.0-rc1" && setx PATH "%PATH%;C:\Python313;C:\Python313\Scripts") +RUN choco install -r -y --pre --no-progress python --version=%PYTHON_VERSION% RUN python -m pip install -U pip setuptools COPY python/requirements-wheel-build.txt arrow/python/ diff --git a/ci/scripts/install_gcs_testbench.sh b/ci/scripts/install_gcs_testbench.sh index 2090290c993..5471b3cc238 100755 --- a/ci/scripts/install_gcs_testbench.sh +++ b/ci/scripts/install_gcs_testbench.sh @@ -41,8 +41,12 @@ version=$1 if [[ "${version}" -eq "default" ]]; then version="v0.39.0" # Latests versions of Testbench require newer setuptools - ${PYTHON:-python3} -m pip install --upgrade setuptools + python3 -m pip install --upgrade setuptools fi -${PYTHON:-python3} -m pip install \ - "https://github.com/googleapis/storage-testbench/archive/${version}.tar.gz" +# This script is run with PYTHON undefined in some places, +# but those only use older pythons. +if [[ -z "${PYTHON_VERSION}" ]] || [[ "${PYTHON_VERSION}" != "3.13" ]]; then + python3 -m pip install \ + "https://github.com/googleapis/storage-testbench/archive/${version}.tar.gz" +fi diff --git a/ci/scripts/install_python.sh b/ci/scripts/install_python.sh index 5f962f02b91..42d0e9ca179 100755 --- a/ci/scripts/install_python.sh +++ b/ci/scripts/install_python.sh @@ -28,8 +28,9 @@ declare -A versions versions=([3.8]=3.8.10 [3.9]=3.9.13 [3.10]=3.10.11 - [3.11]=3.11.5 - [3.12]=3.12.0) + [3.11]=3.11.9 + [3.12]=3.12.4 + [3.13]=3.13.0) if [ "$#" -ne 2 ]; then echo "Usage: $0 " @@ -46,7 +47,14 @@ full_version=${versions[$2]} if [ $platform = "macOS" ]; then echo "Downloading Python installer..." - if [ "$(uname -m)" = "arm64" ] || [ "$version" = "3.10" ] || [ "$version" = "3.11" ] || [ "$version" = "3.12" ]; then + if [ "$version" = "3.13" ]; + then + fname="python-${full_version}rc1-macos11.pkg" + elif [ "$(uname -m)" = "arm64" ] || \ + [ "$version" = "3.10" ] || \ + [ "$version" = "3.11" ] || \ + [ "$version" = "3.12" ]; + then fname="python-${full_version}-macos11.pkg" else fname="python-${full_version}-macosx10.9.pkg" diff --git a/ci/scripts/python_wheel_macos_build.sh b/ci/scripts/python_wheel_macos_build.sh index 3ed9d5d8dd1..d5430f26748 100755 --- a/ci/scripts/python_wheel_macos_build.sh +++ b/ci/scripts/python_wheel_macos_build.sh @@ -48,13 +48,11 @@ fi echo "=== (${PYTHON_VERSION}) Install Python build dependencies ===" export PIP_SITE_PACKAGES=$(python -c 'import site; print(site.getsitepackages()[0])') -export PIP_TARGET_PLATFORM="macosx_${MACOSX_DEPLOYMENT_TARGET//./_}_${arch}" pip install \ --upgrade \ --only-binary=:all: \ --target $PIP_SITE_PACKAGES \ - --platform $PIP_TARGET_PLATFORM \ -r ${source_dir}/python/requirements-wheel-build.txt pip install "delocate>=0.10.3" diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index 6a36109dc2f..07e765a759e 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -1146,7 +1146,7 @@ test_linux_wheels() { local arch="x86_64" fi - local python_versions="${TEST_PYTHON_VERSIONS:-3.8 3.9 3.10 3.11 3.12}" + local python_versions="${TEST_PYTHON_VERSIONS:-3.8 3.9 3.10 3.11 3.12 3.13}" local platform_tags="${TEST_WHEEL_PLATFORM_TAGS:-manylinux_2_17_${arch}.manylinux2014_${arch} manylinux_2_28_${arch}}" for python in ${python_versions}; do @@ -1170,11 +1170,11 @@ test_macos_wheels() { # apple silicon processor if [ "$(uname -m)" = "arm64" ]; then - local python_versions="3.8 3.9 3.10 3.11 3.12" + local python_versions="3.8 3.9 3.10 3.11 3.12 3.13" local platform_tags="macosx_11_0_arm64" local check_flight=OFF else - local python_versions="3.8 3.9 3.10 3.11 3.12" + local python_versions="3.8 3.9 3.10 3.11 3.12 3.13" local platform_tags="macosx_10_15_x86_64" fi diff --git a/dev/tasks/python-wheels/github.linux.yml b/dev/tasks/python-wheels/github.linux.yml index 968c5da2189..2854d4349fb 100644 --- a/dev/tasks/python-wheels/github.linux.yml +++ b/dev/tasks/python-wheels/github.linux.yml @@ -36,6 +36,11 @@ jobs: ARCHERY_USE_LEGACY_DOCKER_COMPOSE: 1 {% endif %} PYTHON: "{{ python_version }}" + {% if python_version == "3.13" %} + PYTHON_IMAGE_TAG: "3.13-rc" + {% else %} + PYTHON_IMAGE_TAG: "{{ python_version }}" + {% endif %} steps: {{ macros.github_checkout_arrow()|indent }} diff --git a/dev/tasks/python-wheels/github.osx.yml b/dev/tasks/python-wheels/github.osx.yml index 8ceb468af89..b26aeba32b7 100644 --- a/dev/tasks/python-wheels/github.osx.yml +++ b/dev/tasks/python-wheels/github.osx.yml @@ -121,7 +121,7 @@ jobs: source test-env/bin/activate pip install --upgrade pip wheel arch -{{ arch }} pip install -r arrow/python/requirements-wheel-test.txt - PYTHON=python arch -{{ arch }} arrow/ci/scripts/install_gcs_testbench.sh default + PYTHON_VERSION={{ python_version }} arch -{{ arch }} arrow/ci/scripts/install_gcs_testbench.sh default arch -{{ arch }} arrow/ci/scripts/python_wheel_unix_test.sh $(pwd)/arrow {{ macros.github_upload_releases("arrow/python/repaired_wheels/*.whl")|indent }} diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index fe02fe9ce68..60114d69308 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -389,7 +389,8 @@ tasks: ("3.9", "cp39", "cp39"), ("3.10", "cp310", "cp310"), ("3.11", "cp311", "cp311"), - ("3.12", "cp312", "cp312")] %} + ("3.12", "cp312", "cp312"), + ("3.13", "cp313", "cp313")] %} {############################## Wheel Linux ##################################} diff --git a/docker-compose.yml b/docker-compose.yml index 14eeeeee6e5..3045cf015bc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1096,9 +1096,10 @@ services: args: arch: ${ARCH} arch_short: ${ARCH_SHORT} - base: quay.io/pypa/manylinux2014_${ARCH_ALIAS}:2024-02-04-ea37246 + base: quay.io/pypa/manylinux2014_${ARCH_ALIAS}:2024-08-03-32dfa47 vcpkg: ${VCPKG} python: ${PYTHON} + python_image_tag: ${PYTHON_IMAGE_TAG} manylinux: 2014 context: . dockerfile: ci/docker/python-wheel-manylinux.dockerfile @@ -1119,9 +1120,10 @@ services: args: arch: ${ARCH} arch_short: ${ARCH_SHORT} - base: quay.io/pypa/manylinux_2_28_${ARCH_ALIAS}:2024-02-04-ea37246 + base: quay.io/pypa/manylinux_2_28_${ARCH_ALIAS}:2024-08-03-32dfa47 vcpkg: ${VCPKG} python: ${PYTHON} + python_image_tag: ${PYTHON_IMAGE_TAG} manylinux: 2_28 context: . dockerfile: ci/docker/python-wheel-manylinux.dockerfile @@ -1135,7 +1137,7 @@ services: command: /arrow/ci/scripts/python_wheel_manylinux_build.sh python-wheel-manylinux-test-imports: - image: ${ARCH}/python:${PYTHON} + image: ${ARCH}/python:${PYTHON_IMAGE_TAG} shm_size: 2G volumes: - .:/arrow:delegated @@ -1151,6 +1153,7 @@ services: args: arch: ${ARCH} python: ${PYTHON} + python_image_tag: ${PYTHON_IMAGE_TAG} context: . dockerfile: ci/docker/python-wheel-manylinux-test.dockerfile cache_from: diff --git a/python/pyproject.toml b/python/pyproject.toml index d863bb3e5f0..8ece65dd467 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -48,6 +48,7 @@ classifiers = [ 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.12', + 'Programming Language :: Python :: 3.13', ] maintainers = [ {name = "Apache Arrow Developers", email = "dev@arrow.apache.org"} diff --git a/python/requirements-wheel-build.txt b/python/requirements-wheel-build.txt index faa078d3d7f..2d448004768 100644 --- a/python/requirements-wheel-build.txt +++ b/python/requirements-wheel-build.txt @@ -1,3 +1,8 @@ +# Remove pre and extra index url once there's NumPy and Cython wheels for 3.13 +# on PyPI +--pre +--extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" + cython>=0.29.31 oldest-supported-numpy>=0.14; python_version<'3.9' numpy>=2.0.0; python_version>='3.9' diff --git a/python/requirements-wheel-test.txt b/python/requirements-wheel-test.txt index c7ff63e3395..98ec2bd4fd4 100644 --- a/python/requirements-wheel-test.txt +++ b/python/requirements-wheel-test.txt @@ -1,3 +1,9 @@ +# Remove pre and extra index url once there's NumPy and Cython wheels for 3.13 +# on PyPI +--pre +--prefer-binary +--extra-index-url "https://pypi.anaconda.org/scientific-python-nightly-wheels/simple" + cffi cython hypothesis @@ -12,5 +18,6 @@ tzdata; sys_platform == 'win32' numpy~=1.21.3; python_version < "3.11" numpy~=1.23.2; python_version == "3.11" numpy~=1.26.0; python_version == "3.12" +numpy~=2.1.0; python_version >= "3.13" pandas From 88d57cf41fde20adf14adca02e02d2cb92c83443 Mon Sep 17 00:00:00 2001 From: Jonathan Keane Date: Thu, 22 Aug 2024 08:45:19 -0500 Subject: [PATCH 059/157] MINOR: [CI][R] Undo #43636 now that the action is approved (#43730) Undo the pinning in #43636 now that INFRA has approved the quarto-dev action Authored-by: Jonathan Keane Signed-off-by: Antoine Pitrou --- .github/workflows/r.yml | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index bf7eb99e7e9..2820d42470b 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -86,19 +86,18 @@ jobs: run: | sudo apt-get install devscripts - # replace the SHA with v2 once INFRA-26031 is resolved - - uses: r-lib/actions/setup-r@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/setup-r@v2 with: use-public-rspm: true install-r: false - - uses: r-lib/actions/setup-r-dependencies@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/setup-r-dependencies@v2 with: extra-packages: any::rcmdcheck needs: check working-directory: src/r - - uses: r-lib/actions/check-r-package@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/check-r-package@v2 with: working-directory: src/r env: @@ -341,11 +340,11 @@ jobs: cd r/windows ls *.zip | xargs -n 1 unzip -uo rm -rf *.zip - - uses: r-lib/actions/setup-r@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/setup-r@v2 with: r-version: ${{ matrix.config.rversion }} Ncpus: 2 - - uses: r-lib/actions/setup-r-dependencies@732fb28088814627972f1ccbacc02561178cf391 + - uses: r-lib/actions/setup-r-dependencies@v2 env: GITHUB_PAT: "${{ github.token }}" with: From 2e33e98f583035cd686455870e9cbf5fb6dc9966 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Thu, 22 Aug 2024 08:26:37 -0800 Subject: [PATCH 060/157] MINOR: [GO] fixup test case name in cast_test.go (#43780) --- go/arrow/compute/cast_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go index 2e748a2fee9..fa08467dd39 100644 --- a/go/arrow/compute/cast_test.go +++ b/go/arrow/compute/cast_test.go @@ -2636,7 +2636,7 @@ func (c *CastSuite) TestStructToDifferentNullabilityStruct() { defer dest3Nullable.Release() checkCast(c.T(), srcNonNull, dest3Nullable, *compute.DefaultCastOptions(true)) }) - c.Run("non-nullable to nullable", func() { + c.Run("nullable to non-nullable", func() { fieldsSrcNullable := []arrow.Field{ {Name: "a", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, {Name: "b", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, From 76e0f6254b75509d83e44fe8997bd14007907c4f Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Thu, 22 Aug 2024 15:37:09 -0400 Subject: [PATCH 061/157] GH-43764: [Go][FlightSQL] Add NewPreparedStatement function (#43781) ### Rationale for this change Allowing creation of the prepared statement object outside of the client allows for logging, proxying, and handing off prepared statements if necessary. ### Are these changes tested? Yes * GitHub Issue: #43764 Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/flight/flightsql/client.go | 9 +++++++++ go/arrow/flight/flightsql/client_test.go | 21 +++++++++++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/go/arrow/flight/flightsql/client.go b/go/arrow/flight/flightsql/client.go index 4a600e5253e..4c9dc501351 100644 --- a/go/arrow/flight/flightsql/client.go +++ b/go/arrow/flight/flightsql/client.go @@ -1102,6 +1102,15 @@ type PreparedStatement struct { closed bool } +// NewPreparedStatement creates a prepared statement object bound to the provided +// client using the given handle. In general, it should be sufficient to use the +// Prepare function a client and this wouldn't be needed. But this can be used +// to propagate a prepared statement from one client to another if needed or if +// proxying requests. +func NewPreparedStatement(client *Client, handle []byte) *PreparedStatement { + return &PreparedStatement{client: client, handle: handle} +} + // Execute executes the prepared statement on the server and returns a FlightInfo // indicating where to retrieve the response. If SetParameters has been called // then the parameter bindings will be sent before execution. diff --git a/go/arrow/flight/flightsql/client_test.go b/go/arrow/flight/flightsql/client_test.go index 7604b554cbc..d060161f94f 100644 --- a/go/arrow/flight/flightsql/client_test.go +++ b/go/arrow/flight/flightsql/client_test.go @@ -378,8 +378,10 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() { createRsp := &mockDoActionClient{} defer createRsp.AssertExpectations(s.T()) createRsp.On("Recv").Return(&pb.Result{Body: data}, nil).Once() - createRsp.On("Recv").Return(&pb.Result{}, io.EOF) - createRsp.On("CloseSend").Return(nil) + createRsp.On("Recv").Return(&pb.Result{}, io.EOF).Once() + createRsp.On("Recv").Return(&pb.Result{Body: data}, nil).Once() + createRsp.On("Recv").Return(&pb.Result{}, io.EOF).Once() + createRsp.On("CloseSend").Return(nil).Twice() closeRsp := &mockDoActionClient{} defer closeRsp.AssertExpectations(s.T()) @@ -387,13 +389,13 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() { closeRsp.On("CloseSend").Return(nil) s.mockClient.On("DoAction", flightsql.CreatePreparedStatementActionType, action.Body, s.callOpts). - Return(createRsp, nil) + Return(createRsp, nil).Twice() s.mockClient.On("DoAction", flightsql.ClosePreparedStatementActionType, closeAct.Body, s.callOpts). Return(closeRsp, nil) infoCmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: []byte(query)} desc := getDesc(infoCmd) - s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil) + s.mockClient.On("GetFlightInfo", desc.Type, desc.Cmd, s.callOpts).Return(&emptyFlightInfo, nil).Twice() prepared, err := s.sqlClient.Prepare(context.TODO(), query, s.callOpts...) s.NoError(err) @@ -404,6 +406,17 @@ func (s *FlightSqlClientSuite) TestPreparedStatementExecute() { info, err := prepared.Execute(context.TODO(), s.callOpts...) s.NoError(err) s.Equal(&emptyFlightInfo, info) + + prepared, err = s.sqlClient.Prepare(context.TODO(), query, s.callOpts...) + s.NoError(err) + + secondPrepare := flightsql.NewPreparedStatement(&s.sqlClient, prepared.Handle()) + s.Equal(string(secondPrepare.Handle()), "query") + defer secondPrepare.Close(context.TODO(), s.callOpts...) + + info, err = secondPrepare.Execute(context.TODO(), s.callOpts...) + s.NoError(err) + s.Equal(&emptyFlightInfo, info) } func (s *FlightSqlClientSuite) TestPreparedStatementExecuteParamBinding() { From d47b305bbce037af18ce65dc968074fe1681b4d4 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Thu, 22 Aug 2024 16:04:59 -0400 Subject: [PATCH 062/157] GH-43624: [Go] Add JSON/UUID extension types, extend arrow -> parquet logical type mapping (#43679) ### Rationale for this change - Missing `JSON` extension type implementation. - Current precedent in C++ (and thereby PyArrow) is that canonical extension types do not require manual registration. - Issues like #43640 and #43624 suggest that we need to expose ways of configuring parquet types written from arrow records, but casting the underlying data presents challenges for a generalized approach. ### What changes are included in this PR? - Move `UUIDType` from `internal` to `arrow/extensions` - Implement `JSON` canonical extension type - Automatically register all canonical extension types at initialization - remove register/unregister from various locations these extension types are used - Add new `CustomParquetType` interface so extension types can specify their target `LogicalType` in Parquet - Refactor parquet `fieldToNode` to split up `PrimitiveNode` type mapping for leaves from `GroupNode` composition - Simplify parquet `LogicalType` to use only value receivers ### Are these changes tested? Yes ### Are there any user-facing changes? - `UUID` and `JSON` extension types are available to end users. - Canonical extension types will automatically be recognized in IPC without registration. - Users with their own extension type implementations may use the `CustomParquetType` interface to control Parquet conversion without needing to fork or upstream the change. * GitHub Issue: #43624 Authored-by: Joel Lubinitsky Signed-off-by: Joel Lubinitsky --- docs/source/status.rst | 6 + go/arrow/array/array_test.go | 4 +- go/arrow/array/diff_test.go | 4 +- go/arrow/array/extension_test.go | 10 - go/arrow/avro/reader_types.go | 4 +- go/arrow/avro/schema.go | 4 +- go/arrow/compute/exec/span_test.go | 6 +- go/arrow/csv/reader_test.go | 4 +- go/arrow/csv/writer_test.go | 6 +- go/arrow/datatype_extension_test.go | 18 +- go/arrow/extensions/bool8_test.go | 3 - go/arrow/extensions/extensions.go | 36 +++ go/arrow/extensions/json.go | 148 ++++++++++ go/arrow/extensions/json_test.go | 268 ++++++++++++++++++ go/arrow/extensions/opaque_test.go | 3 - go/arrow/extensions/uuid.go | 265 +++++++++++++++++ go/arrow/extensions/uuid_test.go | 257 +++++++++++++++++ .../internal/flight_integration/scenario.go | 4 - .../cmd/arrow-json-integration-test/main.go | 4 - go/arrow/ipc/metadata_test.go | 11 +- go/internal/types/extension_types.go | 227 +-------------- go/internal/types/extension_types_test.go | 95 ------- go/parquet/cmd/parquet_reader/main.go | 2 +- go/parquet/metadata/app_version.go | 2 +- go/parquet/pqarrow/encode_arrow_test.go | 82 ++++-- go/parquet/pqarrow/path_builder_test.go | 6 +- go/parquet/pqarrow/schema.go | 228 +++++++-------- go/parquet/pqarrow/schema_test.go | 15 +- go/parquet/schema/converted_types.go | 8 +- go/parquet/schema/logical_types.go | 30 +- go/parquet/schema/logical_types_test.go | 40 +-- go/parquet/schema/schema_element_test.go | 4 +- 32 files changed, 1221 insertions(+), 583 deletions(-) create mode 100644 go/arrow/extensions/extensions.go create mode 100644 go/arrow/extensions/json.go create mode 100644 go/arrow/extensions/json_test.go create mode 100644 go/arrow/extensions/uuid.go create mode 100644 go/arrow/extensions/uuid_test.go delete mode 100644 go/internal/types/extension_types_test.go diff --git a/docs/source/status.rst b/docs/source/status.rst index c232aa280be..5e2c2cc19c8 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -119,6 +119,12 @@ Data Types +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | Variable shape tensor | | | | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| JSON | | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| UUID | | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ +| 8-bit Boolean | ✓ | | ✓ | | | | | | ++-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ Notes: diff --git a/go/arrow/array/array_test.go b/go/arrow/array/array_test.go index 4d83766b4fa..4f0627c6000 100644 --- a/go/arrow/array/array_test.go +++ b/go/arrow/array/array_test.go @@ -21,9 +21,9 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/internal/testing/tools" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/stretchr/testify/assert" ) @@ -122,7 +122,7 @@ func TestMakeFromData(t *testing.T) { {name: "dictionary", d: &arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint64, ValueType: &testDataType{arrow.TIMESTAMP}}, dict: array.NewData(&testDataType{arrow.TIMESTAMP}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */)}, {name: "extension", d: &testDataType{arrow.EXTENSION}, expPanic: true, expError: "arrow/array: DataType for ExtensionArray must implement arrow.ExtensionType"}, - {name: "extension", d: types.NewUUIDType()}, + {name: "extension", d: extensions.NewUUIDType()}, {name: "run end encoded", d: arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int64, arrow.PrimitiveTypes.Int64), child: []arrow.ArrayData{ array.NewData(&testDataType{arrow.INT64}, 0 /* length */, make([]*memory.Buffer, 2 /*null bitmap, values*/), nil /* childData */, 0 /* nulls */, 0 /* offset */), diff --git a/go/arrow/array/diff_test.go b/go/arrow/array/diff_test.go index 65d212be118..9c9ce6a53ae 100644 --- a/go/arrow/array/diff_test.go +++ b/go/arrow/array/diff_test.go @@ -25,9 +25,9 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/json" - "github.com/apache/arrow/go/v18/internal/types" ) type diffTestCase struct { @@ -861,7 +861,7 @@ func TestEdits_UnifiedDiff(t *testing.T) { }, { name: "extensions", - dataType: types.NewUUIDType(), + dataType: extensions.NewUUIDType(), baseJSON: `["00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000001"]`, targetJSON: `["00000000-0000-0000-0000-000000000001", "00000000-0000-0000-0000-000000000002"]`, want: `@@ -0, +0 @@ diff --git a/go/arrow/array/extension_test.go b/go/arrow/array/extension_test.go index 71ea9f105af..26245cf015d 100644 --- a/go/arrow/array/extension_test.go +++ b/go/arrow/array/extension_test.go @@ -30,16 +30,6 @@ type ExtensionTypeTestSuite struct { suite.Suite } -func (e *ExtensionTypeTestSuite) SetupTest() { - e.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (e *ExtensionTypeTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - e.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (e *ExtensionTypeTestSuite) TestParametricEquals() { p1Type := types.NewParametric1Type(6) p2Type := types.NewParametric1Type(6) diff --git a/go/arrow/avro/reader_types.go b/go/arrow/avro/reader_types.go index e07cd380d51..dab2b33dce6 100644 --- a/go/arrow/avro/reader_types.go +++ b/go/arrow/avro/reader_types.go @@ -27,8 +27,8 @@ import ( "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" ) type dataLoader struct { @@ -436,7 +436,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } return nil } - case *types.UUIDBuilder: + case *extensions.UUIDBuilder: f.appendFunc = func(data interface{}) error { switch dt := data.(type) { case nil: diff --git a/go/arrow/avro/schema.go b/go/arrow/avro/schema.go index 007dad06c19..a6de3718d3c 100644 --- a/go/arrow/avro/schema.go +++ b/go/arrow/avro/schema.go @@ -24,7 +24,7 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/decimal128" - "github.com/apache/arrow/go/v18/internal/types" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/internal/utils" avro "github.com/hamba/avro/v2" ) @@ -349,7 +349,7 @@ func avroLogicalToArrowField(n *schemaNode) { // The uuid logical type represents a random generated universally unique identifier (UUID). // A uuid logical type annotates an Avro string. The string has to conform with RFC-4122 case "uuid": - dt = types.NewUUIDType() + dt = extensions.NewUUIDType() // The date logical type represents a date within the calendar, with no reference to a particular // time zone or time of day. diff --git a/go/arrow/compute/exec/span_test.go b/go/arrow/compute/exec/span_test.go index f5beb45ee14..018fbb7d623 100644 --- a/go/arrow/compute/exec/span_test.go +++ b/go/arrow/compute/exec/span_test.go @@ -29,6 +29,7 @@ import ( "github.com/apache/arrow/go/v18/arrow/compute/exec" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/endian" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/arrow/scalar" "github.com/apache/arrow/go/v18/internal/types" @@ -192,9 +193,6 @@ func TestArraySpan_NumBuffers(t *testing.T) { Children []exec.ArraySpan } - arrow.RegisterExtensionType(types.NewUUIDType()) - defer arrow.UnregisterExtensionType("uuid") - tests := []struct { name string fields fields @@ -207,7 +205,7 @@ func TestArraySpan_NumBuffers(t *testing.T) { {"large binary", fields{Type: arrow.BinaryTypes.LargeBinary}, 3}, {"string", fields{Type: arrow.BinaryTypes.String}, 3}, {"large string", fields{Type: arrow.BinaryTypes.LargeString}, 3}, - {"extension", fields{Type: types.NewUUIDType()}, 2}, + {"extension", fields{Type: extensions.NewUUIDType()}, 2}, {"int32", fields{Type: arrow.PrimitiveTypes.Int32}, 2}, } for _, tt := range tests { diff --git a/go/arrow/csv/reader_test.go b/go/arrow/csv/reader_test.go index b0775b9b11a..6a89d497042 100644 --- a/go/arrow/csv/reader_test.go +++ b/go/arrow/csv/reader_test.go @@ -30,8 +30,8 @@ import ( "github.com/apache/arrow/go/v18/arrow/csv" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -356,7 +356,7 @@ func testCSVReader(t *testing.T, filepath string, withHeader bool, stringsCanBeN {Name: "binary", Type: arrow.BinaryTypes.Binary}, {Name: "large_binary", Type: arrow.BinaryTypes.LargeBinary}, {Name: "fixed_size_binary", Type: &arrow.FixedSizeBinaryType{ByteWidth: 3}}, - {Name: "uuid", Type: types.NewUUIDType()}, + {Name: "uuid", Type: extensions.NewUUIDType()}, {Name: "date32", Type: arrow.PrimitiveTypes.Date32}, {Name: "date64", Type: arrow.PrimitiveTypes.Date64}, }, diff --git a/go/arrow/csv/writer_test.go b/go/arrow/csv/writer_test.go index be9ab961c3e..2ae01a6d490 100644 --- a/go/arrow/csv/writer_test.go +++ b/go/arrow/csv/writer_test.go @@ -31,9 +31,9 @@ import ( "github.com/apache/arrow/go/v18/arrow/csv" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/float16" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/google/uuid" ) @@ -230,7 +230,7 @@ func testCSVWriter(t *testing.T, data [][]string, writeHeader bool, fmtr func(bo {Name: "binary", Type: arrow.BinaryTypes.Binary}, {Name: "large_binary", Type: arrow.BinaryTypes.LargeBinary}, {Name: "fixed_size_binary", Type: &arrow.FixedSizeBinaryType{ByteWidth: 3}}, - {Name: "uuid", Type: types.NewUUIDType()}, + {Name: "uuid", Type: extensions.NewUUIDType()}, {Name: "null", Type: arrow.Null}, }, nil, @@ -285,7 +285,7 @@ func testCSVWriter(t *testing.T, data [][]string, writeHeader bool, fmtr func(bo b.Field(22).(*array.BinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) b.Field(23).(*array.BinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) b.Field(24).(*array.FixedSizeBinaryBuilder).AppendValues([][]byte{{0, 1, 2}, {3, 4, 5}, {}}, nil) - b.Field(25).(*types.UUIDBuilder).AppendValues([]uuid.UUID{uuid.MustParse("00000000-0000-0000-0000-000000000001"), uuid.MustParse("00000000-0000-0000-0000-000000000002"), uuid.MustParse("00000000-0000-0000-0000-000000000003")}, nil) + b.Field(25).(*extensions.UUIDBuilder).AppendValues([]uuid.UUID{uuid.MustParse("00000000-0000-0000-0000-000000000001"), uuid.MustParse("00000000-0000-0000-0000-000000000002"), uuid.MustParse("00000000-0000-0000-0000-000000000003")}, nil) b.Field(26).(*array.NullBuilder).AppendEmptyValues(3) for _, field := range b.Fields() { diff --git a/go/arrow/datatype_extension_test.go b/go/arrow/datatype_extension_test.go index c3e595f523e..7244d377bd2 100644 --- a/go/arrow/datatype_extension_test.go +++ b/go/arrow/datatype_extension_test.go @@ -21,7 +21,7 @@ import ( "testing" "github.com/apache/arrow/go/v18/arrow" - "github.com/apache/arrow/go/v18/internal/types" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" ) @@ -50,24 +50,14 @@ type ExtensionTypeTestSuite struct { suite.Suite } -func (e *ExtensionTypeTestSuite) SetupTest() { - e.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (e *ExtensionTypeTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - e.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (e *ExtensionTypeTestSuite) TestExtensionType() { e.Nil(arrow.GetExtensionType("uuid-unknown")) - e.NotNil(arrow.GetExtensionType("uuid")) + e.NotNil(arrow.GetExtensionType("arrow.uuid")) - e.Error(arrow.RegisterExtensionType(types.NewUUIDType())) + e.Error(arrow.RegisterExtensionType(extensions.NewUUIDType())) e.Error(arrow.UnregisterExtensionType("uuid-unknown")) - typ := types.NewUUIDType() + typ := extensions.NewUUIDType() e.Implements((*arrow.ExtensionType)(nil), typ) e.Equal(arrow.EXTENSION, typ.ID()) e.Equal("extension", typ.Name()) diff --git a/go/arrow/extensions/bool8_test.go b/go/arrow/extensions/bool8_test.go index 9f7365d1555..ff129e24bc8 100644 --- a/go/arrow/extensions/bool8_test.go +++ b/go/arrow/extensions/bool8_test.go @@ -178,9 +178,6 @@ func TestReinterpretStorageEqualToValues(t *testing.T) { func TestBool8TypeBatchIPCRoundTrip(t *testing.T) { typ := extensions.NewBool8Type() - arrow.RegisterExtensionType(typ) - defer arrow.UnregisterExtensionType(typ.ExtensionName()) - storage, _, err := array.FromJSON(memory.DefaultAllocator, arrow.PrimitiveTypes.Int8, strings.NewReader(`[-1, 0, 1, 2, null]`)) require.NoError(t, err) diff --git a/go/arrow/extensions/extensions.go b/go/arrow/extensions/extensions.go new file mode 100644 index 00000000000..03c6923e95f --- /dev/null +++ b/go/arrow/extensions/extensions.go @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "github.com/apache/arrow/go/v18/arrow" +) + +var canonicalExtensionTypes = []arrow.ExtensionType{ + &Bool8Type{}, + &UUIDType{}, + &OpaqueType{}, + &JSONType{}, +} + +func init() { + for _, extType := range canonicalExtensionTypes { + if err := arrow.RegisterExtensionType(extType); err != nil { + panic(err) + } + } +} diff --git a/go/arrow/extensions/json.go b/go/arrow/extensions/json.go new file mode 100644 index 00000000000..12c49f9c0a7 --- /dev/null +++ b/go/arrow/extensions/json.go @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "fmt" + "reflect" + "slices" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/apache/arrow/go/v18/parquet/schema" +) + +var jsonSupportedStorageTypes = []arrow.DataType{ + arrow.BinaryTypes.String, + arrow.BinaryTypes.LargeString, + arrow.BinaryTypes.StringView, +} + +// JSONType represents a UTF-8 encoded JSON string as specified in RFC8259. +type JSONType struct { + arrow.ExtensionBase +} + +// ParquetLogicalType implements pqarrow.ExtensionCustomParquetType. +func (b *JSONType) ParquetLogicalType() schema.LogicalType { + return schema.JSONLogicalType{} +} + +// NewJSONType creates a new JSONType with the specified storage type. +// storageType must be one of String, LargeString, StringView. +func NewJSONType(storageType arrow.DataType) (*JSONType, error) { + if !slices.Contains(jsonSupportedStorageTypes, storageType) { + return nil, fmt.Errorf("unsupported storage type for JSON extension type: %s", storageType) + } + return &JSONType{ExtensionBase: arrow.ExtensionBase{Storage: storageType}}, nil +} + +func (b *JSONType) ArrayType() reflect.Type { return reflect.TypeOf(JSONArray{}) } + +func (b *JSONType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + if !(data == "" || data == "{}") { + return nil, fmt.Errorf("serialized metadata for JSON extension type must be '' or '{}', found: %s", data) + } + return NewJSONType(storageType) +} + +func (b *JSONType) ExtensionEquals(other arrow.ExtensionType) bool { + return b.ExtensionName() == other.ExtensionName() && arrow.TypeEqual(b.Storage, other.StorageType()) +} + +func (b *JSONType) ExtensionName() string { return "arrow.json" } + +func (b *JSONType) Serialize() string { return "" } + +func (b *JSONType) String() string { + return fmt.Sprintf("extension<%s[storage_type=%s]>", b.ExtensionName(), b.Storage) +} + +// JSONArray is logically an array of UTF-8 encoded JSON strings. +// Its values are unmarshaled to native Go values. +type JSONArray struct { + array.ExtensionArrayBase +} + +func (a *JSONArray) String() string { + b, err := a.MarshalJSON() + if err != nil { + panic(fmt.Sprintf("failed marshal JSONArray: %s", err)) + } + + return string(b) +} + +func (a *JSONArray) Value(i int) any { + val := a.ValueBytes(i) + + var res any + if err := json.Unmarshal(val, &res); err != nil { + panic(err) + } + + return res +} + +func (a *JSONArray) ValueStr(i int) string { + return string(a.ValueBytes(i)) +} + +func (a *JSONArray) ValueBytes(i int) []byte { + // convert to json.RawMessage, set to nil if elem isNull. + val := a.ValueJSON(i) + + // simply returns wrapped bytes, or null if val is nil. + b, err := val.MarshalJSON() + if err != nil { + panic(err) + } + + return b +} + +// ValueJSON wraps the underlying string value as a json.RawMessage, +// or returns nil if the array value is null. +func (a *JSONArray) ValueJSON(i int) json.RawMessage { + var val json.RawMessage + if a.IsValid(i) { + val = json.RawMessage(a.Storage().(array.StringLike).Value(i)) + } + return val +} + +// MarshalJSON implements json.Marshaler. +// Marshaling json.RawMessage is a no-op, except that nil values will +// be marshaled as a JSON null. +func (a *JSONArray) MarshalJSON() ([]byte, error) { + values := make([]json.RawMessage, a.Len()) + for i := 0; i < a.Len(); i++ { + values[i] = a.ValueJSON(i) + } + return json.Marshal(values) +} + +// GetOneForMarshal implements arrow.Array. +func (a *JSONArray) GetOneForMarshal(i int) interface{} { + return a.ValueJSON(i) +} + +var ( + _ arrow.ExtensionType = (*JSONType)(nil) + _ array.ExtensionArray = (*JSONArray)(nil) +) diff --git a/go/arrow/extensions/json_test.go b/go/arrow/extensions/json_test.go new file mode 100644 index 00000000000..21acc58f939 --- /dev/null +++ b/go/arrow/extensions/json_test.go @@ -0,0 +1,268 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "strings" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/ipc" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestJSONTypeBasics(t *testing.T) { + typ, err := extensions.NewJSONType(arrow.BinaryTypes.String) + require.NoError(t, err) + + typLarge, err := extensions.NewJSONType(arrow.BinaryTypes.LargeString) + require.NoError(t, err) + + typView, err := extensions.NewJSONType(arrow.BinaryTypes.StringView) + require.NoError(t, err) + + assert.Equal(t, "arrow.json", typ.ExtensionName()) + assert.Equal(t, "arrow.json", typLarge.ExtensionName()) + assert.Equal(t, "arrow.json", typView.ExtensionName()) + + assert.True(t, typ.ExtensionEquals(typ)) + assert.True(t, typLarge.ExtensionEquals(typLarge)) + assert.True(t, typView.ExtensionEquals(typView)) + + assert.False(t, arrow.TypeEqual(arrow.BinaryTypes.String, typ)) + assert.False(t, arrow.TypeEqual(typ, typLarge)) + assert.False(t, arrow.TypeEqual(typ, typView)) + assert.False(t, arrow.TypeEqual(typLarge, typView)) + + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.String, typ.StorageType())) + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.LargeString, typLarge.StorageType())) + assert.True(t, arrow.TypeEqual(arrow.BinaryTypes.StringView, typView.StorageType())) + + assert.Equal(t, "extension", typ.String()) + assert.Equal(t, "extension", typLarge.String()) + assert.Equal(t, "extension", typView.String()) +} + +var jsonTestCases = []struct { + Name string + StorageType arrow.DataType + StorageBuilder func(mem memory.Allocator) array.Builder +}{ + { + Name: "string", + StorageType: arrow.BinaryTypes.String, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewStringBuilder(mem) }, + }, + { + Name: "large_string", + StorageType: arrow.BinaryTypes.LargeString, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewLargeStringBuilder(mem) }, + }, + { + Name: "string_view", + StorageType: arrow.BinaryTypes.StringView, + StorageBuilder: func(mem memory.Allocator) array.Builder { return array.NewStringViewBuilder(mem) }, + }, +} + +func TestJSONTypeCreateFromArray(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + require.Equal(t, "foobar", jsonArr.Value(0)) + require.Equal(t, nil, jsonArr.Value(1)) + require.Equal(t, map[string]any{"foo": "bar"}, jsonArr.Value(2)) + require.Equal(t, float64(42), jsonArr.Value(3)) + require.Equal(t, true, jsonArr.Value(4)) + require.Equal(t, []any{float64(1), true, "3", nil, map[string]any{"five": float64(5)}}, jsonArr.Value(5)) + }) + } +} + +func TestJSONTypeBatchIPCRoundTrip(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "field", Type: typ, Nullable: true}}, nil), + []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.Record + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s, got: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s, got: %s", + batch, written) + }) + } +} + +func TestMarshallJSONArray(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + b, err := jsonArr.MarshalJSON() + require.NoError(t, err) + + expectedJSON := `["foobar",null,{"foo":"bar"},42,true,[1,true,"3",null,{"five":5}]]` + require.Equal(t, expectedJSON, string(b)) + require.Equal(t, expectedJSON, jsonArr.String()) + }) + } +} + +func TestJSONRecordToJSON(t *testing.T) { + for _, tc := range jsonTestCases { + t.Run(tc.Name, func(t *testing.T) { + typ, err := extensions.NewJSONType(tc.StorageType) + require.NoError(t, err) + + bldr := tc.StorageBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.AppendValueFromString(`"foobar"`) + bldr.AppendNull() + bldr.AppendValueFromString(`{"foo": "bar"}`) + bldr.AppendValueFromString(`42`) + bldr.AppendValueFromString(`true`) + bldr.AppendValueFromString(`[1, true, "3", null, {"five": 5}]`) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 6, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + jsonArr, ok := arr.(*extensions.JSONArray) + require.True(t, ok) + + rec := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "json", Type: typ, Nullable: true}}, nil), []arrow.Array{jsonArr}, 6) + defer rec.Release() + + buf := bytes.NewBuffer([]byte("\n")) // expected output has leading newline for clearer formatting + require.NoError(t, array.RecordToJSON(rec, buf)) + + expectedJSON := ` + {"json":"foobar"} + {"json":null} + {"json":{"foo":"bar"}} + {"json":42} + {"json":true} + {"json":[1,true,"3",null,{"five":5}]} + ` + + expectedJSONLines := strings.Split(expectedJSON, "\n") + actualJSONLines := strings.Split(buf.String(), "\n") + + require.Equal(t, len(expectedJSONLines), len(actualJSONLines)) + for i := range expectedJSONLines { + if strings.TrimSpace(expectedJSONLines[i]) != "" { + require.JSONEq(t, expectedJSONLines[i], actualJSONLines[i]) + } + } + }) + } +} diff --git a/go/arrow/extensions/opaque_test.go b/go/arrow/extensions/opaque_test.go index b6686e97bc0..a0fc8962ce5 100644 --- a/go/arrow/extensions/opaque_test.go +++ b/go/arrow/extensions/opaque_test.go @@ -161,9 +161,6 @@ func TestOpaqueTypeMetadataRoundTrip(t *testing.T) { func TestOpaqueTypeBatchRoundTrip(t *testing.T) { typ := extensions.NewOpaqueType(arrow.BinaryTypes.String, "geometry", "adbc.postgresql") - arrow.RegisterExtensionType(typ) - defer arrow.UnregisterExtensionType(typ.ExtensionName()) - storage, _, err := array.FromJSON(memory.DefaultAllocator, arrow.BinaryTypes.String, strings.NewReader(`["foobar", null]`)) require.NoError(t, err) diff --git a/go/arrow/extensions/uuid.go b/go/arrow/extensions/uuid.go new file mode 100644 index 00000000000..422b9ea1188 --- /dev/null +++ b/go/arrow/extensions/uuid.go @@ -0,0 +1,265 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions + +import ( + "bytes" + "fmt" + "reflect" + "strings" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/apache/arrow/go/v18/parquet/schema" + "github.com/google/uuid" +) + +type UUIDBuilder struct { + *array.ExtensionBuilder +} + +// NewUUIDBuilder creates a new UUIDBuilder, exposing a convenient and efficient interface +// for writing uuid.UUID (or [16]byte) values to the underlying FixedSizeBinary storage array. +func NewUUIDBuilder(mem memory.Allocator) *UUIDBuilder { + return &UUIDBuilder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewUUIDType())} +} + +func (b *UUIDBuilder) Append(v uuid.UUID) { + b.AppendBytes(v) +} + +func (b *UUIDBuilder) AppendBytes(v [16]byte) { + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).Append(v[:]) +} + +func (b *UUIDBuilder) UnsafeAppend(v uuid.UUID) { + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).UnsafeAppend(v[:]) +} + +func (b *UUIDBuilder) AppendValueFromString(s string) error { + if s == array.NullValueStr { + b.AppendNull() + return nil + } + + uid, err := uuid.Parse(s) + if err != nil { + return err + } + + b.Append(uid) + return nil +} + +func (b *UUIDBuilder) AppendValues(v []uuid.UUID, valid []bool) { + if len(v) != len(valid) && len(valid) != 0 { + panic("len(v) != len(valid) && len(valid) != 0") + } + + data := make([][]byte, len(v)) + for i := range v { + if len(valid) > 0 && !valid[i] { + continue + } + data[i] = v[i][:] + } + b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).AppendValues(data, valid) +} + +func (b *UUIDBuilder) UnmarshalOne(dec *json.Decoder) error { + t, err := dec.Token() + if err != nil { + return err + } + + var val uuid.UUID + switch v := t.(type) { + case string: + val, err = uuid.Parse(v) + if err != nil { + return err + } + case []byte: + val, err = uuid.ParseBytes(v) + if err != nil { + return err + } + case nil: + b.AppendNull() + return nil + default: + return &json.UnmarshalTypeError{ + Value: fmt.Sprint(t), + Type: reflect.TypeOf([]byte{}), + Offset: dec.InputOffset(), + Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16), + } + } + + b.Append(val) + return nil +} + +func (b *UUIDBuilder) Unmarshal(dec *json.Decoder) error { + for dec.More() { + if err := b.UnmarshalOne(dec); err != nil { + return err + } + } + return nil +} + +func (b *UUIDBuilder) UnmarshalJSON(data []byte) error { + dec := json.NewDecoder(bytes.NewReader(data)) + t, err := dec.Token() + if err != nil { + return err + } + + if delim, ok := t.(json.Delim); !ok || delim != '[' { + return fmt.Errorf("uuid builder must unpack from json array, found %s", delim) + } + + return b.Unmarshal(dec) +} + +// UUIDArray is a simple array which is a FixedSizeBinary(16) +type UUIDArray struct { + array.ExtensionArrayBase +} + +func (a *UUIDArray) String() string { + arr := a.Storage().(*array.FixedSizeBinary) + o := new(strings.Builder) + o.WriteString("[") + for i := 0; i < arr.Len(); i++ { + if i > 0 { + o.WriteString(" ") + } + switch { + case a.IsNull(i): + o.WriteString(array.NullValueStr) + default: + fmt.Fprintf(o, "%q", a.Value(i)) + } + } + o.WriteString("]") + return o.String() +} + +func (a *UUIDArray) Value(i int) uuid.UUID { + if a.IsNull(i) { + return uuid.Nil + } + return uuid.Must(uuid.FromBytes(a.Storage().(*array.FixedSizeBinary).Value(i))) +} + +func (a *UUIDArray) Values() []uuid.UUID { + values := make([]uuid.UUID, a.Len()) + for i := range values { + values[i] = a.Value(i) + } + return values +} + +func (a *UUIDArray) ValueStr(i int) string { + switch { + case a.IsNull(i): + return array.NullValueStr + default: + return a.Value(i).String() + } +} + +func (a *UUIDArray) MarshalJSON() ([]byte, error) { + vals := make([]any, a.Len()) + for i := range vals { + vals[i] = a.GetOneForMarshal(i) + } + return json.Marshal(vals) +} + +func (a *UUIDArray) GetOneForMarshal(i int) interface{} { + if a.IsValid(i) { + return a.Value(i) + } + return nil +} + +// UUIDType is a simple extension type that represents a FixedSizeBinary(16) +// to be used for representing UUIDs +type UUIDType struct { + arrow.ExtensionBase +} + +// ParquetLogicalType implements pqarrow.ExtensionCustomParquetType. +func (e *UUIDType) ParquetLogicalType() schema.LogicalType { + return schema.UUIDLogicalType{} +} + +// NewUUIDType is a convenience function to create an instance of UUIDType +// with the correct storage type +func NewUUIDType() *UUIDType { + return &UUIDType{ExtensionBase: arrow.ExtensionBase{Storage: &arrow.FixedSizeBinaryType{ByteWidth: 16}}} +} + +// ArrayType returns TypeOf(UUIDArray{}) for constructing UUID arrays +func (*UUIDType) ArrayType() reflect.Type { + return reflect.TypeOf(UUIDArray{}) +} + +func (*UUIDType) ExtensionName() string { + return "arrow.uuid" +} + +func (e *UUIDType) String() string { + return fmt.Sprintf("extension<%s>", e.ExtensionName()) +} + +func (e *UUIDType) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`{"name":"%s","metadata":%s}`, e.ExtensionName(), e.Serialize())), nil +} + +func (*UUIDType) Serialize() string { + return "" +} + +// Deserialize expects storageType to be FixedSizeBinaryType{ByteWidth: 16} +func (*UUIDType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { + if !arrow.TypeEqual(storageType, &arrow.FixedSizeBinaryType{ByteWidth: 16}) { + return nil, fmt.Errorf("invalid storage type for UUIDType: %s", storageType.Name()) + } + return NewUUIDType(), nil +} + +// ExtensionEquals returns true if both extensions have the same name +func (e *UUIDType) ExtensionEquals(other arrow.ExtensionType) bool { + return e.ExtensionName() == other.ExtensionName() +} + +func (*UUIDType) NewBuilder(mem memory.Allocator) array.Builder { + return NewUUIDBuilder(mem) +} + +var ( + _ arrow.ExtensionType = (*UUIDType)(nil) + _ array.CustomExtensionBuilder = (*UUIDType)(nil) + _ array.ExtensionArray = (*UUIDArray)(nil) + _ array.Builder = (*UUIDBuilder)(nil) +) diff --git a/go/arrow/extensions/uuid_test.go b/go/arrow/extensions/uuid_test.go new file mode 100644 index 00000000000..80c621db2a0 --- /dev/null +++ b/go/arrow/extensions/uuid_test.go @@ -0,0 +1,257 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package extensions_test + +import ( + "bytes" + "fmt" + "strings" + "testing" + + "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" + "github.com/apache/arrow/go/v18/arrow/ipc" + "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/apache/arrow/go/v18/internal/json" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testUUID = uuid.New() + +func TestUUIDExtensionBuilder(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + builder := extensions.NewUUIDBuilder(mem) + builder.Append(testUUID) + builder.AppendNull() + builder.AppendBytes(testUUID) + arr := builder.NewArray() + defer arr.Release() + arrStr := arr.String() + assert.Equal(t, fmt.Sprintf(`["%[1]s" (null) "%[1]s"]`, testUUID), arrStr) + jsonStr, err := json.Marshal(arr) + assert.NoError(t, err) + + arr1, _, err := array.FromJSON(mem, extensions.NewUUIDType(), bytes.NewReader(jsonStr)) + defer arr1.Release() + assert.NoError(t, err) + assert.True(t, array.Equal(arr1, arr)) + + require.NoError(t, json.Unmarshal(jsonStr, builder)) + arr2 := builder.NewArray() + defer arr2.Release() + assert.True(t, array.Equal(arr2, arr)) +} + +func TestUUIDExtensionRecordBuilder(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "uuid", Type: extensions.NewUUIDType()}, + }, nil) + builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) + builder.Field(0).(*extensions.UUIDBuilder).Append(testUUID) + builder.Field(0).(*extensions.UUIDBuilder).AppendNull() + builder.Field(0).(*extensions.UUIDBuilder).Append(testUUID) + record := builder.NewRecord() + b, err := record.MarshalJSON() + require.NoError(t, err) + require.Equal(t, "[{\"uuid\":\""+testUUID.String()+"\"}\n,{\"uuid\":null}\n,{\"uuid\":\""+testUUID.String()+"\"}\n]", string(b)) + record1, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(b)) + require.NoError(t, err) + require.Equal(t, record, record1) +} + +func TestUUIDStringRoundTrip(t *testing.T) { + // 1. create array + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + b := extensions.NewUUIDBuilder(mem) + b.Append(uuid.Nil) + b.AppendNull() + b.Append(uuid.NameSpaceURL) + b.AppendNull() + b.Append(testUUID) + + arr := b.NewArray() + defer arr.Release() + + // 2. create array via AppendValueFromString + b1 := extensions.NewUUIDBuilder(mem) + defer b1.Release() + + for i := 0; i < arr.Len(); i++ { + assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) + } + + arr1 := b1.NewArray() + defer arr1.Release() + + assert.True(t, array.Equal(arr, arr1)) +} + +func TestUUIDTypeBasics(t *testing.T) { + typ := extensions.NewUUIDType() + + assert.Equal(t, "arrow.uuid", typ.ExtensionName()) + assert.True(t, typ.ExtensionEquals(typ)) + + assert.True(t, arrow.TypeEqual(typ, typ)) + assert.False(t, arrow.TypeEqual(&arrow.FixedSizeBinaryType{ByteWidth: 16}, typ)) + assert.True(t, arrow.TypeEqual(&arrow.FixedSizeBinaryType{ByteWidth: 16}, typ.StorageType())) + + assert.Equal(t, "extension", typ.String()) +} + +func TestUUIDTypeCreateFromArray(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := array.NewFixedSizeBinaryBuilder(memory.DefaultAllocator, &arrow.FixedSizeBinaryType{ByteWidth: 16}) + defer bldr.Release() + + bldr.Append(testUUID[:]) + bldr.AppendNull() + bldr.Append(testUUID[:]) + + storage := bldr.NewArray() + defer storage.Release() + + arr := array.NewExtensionArrayWithStorage(typ, storage) + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + require.Equal(t, testUUID, uuidArr.Value(0)) + require.Equal(t, uuid.Nil, uuidArr.Value(1)) + require.Equal(t, testUUID, uuidArr.Value(2)) +} + +func TestUUIDTypeBatchIPCRoundTrip(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.Append(testUUID) + bldr.AppendNull() + bldr.AppendBytes(testUUID) + + arr := bldr.NewArray() + defer arr.Release() + + batch := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "field", Type: typ, Nullable: true}}, nil), + []arrow.Array{arr}, -1) + defer batch.Release() + + var written arrow.Record + { + var buf bytes.Buffer + wr := ipc.NewWriter(&buf, ipc.WithSchema(batch.Schema())) + require.NoError(t, wr.Write(batch)) + require.NoError(t, wr.Close()) + + rdr, err := ipc.NewReader(&buf) + require.NoError(t, err) + written, err = rdr.Read() + require.NoError(t, err) + written.Retain() + defer written.Release() + rdr.Release() + } + + assert.Truef(t, batch.Schema().Equal(written.Schema()), "expected: %s, got: %s", + batch.Schema(), written.Schema()) + + assert.Truef(t, array.RecordEqual(batch, written), "expected: %s, got: %s", + batch, written) +} + +func TestMarshallUUIDArray(t *testing.T) { + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + bldr.Append(testUUID) + bldr.AppendNull() + bldr.AppendBytes(testUUID) + + arr := bldr.NewArray() + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + b, err := uuidArr.MarshalJSON() + require.NoError(t, err) + + expectedJSON := fmt.Sprintf(`["%[1]s",null,"%[1]s"]`, testUUID) + require.Equal(t, expectedJSON, string(b)) +} + +func TestUUIDRecordToJSON(t *testing.T) { + typ := extensions.NewUUIDType() + + bldr := extensions.NewUUIDBuilder(memory.DefaultAllocator) + defer bldr.Release() + + uuid1 := uuid.MustParse("8c607ed4-07b2-4b9c-b5eb-c0387357f9ae") + + bldr.Append(uuid1) + bldr.AppendNull() + + // c5f2cbd9-7094-491a-b267-167bb62efe02 + bldr.AppendBytes([16]byte{197, 242, 203, 217, 112, 148, 73, 26, 178, 103, 22, 123, 182, 46, 254, 2}) + + arr := bldr.NewArray() + defer arr.Release() + + assert.Equal(t, 3, arr.Len()) + assert.Equal(t, 1, arr.NullN()) + + uuidArr, ok := arr.(*extensions.UUIDArray) + require.True(t, ok) + + rec := array.NewRecord(arrow.NewSchema([]arrow.Field{{Name: "uuid", Type: typ, Nullable: true}}, nil), []arrow.Array{uuidArr}, 3) + defer rec.Release() + + buf := bytes.NewBuffer([]byte("\n")) // expected output has leading newline for clearer formatting + require.NoError(t, array.RecordToJSON(rec, buf)) + + expectedJSON := ` + {"uuid":"8c607ed4-07b2-4b9c-b5eb-c0387357f9ae"} + {"uuid":null} + {"uuid":"c5f2cbd9-7094-491a-b267-167bb62efe02"} + ` + + expectedJSONLines := strings.Split(expectedJSON, "\n") + actualJSONLines := strings.Split(buf.String(), "\n") + + require.Equal(t, len(expectedJSONLines), len(actualJSONLines)) + for i := range expectedJSONLines { + if strings.TrimSpace(expectedJSONLines[i]) != "" { + require.JSONEq(t, expectedJSONLines[i], actualJSONLines[i]) + } + } +} diff --git a/go/arrow/internal/flight_integration/scenario.go b/go/arrow/internal/flight_integration/scenario.go index 1528bb05d9d..b9535002a0a 100644 --- a/go/arrow/internal/flight_integration/scenario.go +++ b/go/arrow/internal/flight_integration/scenario.go @@ -40,7 +40,6 @@ import ( "github.com/apache/arrow/go/v18/arrow/internal/arrjson" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "golang.org/x/xerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -161,9 +160,6 @@ func (s *defaultIntegrationTester) RunClient(addr string, opts ...grpc.DialOptio ctx := context.Background() - arrow.RegisterExtensionType(types.NewUUIDType()) - defer arrow.UnregisterExtensionType("uuid") - descr := &flight.FlightDescriptor{ Type: flight.DescriptorPATH, Path: []string{s.path}, diff --git a/go/arrow/ipc/cmd/arrow-json-integration-test/main.go b/go/arrow/ipc/cmd/arrow-json-integration-test/main.go index b3e1dcac141..c47a091268b 100644 --- a/go/arrow/ipc/cmd/arrow-json-integration-test/main.go +++ b/go/arrow/ipc/cmd/arrow-json-integration-test/main.go @@ -22,12 +22,10 @@ import ( "log" "os" - "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/arrio" "github.com/apache/arrow/go/v18/arrow/internal/arrjson" "github.com/apache/arrow/go/v18/arrow/ipc" - "github.com/apache/arrow/go/v18/internal/types" ) func main() { @@ -50,8 +48,6 @@ func main() { } func runCommand(jsonName, arrowName, mode string, verbose bool) error { - arrow.RegisterExtensionType(types.NewUUIDType()) - if jsonName == "" { return fmt.Errorf("must specify json file name") } diff --git a/go/arrow/ipc/metadata_test.go b/go/arrow/ipc/metadata_test.go index 33bc63c2a00..14b8da2cf7c 100644 --- a/go/arrow/ipc/metadata_test.go +++ b/go/arrow/ipc/metadata_test.go @@ -23,10 +23,10 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/internal/dictutils" "github.com/apache/arrow/go/v18/arrow/internal/flatbuf" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" flatbuffers "github.com/google/flatbuffers/go" "github.com/stretchr/testify/assert" ) @@ -169,7 +169,7 @@ func TestRWFooter(t *testing.T) { } func exampleUUID(mem memory.Allocator) arrow.Array { - extType := types.NewUUIDType() + extType := extensions.NewUUIDType() bldr := array.NewExtensionBuilder(mem, extType) defer bldr.Release() @@ -184,9 +184,6 @@ func TestUnrecognizedExtensionType(t *testing.T) { pool := memory.NewCheckedAllocator(memory.NewGoAllocator()) defer pool.AssertSize(t, 0) - // register the uuid type - assert.NoError(t, arrow.RegisterExtensionType(types.NewUUIDType())) - extArr := exampleUUID(pool) defer extArr.Release() @@ -205,7 +202,9 @@ func TestUnrecognizedExtensionType(t *testing.T) { // unregister the uuid type before we read back the buffer so it is // unrecognized when reading back the record batch. - assert.NoError(t, arrow.UnregisterExtensionType("uuid")) + assert.NoError(t, arrow.UnregisterExtensionType("arrow.uuid")) + // re-register once the test is complete + defer arrow.RegisterExtensionType(extensions.NewUUIDType()) rdr, err := NewReader(&buf, WithAllocator(pool)) defer rdr.Release() diff --git a/go/internal/types/extension_types.go b/go/internal/types/extension_types.go index 85c64d86bff..33ada2d488f 100644 --- a/go/internal/types/extension_types.go +++ b/go/internal/types/extension_types.go @@ -18,238 +18,15 @@ package types import ( - "bytes" "encoding/binary" "fmt" "reflect" - "strings" "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" - "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/json" - "github.com/google/uuid" "golang.org/x/xerrors" ) -var UUID = NewUUIDType() - -type UUIDBuilder struct { - *array.ExtensionBuilder -} - -func NewUUIDBuilder(mem memory.Allocator) *UUIDBuilder { - return &UUIDBuilder{ExtensionBuilder: array.NewExtensionBuilder(mem, NewUUIDType())} -} - -func (b *UUIDBuilder) Append(v uuid.UUID) { - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).Append(v[:]) -} - -func (b *UUIDBuilder) UnsafeAppend(v uuid.UUID) { - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).UnsafeAppend(v[:]) -} - -func (b *UUIDBuilder) AppendValueFromString(s string) error { - if s == array.NullValueStr { - b.AppendNull() - return nil - } - - uid, err := uuid.Parse(s) - if err != nil { - return err - } - - b.Append(uid) - return nil -} - -func (b *UUIDBuilder) AppendValues(v []uuid.UUID, valid []bool) { - if len(v) != len(valid) && len(valid) != 0 { - panic("len(v) != len(valid) && len(valid) != 0") - } - - data := make([][]byte, len(v)) - for i := range v { - if len(valid) > 0 && !valid[i] { - continue - } - data[i] = v[i][:] - } - b.ExtensionBuilder.Builder.(*array.FixedSizeBinaryBuilder).AppendValues(data, valid) -} - -func (b *UUIDBuilder) UnmarshalOne(dec *json.Decoder) error { - t, err := dec.Token() - if err != nil { - return err - } - - var val uuid.UUID - switch v := t.(type) { - case string: - val, err = uuid.Parse(v) - if err != nil { - return err - } - case []byte: - val, err = uuid.ParseBytes(v) - if err != nil { - return err - } - case nil: - b.AppendNull() - return nil - default: - return &json.UnmarshalTypeError{ - Value: fmt.Sprint(t), - Type: reflect.TypeOf([]byte{}), - Offset: dec.InputOffset(), - Struct: fmt.Sprintf("FixedSizeBinary[%d]", 16), - } - } - - b.Append(val) - return nil -} - -func (b *UUIDBuilder) Unmarshal(dec *json.Decoder) error { - for dec.More() { - if err := b.UnmarshalOne(dec); err != nil { - return err - } - } - return nil -} - -func (b *UUIDBuilder) UnmarshalJSON(data []byte) error { - dec := json.NewDecoder(bytes.NewReader(data)) - t, err := dec.Token() - if err != nil { - return err - } - - if delim, ok := t.(json.Delim); !ok || delim != '[' { - return fmt.Errorf("uuid builder must unpack from json array, found %s", delim) - } - - return b.Unmarshal(dec) -} - -// UUIDArray is a simple array which is a FixedSizeBinary(16) -type UUIDArray struct { - array.ExtensionArrayBase -} - -func (a *UUIDArray) String() string { - arr := a.Storage().(*array.FixedSizeBinary) - o := new(strings.Builder) - o.WriteString("[") - for i := 0; i < arr.Len(); i++ { - if i > 0 { - o.WriteString(" ") - } - switch { - case a.IsNull(i): - o.WriteString(array.NullValueStr) - default: - fmt.Fprintf(o, "%q", a.Value(i)) - } - } - o.WriteString("]") - return o.String() -} - -func (a *UUIDArray) Value(i int) uuid.UUID { - if a.IsNull(i) { - return uuid.Nil - } - return uuid.Must(uuid.FromBytes(a.Storage().(*array.FixedSizeBinary).Value(i))) -} - -func (a *UUIDArray) ValueStr(i int) string { - switch { - case a.IsNull(i): - return array.NullValueStr - default: - return a.Value(i).String() - } -} - -func (a *UUIDArray) MarshalJSON() ([]byte, error) { - arr := a.Storage().(*array.FixedSizeBinary) - values := make([]interface{}, a.Len()) - for i := 0; i < a.Len(); i++ { - if a.IsValid(i) { - values[i] = uuid.Must(uuid.FromBytes(arr.Value(i))).String() - } - } - return json.Marshal(values) -} - -func (a *UUIDArray) GetOneForMarshal(i int) interface{} { - if a.IsNull(i) { - return nil - } - return a.Value(i) -} - -// UUIDType is a simple extension type that represents a FixedSizeBinary(16) -// to be used for representing UUIDs -type UUIDType struct { - arrow.ExtensionBase -} - -// NewUUIDType is a convenience function to create an instance of UUIDType -// with the correct storage type -func NewUUIDType() *UUIDType { - return &UUIDType{ExtensionBase: arrow.ExtensionBase{Storage: &arrow.FixedSizeBinaryType{ByteWidth: 16}}} -} - -// ArrayType returns TypeOf(UUIDArray{}) for constructing UUID arrays -func (*UUIDType) ArrayType() reflect.Type { - return reflect.TypeOf(UUIDArray{}) -} - -func (*UUIDType) ExtensionName() string { - return "uuid" -} - -func (e *UUIDType) String() string { - return fmt.Sprintf("extension_type", e.Storage) -} - -func (e *UUIDType) MarshalJSON() ([]byte, error) { - return []byte(fmt.Sprintf(`{"name":"%s","metadata":%s}`, e.ExtensionName(), e.Serialize())), nil -} - -// Serialize returns "uuid-serialized" for testing proper metadata passing -func (*UUIDType) Serialize() string { - return "uuid-serialized" -} - -// Deserialize expects storageType to be FixedSizeBinaryType{ByteWidth: 16} and the data to be -// "uuid-serialized" in order to correctly create a UUIDType for testing deserialize. -func (*UUIDType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { - if data != "uuid-serialized" { - return nil, fmt.Errorf("type identifier did not match: '%s'", data) - } - if !arrow.TypeEqual(storageType, &arrow.FixedSizeBinaryType{ByteWidth: 16}) { - return nil, fmt.Errorf("invalid storage type for UUIDType: %s", storageType.Name()) - } - return NewUUIDType(), nil -} - -// ExtensionEquals returns true if both extensions have the same name -func (e *UUIDType) ExtensionEquals(other arrow.ExtensionType) bool { - return e.ExtensionName() == other.ExtensionName() -} - -func (*UUIDType) NewBuilder(mem memory.Allocator) array.Builder { - return NewUUIDBuilder(mem) -} - // Parametric1Array is a simple int32 array for use with the Parametric1Type // in testing a parameterized user-defined extension type. type Parametric1Array struct { @@ -518,14 +295,14 @@ func (SmallintType) ArrayType() reflect.Type { return reflect.TypeOf(SmallintArr func (SmallintType) ExtensionName() string { return "smallint" } -func (SmallintType) Serialize() string { return "smallint" } +func (SmallintType) Serialize() string { return "smallint-serialized" } func (s *SmallintType) ExtensionEquals(other arrow.ExtensionType) bool { return s.Name() == other.Name() } func (SmallintType) Deserialize(storageType arrow.DataType, data string) (arrow.ExtensionType, error) { - if data != "smallint" { + if data != "smallint-serialized" { return nil, fmt.Errorf("type identifier did not match: '%s'", data) } if !arrow.TypeEqual(storageType, arrow.PrimitiveTypes.Int16) { diff --git a/go/internal/types/extension_types_test.go b/go/internal/types/extension_types_test.go deleted file mode 100644 index 65f6353d01b..00000000000 --- a/go/internal/types/extension_types_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package types_test - -import ( - "bytes" - "testing" - - "github.com/apache/arrow/go/v18/arrow" - "github.com/apache/arrow/go/v18/arrow/array" - "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/json" - "github.com/apache/arrow/go/v18/internal/types" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -var testUUID = uuid.New() - -func TestUUIDExtensionBuilder(t *testing.T) { - mem := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer mem.AssertSize(t, 0) - builder := types.NewUUIDBuilder(mem) - builder.Append(testUUID) - arr := builder.NewArray() - defer arr.Release() - arrStr := arr.String() - assert.Equal(t, "[\""+testUUID.String()+"\"]", arrStr) - jsonStr, err := json.Marshal(arr) - assert.NoError(t, err) - - arr1, _, err := array.FromJSON(mem, types.NewUUIDType(), bytes.NewReader(jsonStr)) - defer arr1.Release() - assert.NoError(t, err) - assert.Equal(t, arr, arr1) -} - -func TestUUIDExtensionRecordBuilder(t *testing.T) { - schema := arrow.NewSchema([]arrow.Field{ - {Name: "uuid", Type: types.NewUUIDType()}, - }, nil) - builder := array.NewRecordBuilder(memory.DefaultAllocator, schema) - builder.Field(0).(*types.UUIDBuilder).Append(testUUID) - record := builder.NewRecord() - b, err := record.MarshalJSON() - require.NoError(t, err) - require.Equal(t, "[{\"uuid\":\""+testUUID.String()+"\"}\n]", string(b)) - record1, _, err := array.RecordFromJSON(memory.DefaultAllocator, schema, bytes.NewReader(b)) - require.NoError(t, err) - require.Equal(t, record, record1) -} - -func TestUUIDStringRoundTrip(t *testing.T) { - // 1. create array - mem := memory.NewCheckedAllocator(memory.DefaultAllocator) - defer mem.AssertSize(t, 0) - - b := types.NewUUIDBuilder(mem) - b.Append(uuid.Nil) - b.AppendNull() - b.Append(uuid.NameSpaceURL) - b.AppendNull() - b.Append(testUUID) - - arr := b.NewArray() - defer arr.Release() - - // 2. create array via AppendValueFromString - b1 := types.NewUUIDBuilder(mem) - defer b1.Release() - - for i := 0; i < arr.Len(); i++ { - assert.NoError(t, b1.AppendValueFromString(arr.ValueStr(i))) - } - - arr1 := b1.NewArray() - defer arr1.Release() - - assert.True(t, array.Equal(arr, arr1)) -} diff --git a/go/parquet/cmd/parquet_reader/main.go b/go/parquet/cmd/parquet_reader/main.go index 6e04f4254f9..4e480aeb866 100644 --- a/go/parquet/cmd/parquet_reader/main.go +++ b/go/parquet/cmd/parquet_reader/main.go @@ -154,7 +154,7 @@ func main() { if descr.ConvertedType() != schema.ConvertedTypes.None { fmt.Printf("/%s", descr.ConvertedType()) if descr.ConvertedType() == schema.ConvertedTypes.Decimal { - dec := descr.LogicalType().(*schema.DecimalLogicalType) + dec := descr.LogicalType().(schema.DecimalLogicalType) fmt.Printf("(%d,%d)", dec.Precision(), dec.Scale()) } } diff --git a/go/parquet/metadata/app_version.go b/go/parquet/metadata/app_version.go index 887ed79343a..345e9d440a1 100644 --- a/go/parquet/metadata/app_version.go +++ b/go/parquet/metadata/app_version.go @@ -164,7 +164,7 @@ func (v AppVersion) HasCorrectStatistics(coltype parquet.Type, logicalType schem // parquet-cpp-arrow version 4.0.0 fixed Decimal comparisons for creating min/max stats // parquet-cpp also becomes parquet-cpp-arrow as of version 4.0.0 if v.App == "parquet-cpp" || (v.App == "parquet-cpp-arrow" && v.LessThan(parquet1655FixedVersion)) { - if _, ok := logicalType.(*schema.DecimalLogicalType); ok && coltype == parquet.Types.FixedLenByteArray { + if _, ok := logicalType.(schema.DecimalLogicalType); ok && coltype == parquet.Types.FixedLenByteArray { return false } } diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go index 16282173a68..a238a78133e 100644 --- a/go/parquet/pqarrow/encode_arrow_test.go +++ b/go/parquet/pqarrow/encode_arrow_test.go @@ -30,6 +30,7 @@ import ( "github.com/apache/arrow/go/v18/arrow/bitutil" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/decimal256" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/internal/types" @@ -715,16 +716,6 @@ type ParquetIOTestSuite struct { suite.Suite } -func (ps *ParquetIOTestSuite) SetupTest() { - ps.NoError(arrow.RegisterExtensionType(types.NewUUIDType())) -} - -func (ps *ParquetIOTestSuite) TearDownTest() { - if arrow.GetExtensionType("uuid") != nil { - ps.NoError(arrow.UnregisterExtensionType("uuid")) - } -} - func (ps *ParquetIOTestSuite) makeSimpleSchema(typ arrow.DataType, rep parquet.Repetition) *schema.GroupNode { byteWidth := int32(-1) @@ -2053,7 +2044,7 @@ func (ps *ParquetIOTestSuite) TestArrowExtensionTypeRoundTrip() { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(ps.T(), 0) - builder := types.NewUUIDBuilder(mem) + builder := extensions.NewUUIDBuilder(mem) builder.Append(uuid.New()) arr := builder.NewArray() defer arr.Release() @@ -2076,22 +2067,23 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { { // Prepare `written` table with the extension type registered. - extType := types.NewUUIDType() + extType := types.NewSmallintType() bldr := array.NewExtensionBuilder(mem, extType) defer bldr.Release() - bldr.Builder.(*array.FixedSizeBinaryBuilder).AppendValues( - [][]byte{nil, []byte("abcdefghijklmno0"), []byte("abcdefghijklmno1"), []byte("abcdefghijklmno2")}, + bldr.Builder.(*array.Int16Builder).AppendValues( + []int16{0, 0, 1, 2}, []bool{false, true, true, true}) arr := bldr.NewArray() defer arr.Release() - if arrow.GetExtensionType("uuid") != nil { - ps.NoError(arrow.UnregisterExtensionType("uuid")) + if arrow.GetExtensionType("smallint") != nil { + ps.NoError(arrow.UnregisterExtensionType("smallint")) + defer arrow.RegisterExtensionType(extType) } - fld := arrow.Field{Name: "uuid", Type: arr.DataType(), Nullable: true} + fld := arrow.Field{Name: "smallint", Type: arr.DataType(), Nullable: true} cnk := arrow.NewChunked(arr.DataType(), []arrow.Array{arr}) defer arr.Release() // NewChunked written = array.NewTable(arrow.NewSchema([]arrow.Field{fld}, nil), []arrow.Column{*arrow.NewColumn(fld, cnk)}, -1) @@ -2101,16 +2093,16 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { { // Prepare `expected` table with the extension type unregistered in the underlying type. - bldr := array.NewFixedSizeBinaryBuilder(mem, &arrow.FixedSizeBinaryType{ByteWidth: 16}) + bldr := array.NewInt16Builder(mem) defer bldr.Release() bldr.AppendValues( - [][]byte{nil, []byte("abcdefghijklmno0"), []byte("abcdefghijklmno1"), []byte("abcdefghijklmno2")}, + []int16{0, 0, 1, 2}, []bool{false, true, true, true}) arr := bldr.NewArray() defer arr.Release() - fld := arrow.Field{Name: "uuid", Type: arr.DataType(), Nullable: true} + fld := arrow.Field{Name: "smallint", Type: arr.DataType(), Nullable: true} cnk := arrow.NewChunked(arr.DataType(), []arrow.Array{arr}) defer arr.Release() // NewChunked expected = array.NewTable(arrow.NewSchema([]arrow.Field{fld}, nil), []arrow.Column{*arrow.NewColumn(fld, cnk)}, -1) @@ -2147,13 +2139,55 @@ func (ps *ParquetIOTestSuite) TestArrowUnknownExtensionTypeRoundTrip() { ps.Truef(array.Equal(exc, tbc), "expected: %T %s\ngot: %T %s", exc, exc, tbc, tbc) expectedMd := arrow.MetadataFrom(map[string]string{ - ipc.ExtensionTypeKeyName: "uuid", - ipc.ExtensionMetadataKeyName: "uuid-serialized", + ipc.ExtensionTypeKeyName: "smallint", + ipc.ExtensionMetadataKeyName: "smallint-serialized", "PARQUET:field_id": "-1", }) ps.Truef(expectedMd.Equal(tbl.Column(0).Field().Metadata), "expected: %v\ngot: %v", expectedMd, tbl.Column(0).Field().Metadata) } +func (ps *ParquetIOTestSuite) TestArrowExtensionTypeLogicalType() { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(ps.T(), 0) + + jsonType, err := extensions.NewJSONType(arrow.BinaryTypes.String) + ps.NoError(err) + + sch := arrow.NewSchema([]arrow.Field{ + {Name: "uuid", Type: extensions.NewUUIDType()}, + {Name: "json", Type: jsonType}, + }, + nil, + ) + bldr := array.NewRecordBuilder(mem, sch) + defer bldr.Release() + + bldr.Field(0).(*extensions.UUIDBuilder).Append(uuid.New()) + bldr.Field(1).(*array.ExtensionBuilder).AppendValueFromString(`{"hello": ["world", 2, true], "world": null}`) + rec := bldr.NewRecord() + defer rec.Release() + + var buf bytes.Buffer + wr, err := pqarrow.NewFileWriter( + sch, + &buf, + parquet.NewWriterProperties(), + pqarrow.DefaultWriterProps(), + ) + ps.Require().NoError(err) + + ps.Require().NoError(wr.Write(rec)) + ps.Require().NoError(wr.Close()) + + rdr, err := file.NewParquetReader(bytes.NewReader(buf.Bytes())) + ps.Require().NoError(err) + defer rdr.Close() + + pqSchema := rdr.MetaData().Schema + ps.True(pqSchema.Column(0).LogicalType().Equals(schema.UUIDLogicalType{})) + ps.True(pqSchema.Column(1).LogicalType().Equals(schema.JSONLogicalType{})) +} + func TestWriteTableMemoryAllocation(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) sc := arrow.NewSchema([]arrow.Field{ @@ -2163,7 +2197,7 @@ func TestWriteTableMemoryAllocation(t *testing.T) { arrow.Field{Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, arrow.Field{Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true})}, {Name: "arr_i64", Type: arrow.ListOf(arrow.PrimitiveTypes.Int64)}, - {Name: "uuid", Type: types.NewUUIDType(), Nullable: true}, + {Name: "uuid", Type: extensions.NewUUIDType(), Nullable: true}, }, nil) bld := array.NewRecordBuilder(mem, sc) @@ -2176,7 +2210,7 @@ func TestWriteTableMemoryAllocation(t *testing.T) { abld := bld.Field(3).(*array.ListBuilder) abld.Append(true) abld.ValueBuilder().(*array.Int64Builder).Append(2) - bld.Field(4).(*types.UUIDBuilder).Append(uuid.MustParse("00000000-0000-0000-0000-000000000001")) + bld.Field(4).(*extensions.UUIDBuilder).Append(uuid.MustParse("00000000-0000-0000-0000-000000000001")) rec := bld.NewRecord() bld.Release() diff --git a/go/parquet/pqarrow/path_builder_test.go b/go/parquet/pqarrow/path_builder_test.go index 9bbae426b8a..364f836d0bb 100644 --- a/go/parquet/pqarrow/path_builder_test.go +++ b/go/parquet/pqarrow/path_builder_test.go @@ -22,8 +22,8 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/array" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -364,12 +364,12 @@ func TestNestedExtensionListsWithSomeNulls(t *testing.T) { mem := memory.NewCheckedAllocator(memory.DefaultAllocator) defer mem.AssertSize(t, 0) - listType := arrow.ListOf(types.NewUUIDType()) + listType := arrow.ListOf(extensions.NewUUIDType()) bldr := array.NewListBuilder(mem, listType) defer bldr.Release() nestedBldr := bldr.ValueBuilder().(*array.ListBuilder) - vb := nestedBldr.ValueBuilder().(*types.UUIDBuilder) + vb := nestedBldr.ValueBuilder().(*extensions.UUIDBuilder) uuid1 := uuid.New() uuid3 := uuid.New() diff --git a/go/parquet/pqarrow/schema.go b/go/parquet/pqarrow/schema.go index ce5cc6f9050..4882077671f 100644 --- a/go/parquet/pqarrow/schema.go +++ b/go/parquet/pqarrow/schema.go @@ -25,7 +25,6 @@ import ( "github.com/apache/arrow/go/v18/arrow" "github.com/apache/arrow/go/v18/arrow/decimal128" "github.com/apache/arrow/go/v18/arrow/flight" - "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/parquet" "github.com/apache/arrow/go/v18/parquet/file" @@ -120,6 +119,15 @@ func (sm *SchemaManifest) GetFieldIndices(indices []int) ([]int, error) { return ret, nil } +// ExtensionCustomParquetType is an interface that Arrow ExtensionTypes may implement +// to specify the target LogicalType to use when converting to Parquet. +// +// The PrimitiveType is not configurable, and is determined by a fixed mapping from +// the extension's StorageType to a Parquet type (see getParquetType in pqarrow source). +type ExtensionCustomParquetType interface { + ParquetLogicalType() schema.LogicalType +} + func isDictionaryReadSupported(dt arrow.DataType) bool { return arrow.IsBinaryLike(dt.ID()) } @@ -250,104 +258,14 @@ func structToNode(typ *arrow.StructType, name string, nullable bool, props *parq } func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (schema.Node, error) { - var ( - logicalType schema.LogicalType = schema.NoLogicalType{} - typ parquet.Type - repType = repFromNullable(field.Nullable) - length = -1 - precision = -1 - scale = -1 - err error - ) + repType := repFromNullable(field.Nullable) + // Handle complex types i.e. GroupNodes switch field.Type.ID() { case arrow.NULL: - typ = parquet.Types.Int32 - logicalType = &schema.NullLogicalType{} if repType != parquet.Repetitions.Optional { return nil, xerrors.New("nulltype arrow field must be nullable") } - case arrow.BOOL: - typ = parquet.Types.Boolean - case arrow.UINT8: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(8, false) - case arrow.INT8: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(8, true) - case arrow.UINT16: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(16, false) - case arrow.INT16: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(16, true) - case arrow.UINT32: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(32, false) - case arrow.INT32: - typ = parquet.Types.Int32 - logicalType = schema.NewIntLogicalType(32, true) - case arrow.UINT64: - typ = parquet.Types.Int64 - logicalType = schema.NewIntLogicalType(64, false) - case arrow.INT64: - typ = parquet.Types.Int64 - logicalType = schema.NewIntLogicalType(64, true) - case arrow.FLOAT32: - typ = parquet.Types.Float - case arrow.FLOAT64: - typ = parquet.Types.Double - case arrow.STRING, arrow.LARGE_STRING: - logicalType = schema.StringLogicalType{} - fallthrough - case arrow.BINARY, arrow.LARGE_BINARY: - typ = parquet.Types.ByteArray - case arrow.FIXED_SIZE_BINARY: - typ = parquet.Types.FixedLenByteArray - length = field.Type.(*arrow.FixedSizeBinaryType).ByteWidth - case arrow.DECIMAL, arrow.DECIMAL256: - dectype := field.Type.(arrow.DecimalType) - precision = int(dectype.GetPrecision()) - scale = int(dectype.GetScale()) - - if props.StoreDecimalAsInteger() && 1 <= precision && precision <= 18 { - if precision <= 9 { - typ = parquet.Types.Int32 - } else { - typ = parquet.Types.Int64 - } - } else { - typ = parquet.Types.FixedLenByteArray - length = int(DecimalSize(int32(precision))) - } - - logicalType = schema.NewDecimalLogicalType(int32(precision), int32(scale)) - case arrow.DATE32: - typ = parquet.Types.Int32 - logicalType = schema.DateLogicalType{} - case arrow.DATE64: - typ = parquet.Types.Int32 - logicalType = schema.DateLogicalType{} - case arrow.TIMESTAMP: - typ, logicalType, err = getTimestampMeta(field.Type.(*arrow.TimestampType), props, arrprops) - if err != nil { - return nil, err - } - case arrow.TIME32: - typ = parquet.Types.Int32 - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMillis) - case arrow.TIME64: - typ = parquet.Types.Int64 - timeType := field.Type.(*arrow.Time64Type) - if timeType.Unit == arrow.Nanosecond { - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitNanos) - } else { - logicalType = schema.NewTimeLogicalType(true, schema.TimeUnitMicros) - } - case arrow.FLOAT16: - typ = parquet.Types.FixedLenByteArray - length = arrow.Float16SizeBytes - logicalType = schema.Float16LogicalType{} case arrow.STRUCT: return structToNode(field.Type.(*arrow.StructType), field.Name, field.Nullable, props, arrprops) case arrow.FIXED_SIZE_LIST, arrow.LIST: @@ -369,16 +287,6 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties dictType := field.Type.(*arrow.DictionaryType) return fieldToNode(name, arrow.Field{Name: name, Type: dictType.ValueType, Nullable: field.Nullable, Metadata: field.Metadata}, props, arrprops) - case arrow.EXTENSION: - return fieldToNode(name, arrow.Field{ - Name: name, - Type: field.Type.(arrow.ExtensionType).StorageType(), - Nullable: field.Nullable, - Metadata: arrow.MetadataFrom(map[string]string{ - ipc.ExtensionTypeKeyName: field.Type.(arrow.ExtensionType).ExtensionName(), - ipc.ExtensionMetadataKeyName: field.Type.(arrow.ExtensionType).Serialize(), - }), - }, props, arrprops) case arrow.MAP: mapType := field.Type.(*arrow.MapType) keyNode, err := fieldToNode("key", mapType.KeyField(), props, arrprops) @@ -402,8 +310,12 @@ func fieldToNode(name string, field arrow.Field, props *parquet.WriterProperties }, -1) } return schema.MapOf(field.Name, keyNode, valueNode, repFromNullable(field.Nullable), -1) - default: - return nil, fmt.Errorf("%w: support for %s", arrow.ErrNotImplemented, field.Type.ID()) + } + + // Not a GroupNode + typ, logicalType, length, err := getParquetType(field.Type, props, arrprops) + if err != nil { + return nil, err } return schema.NewPrimitiveNodeLogical(name, repType, logicalType, typ, length, fieldIDFromMeta(field.Metadata)) @@ -472,7 +384,7 @@ func (s schemaTree) RecordLeaf(leaf *SchemaField) { s.manifest.ColIndexToField[leaf.ColIndex] = leaf } -func arrowInt(log *schema.IntLogicalType) (arrow.DataType, error) { +func arrowInt(log schema.IntLogicalType) (arrow.DataType, error) { switch log.BitWidth() { case 8: if log.IsSigned() { @@ -499,7 +411,7 @@ func arrowInt(log *schema.IntLogicalType) (arrow.DataType, error) { } } -func arrowTime32(logical *schema.TimeLogicalType) (arrow.DataType, error) { +func arrowTime32(logical schema.TimeLogicalType) (arrow.DataType, error) { if logical.TimeUnit() == schema.TimeUnitMillis { return arrow.FixedWidthTypes.Time32ms, nil } @@ -507,7 +419,7 @@ func arrowTime32(logical *schema.TimeLogicalType) (arrow.DataType, error) { return nil, xerrors.New(logical.String() + " cannot annotate a time32") } -func arrowTime64(logical *schema.TimeLogicalType) (arrow.DataType, error) { +func arrowTime64(logical schema.TimeLogicalType) (arrow.DataType, error) { switch logical.TimeUnit() { case schema.TimeUnitMicros: return arrow.FixedWidthTypes.Time64us, nil @@ -518,7 +430,7 @@ func arrowTime64(logical *schema.TimeLogicalType) (arrow.DataType, error) { } } -func arrowTimestamp(logical *schema.TimestampLogicalType) (arrow.DataType, error) { +func arrowTimestamp(logical schema.TimestampLogicalType) (arrow.DataType, error) { tz := "" // ConvertedTypes are adjusted to UTC per backward compatibility guidelines @@ -539,7 +451,7 @@ func arrowTimestamp(logical *schema.TimestampLogicalType) (arrow.DataType, error } } -func arrowDecimal(logical *schema.DecimalLogicalType) arrow.DataType { +func arrowDecimal(logical schema.DecimalLogicalType) arrow.DataType { if logical.Precision() <= decimal128.MaxPrecision { return &arrow.Decimal128Type{Precision: logical.Precision(), Scale: logical.Scale()} } @@ -550,11 +462,11 @@ func arrowFromInt32(logical schema.LogicalType) (arrow.DataType, error) { switch logtype := logical.(type) { case schema.NoLogicalType: return arrow.PrimitiveTypes.Int32, nil - case *schema.TimeLogicalType: + case schema.TimeLogicalType: return arrowTime32(logtype) - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil - case *schema.IntLogicalType: + case schema.IntLogicalType: return arrowInt(logtype) case schema.DateLogicalType: return arrow.FixedWidthTypes.Date32, nil @@ -569,13 +481,13 @@ func arrowFromInt64(logical schema.LogicalType) (arrow.DataType, error) { } switch logtype := logical.(type) { - case *schema.IntLogicalType: + case schema.IntLogicalType: return arrowInt(logtype) - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil - case *schema.TimeLogicalType: + case schema.TimeLogicalType: return arrowTime64(logtype) - case *schema.TimestampLogicalType: + case schema.TimestampLogicalType: return arrowTimestamp(logtype) default: return nil, xerrors.New(logical.String() + " cannot annotate int64") @@ -586,7 +498,7 @@ func arrowFromByteArray(logical schema.LogicalType) (arrow.DataType, error) { switch logtype := logical.(type) { case schema.StringLogicalType: return arrow.BinaryTypes.String, nil - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil case schema.NoLogicalType, schema.EnumLogicalType, @@ -600,7 +512,7 @@ func arrowFromByteArray(logical schema.LogicalType) (arrow.DataType, error) { func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, error) { switch logtype := logical.(type) { - case *schema.DecimalLogicalType: + case schema.DecimalLogicalType: return arrowDecimal(logtype), nil case schema.NoLogicalType, schema.IntervalLogicalType, schema.UUIDLogicalType: return &arrow.FixedSizeBinaryType{ByteWidth: int(length)}, nil @@ -611,6 +523,84 @@ func arrowFromFLBA(logical schema.LogicalType, length int) (arrow.DataType, erro } } +func getParquetType(typ arrow.DataType, props *parquet.WriterProperties, arrprops ArrowWriterProperties) (parquet.Type, schema.LogicalType, int, error) { + switch typ.ID() { + case arrow.NULL: + return parquet.Types.Int32, schema.NullLogicalType{}, -1, nil + case arrow.BOOL: + return parquet.Types.Boolean, schema.NoLogicalType{}, -1, nil + case arrow.UINT8: + return parquet.Types.Int32, schema.NewIntLogicalType(8, false), -1, nil + case arrow.INT8: + return parquet.Types.Int32, schema.NewIntLogicalType(8, true), -1, nil + case arrow.UINT16: + return parquet.Types.Int32, schema.NewIntLogicalType(16, false), -1, nil + case arrow.INT16: + return parquet.Types.Int32, schema.NewIntLogicalType(16, true), -1, nil + case arrow.UINT32: + return parquet.Types.Int32, schema.NewIntLogicalType(32, false), -1, nil + case arrow.INT32: + return parquet.Types.Int32, schema.NewIntLogicalType(32, true), -1, nil + case arrow.UINT64: + return parquet.Types.Int64, schema.NewIntLogicalType(64, false), -1, nil + case arrow.INT64: + return parquet.Types.Int64, schema.NewIntLogicalType(64, true), -1, nil + case arrow.FLOAT32: + return parquet.Types.Float, schema.NoLogicalType{}, -1, nil + case arrow.FLOAT64: + return parquet.Types.Double, schema.NoLogicalType{}, -1, nil + case arrow.STRING, arrow.LARGE_STRING: + return parquet.Types.ByteArray, schema.StringLogicalType{}, -1, nil + case arrow.BINARY, arrow.LARGE_BINARY: + return parquet.Types.ByteArray, schema.NoLogicalType{}, -1, nil + case arrow.FIXED_SIZE_BINARY: + return parquet.Types.FixedLenByteArray, schema.NoLogicalType{}, typ.(*arrow.FixedSizeBinaryType).ByteWidth, nil + case arrow.DECIMAL, arrow.DECIMAL256: + dectype := typ.(arrow.DecimalType) + precision := int(dectype.GetPrecision()) + scale := int(dectype.GetScale()) + + if !props.StoreDecimalAsInteger() || precision > 18 { + return parquet.Types.FixedLenByteArray, schema.NewDecimalLogicalType(int32(precision), int32(scale)), int(DecimalSize(int32(precision))), nil + } + + pqType := parquet.Types.Int32 + if precision > 9 { + pqType = parquet.Types.Int64 + } + + return pqType, schema.NoLogicalType{}, -1, nil + case arrow.DATE32: + return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil + case arrow.DATE64: + return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil + case arrow.TIMESTAMP: + pqType, logicalType, err := getTimestampMeta(typ.(*arrow.TimestampType), props, arrprops) + return pqType, logicalType, -1, err + case arrow.TIME32: + return parquet.Types.Int32, schema.NewTimeLogicalType(true, schema.TimeUnitMillis), -1, nil + case arrow.TIME64: + pqTimeUnit := schema.TimeUnitMicros + if typ.(*arrow.Time64Type).Unit == arrow.Nanosecond { + pqTimeUnit = schema.TimeUnitNanos + } + + return parquet.Types.Int64, schema.NewTimeLogicalType(true, pqTimeUnit), -1, nil + case arrow.FLOAT16: + return parquet.Types.FixedLenByteArray, schema.Float16LogicalType{}, arrow.Float16SizeBytes, nil + case arrow.EXTENSION: + storageType := typ.(arrow.ExtensionType).StorageType() + pqType, logicalType, length, err := getParquetType(storageType, props, arrprops) + if withCustomType, ok := typ.(ExtensionCustomParquetType); ok { + logicalType = withCustomType.ParquetLogicalType() + } + + return pqType, logicalType, length, err + default: + return parquet.Type(0), nil, 0, fmt.Errorf("%w: support for %s", arrow.ErrNotImplemented, typ.ID()) + } +} + func getArrowType(physical parquet.Type, logical schema.LogicalType, typeLen int) (arrow.DataType, error) { if !logical.IsValid() || logical.Equals(schema.NullLogicalType{}) { return arrow.Null, nil diff --git a/go/parquet/pqarrow/schema_test.go b/go/parquet/pqarrow/schema_test.go index 24b031c174b..528200fd0e7 100644 --- a/go/parquet/pqarrow/schema_test.go +++ b/go/parquet/pqarrow/schema_test.go @@ -21,10 +21,10 @@ import ( "testing" "github.com/apache/arrow/go/v18/arrow" + "github.com/apache/arrow/go/v18/arrow/extensions" "github.com/apache/arrow/go/v18/arrow/flight" "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" - "github.com/apache/arrow/go/v18/internal/types" "github.com/apache/arrow/go/v18/parquet" "github.com/apache/arrow/go/v18/parquet/metadata" "github.com/apache/arrow/go/v18/parquet/pqarrow" @@ -34,7 +34,7 @@ import ( ) func TestGetOriginSchemaBase64(t *testing.T) { - uuidType := types.NewUUIDType() + uuidType := extensions.NewUUIDType() md := arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"-1"}) extMd := arrow.NewMetadata([]string{ipc.ExtensionMetadataKeyName, ipc.ExtensionTypeKeyName, "PARQUET:field_id"}, []string{uuidType.Serialize(), uuidType.ExtensionName(), "-1"}) origArrSc := arrow.NewSchema([]arrow.Field{ @@ -44,10 +44,6 @@ func TestGetOriginSchemaBase64(t *testing.T) { }, nil) arrSerializedSc := flight.SerializeSchema(origArrSc, memory.DefaultAllocator) - if err := arrow.RegisterExtensionType(uuidType); err != nil { - t.Fatal(err) - } - defer arrow.UnregisterExtensionType(uuidType.ExtensionName()) pqschema, err := pqarrow.ToParquet(origArrSc, nil, pqarrow.DefaultWriterProps()) require.NoError(t, err) @@ -71,11 +67,7 @@ func TestGetOriginSchemaBase64(t *testing.T) { } func TestGetOriginSchemaUnregisteredExtension(t *testing.T) { - uuidType := types.NewUUIDType() - if err := arrow.RegisterExtensionType(uuidType); err != nil { - t.Fatal(err) - } - + uuidType := extensions.NewUUIDType() md := arrow.NewMetadata([]string{"PARQUET:field_id"}, []string{"-1"}) origArrSc := arrow.NewSchema([]arrow.Field{ {Name: "f1", Type: arrow.BinaryTypes.String, Metadata: md}, @@ -90,6 +82,7 @@ func TestGetOriginSchemaUnregisteredExtension(t *testing.T) { kv.Append("ARROW:schema", base64.StdEncoding.EncodeToString(arrSerializedSc)) arrow.UnregisterExtensionType(uuidType.ExtensionName()) + defer arrow.RegisterExtensionType(uuidType) arrsc, err := pqarrow.FromParquet(pqschema, nil, kv) require.NoError(t, err) diff --git a/go/parquet/schema/converted_types.go b/go/parquet/schema/converted_types.go index 5fc10f61ceb..b2b6f50cbf6 100644 --- a/go/parquet/schema/converted_types.go +++ b/go/parquet/schema/converted_types.go @@ -113,13 +113,9 @@ func (p ConvertedType) ToLogicalType(convertedDecimal DecimalMetadata) LogicalTy case ConvertedTypes.TimeMicros: return NewTimeLogicalType(true /* adjustedToUTC */, TimeUnitMicros) case ConvertedTypes.TimestampMillis: - t := NewTimestampLogicalType(true /* adjustedToUTC */, TimeUnitMillis) - t.(*TimestampLogicalType).fromConverted = true - return t + return NewTimestampLogicalTypeWithOpts(WithTSIsAdjustedToUTC(), WithTSTimeUnitType(TimeUnitMillis), WithTSFromConverted()) case ConvertedTypes.TimestampMicros: - t := NewTimestampLogicalType(true /* adjustedToUTC */, TimeUnitMicros) - t.(*TimestampLogicalType).fromConverted = true - return t + return NewTimestampLogicalTypeWithOpts(WithTSIsAdjustedToUTC(), WithTSTimeUnitType(TimeUnitMicros), WithTSFromConverted()) case ConvertedTypes.Interval: return IntervalLogicalType{} case ConvertedTypes.Int8: diff --git a/go/parquet/schema/logical_types.go b/go/parquet/schema/logical_types.go index e8adce1ca14..fa46ea0172f 100644 --- a/go/parquet/schema/logical_types.go +++ b/go/parquet/schema/logical_types.go @@ -45,21 +45,21 @@ func getLogicalType(l *format.LogicalType) LogicalType { case l.IsSetENUM(): return EnumLogicalType{} case l.IsSetDECIMAL(): - return &DecimalLogicalType{typ: l.DECIMAL} + return DecimalLogicalType{typ: l.DECIMAL} case l.IsSetDATE(): return DateLogicalType{} case l.IsSetTIME(): if timeUnitFromThrift(l.TIME.Unit) == TimeUnitUnknown { panic("parquet: TimeUnit must be one of MILLIS, MICROS, or NANOS for Time logical type") } - return &TimeLogicalType{typ: l.TIME} + return TimeLogicalType{typ: l.TIME} case l.IsSetTIMESTAMP(): if timeUnitFromThrift(l.TIMESTAMP.Unit) == TimeUnitUnknown { panic("parquet: TimeUnit must be one of MILLIS, MICROS, or NANOS for Timestamp logical type") } - return &TimestampLogicalType{typ: l.TIMESTAMP} + return TimestampLogicalType{typ: l.TIMESTAMP} case l.IsSetINTEGER(): - return &IntLogicalType{typ: l.INTEGER} + return IntLogicalType{typ: l.INTEGER} case l.IsSetUNKNOWN(): return NullLogicalType{} case l.IsSetJSON(): @@ -344,7 +344,7 @@ func NewDecimalLogicalType(precision int32, scale int32) LogicalType { if scale < 0 || scale > precision { panic("parquet: scale must be a non-negative integer that does not exceed precision for decimal logical type") } - return &DecimalLogicalType{typ: &format.DecimalType{Precision: precision, Scale: scale}} + return DecimalLogicalType{typ: &format.DecimalType{Precision: precision, Scale: scale}} } // DecimalLogicalType is used to represent a decimal value of a given @@ -405,7 +405,7 @@ func (t DecimalLogicalType) toThrift() *format.LogicalType { } func (t DecimalLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*DecimalLogicalType) + other, ok := rhs.(DecimalLogicalType) if !ok { return false } @@ -509,7 +509,7 @@ func createTimeUnit(unit TimeUnitType) *format.TimeUnit { // NewTimeLogicalType returns a time type of the given unit. func NewTimeLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimeLogicalType{typ: &format.TimeType{ + return TimeLogicalType{typ: &format.TimeType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), }} @@ -584,7 +584,7 @@ func (t TimeLogicalType) toThrift() *format.LogicalType { } func (t TimeLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*TimeLogicalType) + other, ok := rhs.(TimeLogicalType) if !ok { return false } @@ -595,7 +595,7 @@ func (t TimeLogicalType) Equals(rhs LogicalType) bool { // NewTimestampLogicalType returns a logical timestamp type with "forceConverted" // set to false func NewTimestampLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimestampLogicalType{ + return TimestampLogicalType{ typ: &format.TimestampType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), @@ -608,7 +608,7 @@ func NewTimestampLogicalType(isAdjustedToUTC bool, unit TimeUnitType) LogicalTyp // NewTimestampLogicalTypeForce returns a timestamp logical type with // "forceConverted" set to true func NewTimestampLogicalTypeForce(isAdjustedToUTC bool, unit TimeUnitType) LogicalType { - return &TimestampLogicalType{ + return TimestampLogicalType{ typ: &format.TimestampType{ IsAdjustedToUTC: isAdjustedToUTC, Unit: createTimeUnit(unit), @@ -654,14 +654,14 @@ func WithTSFromConverted() TimestampOpt { // // TimestampType Unit defaults to milliseconds (TimeUnitMillis) func NewTimestampLogicalTypeWithOpts(opts ...TimestampOpt) LogicalType { - ts := &TimestampLogicalType{ + ts := TimestampLogicalType{ typ: &format.TimestampType{ Unit: createTimeUnit(TimeUnitMillis), // default to milliseconds }, } for _, o := range opts { - o(ts) + o(&ts) } return ts @@ -760,7 +760,7 @@ func (t TimestampLogicalType) toThrift() *format.LogicalType { } func (t TimestampLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*TimestampLogicalType) + other, ok := rhs.(TimestampLogicalType) if !ok { return false } @@ -778,7 +778,7 @@ func NewIntLogicalType(bitWidth int8, signed bool) LogicalType { default: panic("parquet: bit width must be exactly 8, 16, 32, or 64 for Int logical type") } - return &IntLogicalType{ + return IntLogicalType{ typ: &format.IntType{ BitWidth: bitWidth, IsSigned: signed, @@ -864,7 +864,7 @@ func (t IntLogicalType) toThrift() *format.LogicalType { } func (t IntLogicalType) Equals(rhs LogicalType) bool { - other, ok := rhs.(*IntLogicalType) + other, ok := rhs.(IntLogicalType) if !ok { return false } diff --git a/go/parquet/schema/logical_types_test.go b/go/parquet/schema/logical_types_test.go index e33925966e1..395d1504182 100644 --- a/go/parquet/schema/logical_types_test.go +++ b/go/parquet/schema/logical_types_test.go @@ -38,18 +38,18 @@ func TestConvertedLogicalEquivalences(t *testing.T) { {"list", schema.ConvertedTypes.List, schema.NewListLogicalType(), schema.NewListLogicalType()}, {"enum", schema.ConvertedTypes.Enum, schema.EnumLogicalType{}, schema.EnumLogicalType{}}, {"date", schema.ConvertedTypes.Date, schema.DateLogicalType{}, schema.DateLogicalType{}}, - {"timemilli", schema.ConvertedTypes.TimeMillis, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimeLogicalType{}}, - {"timemicro", schema.ConvertedTypes.TimeMicros, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimeLogicalType{}}, - {"timestampmilli", schema.ConvertedTypes.TimestampMillis, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimestampLogicalType{}}, - {"timestampmicro", schema.ConvertedTypes.TimestampMicros, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimestampLogicalType{}}, - {"uint8", schema.ConvertedTypes.Uint8, schema.NewIntLogicalType(8 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint16", schema.ConvertedTypes.Uint16, schema.NewIntLogicalType(16 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint32", schema.ConvertedTypes.Uint32, schema.NewIntLogicalType(32 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"uint64", schema.ConvertedTypes.Uint64, schema.NewIntLogicalType(64 /* bitWidth */, false /* signed */), &schema.IntLogicalType{}}, - {"int8", schema.ConvertedTypes.Int8, schema.NewIntLogicalType(8 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int16", schema.ConvertedTypes.Int16, schema.NewIntLogicalType(16 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int32", schema.ConvertedTypes.Int32, schema.NewIntLogicalType(32 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, - {"int64", schema.ConvertedTypes.Int64, schema.NewIntLogicalType(64 /* bitWidth */, true /* signed */), &schema.IntLogicalType{}}, + {"timemilli", schema.ConvertedTypes.TimeMillis, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimeLogicalType{}}, + {"timemicro", schema.ConvertedTypes.TimeMicros, schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimeLogicalType{}}, + {"timestampmilli", schema.ConvertedTypes.TimestampMillis, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimestampLogicalType{}}, + {"timestampmicro", schema.ConvertedTypes.TimestampMicros, schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimestampLogicalType{}}, + {"uint8", schema.ConvertedTypes.Uint8, schema.NewIntLogicalType(8 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint16", schema.ConvertedTypes.Uint16, schema.NewIntLogicalType(16 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint32", schema.ConvertedTypes.Uint32, schema.NewIntLogicalType(32 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"uint64", schema.ConvertedTypes.Uint64, schema.NewIntLogicalType(64 /* bitWidth */, false /* signed */), schema.IntLogicalType{}}, + {"int8", schema.ConvertedTypes.Int8, schema.NewIntLogicalType(8 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int16", schema.ConvertedTypes.Int16, schema.NewIntLogicalType(16 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int32", schema.ConvertedTypes.Int32, schema.NewIntLogicalType(32 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, + {"int64", schema.ConvertedTypes.Int64, schema.NewIntLogicalType(64 /* bitWidth */, true /* signed */), schema.IntLogicalType{}}, {"json", schema.ConvertedTypes.JSON, schema.JSONLogicalType{}, schema.JSONLogicalType{}}, {"bson", schema.ConvertedTypes.BSON, schema.BSONLogicalType{}, schema.BSONLogicalType{}}, {"interval", schema.ConvertedTypes.Interval, schema.IntervalLogicalType{}, schema.IntervalLogicalType{}}, @@ -72,8 +72,8 @@ func TestConvertedLogicalEquivalences(t *testing.T) { fromMake := schema.NewDecimalLogicalType(10, 4) assert.IsType(t, fromMake, fromConverted) assert.True(t, fromConverted.Equals(fromMake)) - assert.IsType(t, &schema.DecimalLogicalType{}, fromConverted) - assert.IsType(t, &schema.DecimalLogicalType{}, fromMake) + assert.IsType(t, schema.DecimalLogicalType{}, fromConverted) + assert.IsType(t, schema.DecimalLogicalType{}, fromMake) assert.True(t, schema.NewDecimalLogicalType(16, 0).Equals(schema.NewDecimalLogicalType(16, 0))) }) } @@ -160,12 +160,12 @@ func TestNewTypeIncompatibility(t *testing.T) { {"uuid", schema.UUIDLogicalType{}, schema.UUIDLogicalType{}}, {"float16", schema.Float16LogicalType{}, schema.Float16LogicalType{}}, {"null", schema.NullLogicalType{}, schema.NullLogicalType{}}, - {"not-utc-time_milli", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMillis), &schema.TimeLogicalType{}}, - {"not-utc-time-micro", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMicros), &schema.TimeLogicalType{}}, - {"not-utc-time-nano", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimeLogicalType{}}, - {"utc-time-nano", schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimeLogicalType{}}, - {"not-utc-timestamp-nano", schema.NewTimestampLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimestampLogicalType{}}, - {"utc-timestamp-nano", schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), &schema.TimestampLogicalType{}}, + {"not-utc-time_milli", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMillis), schema.TimeLogicalType{}}, + {"not-utc-time-micro", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitMicros), schema.TimeLogicalType{}}, + {"not-utc-time-nano", schema.NewTimeLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimeLogicalType{}}, + {"utc-time-nano", schema.NewTimeLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimeLogicalType{}}, + {"not-utc-timestamp-nano", schema.NewTimestampLogicalType(false /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimestampLogicalType{}}, + {"utc-timestamp-nano", schema.NewTimestampLogicalType(true /* adjustedToUTC */, schema.TimeUnitNanos), schema.TimestampLogicalType{}}, } for _, tt := range tests { diff --git a/go/parquet/schema/schema_element_test.go b/go/parquet/schema/schema_element_test.go index 7da55ce93ab..e427ba6485e 100644 --- a/go/parquet/schema/schema_element_test.go +++ b/go/parquet/schema/schema_element_test.go @@ -192,7 +192,7 @@ func (s *SchemaElementConstructionSuite) TestSimple() { func (s *SchemaElementConstructionSuite) reconstructDecimal(c schemaElementConstructArgs) *decimalSchemaElementConstruction { ret := s.reconstruct(c) - dec := c.logical.(*DecimalLogicalType) + dec := c.logical.(DecimalLogicalType) return &decimalSchemaElementConstruction{*ret, int(dec.Precision()), int(dec.Scale())} } @@ -359,7 +359,7 @@ func (s *SchemaElementConstructionSuite) TestTemporal() { func (s *SchemaElementConstructionSuite) reconstructInteger(c schemaElementConstructArgs) *intSchemaElementConstruction { base := s.reconstruct(c) - l := c.logical.(*IntLogicalType) + l := c.logical.(IntLogicalType) return &intSchemaElementConstruction{ *base, l.BitWidth(), From 82ecf3e6ed8cb58a08d600041617ce85c9bdb7c1 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 22 Aug 2024 22:57:14 +0200 Subject: [PATCH 063/157] MINOR: [CI][C++][Python] Fix Cuda builds on git main (#43789) On the Cuda self-hosted runners, we need to use legacy `docker-compose` on all Archery Docker invocations, including the "image push" step. This is because the Docker client version on those runners is too old to accept the `--file` option to the `compose` subcommand. This is a followup to https://github.com/apache/arrow/pull/43586 . The image push step cannot easily be verified in a PR, hence this second PR. Authored-by: Antoine Pitrou Signed-off-by: Sutou Kouhei --- dev/tasks/docker-tests/github.cuda.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/tasks/docker-tests/github.cuda.yml b/dev/tasks/docker-tests/github.cuda.yml index 9c7adf53a6f..8c04da8a91a 100644 --- a/dev/tasks/docker-tests/github.cuda.yml +++ b/dev/tasks/docker-tests/github.cuda.yml @@ -26,6 +26,8 @@ jobs: runs-on: ['self-hosted', 'cuda'] {{ macros.github_set_env(env) }} timeout-minutes: {{ timeout|default(60) }} + env: + ARCHERY_USE_LEGACY_DOCKER_COMPOSE: 1 steps: {{ macros.github_checkout_arrow(fetch_depth=fetch_depth|default(1))|indent }} # python 3.8 is installed on the runner, no need to install @@ -34,7 +36,6 @@ jobs: - name: Execute Docker Build shell: bash env: - ARCHERY_USE_LEGACY_DOCKER_COMPOSE: 1 {{ macros.github_set_sccache_envvars()|indent(8) }} run: | archery docker run \ From bad064f705ec9fc72efac2d13a1fc3fac6d3d137 Mon Sep 17 00:00:00 2001 From: Vyas Ramasubramani Date: Thu, 22 Aug 2024 14:08:26 -0700 Subject: [PATCH 064/157] MINOR: [C++] Ensure setting the default CMAKE_BUILD_TYPE (#43794) ### Rationale for this change The current logic for detecting whether the `CMAKE_BUILD_TYPE` is set is incorrect. That variable is never fully undefined; by default, in cases where it is unset is actually set to the empty string. Therefore, the condition that must be checked is not whether the variable is defined, but whether it tests to a truthy value (i.e. is a non-empty string). I consider this a minor change so I have not opened an associated issue. ### What changes are included in this PR? This PR changes `if(NOT DEFINED CMAKE_BUILD_TYPE)` to `if(NOT CMAKE_BUILD_TYPE)`. ### Are these changes tested? Since this fixes a particular CMake build scenario I am not sure if a test is merited, or where one would be added. ### Are there any user-facing changes? No. Authored-by: Vyas Ramasubramani Signed-off-by: Sutou Kouhei --- cpp/CMakeLists.txt | 2 +- cpp/examples/minimal_build/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a1e3138da9e..5ead9e4b063 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -84,7 +84,7 @@ set(ARROW_VERSION "18.0.0-SNAPSHOT") string(REGEX MATCH "^[0-9]+\\.[0-9]+\\.[0-9]+" ARROW_BASE_VERSION "${ARROW_VERSION}") # if no build type is specified, default to release builds -if(NOT DEFINED CMAKE_BUILD_TYPE) +if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build.") diff --git a/cpp/examples/minimal_build/CMakeLists.txt b/cpp/examples/minimal_build/CMakeLists.txt index b4a7cde938c..95dad34221a 100644 --- a/cpp/examples/minimal_build/CMakeLists.txt +++ b/cpp/examples/minimal_build/CMakeLists.txt @@ -30,7 +30,7 @@ endif() # We require a C++17 compliant compiler set(CMAKE_CXX_STANDARD_REQUIRED ON) -if(NOT DEFINED CMAKE_BUILD_TYPE) +if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) endif() From 53b15b61691dde1ea86e14b7a2216fa0a26f8054 Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Fri, 23 Aug 2024 16:17:29 -0400 Subject: [PATCH 065/157] MINOR: [Go] Fix Flakey TestRowsPrematureCloseDuringNextLoop Test (#43804) ### Rationale for this change Fixes a race condition in rows initialization that has been causing intermittent test failures. ### What changes are included in this PR? Split query and init context. Update test to check for failure _after_ reading rows. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. Authored-by: Joel Lubinitsky Signed-off-by: Joel Lubinitsky --- go/arrow/flight/flightsql/driver/driver.go | 10 ++++++---- go/arrow/flight/flightsql/driver/driver_test.go | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/go/arrow/flight/flightsql/driver/driver.go b/go/arrow/flight/flightsql/driver/driver.go index 0f2b02deaca..0513fe1ecd3 100644 --- a/go/arrow/flight/flightsql/driver/driver.go +++ b/go/arrow/flight/flightsql/driver/driver.go @@ -266,13 +266,14 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv return nil, err } + execCtx := ctx if _, set := ctx.Deadline(); !set && s.timeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, s.timeout) + execCtx, cancel = context.WithTimeout(ctx, s.timeout) defer cancel() } - info, err := s.stmt.Execute(ctx) + info, err := s.stmt.Execute(execCtx) if err != nil { return nil, err } @@ -497,13 +498,14 @@ func (c *Connection) QueryContext(ctx context.Context, query string, args []driv return nil, driver.ErrSkip } + execCtx := ctx if _, set := ctx.Deadline(); !set && c.timeout > 0 { var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, c.timeout) + execCtx, cancel = context.WithTimeout(ctx, c.timeout) defer cancel() } - info, err := c.client.Execute(ctx, query) + info, err := c.client.Execute(execCtx, query) if err != nil { return nil, err } diff --git a/go/arrow/flight/flightsql/driver/driver_test.go b/go/arrow/flight/flightsql/driver/driver_test.go index e5060ccbe33..c00dfe3c5d9 100644 --- a/go/arrow/flight/flightsql/driver/driver_test.go +++ b/go/arrow/flight/flightsql/driver/driver_test.go @@ -626,7 +626,6 @@ func (s *SqlTestSuite) TestRowsPrematureCloseDuringNextLoop() { rows, err := db.QueryContext(context.TODO(), sqlSelectAll) require.NoError(t, err) require.NotNil(t, rows) - require.NoError(t, rows.Err()) const closeAfterNRows = 10 var ( @@ -645,6 +644,7 @@ func (s *SqlTestSuite) TestRowsPrematureCloseDuringNextLoop() { require.NoError(t, rows.Close()) } } + require.NoError(t, rows.Err()) require.Equal(t, closeAfterNRows, i) From cb645a1b27dd66fddb88458c939e2851f9dadf35 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 24 Aug 2024 06:08:18 +0900 Subject: [PATCH 066/157] GH-43802: [GLib] Add `GAFlightRecordBatchWriter` (#43803) ### Rationale for this change This is needed to implement `DoPut`. ### What changes are included in this PR? We can't add tests for it because it's an abstract class. I'm not sure `is_owner` is needed like `GAFlightRecordBatchReader`. `is_owner` may be removed later if we find that it's needless. ### Are these changes tested? No. ### Are there any user-facing changes? Yes. `GAFlightRecordBatchWriter` is a new public API. * GitHub Issue: #43802 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- c_glib/arrow-flight-glib/common.cpp | 198 ++++++++++++++++++++++++++-- c_glib/arrow-flight-glib/common.h | 32 +++++ c_glib/arrow-flight-glib/common.hpp | 4 + 3 files changed, 224 insertions(+), 10 deletions(-) diff --git a/c_glib/arrow-flight-glib/common.cpp b/c_glib/arrow-flight-glib/common.cpp index efc544f10cf..f7eea08c264 100644 --- a/c_glib/arrow-flight-glib/common.cpp +++ b/c_glib/arrow-flight-glib/common.cpp @@ -48,7 +48,11 @@ G_BEGIN_DECLS * * #GAFlightStreamChunk is a class for a chunk in stream. * - * #GAFlightRecordBatchReader is a class for reading record batches. + * #GAFlightRecordBatchReader is an abstract class for reading record + * batches with metadata. + * + * #GAFlightRecordBatchWeriter is an abstract class for + * writing record batches with metadata. * * Since: 5.0.0 */ @@ -1172,13 +1176,13 @@ typedef struct GAFlightRecordBatchReaderPrivate_ } GAFlightRecordBatchReaderPrivate; enum { - PROP_READER = 1, - PROP_IS_OWNER, + PROP_RECORD_BATCH_READER_READER = 1, + PROP_RECORD_BATCH_READER_IS_OWNER, }; -G_DEFINE_TYPE_WITH_PRIVATE(GAFlightRecordBatchReader, - gaflight_record_batch_reader, - G_TYPE_OBJECT) +G_DEFINE_ABSTRACT_TYPE_WITH_PRIVATE(GAFlightRecordBatchReader, + gaflight_record_batch_reader, + G_TYPE_OBJECT) #define GAFLIGHT_RECORD_BATCH_READER_GET_PRIVATE(obj) \ static_cast( \ @@ -1204,11 +1208,11 @@ gaflight_record_batch_reader_set_property(GObject *object, auto priv = GAFLIGHT_RECORD_BATCH_READER_GET_PRIVATE(object); switch (prop_id) { - case PROP_READER: + case PROP_RECORD_BATCH_READER_READER: priv->reader = static_cast(g_value_get_pointer(value)); break; - case PROP_IS_OWNER: + case PROP_RECORD_BATCH_READER_IS_OWNER: priv->is_owner = g_value_get_boolean(value); break; default: @@ -1236,7 +1240,7 @@ gaflight_record_batch_reader_class_init(GAFlightRecordBatchReaderClass *klass) nullptr, nullptr, static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_READER, spec); + g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_READER_READER, spec); spec = g_param_spec_boolean( "is-owner", @@ -1244,7 +1248,7 @@ gaflight_record_batch_reader_class_init(GAFlightRecordBatchReaderClass *klass) nullptr, TRUE, static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_IS_OWNER, spec); + g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_READER_IS_OWNER, spec); } /** @@ -1296,6 +1300,173 @@ gaflight_record_batch_reader_read_all(GAFlightRecordBatchReader *reader, GError } } +typedef struct GAFlightRecordBatchWriterPrivate_ +{ + arrow::flight::MetadataRecordBatchWriter *writer; + bool is_owner; +} GAFlightRecordBatchWriterPrivate; + +enum { + PROP_RECORD_BATCH_WRITER_WRITER = 1, + PROP_RECORD_BATCH_WRITER_IS_OWNER, +}; + +G_DEFINE_ABSTRACT_TYPE_WITH_PRIVATE(GAFlightRecordBatchWriter, + gaflight_record_batch_writer, + GARROW_TYPE_RECORD_BATCH_WRITER) + +#define GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object) \ + static_cast( \ + gaflight_record_batch_writer_get_instance_private( \ + GAFLIGHT_RECORD_BATCH_WRITER(object))) + +static void +gaflight_record_batch_writer_finalize(GObject *object) +{ + auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object); + if (priv->is_owner) { + delete priv->writer; + } + G_OBJECT_CLASS(gaflight_info_parent_class)->finalize(object); +} + +static void +gaflight_record_batch_writer_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_RECORD_BATCH_WRITER_WRITER: + priv->writer = + static_cast(g_value_get_pointer(value)); + break; + case PROP_RECORD_BATCH_WRITER_IS_OWNER: + priv->is_owner = g_value_get_boolean(value); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gaflight_record_batch_writer_init(GAFlightRecordBatchWriter *object) +{ +} + +static void +gaflight_record_batch_writer_class_init(GAFlightRecordBatchWriterClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->finalize = gaflight_record_batch_writer_finalize; + gobject_class->set_property = gaflight_record_batch_writer_set_property; + + GParamSpec *spec; + spec = g_param_spec_pointer( + "writer", + nullptr, + nullptr, + static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_WRITER_WRITER, spec); + + spec = g_param_spec_boolean( + "is-owner", + nullptr, + nullptr, + TRUE, + static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_WRITER_IS_OWNER, spec); +} + +/** + * gaflight_record_batch_writer_begin: + * @writer: A #GAFlightRecordBatchWriter. + * @schema: A #GArrowSchema. + * @options: (nullable): A #GArrowWriteOptions. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Begins writing data with the given schema. Only used with + * `DoExchange`. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 18.0.0 + */ +gboolean +gaflight_record_batch_writer_begin(GAFlightRecordBatchWriter *writer, + GArrowSchema *schema, + GArrowWriteOptions *options, + GError **error) +{ + auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto arrow_schema = garrow_schema_get_raw(schema); + arrow::ipc::IpcWriteOptions arrow_write_options; + if (options) { + arrow_write_options = *garrow_write_options_get_raw(options); + } else { + arrow_write_options = arrow::ipc::IpcWriteOptions::Defaults(); + } + return garrow::check(error, + flight_writer->Begin(arrow_schema, arrow_write_options), + "[flight-record-batch-writer][begin]"); +} + +/** + * gaflight_record_batch_writer_write_metadata: + * @writer: A #GAFlightRecordBatchWriter. + * @metadata: A #GArrowBuffer. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Write metadata. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 18.0.0 + */ +gboolean +gaflight_record_batch_writer_write_metadata(GAFlightRecordBatchWriter *writer, + GArrowBuffer *metadata, + GError **error) +{ + auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto arrow_metadata = garrow_buffer_get_raw(metadata); + return garrow::check(error, + flight_writer->WriteMetadata(arrow_metadata), + "[flight-record-batch-writer][write-metadata]"); +} + +/** + * gaflight_record_batch_writer_write: + * @writer: A #GAFlightRecordBatchWriter. + * @record_batch: A #GArrowRecordBatch. + * @metadata: (nullable): A #GArrowBuffer. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Write a record batch with metadata. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 18.0.0 + */ +gboolean +gaflight_record_batch_writer_write(GAFlightRecordBatchWriter *writer, + GArrowRecordBatch *record_batch, + GArrowBuffer *metadata, + GError **error) +{ + auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto arrow_record_batch = garrow_record_batch_get_raw(record_batch); + auto arrow_metadata = garrow_buffer_get_raw(metadata); + return garrow::check( + error, + flight_writer->WriteWithMetadata(*arrow_record_batch, arrow_metadata), + "[flight-record-batch-writer][write]"); +} + G_END_DECLS GAFlightCriteria * @@ -1428,3 +1599,10 @@ gaflight_record_batch_reader_get_raw(GAFlightRecordBatchReader *reader) auto priv = GAFLIGHT_RECORD_BATCH_READER_GET_PRIVATE(reader); return priv->reader; } + +arrow::flight::MetadataRecordBatchWriter * +gaflight_record_batch_writer_get_raw(GAFlightRecordBatchWriter *writer) +{ + auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(writer); + return priv->writer; +} diff --git a/c_glib/arrow-flight-glib/common.h b/c_glib/arrow-flight-glib/common.h index b1d89f79c35..91c828caabb 100644 --- a/c_glib/arrow-flight-glib/common.h +++ b/c_glib/arrow-flight-glib/common.h @@ -232,4 +232,36 @@ GAFLIGHT_AVAILABLE_IN_6_0 GArrowTable * gaflight_record_batch_reader_read_all(GAFlightRecordBatchReader *reader, GError **error); +#define GAFLIGHT_TYPE_RECORD_BATCH_WRITER (gaflight_record_batch_writer_get_type()) +GAFLIGHT_AVAILABLE_IN_18_0 +G_DECLARE_DERIVABLE_TYPE(GAFlightRecordBatchWriter, + gaflight_record_batch_writer, + GAFLIGHT, + RECORD_BATCH_WRITER, + GArrowRecordBatchWriter) +struct _GAFlightRecordBatchWriterClass +{ + GArrowRecordBatchWriterClass parent_class; +}; + +GAFLIGHT_AVAILABLE_IN_18_0 +gboolean +gaflight_record_batch_writer_begin(GAFlightRecordBatchWriter *writer, + GArrowSchema *schema, + GArrowWriteOptions *options, + GError **error); + +GAFLIGHT_AVAILABLE_IN_18_0 +gboolean +gaflight_record_batch_writer_write_metadata(GAFlightRecordBatchWriter *writer, + GArrowBuffer *metadata, + GError **error); + +GAFLIGHT_AVAILABLE_IN_18_0 +gboolean +gaflight_record_batch_writer_write(GAFlightRecordBatchWriter *writer, + GArrowRecordBatch *record_batch, + GArrowBuffer *metadata, + GError **error); + G_END_DECLS diff --git a/c_glib/arrow-flight-glib/common.hpp b/c_glib/arrow-flight-glib/common.hpp index db56fff579b..ae5a7703397 100644 --- a/c_glib/arrow-flight-glib/common.hpp +++ b/c_glib/arrow-flight-glib/common.hpp @@ -79,3 +79,7 @@ gaflight_stream_chunk_get_raw(GAFlightStreamChunk *chunk); GAFLIGHT_EXTERN arrow::flight::MetadataRecordBatchReader * gaflight_record_batch_reader_get_raw(GAFlightRecordBatchReader *reader); + +GAFLIGHT_EXTERN +arrow::flight::MetadataRecordBatchWriter * +gaflight_record_batch_writer_get_raw(GAFlightRecordBatchWriter *writer); From 146b4e9669071984c883ec5791676638014bd655 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 24 Aug 2024 06:22:26 +0900 Subject: [PATCH 067/157] GH-43743: [CI][Docs] Ensure creating build directory (#43744) ### Rationale for this change It's used as a volume. If it doesn't exist, `docker compose` reports an error: Error response from daemon: invalid mount config for type "bind": bind source path does not exist: /home/runner/work/crossbow/crossbow/build/ ### What changes are included in this PR? * Create build directory * Move required `-v $PWD/build/:/build/` to `docs/github.linux.yml` ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * GitHub Issue: #43743 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- dev/tasks/docs/github.linux.yml | 4 +++- dev/tasks/tasks.yml | 4 +--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dev/tasks/docs/github.linux.yml b/dev/tasks/docs/github.linux.yml index 8ab8a593c3e..5863d68d2c8 100644 --- a/dev/tasks/docs/github.linux.yml +++ b/dev/tasks/docs/github.linux.yml @@ -34,8 +34,10 @@ jobs: env: ARROW_JAVA_SKIP_GIT_PLUGIN: true run: | + mkdir -p build archery docker run \ -e SETUPTOOLS_SCM_PRETEND_VERSION="{{ arrow.no_rc_version }}" \ + -v $PWD/build/:/build/ \ {{ flags|default("") }} \ {{ image }} \ {{ command|default("") }} @@ -45,7 +47,7 @@ jobs: ref: {{ default_branch|default("main") }} path: crossbow fetch-depth: 1 - {% if publish %} + {% if publish %} - name: Prepare Docs Preview run: | # build files are created by the docker user diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 60114d69308..cae34c32313 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1487,7 +1487,7 @@ tasks: image: debian-go {% endfor %} - # be sure to update binary-task.rb when upgrading ubuntu + # be sure to update binary-task.rb when upgrading Debian test-debian-12-docs: ci: github template: docs/github.linux.yml @@ -1495,7 +1495,6 @@ tasks: env: JDK: 17 pr_number: Unset - flags: "-v $PWD/build/:/build/" image: debian-docs publish: false artifacts: @@ -1621,6 +1620,5 @@ tasks: env: JDK: 17 pr_number: Unset - flags: "-v $PWD/build/:/build/" image: debian-docs publish: true From e61c105c73dfabb51d5afc972ff21cc5326b3d93 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Sat, 24 Aug 2024 07:07:09 +0530 Subject: [PATCH 068/157] GH-41584: [Java] ListView Implementation for C Data Interface (#43686) ### Rationale for this change C Data Interface is missing `ListView` and `LargeListView` after recently merging core functionalities. Also closes; - [x] https://github.com/apache/arrow/issues/41585 ### What changes are included in this PR? This PR includes C Data interface related component additions to `ListView` and `LargeListView` along with the corresponding test cases. ### Are these changes tested? Yes ### Are there any user-facing changes? No * GitHub Issue: #41584 Authored-by: Vibhatha Abeykoon Signed-off-by: David Li --- dev/archery/archery/integration/datagen.py | 1 - .../arrow/c/BufferImportTypeVisitor.java | 14 +- .../main/java/org/apache/arrow/c/Format.java | 8 ++ .../org/apache/arrow/c/RoundtripTest.java | 42 ++++++ java/c/src/test/python/integration_tests.py | 47 ++++++ .../BaseLargeRepeatedValueViewVector.java | 29 ++-- .../complex/BaseRepeatedValueViewVector.java | 30 ++-- .../vector/complex/LargeListViewVector.java | 10 +- .../arrow/vector/complex/ListViewVector.java | 6 +- .../arrow/vector/TestLargeListViewVector.java | 134 ++++++++++++++++++ .../arrow/vector/TestListViewVector.java | 132 +++++++++++++++++ .../testing/ValueVectorDataPopulator.java | 34 +++++ 12 files changed, 451 insertions(+), 36 deletions(-) diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index 47310c905a9..d395d26cb71 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1936,7 +1936,6 @@ def _temp_path(): generate_list_view_case() .skip_tester('C#') # Doesn't support large list views - .skip_tester('Java') .skip_tester('JS') .skip_tester('nanoarrow') .skip_tester('Rust'), diff --git a/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java b/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java index 633ecd43bd5..93fef6d7ca8 100644 --- a/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java +++ b/java/c/src/main/java/org/apache/arrow/c/BufferImportTypeVisitor.java @@ -47,7 +47,9 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; @@ -400,13 +402,17 @@ public List visit(ArrowType.Duration type) { @Override public List visit(ArrowType.ListView type) { - throw new UnsupportedOperationException( - "Importing buffers for view type: " + type + " not supported"); + return Arrays.asList( + maybeImportBitmap(type), + importFixedBytes(type, 1, ListViewVector.OFFSET_WIDTH), + importFixedBytes(type, 2, ListViewVector.SIZE_WIDTH)); } @Override public List visit(ArrowType.LargeListView type) { - throw new UnsupportedOperationException( - "Importing buffers for view type: " + type + " not supported"); + return Arrays.asList( + maybeImportBitmap(type), + importFixedBytes(type, 1, LargeListViewVector.OFFSET_WIDTH), + importFixedBytes(type, 2, LargeListViewVector.SIZE_WIDTH)); } } diff --git a/java/c/src/main/java/org/apache/arrow/c/Format.java b/java/c/src/main/java/org/apache/arrow/c/Format.java index aff51e7b734..f77a555d184 100644 --- a/java/c/src/main/java/org/apache/arrow/c/Format.java +++ b/java/c/src/main/java/org/apache/arrow/c/Format.java @@ -229,6 +229,10 @@ static String asString(ArrowType arrowType) { return "vu"; case BinaryView: return "vz"; + case ListView: + return "+vl"; + case LargeListView: + return "+vL"; case NONE: throw new IllegalArgumentException("Arrow type ID is NONE"); default: @@ -313,6 +317,10 @@ static ArrowType asType(String format, long flags) return new ArrowType.Utf8View(); case "vz": return new ArrowType.BinaryView(); + case "+vl": + return new ArrowType.ListView(); + case "+vL": + return new ArrowType.LargeListView(); default: String[] parts = format.split(":", 2); if (parts.length == 2) { diff --git a/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java b/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java index 6591d1f7309..18b2e94adde 100644 --- a/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java +++ b/java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java @@ -84,7 +84,9 @@ import org.apache.arrow.vector.compare.VectorEqualsVisitor; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; @@ -683,6 +685,46 @@ public void testFixedSizeListVector() { } } + @Test + public void testListViewVector() { + try (final ListViewVector vector = ListViewVector.empty("v", allocator)) { + setVector( + vector, + Arrays.stream(new int[] {1, 2}).boxed().collect(Collectors.toList()), + Arrays.stream(new int[] {3, 4}).boxed().collect(Collectors.toList()), + new ArrayList()); + assertTrue(roundtrip(vector, ListViewVector.class)); + } + } + + @Test + public void testEmptyListViewVector() { + try (final ListViewVector vector = ListViewVector.empty("v", allocator)) { + setVector(vector, new ArrayList()); + assertTrue(roundtrip(vector, ListViewVector.class)); + } + } + + @Test + public void testLargeListViewVector() { + try (final LargeListViewVector vector = LargeListViewVector.empty("v", allocator)) { + setVector( + vector, + Arrays.stream(new int[] {1, 2}).boxed().collect(Collectors.toList()), + Arrays.stream(new int[] {3, 4}).boxed().collect(Collectors.toList()), + new ArrayList()); + assertTrue(roundtrip(vector, LargeListViewVector.class)); + } + } + + @Test + public void testEmptyLargeListViewVector() { + try (final LargeListViewVector vector = LargeListViewVector.empty("v", allocator)) { + setVector(vector, new ArrayList()); + assertTrue(roundtrip(vector, LargeListViewVector.class)); + } + } + @Test public void testMapVector() { int count = 5; diff --git a/java/c/src/test/python/integration_tests.py b/java/c/src/test/python/integration_tests.py index ab2ee1742f3..b0a86e9c66e 100644 --- a/java/c/src/test/python/integration_tests.py +++ b/java/c/src/test/python/integration_tests.py @@ -352,6 +352,53 @@ def test_reader_complex_roundtrip(self): ] self.round_trip_reader(schema, data) + def test_listview_array(self): + self.round_trip_array(lambda: pa.array( + [[], [0], [1, 2], [4, 5, 6]], pa.list_view(pa.int64()) + # disabled check_metadata since in Java API the listview + # internal field name ("item") is not preserved + # during round trips (it becomes "$data$"). + ), check_metadata=False) + + def test_empty_listview_array(self): + with pa.BufferOutputStream() as bos: + schema = pa.schema([pa.field("f0", pa.list_view(pa.int32()), True)]) + with ipc.new_stream(bos, schema) as writer: + src = pa.RecordBatch.from_arrays( + [pa.array([[]], pa.list_view(pa.int32()))], schema=schema) + writer.write(src) + data_bytes = bos.getvalue() + + def recreate_batch(): + with pa.input_stream(data_bytes) as ios: + with ipc.open_stream(ios) as reader: + return reader.read_next_batch() + + self.round_trip_record_batch(recreate_batch) + + def test_largelistview_array(self): + self.round_trip_array(lambda: pa.array( + [[], [0], [1, 2], [4, 5, 6]], pa.large_list_view(pa.int64()) + # disabled check_metadata since in Java API the listview + # internal field name ("item") is not preserved + # during round trips (it becomes "$data$"). + ), check_metadata=False) + + def test_empty_largelistview_array(self): + with pa.BufferOutputStream() as bos: + schema = pa.schema([pa.field("f0", pa.large_list_view(pa.int32()), True)]) + with ipc.new_stream(bos, schema) as writer: + src = pa.RecordBatch.from_arrays( + [pa.array([[]], pa.large_list_view(pa.int32()))], schema=schema) + writer.write(src) + data_bytes = bos.getvalue() + + def recreate_batch(): + with pa.input_stream(data_bytes) as ios: + with ipc.open_stream(ios) as reader: + return reader.read_next_batch() + + self.round_trip_record_batch(recreate_batch) if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java index f643306cfdc..12edd6557bd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseLargeRepeatedValueViewVector.java @@ -305,38 +305,43 @@ public void setValueCount(int valueCount) { while (valueCount > getOffsetBufferValueCapacity()) { reallocateBuffers(); } - final int childValueCount = valueCount == 0 ? 0 : getLengthOfChildVector(); + final int childValueCount = valueCount == 0 ? 0 : getMaxViewEndChildVector(); vector.setValueCount(childValueCount); } - protected int getLengthOfChildVector() { + /** + * Get the end of the child vector via the maximum view length. This method deduces the length by + * considering the condition i.e., argmax_i(offsets[i] + size[i]). + * + * @return the end of the child vector. + */ + protected int getMaxViewEndChildVector() { int maxOffsetSizeSum = offsetBuffer.getInt(0) + sizeBuffer.getInt(0); - int minOffset = offsetBuffer.getInt(0); for (int i = 0; i < valueCount; i++) { int currentOffset = offsetBuffer.getInt((long) i * OFFSET_WIDTH); int currentSize = sizeBuffer.getInt((long) i * SIZE_WIDTH); int currentSum = currentOffset + currentSize; - maxOffsetSizeSum = Math.max(maxOffsetSizeSum, currentSum); - minOffset = Math.min(minOffset, currentOffset); } - return maxOffsetSizeSum - minOffset; + return maxOffsetSizeSum; } - protected int getLengthOfChildVectorByIndex(int index) { + /** + * Get the end of the child vector via the maximum view length of the child vector by index. + * + * @return the end of the child vector by index + */ + protected int getMaxViewEndChildVectorByIndex(int index) { int maxOffsetSizeSum = offsetBuffer.getInt(0) + sizeBuffer.getInt(0); - int minOffset = offsetBuffer.getInt(0); for (int i = 0; i < index; i++) { int currentOffset = offsetBuffer.getInt((long) i * OFFSET_WIDTH); int currentSize = sizeBuffer.getInt((long) i * SIZE_WIDTH); int currentSum = currentOffset + currentSize; - maxOffsetSizeSum = Math.max(maxOffsetSizeSum, currentSum); - minOffset = Math.min(minOffset, currentOffset); } - return maxOffsetSizeSum - minOffset; + return maxOffsetSizeSum; } /** @@ -390,7 +395,7 @@ public int startNewValue(int index) { } if (index > 0) { - final int prevOffset = getLengthOfChildVectorByIndex(index); + final int prevOffset = getMaxViewEndChildVectorByIndex(index); offsetBuffer.setInt((long) index * OFFSET_WIDTH, prevOffset); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueViewVector.java index 031cc8037bb..e6213316b55 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueViewVector.java @@ -304,38 +304,44 @@ public void setValueCount(int valueCount) { while (valueCount > getOffsetBufferValueCapacity()) { reallocateBuffers(); } - final int childValueCount = valueCount == 0 ? 0 : getLengthOfChildVector(); + final int childValueCount = valueCount == 0 ? 0 : getMaxViewEndChildVector(); vector.setValueCount(childValueCount); } - protected int getLengthOfChildVector() { + /** + * Get the end of the child vector via the maximum view length. This method deduces the length by + * considering the condition i.e., argmax_i(offsets[i] + size[i]). + * + * @return the end of the child vector. + */ + protected int getMaxViewEndChildVector() { int maxOffsetSizeSum = offsetBuffer.getInt(0) + sizeBuffer.getInt(0); - int minOffset = offsetBuffer.getInt(0); for (int i = 0; i < valueCount; i++) { int currentOffset = offsetBuffer.getInt(i * OFFSET_WIDTH); int currentSize = sizeBuffer.getInt(i * SIZE_WIDTH); int currentSum = currentOffset + currentSize; - maxOffsetSizeSum = Math.max(maxOffsetSizeSum, currentSum); - minOffset = Math.min(minOffset, currentOffset); } - return maxOffsetSizeSum - minOffset; + return maxOffsetSizeSum; } - protected int getLengthOfChildVectorByIndex(int index) { + /** + * Get the end of the child vector via the maximum view length of the child vector by index. + * + * @return the end of the child vector by index + */ + protected int getMaxViewEndChildVectorByIndex(int index) { int maxOffsetSizeSum = offsetBuffer.getInt(0) + sizeBuffer.getInt(0); - int minOffset = offsetBuffer.getInt(0); + // int minOffset = offsetBuffer.getInt(0); for (int i = 0; i < index; i++) { int currentOffset = offsetBuffer.getInt(i * OFFSET_WIDTH); int currentSize = sizeBuffer.getInt(i * SIZE_WIDTH); int currentSum = currentOffset + currentSize; - maxOffsetSizeSum = Math.max(maxOffsetSizeSum, currentSum); - minOffset = Math.min(minOffset, currentOffset); } - return maxOffsetSizeSum - minOffset; + return maxOffsetSizeSum; } /** @@ -389,7 +395,7 @@ public int startNewValue(int index) { } if (index > 0) { - final int prevOffset = getLengthOfChildVectorByIndex(index); + final int prevOffset = getMaxViewEndChildVectorByIndex(index); offsetBuffer.setInt(index * OFFSET_WIDTH, prevOffset); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java index 2c61f799a4c..84c6f03edb2 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/LargeListViewVector.java @@ -250,7 +250,9 @@ public List getFieldBuffers() { */ @Override public void exportCDataBuffers(List buffers, ArrowBuf buffersPtr, long nullValue) { - throw new UnsupportedOperationException("exportCDataBuffers Not implemented yet"); + exportBuffer(validityBuffer, buffers, buffersPtr, nullValue, true); + exportBuffer(offsetBuffer, buffers, buffersPtr, nullValue, true); + exportBuffer(sizeBuffer, buffers, buffersPtr, nullValue, true); } @Override @@ -851,7 +853,7 @@ public int startNewValue(int index) { } if (index > 0) { - final int prevOffset = getLengthOfChildVectorByIndex(index); + final int prevOffset = getMaxViewEndChildVectorByIndex(index); offsetBuffer.setInt(index * OFFSET_WIDTH, prevOffset); } @@ -943,7 +945,7 @@ public void setValueCount(int valueCount) { } } /* valueCount for the data vector is the current end offset */ - final long childValueCount = (valueCount == 0) ? 0 : getLengthOfChildVector(); + final long childValueCount = (valueCount == 0) ? 0 : getMaxViewEndChildVector(); /* set the value count of data vector and this will take care of * checking whether data buffer needs to be reallocated. * TODO: revisit when 64-bit vectors are supported @@ -1001,7 +1003,7 @@ public double getDensity() { if (valueCount == 0) { return 0.0D; } - final double totalListSize = getLengthOfChildVector(); + final double totalListSize = getMaxViewEndChildVector(); return totalListSize / valueCount; } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java index 7f6d92f3be9..9b4e6b4c0cd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListViewVector.java @@ -858,7 +858,7 @@ public int startNewValue(int index) { } if (index > 0) { - final int prevOffset = getLengthOfChildVectorByIndex(index); + final int prevOffset = getMaxViewEndChildVectorByIndex(index); offsetBuffer.setInt(index * OFFSET_WIDTH, prevOffset); } @@ -942,7 +942,7 @@ public void setValueCount(int valueCount) { } } /* valueCount for the data vector is the current end offset */ - final int childValueCount = (valueCount == 0) ? 0 : getLengthOfChildVector(); + final int childValueCount = (valueCount == 0) ? 0 : getMaxViewEndChildVector(); /* set the value count of data vector and this will take care of * checking whether data buffer needs to be reallocated. */ @@ -1005,7 +1005,7 @@ public double getDensity() { if (valueCount == 0) { return 0.0D; } - final double totalListSize = getLengthOfChildVector(); + final double totalListSize = getMaxViewEndChildVector(); return totalListSize / valueCount; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java index 2ed8d4d7005..26e7bb4a0d3 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestLargeListViewVector.java @@ -2095,6 +2095,140 @@ public void testOutOfOrderOffsetSplitAndTransfer() { } } + @Test + public void testRangeChildVector1() { + /* + * Non-overlapping ranges + * offsets: [0, 2] + * sizes: [4, 1] + * values: [0, 1, 2, 3] + * + * vector: [[0, 1, 2, 3], [2]] + * */ + try (LargeListViewVector largeListViewVector = + LargeListViewVector.empty("largelistview", allocator)) { + // Allocate buffers in listViewVector by calling `allocateNew` method. + largeListViewVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), null, null); + Field field = new Field("child-vector", fieldType, null); + largeListViewVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = largeListViewVector.getDataVector(); + fieldVector.clear(); + + IntVector childVector = (IntVector) fieldVector; + + childVector.allocateNew(8); + + childVector.set(0, 0); + childVector.set(1, 1); + childVector.set(2, 2); + childVector.set(3, 3); + childVector.set(4, 4); + childVector.set(5, 5); + childVector.set(6, 6); + childVector.set(7, 7); + + childVector.setValueCount(8); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + largeListViewVector.setValidity(0, 1); + largeListViewVector.setValidity(1, 1); + + largeListViewVector.setOffset(0, 0); + largeListViewVector.setOffset(1, 2); + + largeListViewVector.setSize(0, 4); + largeListViewVector.setSize(1, 1); + + assertEquals(8, largeListViewVector.getDataVector().getValueCount()); + + largeListViewVector.setValueCount(2); + assertEquals(4, largeListViewVector.getDataVector().getValueCount()); + + IntVector childVector1 = (IntVector) largeListViewVector.getDataVector(); + final ArrowBuf dataBuffer = childVector1.getDataBuffer(); + final ArrowBuf validityBuffer = childVector1.getValidityBuffer(); + + // yet the underneath buffer contains the original buffer + for (int i = 0; i < validityBuffer.capacity(); i++) { + assertEquals(i, dataBuffer.getInt((long) i * IntVector.TYPE_WIDTH)); + } + } + } + + @Test + public void testRangeChildVector2() { + /* + * Overlapping ranges + * offsets: [0, 2] + * sizes: [3, 1] + * values: [0, 1, 2, 3] + * + * vector: [[1, 2, 3], [2]] + * */ + try (LargeListViewVector largeListViewVector = + LargeListViewVector.empty("largelistview", allocator)) { + // Allocate buffers in listViewVector by calling `allocateNew` method. + largeListViewVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), null, null); + Field field = new Field("child-vector", fieldType, null); + largeListViewVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = largeListViewVector.getDataVector(); + fieldVector.clear(); + + IntVector childVector = (IntVector) fieldVector; + + childVector.allocateNew(8); + + childVector.set(0, 0); + childVector.set(1, 1); + childVector.set(2, 2); + childVector.set(3, 3); + childVector.set(4, 4); + childVector.set(5, 5); + childVector.set(6, 6); + childVector.set(7, 7); + + childVector.setValueCount(8); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + largeListViewVector.setValidity(0, 1); + largeListViewVector.setValidity(1, 1); + + largeListViewVector.setOffset(0, 1); + largeListViewVector.setOffset(1, 2); + + largeListViewVector.setSize(0, 3); + largeListViewVector.setSize(1, 1); + + assertEquals(8, largeListViewVector.getDataVector().getValueCount()); + + largeListViewVector.setValueCount(2); + assertEquals(4, largeListViewVector.getDataVector().getValueCount()); + + IntVector childVector1 = (IntVector) largeListViewVector.getDataVector(); + final ArrowBuf dataBuffer = childVector1.getDataBuffer(); + final ArrowBuf validityBuffer = childVector1.getValidityBuffer(); + + // yet the underneath buffer contains the original buffer + for (int i = 0; i < validityBuffer.capacity(); i++) { + assertEquals(i, dataBuffer.getInt((long) i * IntVector.TYPE_WIDTH)); + } + } + } + private void writeIntValues(UnionLargeListViewWriter writer, int[] values) { writer.startListView(); for (int v : values) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java index 4fa808c18ae..639585fc48d 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestListViewVector.java @@ -2084,6 +2084,138 @@ public void testOutOfOrderOffsetSplitAndTransfer() { } } + @Test + public void testRangeChildVector1() { + /* + * Non-overlapping ranges + * offsets: [0, 2] + * sizes: [4, 1] + * values: [0, 1, 2, 3] + * + * vector: [[0, 1, 2, 3], [2]] + * */ + try (ListViewVector listViewVector = ListViewVector.empty("listview", allocator)) { + // Allocate buffers in listViewVector by calling `allocateNew` method. + listViewVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), null, null); + Field field = new Field("child-vector", fieldType, null); + listViewVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = listViewVector.getDataVector(); + fieldVector.clear(); + + IntVector childVector = (IntVector) fieldVector; + + childVector.allocateNew(8); + + childVector.set(0, 0); + childVector.set(1, 1); + childVector.set(2, 2); + childVector.set(3, 3); + childVector.set(4, 4); + childVector.set(5, 5); + childVector.set(6, 6); + childVector.set(7, 7); + + childVector.setValueCount(8); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + listViewVector.setValidity(0, 1); + listViewVector.setValidity(1, 1); + + listViewVector.setOffset(0, 0); + listViewVector.setOffset(1, 2); + + listViewVector.setSize(0, 4); + listViewVector.setSize(1, 1); + + assertEquals(8, listViewVector.getDataVector().getValueCount()); + + listViewVector.setValueCount(2); + assertEquals(4, listViewVector.getDataVector().getValueCount()); + + IntVector childVector1 = (IntVector) listViewVector.getDataVector(); + final ArrowBuf dataBuffer = childVector1.getDataBuffer(); + final ArrowBuf validityBuffer = childVector1.getValidityBuffer(); + + // yet the underneath buffer contains the original buffer + for (int i = 0; i < validityBuffer.capacity(); i++) { + assertEquals(i, dataBuffer.getInt((long) i * IntVector.TYPE_WIDTH)); + } + } + } + + @Test + public void testRangeChildVector2() { + /* + * Overlapping ranges + * offsets: [0, 2] + * sizes: [3, 1] + * values: [0, 1, 2, 3] + * + * vector: [[1, 2, 3], [2]] + * */ + try (ListViewVector listViewVector = ListViewVector.empty("listview", allocator)) { + // Allocate buffers in listViewVector by calling `allocateNew` method. + listViewVector.allocateNew(); + + // Initialize the child vector using `initializeChildrenFromFields` method. + + FieldType fieldType = new FieldType(true, new ArrowType.Int(32, true), null, null); + Field field = new Field("child-vector", fieldType, null); + listViewVector.initializeChildrenFromFields(Collections.singletonList(field)); + + // Set values in the child vector. + FieldVector fieldVector = listViewVector.getDataVector(); + fieldVector.clear(); + + IntVector childVector = (IntVector) fieldVector; + + childVector.allocateNew(8); + + childVector.set(0, 0); + childVector.set(1, 1); + childVector.set(2, 2); + childVector.set(3, 3); + childVector.set(4, 4); + childVector.set(5, 5); + childVector.set(6, 6); + childVector.set(7, 7); + + childVector.setValueCount(8); + + // Set validity, offset and size buffers using `setValidity`, + // `setOffset` and `setSize` methods. + listViewVector.setValidity(0, 1); + listViewVector.setValidity(1, 1); + + listViewVector.setOffset(0, 1); + listViewVector.setOffset(1, 2); + + listViewVector.setSize(0, 3); + listViewVector.setSize(1, 1); + + assertEquals(8, listViewVector.getDataVector().getValueCount()); + + listViewVector.setValueCount(2); + assertEquals(4, listViewVector.getDataVector().getValueCount()); + + IntVector childVector1 = (IntVector) listViewVector.getDataVector(); + final ArrowBuf dataBuffer = childVector1.getDataBuffer(); + final ArrowBuf validityBuffer = childVector1.getValidityBuffer(); + + // yet the underneath buffer contains the original buffer + for (int i = 0; i < validityBuffer.capacity(); i++) { + assertEquals(i, dataBuffer.getInt((long) i * IntVector.TYPE_WIDTH)); + } + } + } + private void writeIntValues(UnionListViewWriter writer, int[] values) { writer.startListView(); for (int v : values) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java b/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java index 69e16dc4703..afbc30f019e 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java @@ -60,10 +60,12 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VariableWidthFieldVector; +import org.apache.arrow.vector.complex.BaseLargeRepeatedValueViewVector; import org.apache.arrow.vector.complex.BaseRepeatedValueVector; import org.apache.arrow.vector.complex.BaseRepeatedValueViewVector; import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.LargeListViewVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.ListViewVector; import org.apache.arrow.vector.complex.StructVector; @@ -760,4 +762,36 @@ public static void setVector(ListViewVector vector, List... values) { dataVector.setValueCount(curPos); vector.setValueCount(values.length); } + + /** Populate values for {@link ListViewVector}. */ + public static void setVector(LargeListViewVector vector, List... values) { + vector.allocateNewSafe(); + Types.MinorType type = Types.MinorType.INT; + vector.addOrGetVector(FieldType.nullable(type.getType())); + + IntVector dataVector = (IntVector) vector.getDataVector(); + dataVector.allocateNew(); + + // set underlying vectors + int curPos = 0; + for (int i = 0; i < values.length; i++) { + vector + .getOffsetBuffer() + .setInt((long) i * BaseLargeRepeatedValueViewVector.OFFSET_WIDTH, curPos); + if (values[i] == null) { + BitVectorHelper.unsetBit(vector.getValidityBuffer(), i); + } else { + BitVectorHelper.setBit(vector.getValidityBuffer(), i); + for (int value : values[i]) { + dataVector.setSafe(curPos, value); + curPos += 1; + } + } + vector + .getSizeBuffer() + .setInt((long) i * BaseRepeatedValueViewVector.SIZE_WIDTH, values[i].size()); + } + dataVector.setValueCount(curPos); + vector.setValueCount(values.length); + } } From 83d915a3d2ac2acecbb2cb2dc0dd7f5a213dd625 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 12:38:38 +0900 Subject: [PATCH 069/157] MINOR: [Java] Bump dep.slf4j.version from 2.0.13 to 2.0.16 in /java (#43652) Bumps `dep.slf4j.version` from 2.0.13 to 2.0.16. Updates `org.slf4j:slf4j-api` from 2.0.13 to 2.0.16 Updates `org.slf4j:slf4j-jdk14` from 2.0.13 to 2.0.16 Updates `org.slf4j:jul-to-slf4j` from 2.0.13 to 2.0.16 Updates `org.slf4j:jcl-over-slf4j` from 2.0.13 to 2.0.16 Updates `org.slf4j:log4j-over-slf4j` from 2.0.13 to 2.0.16 Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index a73453df68f..54bb7a0ae0e 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -94,7 +94,7 @@ under the License. ${project.build.directory}/generated-sources 1.9.0 5.10.3 - 2.0.13 + 2.0.16 33.2.1-jre 4.1.112.Final 1.66.0 From cbb5f96306972aa236750602aba4b40ceb4219c4 Mon Sep 17 00:00:00 2001 From: Bryce Mecum Date: Sun, 25 Aug 2024 21:33:51 -0700 Subject: [PATCH 070/157] MINOR: [R] Add missing PR num to news.md item (#43811) ### Rationale for this change We normally link to somewhere to give the user more context on news items. I noticed the link was missing for this one. ### What changes are included in this PR? Added PR number to news item. ### Are these changes tested? No. ### Are there any user-facing changes? No. Authored-by: Bryce Mecum Signed-off-by: Jacob Wujciak-Jens --- r/NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/NEWS.md b/r/NEWS.md index 0e6e4634a0a..b9568afe665 100644 --- a/r/NEWS.md +++ b/r/NEWS.md @@ -32,7 +32,7 @@ functions (UDFs); for UDFs, see `register_scalar_function()`. (#41223) * `mutate()` expressions can now include aggregations, such as `x - mean(x)`. (#41350) * `summarize()` supports more complex expressions, and correctly handles cases - where column names are reused in expressions. + where column names are reused in expressions. (#41223) * The `na_matches` argument to the `dplyr::*_join()` functions is now supported. This argument controls whether `NA` values are considered equal when joining. (#41358) * R metadata, stored in the Arrow schema to support round-tripping data between From 51e9f70f94cd09a0a08196afdd2f4fc644666b5e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 16:20:20 +0900 Subject: [PATCH 071/157] MINOR: [Java] Bump dep.junit.jupiter.version from 5.10.3 to 5.11.0 in /java (#43751) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps `dep.junit.jupiter.version` from 5.10.3 to 5.11.0. Updates `org.junit.jupiter:junit-jupiter-engine` from 5.10.3 to 5.11.0
Release notes

Sourced from org.junit.jupiter:junit-jupiter-engine's releases.

JUnit 5.11.0 = Platform 1.11.0 + Jupiter 5.11.0 + Vintage 5.11.0

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.10.3...r5.11.0

JUnit 5.11.0-RC1 = Platform 1.11.0-RC1 + Jupiter 5.11.0-RC1 + Vintage 5.11.0-RC1

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M2...r5.11.0-RC1

JUnit 5.11.0-M2 = Platform 1.11.0-M2 + Jupiter 5.11.0-M2 + Vintage 5.11.0-M2

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M1...r5.11.0-M2

JUnit 5.11.0-M1 = Platform 1.11.0-M1 + Jupiter 5.11.0-M1 + Vintage 5.11.0-M1

... (truncated)

Commits
  • 6b8e42b Release 5.11
  • 9430ece Allow potentially unlimited maxCharsPerColumn in Csv{File}Source (#3924)
  • 0b10f86 Polish release notes
  • 4dbd0f9 Let @ TempDir fail fast with File annotated element and non-default file s...
  • 57f1ad4 Fix syntax
  • d78730a Prioritize tasks on critical path of task graph
  • b6719e2 Remove obsolete directory
  • d8ec757 Apply Spotless formatting to Gradle script plugins
  • dae525d Disable caching of some Spotless tasks due to negative avoidance savings
  • c63d118 Re-enable caching verifyOSGi tasks (issue was fixed in bnd 7.0.0)
  • Additional commits viewable in compare view

Updates `org.junit.jupiter:junit-jupiter-api` from 5.10.3 to 5.11.0
Release notes

Sourced from org.junit.jupiter:junit-jupiter-api's releases.

JUnit 5.11.0 = Platform 1.11.0 + Jupiter 5.11.0 + Vintage 5.11.0

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.10.3...r5.11.0

JUnit 5.11.0-RC1 = Platform 1.11.0-RC1 + Jupiter 5.11.0-RC1 + Vintage 5.11.0-RC1

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M2...r5.11.0-RC1

JUnit 5.11.0-M2 = Platform 1.11.0-M2 + Jupiter 5.11.0-M2 + Vintage 5.11.0-M2

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M1...r5.11.0-M2

JUnit 5.11.0-M1 = Platform 1.11.0-M1 + Jupiter 5.11.0-M1 + Vintage 5.11.0-M1

... (truncated)

Commits
  • 6b8e42b Release 5.11
  • 9430ece Allow potentially unlimited maxCharsPerColumn in Csv{File}Source (#3924)
  • 0b10f86 Polish release notes
  • 4dbd0f9 Let @ TempDir fail fast with File annotated element and non-default file s...
  • 57f1ad4 Fix syntax
  • d78730a Prioritize tasks on critical path of task graph
  • b6719e2 Remove obsolete directory
  • d8ec757 Apply Spotless formatting to Gradle script plugins
  • dae525d Disable caching of some Spotless tasks due to negative avoidance savings
  • c63d118 Re-enable caching verifyOSGi tasks (issue was fixed in bnd 7.0.0)
  • Additional commits viewable in compare view

Updates `org.junit.jupiter:junit-jupiter-params` from 5.10.3 to 5.11.0
Release notes

Sourced from org.junit.jupiter:junit-jupiter-params's releases.

JUnit 5.11.0 = Platform 1.11.0 + Jupiter 5.11.0 + Vintage 5.11.0

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.10.3...r5.11.0

JUnit 5.11.0-RC1 = Platform 1.11.0-RC1 + Jupiter 5.11.0-RC1 + Vintage 5.11.0-RC1

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M2...r5.11.0-RC1

JUnit 5.11.0-M2 = Platform 1.11.0-M2 + Jupiter 5.11.0-M2 + Vintage 5.11.0-M2

See Release Notes.

New Contributors

Full Changelog: https://github.com/junit-team/junit5/compare/r5.11.0-M1...r5.11.0-M2

JUnit 5.11.0-M1 = Platform 1.11.0-M1 + Jupiter 5.11.0-M1 + Vintage 5.11.0-M1

... (truncated)

Commits
  • 6b8e42b Release 5.11
  • 9430ece Allow potentially unlimited maxCharsPerColumn in Csv{File}Source (#3924)
  • 0b10f86 Polish release notes
  • 4dbd0f9 Let @ TempDir fail fast with File annotated element and non-default file s...
  • 57f1ad4 Fix syntax
  • d78730a Prioritize tasks on critical path of task graph
  • b6719e2 Remove obsolete directory
  • d8ec757 Apply Spotless formatting to Gradle script plugins
  • dae525d Disable caching of some Spotless tasks due to negative avoidance savings
  • c63d118 Re-enable caching verifyOSGi tasks (issue was fixed in bnd 7.0.0)
  • Additional commits viewable in compare view

Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 54bb7a0ae0e..77feed12f3f 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -93,7 +93,7 @@ under the License. ${project.build.directory}/generated-sources 1.9.0 - 5.10.3 + 5.11.0 2.0.16 33.2.1-jre 4.1.112.Final From 2328b6ee39b497d9f48e6d342db9f7d0c34d9791 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Mon, 26 Aug 2024 16:34:18 +0200 Subject: [PATCH 072/157] GH-15058: [C++][Python] Native support for UUID (#37298) ### Rationale for this change See #15058. UUID datatype is common in throughout the ecosystem and Arrow as supporting it as a native type would reduce friction. ### What changes are included in this PR? This PR implements logic for Arrow canonical extension type in C++ and a Python wrapper. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes, new extension type is added. * Closes: #15058 Authored-by: Rok Mihevc Signed-off-by: Antoine Pitrou --- cpp/src/arrow/CMakeLists.txt | 3 +- cpp/src/arrow/acero/hash_join_node_test.cc | 1 + cpp/src/arrow/extension/CMakeLists.txt | 2 +- .../extension/fixed_shape_tensor_test.cc | 17 +-- cpp/src/arrow/extension/uuid.cc | 58 ++++++++++ cpp/src/arrow/extension/uuid.h | 61 ++++++++++ cpp/src/arrow/extension/uuid_test.cc | 72 ++++++++++++ cpp/src/arrow/extension_type.cc | 4 +- cpp/src/arrow/extension_type_test.cc | 19 +--- .../integration/json_integration_test.cc | 2 +- cpp/src/arrow/ipc/test_common.cc | 35 ++++-- cpp/src/arrow/ipc/test_common.h | 3 + cpp/src/arrow/scalar_test.cc | 5 +- cpp/src/arrow/testing/extension_type.h | 6 +- cpp/src/arrow/testing/gtest_util.cc | 16 ++- dev/archery/archery/integration/datagen.py | 2 +- docs/source/format/CanonicalExtensions.rst | 2 + docs/source/status.rst | 2 +- python/pyarrow/__init__.py | 18 +-- python/pyarrow/array.pxi | 6 + python/pyarrow/includes/libarrow.pxd | 10 ++ python/pyarrow/lib.pxd | 3 + python/pyarrow/public-api.pxi | 11 +- python/pyarrow/scalar.pxi | 10 ++ python/pyarrow/src/arrow/python/gdb.cc | 27 +---- python/pyarrow/tests/extensions.pyx | 2 +- python/pyarrow/tests/test_extension_type.py | 105 ++++++++++++------ python/pyarrow/tests/test_gdb.py | 8 +- python/pyarrow/types.pxi | 34 ++++++ 29 files changed, 412 insertions(+), 132 deletions(-) create mode 100644 cpp/src/arrow/extension/uuid.cc create mode 100644 cpp/src/arrow/extension/uuid.h create mode 100644 cpp/src/arrow/extension/uuid_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 89f28ee416e..6b0ac8c23c7 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -375,6 +375,7 @@ set(ARROW_SRCS device.cc extension_type.cc extension/bool8.cc + extension/uuid.cc pretty_print.cc record_batch.cc result.cc @@ -1225,6 +1226,7 @@ add_subdirectory(testing) add_subdirectory(array) add_subdirectory(c) add_subdirectory(compute) +add_subdirectory(extension) add_subdirectory(io) add_subdirectory(tensor) add_subdirectory(util) @@ -1267,7 +1269,6 @@ endif() if(ARROW_JSON) add_subdirectory(json) - add_subdirectory(extension) endif() if(ARROW_ORC) diff --git a/cpp/src/arrow/acero/hash_join_node_test.cc b/cpp/src/arrow/acero/hash_join_node_test.cc index 9065e286a22..76ad9c7d650 100644 --- a/cpp/src/arrow/acero/hash_join_node_test.cc +++ b/cpp/src/arrow/acero/hash_join_node_test.cc @@ -29,6 +29,7 @@ #include "arrow/compute/kernels/test_util.h" #include "arrow/compute/light_array_internal.h" #include "arrow/compute/row/row_encoder_internal.h" +#include "arrow/extension/uuid.h" #include "arrow/testing/extension_type.h" #include "arrow/testing/generator.h" #include "arrow/testing/gtest_util.h" diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index 5cb4bc77af2..065ea3f1ddb 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -set(CANONICAL_EXTENSION_TESTS bool8_test.cc) +set(CANONICAL_EXTENSION_TESTS bool8_test.cc uuid_test.cc) if(ARROW_JSON) list(APPEND CANONICAL_EXTENSION_TESTS fixed_shape_tensor_test.cc opaque_test.cc) diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc index 3fd39a11ff5..842a78e1a4f 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc +++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc @@ -23,7 +23,7 @@ #include "arrow/array/array_primitive.h" #include "arrow/io/memory.h" #include "arrow/ipc/reader.h" -#include "arrow/ipc/writer.h" +#include "arrow/ipc/test_common.h" #include "arrow/record_batch.h" #include "arrow/tensor.h" #include "arrow/testing/gtest_util.h" @@ -33,6 +33,7 @@ namespace arrow { using FixedShapeTensorType = extension::FixedShapeTensorType; +using arrow::ipc::test::RoundtripBatch; using extension::fixed_shape_tensor; using extension::FixedShapeTensorArray; @@ -71,20 +72,6 @@ class TestExtensionType : public ::testing::Test { std::string serialized_; }; -auto RoundtripBatch = [](const std::shared_ptr& batch, - std::shared_ptr* out) { - ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); - ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), - out_stream.get())); - - ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); - - io::BufferReader reader(complete_ipc_stream); - std::shared_ptr batch_reader; - ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); - ASSERT_OK(batch_reader->ReadNext(out)); -}; - TEST_F(TestExtensionType, CheckDummyRegistration) { // We need a registered dummy type at runtime to allow for IPC deserialization auto registered_type = GetExtensionType("arrow.fixed_shape_tensor"); diff --git a/cpp/src/arrow/extension/uuid.cc b/cpp/src/arrow/extension/uuid.cc new file mode 100644 index 00000000000..43b917a17f8 --- /dev/null +++ b/cpp/src/arrow/extension/uuid.cc @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/extension_type.h" +#include "arrow/util/logging.h" + +#include "arrow/extension/uuid.h" + +namespace arrow::extension { + +bool UuidType::ExtensionEquals(const ExtensionType& other) const { + return (other.extension_name() == this->extension_name()); +} + +std::shared_ptr UuidType::MakeArray(std::shared_ptr data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ("arrow.uuid", + static_cast(*data->type).extension_name()); + return std::make_shared(data); +} + +Result> UuidType::Deserialize( + std::shared_ptr storage_type, const std::string& serialized) const { + if (!serialized.empty()) { + return Status::Invalid("Unexpected serialized metadata: '", serialized, "'"); + } + if (!storage_type->Equals(*fixed_size_binary(16))) { + return Status::Invalid("Invalid storage type for UuidType: ", + storage_type->ToString()); + } + return std::make_shared(); +} + +std::string UuidType::ToString(bool show_metadata) const { + std::stringstream ss; + ss << "extension<" << this->extension_name() << ">"; + return ss.str(); +} + +std::shared_ptr uuid() { return std::make_shared(); } + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/uuid.h b/cpp/src/arrow/extension/uuid.h new file mode 100644 index 00000000000..42bb21cf0b2 --- /dev/null +++ b/cpp/src/arrow/extension/uuid.h @@ -0,0 +1,61 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/extension_type.h" + +namespace arrow::extension { + +/// \brief UuidArray stores array of UUIDs. Underlying storage type is +/// FixedSizeBinary(16). +class ARROW_EXPORT UuidArray : public ExtensionArray { + public: + using ExtensionArray::ExtensionArray; +}; + +/// \brief UuidType is a canonical arrow extension type for UUIDs. +/// UUIDs are stored as FixedSizeBinary(16) with big-endian notation and this +/// does not interpret the bytes in any way. Specific UUID version is not +/// required or guaranteed. +class ARROW_EXPORT UuidType : public ExtensionType { + public: + /// \brief Construct a UuidType. + UuidType() : ExtensionType(fixed_size_binary(16)) {} + + std::string extension_name() const override { return "arrow.uuid"; } + std::string ToString(bool show_metadata = false) const override; + + bool ExtensionEquals(const ExtensionType& other) const override; + + /// Create a UuidArray from ArrayData + std::shared_ptr MakeArray(std::shared_ptr data) const override; + + Result> Deserialize( + std::shared_ptr storage_type, + const std::string& serialized) const override; + + std::string Serialize() const override { return ""; } + + /// \brief Create a UuidType instance + static Result> Make() { return std::make_shared(); } +}; + +/// \brief Return a UuidType instance. +ARROW_EXPORT std::shared_ptr uuid(); + +} // namespace arrow::extension diff --git a/cpp/src/arrow/extension/uuid_test.cc b/cpp/src/arrow/extension/uuid_test.cc new file mode 100644 index 00000000000..3bbb6eeb4ae --- /dev/null +++ b/cpp/src/arrow/extension/uuid_test.cc @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/uuid.h" + +#include "arrow/testing/matchers.h" + +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/test_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/util/key_value_metadata.h" + +#include "arrow/testing/extension_type.h" + +namespace arrow { + +using arrow::ipc::test::RoundtripBatch; + +TEST(TestUuuidExtensionType, ExtensionTypeTest) { + auto type = uuid(); + ASSERT_EQ(type->id(), Type::EXTENSION); + + const auto& ext_type = static_cast(*type); + std::string serialized = ext_type.Serialize(); + + ASSERT_OK_AND_ASSIGN(auto deserialized, + ext_type.Deserialize(fixed_size_binary(16), serialized)); + ASSERT_TRUE(deserialized->Equals(*type)); + ASSERT_FALSE(deserialized->Equals(*fixed_size_binary(16))); +} + +TEST(TestUuuidExtensionType, RoundtripBatch) { + auto ext_type = extension::uuid(); + auto exact_ext_type = internal::checked_pointer_cast(ext_type); + auto arr = ArrayFromJSON(fixed_size_binary(16), R"(["abcdefghijklmnop", null])"); + auto ext_arr = ExtensionType::WrapArray(ext_type, arr); + + // Pass extension array, expect getting back extension array + std::shared_ptr read_batch; + auto ext_field = field(/*name=*/"f0", /*type=*/ext_type); + auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(), {ext_arr}); + RoundtripBatch(batch, &read_batch); + CompareBatch(*batch, *read_batch, /*compare_metadata=*/true); + + // Pass extension metadata and storage array, expect getting back extension array + std::shared_ptr read_batch2; + auto ext_metadata = + key_value_metadata({{"ARROW:extension:name", exact_ext_type->extension_name()}, + {"ARROW:extension:metadata", ""}}); + ext_field = field(/*name=*/"f0", /*type=*/exact_ext_type->storage_type(), + /*nullable=*/true, /*metadata=*/ext_metadata); + auto batch2 = RecordBatch::Make(schema({ext_field}), arr->length(), {arr}); + RoundtripBatch(batch2, &read_batch2); + CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true); +} + +} // namespace arrow diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index 83c7ebed4f3..fc220f73a6b 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -32,6 +32,7 @@ #include "arrow/extension/fixed_shape_tensor.h" #include "arrow/extension/opaque.h" #endif +#include "arrow/extension/uuid.h" #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" @@ -147,14 +148,13 @@ static void CreateGlobalRegistry() { // Register canonical extension types g_registry = std::make_shared(); - std::vector> ext_types{extension::bool8()}; + std::vector> ext_types{extension::bool8(), extension::uuid()}; #ifdef ARROW_JSON ext_types.push_back(extension::fixed_shape_tensor(int64(), {})); ext_types.push_back(extension::opaque(null(), "", "")); #endif - // Register canonical extension types for (const auto& ext_type : ext_types) { ARROW_CHECK_OK( g_registry->RegisterType(checked_pointer_cast(ext_type))); diff --git a/cpp/src/arrow/extension_type_test.cc b/cpp/src/arrow/extension_type_test.cc index f104c984a64..f49ffc5cba5 100644 --- a/cpp/src/arrow/extension_type_test.cc +++ b/cpp/src/arrow/extension_type_test.cc @@ -30,6 +30,7 @@ #include "arrow/io/memory.h" #include "arrow/ipc/options.h" #include "arrow/ipc/reader.h" +#include "arrow/ipc/test_common.h" #include "arrow/ipc/writer.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -41,6 +42,8 @@ namespace arrow { +using arrow::ipc::test::RoundtripBatch; + class Parametric1Array : public ExtensionArray { public: using ExtensionArray::ExtensionArray; @@ -178,7 +181,7 @@ class ExtStructType : public ExtensionType { class TestExtensionType : public ::testing::Test { public: - void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared())); } + void SetUp() { ASSERT_OK(RegisterExtensionType(std::make_shared())); } void TearDown() { if (GetExtensionType("uuid")) { @@ -211,20 +214,6 @@ TEST_F(TestExtensionType, ExtensionTypeTest) { ASSERT_EQ(deserialized->byte_width(), 16); } -auto RoundtripBatch = [](const std::shared_ptr& batch, - std::shared_ptr* out) { - ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); - ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), - out_stream.get())); - - ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); - - io::BufferReader reader(complete_ipc_stream); - std::shared_ptr batch_reader; - ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); - ASSERT_OK(batch_reader->ReadNext(out)); -}; - TEST_F(TestExtensionType, IpcRoundtrip) { auto ext_arr = ExampleUuid(); auto batch = RecordBatch::Make(schema({field("f0", uuid())}), 4, {ext_arr}); diff --git a/cpp/src/arrow/integration/json_integration_test.cc b/cpp/src/arrow/integration/json_integration_test.cc index 9b56928c688..0e84ea6124d 100644 --- a/cpp/src/arrow/integration/json_integration_test.cc +++ b/cpp/src/arrow/integration/json_integration_test.cc @@ -1046,7 +1046,7 @@ TEST(TestJsonFileReadWrite, JsonExample2) { auto storage_array = ArrayFromJSON(fixed_size_binary(16), R"(["0123456789abcdef", null])"); - AssertArraysEqual(*batch->column(0), UuidArray(uuid_type, storage_array)); + AssertArraysEqual(*batch->column(0), ExampleUuidArray(uuid_type, storage_array)); AssertArraysEqual(*batch->column(1), NullArray(2)); } diff --git a/cpp/src/arrow/ipc/test_common.cc b/cpp/src/arrow/ipc/test_common.cc index 87c02e2d87a..fb4f6bd8ead 100644 --- a/cpp/src/arrow/ipc/test_common.cc +++ b/cpp/src/arrow/ipc/test_common.cc @@ -27,8 +27,10 @@ #include "arrow/array.h" #include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" -#include "arrow/array/builder_time.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/reader.h" #include "arrow/ipc/test_common.h" +#include "arrow/ipc/writer.h" #include "arrow/pretty_print.h" #include "arrow/record_batch.h" #include "arrow/status.h" @@ -242,11 +244,11 @@ Status MakeRandomBooleanArray(const int length, bool include_nulls, std::shared_ptr* out) { std::vector values(length); random_null_bytes(length, 0.5, values.data()); - ARROW_ASSIGN_OR_RAISE(auto data, internal::BytesToBits(values)); + ARROW_ASSIGN_OR_RAISE(auto data, arrow::internal::BytesToBits(values)); if (include_nulls) { std::vector valid_bytes(length); - ARROW_ASSIGN_OR_RAISE(auto null_bitmap, internal::BytesToBits(valid_bytes)); + ARROW_ASSIGN_OR_RAISE(auto null_bitmap, arrow::internal::BytesToBits(valid_bytes)); random_null_bytes(length, 0.1, valid_bytes.data()); *out = std::make_shared(length, data, null_bitmap, -1); } else { @@ -596,7 +598,7 @@ Status MakeStruct(std::shared_ptr* out) { std::shared_ptr no_nulls(new StructArray(type, list_batch->num_rows(), columns)); std::vector null_bytes(list_batch->num_rows(), 1); null_bytes[0] = 0; - ARROW_ASSIGN_OR_RAISE(auto null_bitmap, internal::BytesToBits(null_bytes)); + ARROW_ASSIGN_OR_RAISE(auto null_bitmap, arrow::internal::BytesToBits(null_bytes)); std::shared_ptr with_nulls( new StructArray(type, list_batch->num_rows(), columns, null_bitmap, 1)); @@ -1088,9 +1090,9 @@ Status MakeUuid(std::shared_ptr* out) { auto f1 = field("f1", uuid_type, /*nullable=*/false); auto schema = ::arrow::schema({f0, f1}); - auto a0 = std::make_shared( + auto a0 = std::make_shared( uuid_type, ArrayFromJSON(storage_type, R"(["0123456789abcdef", null])")); - auto a1 = std::make_shared( + auto a1 = std::make_shared( uuid_type, ArrayFromJSON(storage_type, R"(["ZYXWVUTSRQPONMLK", "JIHGFEDBA9876543"])")); @@ -1176,12 +1178,13 @@ enable_if_t::value, void> FillRandomData( Status MakeRandomTensor(const std::shared_ptr& type, const std::vector& shape, bool row_major_p, std::shared_ptr* out, uint32_t seed) { - const auto& element_type = internal::checked_cast(*type); + const auto& element_type = arrow::internal::checked_cast(*type); std::vector strides; if (row_major_p) { - RETURN_NOT_OK(internal::ComputeRowMajorStrides(element_type, shape, &strides)); + RETURN_NOT_OK(arrow::internal::ComputeRowMajorStrides(element_type, shape, &strides)); } else { - RETURN_NOT_OK(internal::ComputeColumnMajorStrides(element_type, shape, &strides)); + RETURN_NOT_OK( + arrow::internal::ComputeColumnMajorStrides(element_type, shape, &strides)); } const int64_t element_size = element_type.bit_width() / CHAR_BIT; @@ -1233,6 +1236,20 @@ Status MakeRandomTensor(const std::shared_ptr& type, return Tensor::Make(type, buf, shape, strides).Value(out); } +void RoundtripBatch(const std::shared_ptr& batch, + std::shared_ptr* out) { + ASSERT_OK_AND_ASSIGN(auto out_stream, io::BufferOutputStream::Create()); + ASSERT_OK(ipc::WriteRecordBatchStream({batch}, ipc::IpcWriteOptions::Defaults(), + out_stream.get())); + + ASSERT_OK_AND_ASSIGN(auto complete_ipc_stream, out_stream->Finish()); + + io::BufferReader reader(complete_ipc_stream); + std::shared_ptr batch_reader; + ASSERT_OK_AND_ASSIGN(batch_reader, ipc::RecordBatchStreamReader::Open(&reader)); + ASSERT_OK(batch_reader->ReadNext(out)); +} + } // namespace test } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/ipc/test_common.h b/cpp/src/arrow/ipc/test_common.h index db8613cbb1e..9b7e7f13e3a 100644 --- a/cpp/src/arrow/ipc/test_common.h +++ b/cpp/src/arrow/ipc/test_common.h @@ -184,6 +184,9 @@ Status MakeRandomTensor(const std::shared_ptr& type, const std::vector& shape, bool row_major_p, std::shared_ptr* out, uint32_t seed = 0); +ARROW_TESTING_EXPORT void RoundtripBatch(const std::shared_ptr& batch, + std::shared_ptr* out); + } // namespace test } // namespace ipc } // namespace arrow diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 104a5697b57..e9ec13e98b4 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -43,7 +43,6 @@ namespace arrow { using compute::Cast; using compute::CastOptions; - using internal::checked_cast; using internal::checked_pointer_cast; @@ -2038,7 +2037,7 @@ class TestExtensionScalar : public ::testing::Test { void SetUp() { type_ = uuid(); storage_type_ = fixed_size_binary(16); - uuid_type_ = checked_cast(type_.get()); + uuid_type_ = checked_cast(type_.get()); } protected: @@ -2049,7 +2048,7 @@ class TestExtensionScalar : public ::testing::Test { } std::shared_ptr type_, storage_type_; - const UuidType* uuid_type_{nullptr}; + const ExampleUuidType* uuid_type_{nullptr}; const std::string_view uuid_string1_{UUID_STRING1}; const std::string_view uuid_string2_{UUID_STRING2}; diff --git a/cpp/src/arrow/testing/extension_type.h b/cpp/src/arrow/testing/extension_type.h index 6515631f202..a4526e31c2b 100644 --- a/cpp/src/arrow/testing/extension_type.h +++ b/cpp/src/arrow/testing/extension_type.h @@ -27,14 +27,14 @@ namespace arrow { -class ARROW_TESTING_EXPORT UuidArray : public ExtensionArray { +class ARROW_TESTING_EXPORT ExampleUuidArray : public ExtensionArray { public: using ExtensionArray::ExtensionArray; }; -class ARROW_TESTING_EXPORT UuidType : public ExtensionType { +class ARROW_TESTING_EXPORT ExampleUuidType : public ExtensionType { public: - UuidType() : ExtensionType(fixed_size_binary(16)) {} + ExampleUuidType() : ExtensionType(fixed_size_binary(16)) {} std::string extension_name() const override { return "uuid"; } diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 95de16c715f..ae2e53b30a3 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -49,9 +49,13 @@ #include "arrow/buffer.h" #include "arrow/compute/api_vector.h" #include "arrow/datum.h" +#include "arrow/io/memory.h" #include "arrow/ipc/json_simple.h" +#include "arrow/ipc/reader.h" +#include "arrow/ipc/writer.h" #include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep #include "arrow/pretty_print.h" +#include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/table.h" #include "arrow/tensor.h" @@ -847,17 +851,17 @@ Future<> SleepABitAsync() { /////////////////////////////////////////////////////////////////////////// // Extension types -bool UuidType::ExtensionEquals(const ExtensionType& other) const { +bool ExampleUuidType::ExtensionEquals(const ExtensionType& other) const { return (other.extension_name() == this->extension_name()); } -std::shared_ptr UuidType::MakeArray(std::shared_ptr data) const { +std::shared_ptr ExampleUuidType::MakeArray(std::shared_ptr data) const { DCHECK_EQ(data->type->id(), Type::EXTENSION); DCHECK_EQ("uuid", static_cast(*data->type).extension_name()); - return std::make_shared(data); + return std::make_shared(data); } -Result> UuidType::Deserialize( +Result> ExampleUuidType::Deserialize( std::shared_ptr storage_type, const std::string& serialized) const { if (serialized != "uuid-serialized") { return Status::Invalid("Type identifier did not match: '", serialized, "'"); @@ -866,7 +870,7 @@ Result> UuidType::Deserialize( return Status::Invalid("Invalid storage type for UuidType: ", storage_type->ToString()); } - return std::make_shared(); + return std::make_shared(); } bool SmallintType::ExtensionEquals(const ExtensionType& other) const { @@ -982,7 +986,7 @@ Result> Complex128Type::Deserialize( return std::make_shared(); } -std::shared_ptr uuid() { return std::make_shared(); } +std::shared_ptr uuid() { return std::make_shared(); } std::shared_ptr smallint() { return std::make_shared(); } diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py index d395d26cb71..f63aa0d95a4 100644 --- a/dev/archery/archery/integration/datagen.py +++ b/dev/archery/archery/integration/datagen.py @@ -1845,7 +1845,7 @@ def generate_nested_dictionary_case(): def generate_extension_case(): dict0 = Dictionary(0, StringField('dictionary0'), size=5, name='DICT0') - uuid_type = ExtensionType('uuid', 'uuid-serialized', + uuid_type = ExtensionType('arrow.uuid', '', FixedSizeBinaryField('', 16)) dict_ext_type = ExtensionType( 'dict-extension', 'dict-extension-serialized', diff --git a/docs/source/format/CanonicalExtensions.rst b/docs/source/format/CanonicalExtensions.rst index 5658f949cee..1106f8aaffd 100644 --- a/docs/source/format/CanonicalExtensions.rst +++ b/docs/source/format/CanonicalExtensions.rst @@ -272,6 +272,8 @@ JSON In the future, additional fields may be added, but they are not required to interpret the array. +.. _uuid_extension: + UUID ==== diff --git a/docs/source/status.rst b/docs/source/status.rst index 5e2c2cc19c8..b685d4bbf8a 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -121,7 +121,7 @@ Data Types +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | JSON | | | ✓ | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ -| UUID | | | ✓ | | | | | | +| UUID | ✓ | | ✓ | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ | 8-bit Boolean | ✓ | | ✓ | | | | | | +-----------------------+-------+-------+-------+------------+-------+-------+-------+-------+ diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 807bcdc3150..d31c93119b7 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -172,9 +172,7 @@ def print_entry(label, value): union, sparse_union, dense_union, dictionary, run_end_encoded, - fixed_shape_tensor, - opaque, - bool8, + bool8, fixed_shape_tensor, opaque, uuid, field, type_for_alias, DataType, DictionaryType, StructType, @@ -184,8 +182,9 @@ def print_entry(label, value): TimestampType, Time32Type, Time64Type, DurationType, FixedSizeBinaryType, Decimal128Type, Decimal256Type, BaseExtensionType, ExtensionType, - RunEndEncodedType, FixedShapeTensorType, OpaqueType, - Bool8Type, PyExtensionType, UnknownExtensionType, + RunEndEncodedType, Bool8Type, FixedShapeTensorType, + OpaqueType, UuidType, + PyExtensionType, UnknownExtensionType, register_extension_type, unregister_extension_type, DictionaryMemo, KeyValueMetadata, @@ -218,8 +217,9 @@ def print_entry(label, value): Time32Array, Time64Array, DurationArray, MonthDayNanoIntervalArray, Decimal128Array, Decimal256Array, StructArray, ExtensionArray, - RunEndEncodedArray, FixedShapeTensorArray, OpaqueArray, - Bool8Array, scalar, NA, _NULL as NULL, Scalar, + RunEndEncodedArray, Bool8Array, FixedShapeTensorArray, + OpaqueArray, UuidArray, + scalar, NA, _NULL as NULL, Scalar, NullScalar, BooleanScalar, Int8Scalar, Int16Scalar, Int32Scalar, Int64Scalar, UInt8Scalar, UInt16Scalar, UInt32Scalar, UInt64Scalar, @@ -235,8 +235,8 @@ def print_entry(label, value): StringScalar, LargeStringScalar, StringViewScalar, FixedSizeBinaryScalar, DictionaryScalar, MapScalar, StructScalar, UnionScalar, - RunEndEncodedScalar, ExtensionScalar, - FixedShapeTensorScalar, OpaqueScalar, Bool8Scalar) + RunEndEncodedScalar, Bool8Scalar, ExtensionScalar, + FixedShapeTensorScalar, OpaqueScalar, UuidScalar) # Buffers, allocation from pyarrow.lib import (DeviceAllocationType, Device, MemoryManager, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 77d6c9c06d2..1587de0e6b7 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -4338,6 +4338,12 @@ cdef class ExtensionArray(Array): return result +class UuidArray(ExtensionArray): + """ + Concrete class for Arrow arrays of UUID data type. + """ + + cdef class FixedShapeTensorArray(ExtensionArray): """ Concrete class for fixed shape tensor extension arrays. diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 6f510cfc0c0..c2346750a19 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2865,6 +2865,16 @@ cdef extern from "arrow/extension_type.h" namespace "arrow": shared_ptr[CArray] storage() +cdef extern from "arrow/extension/uuid.h" namespace "arrow::extension" nogil: + cdef cppclass CUuidType" arrow::extension::UuidType"(CExtensionType): + + @staticmethod + CResult[shared_ptr[CDataType]] Make() + + cdef cppclass CUuidArray" arrow::extension::UuidArray"(CExtensionArray): + pass + + cdef extern from "arrow/extension/fixed_shape_tensor.h" namespace "arrow::extension" nogil: cdef cppclass CFixedShapeTensorType \ " arrow::extension::FixedShapeTensorType"(CExtensionType): diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index a7c3b496a00..5c3d981c3ad 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -222,6 +222,9 @@ cdef class OpaqueType(BaseExtensionType): cdef: const COpaqueType* opaque_ext_type +cdef class UuidType(BaseExtensionType): + cdef: + const CUuidType* uuid_ext_type cdef class PyExtensionType(ExtensionType): pass diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 19a26bd6c68..d3e2ff2e99d 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -120,14 +120,17 @@ cdef api object pyarrow_wrap_data_type( elif type.get().id() == _Type_EXTENSION: ext_type = type.get() cpy_ext_type = dynamic_cast[_CPyExtensionTypePtr](ext_type) + extension_name = ext_type.extension_name() if cpy_ext_type != nullptr: return cpy_ext_type.GetInstance() - elif ext_type.extension_name() == b"arrow.fixed_shape_tensor": + elif extension_name == b"arrow.bool8": + out = Bool8Type.__new__(Bool8Type) + elif extension_name == b"arrow.fixed_shape_tensor": out = FixedShapeTensorType.__new__(FixedShapeTensorType) - elif ext_type.extension_name() == b"arrow.opaque": + elif extension_name == b"arrow.opaque": out = OpaqueType.__new__(OpaqueType) - elif ext_type.extension_name() == b"arrow.bool8": - out = Bool8Type.__new__(Bool8Type) + elif extension_name == b"arrow.uuid": + out = UuidType.__new__(UuidType) else: out = BaseExtensionType.__new__(BaseExtensionType) else: diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index 72ae2aee5f8..68f77832c43 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -17,6 +17,7 @@ import collections from cython cimport binding +from uuid import UUID cdef class Scalar(_Weakrefable): @@ -1043,6 +1044,15 @@ cdef class ExtensionScalar(Scalar): return pyarrow_wrap_scalar( sp_scalar) +class UuidScalar(ExtensionScalar): + """ + Concrete class for Uuid extension scalar. + """ + + def as_py(self): + return None if self.value is None else UUID(bytes=self.value.as_py()) + + cdef class FixedShapeTensorScalar(ExtensionScalar): """ Concrete class for fixed shape tensor extension scalar. diff --git a/python/pyarrow/src/arrow/python/gdb.cc b/python/pyarrow/src/arrow/python/gdb.cc index 6941769e4ef..7c58bae3342 100644 --- a/python/pyarrow/src/arrow/python/gdb.cc +++ b/python/pyarrow/src/arrow/python/gdb.cc @@ -22,7 +22,7 @@ #include "arrow/array.h" #include "arrow/chunked_array.h" #include "arrow/datum.h" -#include "arrow/extension_type.h" +#include "arrow/extension/uuid.h" #include "arrow/ipc/json_simple.h" #include "arrow/python/gdb.h" #include "arrow/record_batch.h" @@ -37,6 +37,8 @@ namespace arrow { +using extension::uuid; +using extension::UuidType; using ipc::internal::json::ArrayFromJSON; using ipc::internal::json::ChunkedArrayFromJSON; using ipc::internal::json::ScalarFromJSON; @@ -56,29 +58,6 @@ class CustomStatusDetail : public StatusDetail { std::string ToString() const override { return "This is a detail"; } }; -class UuidType : public ExtensionType { - public: - UuidType() : ExtensionType(fixed_size_binary(16)) {} - - std::string extension_name() const override { return "uuid"; } - - bool ExtensionEquals(const ExtensionType& other) const override { - return (other.extension_name() == this->extension_name()); - } - - std::shared_ptr MakeArray(std::shared_ptr data) const override { - return std::make_shared(data); - } - - Result> Deserialize( - std::shared_ptr storage_type, - const std::string& serialized) const override { - return Status::NotImplemented(""); - } - - std::string Serialize() const override { return "uuid-serialized"; } -}; - std::shared_ptr SliceArrayFromJSON(const std::shared_ptr& ty, std::string_view json, int64_t offset = 0, int64_t length = -1) { diff --git a/python/pyarrow/tests/extensions.pyx b/python/pyarrow/tests/extensions.pyx index c1bf9aae1ec..309b574dc02 100644 --- a/python/pyarrow/tests/extensions.pyx +++ b/python/pyarrow/tests/extensions.pyx @@ -37,7 +37,7 @@ cdef extern from * namespace "arrow::py" nogil: class UuidType : public ExtensionType { public: UuidType() : ExtensionType(fixed_size_binary(16)) {} - std::string extension_name() const override { return "uuid"; } + std::string extension_name() const override { return "example-uuid"; } bool ExtensionEquals(const ExtensionType& other) const override { return other.extension_name() == this->extension_name(); diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py index 0d50c467e96..aacbd2cb6e7 100644 --- a/python/pyarrow/tests/test_extension_type.py +++ b/python/pyarrow/tests/test_extension_type.py @@ -95,18 +95,21 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized): return cls() -class UuidScalarType(pa.ExtensionScalar): +class ExampleUuidScalarType(pa.ExtensionScalar): def as_py(self): return None if self.value is None else UUID(bytes=self.value.as_py()) -class UuidType(pa.ExtensionType): +class ExampleUuidType(pa.ExtensionType): def __init__(self): - super().__init__(pa.binary(16), 'pyarrow.tests.UuidType') + super().__init__(pa.binary(16), 'pyarrow.tests.ExampleUuidType') + + def __reduce__(self): + return ExampleUuidType, () def __arrow_ext_scalar_class__(self): - return UuidScalarType + return ExampleUuidScalarType def __arrow_ext_serialize__(self): return b'' @@ -116,10 +119,10 @@ def __arrow_ext_deserialize__(cls, storage_type, serialized): return cls() -class UuidType2(pa.ExtensionType): +class ExampleUuidType2(pa.ExtensionType): def __init__(self): - super().__init__(pa.binary(16), 'pyarrow.tests.UuidType2') + super().__init__(pa.binary(16), 'pyarrow.tests.ExampleUuidType2') def __arrow_ext_serialize__(self): return b'' @@ -250,8 +253,8 @@ def ipc_read_batch(buf): def test_ext_type_basics(): - ty = UuidType() - assert ty.extension_name == "pyarrow.tests.UuidType" + ty = ExampleUuidType() + assert ty.extension_name == "pyarrow.tests.ExampleUuidType" def test_ext_type_str(): @@ -267,16 +270,16 @@ def test_ext_type_repr(): def test_ext_type_lifetime(): - ty = UuidType() + ty = ExampleUuidType() wr = weakref.ref(ty) del ty assert wr() is None def test_ext_type_storage_type(): - ty = UuidType() + ty = ExampleUuidType() assert ty.storage_type == pa.binary(16) - assert ty.__class__ is UuidType + assert ty.__class__ is ExampleUuidType ty = ParamExtType(5) assert ty.storage_type == pa.binary(5) assert ty.__class__ is ParamExtType @@ -284,7 +287,7 @@ def test_ext_type_storage_type(): def test_ext_type_byte_width(): # Test for fixed-size binary types - ty = UuidType() + ty = pa.uuid() assert ty.byte_width == 16 ty = ParamExtType(5) assert ty.byte_width == 5 @@ -297,7 +300,7 @@ def test_ext_type_byte_width(): def test_ext_type_bit_width(): # Test for fixed-size binary types - ty = UuidType() + ty = pa.uuid() assert ty.bit_width == 128 ty = ParamExtType(5) assert ty.bit_width == 40 @@ -309,7 +312,7 @@ def test_ext_type_bit_width(): def test_ext_type_as_py(): - ty = UuidType() + ty = ExampleUuidType() expected = uuid4() scalar = pa.ExtensionScalar.from_storage(ty, expected.bytes) assert scalar.as_py() == expected @@ -342,12 +345,22 @@ def test_ext_type_as_py(): def test_uuid_type_pickle(pickle_module): for proto in range(0, pickle_module.HIGHEST_PROTOCOL + 1): - ty = UuidType() + ty = ExampleUuidType() ser = pickle_module.dumps(ty, protocol=proto) del ty ty = pickle_module.loads(ser) wr = weakref.ref(ty) - assert ty.extension_name == "pyarrow.tests.UuidType" + assert ty.extension_name == "pyarrow.tests.ExampleUuidType" + del ty + assert wr() is None + + for proto in range(0, pickle_module.HIGHEST_PROTOCOL + 1): + ty = pa.uuid() + ser = pickle_module.dumps(ty, protocol=proto) + del ty + ty = pickle_module.loads(ser) + wr = weakref.ref(ty) + assert ty.extension_name == "arrow.uuid" del ty assert wr() is None @@ -358,8 +371,8 @@ def test_ext_type_equality(): c = ParamExtType(6) assert a != b assert b == c - d = UuidType() - e = UuidType() + d = ExampleUuidType() + e = ExampleUuidType() assert a != d assert d == e @@ -403,7 +416,7 @@ def test_ext_array_equality(): storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) storage3 = pa.array([], type=pa.binary(16)) - ty1 = UuidType() + ty1 = ExampleUuidType() ty2 = ParamExtType(16) a = pa.ExtensionArray.from_storage(ty1, storage1) @@ -451,9 +464,9 @@ def test_ext_scalar_from_array(): data = [b"0123456789abcdef", b"0123456789abcdef", b"zyxwvutsrqponmlk", None] storage = pa.array(data, type=pa.binary(16)) - ty1 = UuidType() + ty1 = ExampleUuidType() ty2 = ParamExtType(16) - ty3 = UuidType2() + ty3 = ExampleUuidType2() a = pa.ExtensionArray.from_storage(ty1, storage) b = pa.ExtensionArray.from_storage(ty2, storage) @@ -462,9 +475,9 @@ def test_ext_scalar_from_array(): scalars_a = list(a) assert len(scalars_a) == 4 - assert ty1.__arrow_ext_scalar_class__() == UuidScalarType - assert isinstance(a[0], UuidScalarType) - assert isinstance(scalars_a[0], UuidScalarType) + assert ty1.__arrow_ext_scalar_class__() == ExampleUuidScalarType + assert isinstance(a[0], ExampleUuidScalarType) + assert isinstance(scalars_a[0], ExampleUuidScalarType) for s, val in zip(scalars_a, data): assert isinstance(s, pa.ExtensionScalar) @@ -505,7 +518,7 @@ def test_ext_scalar_from_array(): def test_ext_scalar_from_storage(): - ty = UuidType() + ty = ExampleUuidType() s = pa.ExtensionScalar.from_storage(ty, None) assert isinstance(s, pa.ExtensionScalar) @@ -706,14 +719,14 @@ def test_cast_between_extension_types(): tiny_int_arr.cast(pa.int64()).cast(IntegerType()) # Between the same extension types is okay - array = pa.array([b'1' * 16, b'2' * 16], pa.binary(16)).cast(UuidType()) - out = array.cast(UuidType()) - assert out.type == UuidType() + array = pa.array([b'1' * 16, b'2' * 16], pa.binary(16)).cast(ExampleUuidType()) + out = array.cast(ExampleUuidType()) + assert out.type == ExampleUuidType() # Will still fail casting between extensions who share storage type, # can only cast between exactly the same extension types. with pytest.raises(TypeError, match='Casting from *'): - array.cast(UuidType2()) + array.cast(ExampleUuidType2()) def test_cast_to_extension_with_extension_storage(): @@ -744,10 +757,10 @@ def test_cast_nested_extension_types(data, type_factory): def test_casting_dict_array_to_extension_type(): storage = pa.array([b"0123456789abcdef"], type=pa.binary(16)) - arr = pa.ExtensionArray.from_storage(UuidType(), storage) + arr = pa.ExtensionArray.from_storage(ExampleUuidType(), storage) dict_arr = pa.DictionaryArray.from_arrays(pa.array([0, 0], pa.int32()), arr) - out = dict_arr.cast(UuidType()) + out = dict_arr.cast(ExampleUuidType()) assert isinstance(out, pa.ExtensionArray) assert out.to_pylist() == [UUID('30313233-3435-3637-3839-616263646566'), UUID('30313233-3435-3637-3839-616263646566')] @@ -1347,7 +1360,7 @@ def test_cpp_extension_in_python(tmpdir): mod = __import__('extensions') uuid_type = mod._make_uuid_type() - assert uuid_type.extension_name == "uuid" + assert uuid_type.extension_name == "example-uuid" assert uuid_type.storage_type == pa.binary(16) array = mod._make_uuid_array() @@ -1356,6 +1369,31 @@ def test_cpp_extension_in_python(tmpdir): assert array[0].as_py() == b'abcdefghijklmno0' assert array[1].as_py() == b'0onmlkjihgfedcba' + buf = ipc_write_batch(pa.RecordBatch.from_arrays([array], ["example-uuid"])) + + batch = ipc_read_batch(buf) + reconstructed_array = batch.column(0) + assert reconstructed_array.type == uuid_type + assert reconstructed_array == array + + +def test_uuid_extension(): + data = [b"0123456789abcdef", b"0123456789abcdef", + b"zyxwvutsrqponmlk", None] + + uuid_type = pa.uuid() + assert uuid_type.extension_name == "arrow.uuid" + assert uuid_type.storage_type == pa.binary(16) + assert uuid_type.__class__ is pa.UuidType + + storage = pa.array(data, pa.binary(16)) + array = pa.ExtensionArray.from_storage(uuid_type, storage) + assert array.type == uuid_type + + assert array.to_pylist() == [x if x is None else UUID(bytes=x) for x in data] + assert array[0].as_py() == UUID(bytes=data[0]) + assert array[3].as_py() is None + buf = ipc_write_batch(pa.RecordBatch.from_arrays([array], ["uuid"])) batch = ipc_read_batch(buf) @@ -1363,6 +1401,9 @@ def test_cpp_extension_in_python(tmpdir): assert reconstructed_array.type == uuid_type assert reconstructed_array == array + assert uuid_type.__arrow_ext_scalar_class__() == pa.UuidScalar + assert isinstance(array[0], pa.UuidScalar) + def test_tensor_type(): tensor_type = pa.fixed_shape_tensor(pa.int8(), [2, 3]) diff --git a/python/pyarrow/tests/test_gdb.py b/python/pyarrow/tests/test_gdb.py index 0d12d710dcf..2ac2f55754f 100644 --- a/python/pyarrow/tests/test_gdb.py +++ b/python/pyarrow/tests/test_gdb.py @@ -409,7 +409,7 @@ def test_types_stack(gdb_arrow): check_stack_repr( gdb_arrow, "uuid_type", - ('arrow::ExtensionType "extension" ' + ('arrow::ExtensionType "extension" ' 'with storage type arrow::fixed_size_binary(16)')) @@ -447,7 +447,7 @@ def test_types_heap(gdb_arrow): check_heap_repr( gdb_arrow, "heap_uuid_type", - ('arrow::ExtensionType "extension" ' + ('arrow::ExtensionType "extension" ' 'with storage type arrow::fixed_size_binary(16)')) @@ -716,12 +716,12 @@ def test_scalars_stack(gdb_arrow): check_stack_repr( gdb_arrow, "extension_scalar", - ('arrow::ExtensionScalar of type "extension", ' + ('arrow::ExtensionScalar of type "extension", ' 'value arrow::FixedSizeBinaryScalar of size 16, ' 'value "0123456789abcdef"')) check_stack_repr( gdb_arrow, "extension_scalar_null", - 'arrow::ExtensionScalar of type "extension", null value') + 'arrow::ExtensionScalar of type "extension", null value') def test_scalars_heap(gdb_arrow): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 563782f0c26..f83ecc3aa43 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1765,6 +1765,25 @@ cdef class ExtensionType(BaseExtensionType): return ExtensionScalar +cdef class UuidType(BaseExtensionType): + """ + Concrete class for UUID extension type. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.uuid_ext_type = type.get() + + def __arrow_ext_class__(self): + return UuidArray + + def __reduce__(self): + return uuid, () + + def __arrow_ext_scalar_class__(self): + return UuidScalar + + cdef class FixedShapeTensorType(BaseExtensionType): """ Concrete class for fixed shape tensor extension type. @@ -5208,6 +5227,21 @@ def run_end_encoded(run_end_type, value_type): return pyarrow_wrap_data_type(ree_type) +def uuid(): + """ + Create UuidType instance. + + Returns + ------- + type : UuidType + """ + + cdef UuidType out = UuidType.__new__(UuidType) + c_uuid_ext_type = GetResultValue(CUuidType.Make()) + out.init(c_uuid_ext_type) + return out + + def fixed_shape_tensor(DataType value_type, shape, dim_names=None, permutation=None): """ Create instance of fixed shape tensor extension type with shape and optional From 8eb7bd4115da0027aad6362f0fe0901ec44b0616 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 09:12:57 +0900 Subject: [PATCH 073/157] MINOR: [Go] Bump github.com/hamba/avro/v2 from 2.24.1 to 2.25.0 in /go (#43829) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [github.com/hamba/avro/v2](https://github.com/hamba/avro) from 2.24.1 to 2.25.0.
Release notes

Sourced from github.com/hamba/avro/v2's releases.

v2.25.0

What's Changed

New Contributors

Full Changelog: https://github.com/hamba/avro/compare/v2.24.1...v2.24.2

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/hamba/avro/v2&package-manager=go_modules&previous-version=2.24.1&new-version=2.25.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Sutou Kouhei --- go/go.mod | 2 +- go/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index 9f4222a541b..97ac0568597 100644 --- a/go/go.mod +++ b/go/go.mod @@ -47,7 +47,7 @@ require ( require ( github.com/google/uuid v1.6.0 - github.com/hamba/avro/v2 v2.24.1 + github.com/hamba/avro/v2 v2.25.0 github.com/huandu/xstrings v1.4.0 github.com/substrait-io/substrait-go v0.6.0 github.com/tidwall/sjson v1.2.5 diff --git a/go/go.sum b/go/go.sum index c7eb3a66dee..bd761e15894 100644 --- a/go/go.sum +++ b/go/go.sum @@ -43,8 +43,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbu github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/hamba/avro/v2 v2.24.1 h1:Xi+7AnhaAc41aA/jmmYpxMsdEDOf1rdup6NJ85P7q2I= -github.com/hamba/avro/v2 v2.24.1/go.mod h1:7vDfy/2+kYCE8WUHoj2et59GTv0ap7ptktMXu0QHePI= +github.com/hamba/avro/v2 v2.25.0 h1:9qig/K4VP5tMq6DuKGfI6YdXncTkPJT1IJDMSv82EeI= +github.com/hamba/avro/v2 v2.25.0/go.mod h1:I8glyswHnpED3Nlx2ZdUe+4LJnCOOyiCzLMno9i/Uu0= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= From 93c5ddb957bb93421a8f84dbd7c5a5b7be2d6d45 Mon Sep 17 00:00:00 2001 From: PANKAJ9768 <48675737+PANKAJ9768@users.noreply.github.com> Date: Tue, 27 Aug 2024 05:59:09 +0530 Subject: [PATCH 074/157] GH-43667: [Java] Keeping Flight default header size consistent between server and client (#43697) ### Rationale for this change ### What changes are included in this PR? Flight client can send header size larger than server can accept. This PR is to keep default values consistent across server and client. ### Are these changes tested? ### Are there any user-facing changes? * GitHub Issue: #43667 Authored-by: pankaj kesari Signed-off-by: David Li --- .../org/apache/arrow/flight/FlightServer.java | 7 ++ .../arrow/flight/TestFlightService.java | 73 +++++++++++++++++++ 2 files changed, 80 insertions(+) diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java index 05dbe42c491..ac761457f57 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -188,6 +188,7 @@ public static final class Builder { private CallHeaderAuthenticator headerAuthenticator = CallHeaderAuthenticator.NO_OP; private ExecutorService executor = null; private int maxInboundMessageSize = MAX_GRPC_MESSAGE_SIZE; + private int maxHeaderListSize = MAX_GRPC_MESSAGE_SIZE; private int backpressureThreshold = DEFAULT_BACKPRESSURE_THRESHOLD; private InputStream certChain; private InputStream key; @@ -324,6 +325,7 @@ public FlightServer build() { builder .executor(exec) .maxInboundMessageSize(maxInboundMessageSize) + .maxInboundMetadataSize(maxHeaderListSize) .addService( ServerInterceptors.intercept( flightService, @@ -366,6 +368,11 @@ public FlightServer build() { return new FlightServer(location, builder.build(), grpcExecutor); } + public Builder setMaxHeaderListSize(int maxHeaderListSize) { + this.maxHeaderListSize = maxHeaderListSize; + return this; + } + /** * Set the maximum size of a message. Defaults to "unlimited", depending on the underlying * transport. diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java index 5ebeb44c1d3..fc3f83e4eaf 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java @@ -27,6 +27,7 @@ import java.nio.charset.StandardCharsets; import java.util.Collections; import java.util.Optional; +import java.util.Random; import org.apache.arrow.flight.impl.Flight; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; @@ -152,4 +153,76 @@ public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor assertEquals("No schema is present in FlightInfo", e.getMessage()); } } + + /** + * Test for GH-41584 where flight defaults for header size was not in sync b\w client and server. + */ + @Test + public void testHeaderSizeExchangeInService() throws Exception { + final FlightProducer producer = + new NoOpFlightProducer() { + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + String longHeader = + context.getMiddleware(FlightConstants.HEADER_KEY).headers().get("long-header"); + return new FlightInfo( + null, + descriptor, + Collections.emptyList(), + 0, + 0, + false, + IpcOption.DEFAULT, + longHeader.getBytes(StandardCharsets.UTF_8)); + } + }; + + String headerVal = generateRandom(1024 * 10); + FlightCallHeaders callHeaders = new FlightCallHeaders(); + callHeaders.insert("long-header", headerVal); + // sever with default header limit same as client + try (final FlightServer s = + FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer) + .build() + .start(); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + FlightInfo flightInfo = + client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders)); + assertEquals(Optional.empty(), flightInfo.getSchemaOptional()); + assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema()); + assertArrayEquals(flightInfo.getAppMetadata(), headerVal.getBytes(StandardCharsets.UTF_8)); + } + // server with 15kb header limit + try (final FlightServer s = + FlightServer.builder(allocator, forGrpcInsecure(LOCALHOST, 0), producer) + .setMaxHeaderListSize(1024 * 15) + .build() + .start(); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + FlightInfo flightInfo = + client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders)); + assertEquals(Optional.empty(), flightInfo.getSchemaOptional()); + assertEquals(new Schema(Collections.emptyList()), flightInfo.getSchema()); + assertArrayEquals(flightInfo.getAppMetadata(), headerVal.getBytes(StandardCharsets.UTF_8)); + + callHeaders.insert("another-header", headerVal + headerVal); + FlightRuntimeException e = + assertThrows( + FlightRuntimeException.class, + () -> + client.getInfo(FlightDescriptor.path("test"), new HeaderCallOption(callHeaders))); + assertEquals("http2 exception", e.getMessage()); + } + } + + private static String generateRandom(int size) { + String aToZ = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + Random random = new Random(); + StringBuilder res = new StringBuilder(); + for (int i = 0; i < size; i++) { + int randIndex = random.nextInt(aToZ.length()); + res.append(aToZ.charAt(randIndex)); + } + return res.toString(); + } } From 11f92491b1d2ecf700e6e023a1e413ec4c4345ae Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 11:06:13 +0900 Subject: [PATCH 075/157] MINOR: [Go] Bump github.com/substrait-io/substrait-go from 0.6.0 to 0.7.0 in /go (#43830) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [github.com/substrait-io/substrait-go](https://github.com/substrait-io/substrait-go) from 0.6.0 to 0.7.0.
Release notes

Sourced from github.com/substrait-io/substrait-go's releases.

v0.7.0 (2024-08-25)

Features

  • Add convenience literal APIs (#47) (597afdb)
    • Introduce literal package

Changes to the build process or auxiliary tools and libraries such as documentation generation

  • extensions Minor refactoring in extension_mgr.go (#45) (cbd28cb)
    • Minor refactoring in extension_mgr.go
  • Move typeName maps to types package (#46) (5556c23)
Commits
  • 597afdb feat: Add convenience literal APIs (#47)
  • e77df67 feat(types) Make time precision value explicit (#49)
  • a3e8ee0 feat(substrait) Update to substrait v0.55.0 (#48)
  • 2229c12 ci(build-test): golangci should use the go.mod version of golang (#51)
  • cbd28cb chore(extensions): Minor refactoring in extension_mgr.go (#45)
  • 5556c23 chore: Move typeName maps to types package (#46)
  • dd790cb Add a function registry for a given BFT dialect (#32)
  • 828636c ci(build-test): Add golangci-lint to do import checking and other linting (#42)
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github.com/substrait-io/substrait-go&package-manager=go_modules&previous-version=0.6.0&new-version=0.7.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Sutou Kouhei --- go/go.mod | 2 +- go/go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go/go.mod b/go/go.mod index 97ac0568597..a995eee24d5 100644 --- a/go/go.mod +++ b/go/go.mod @@ -49,7 +49,7 @@ require ( github.com/google/uuid v1.6.0 github.com/hamba/avro/v2 v2.25.0 github.com/huandu/xstrings v1.4.0 - github.com/substrait-io/substrait-go v0.6.0 + github.com/substrait-io/substrait-go v0.7.0 github.com/tidwall/sjson v1.2.5 ) diff --git a/go/go.sum b/go/go.sum index bd761e15894..6f22e11aef0 100644 --- a/go/go.sum +++ b/go/go.sum @@ -99,8 +99,8 @@ github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/substrait-io/substrait-go v0.6.0 h1:n2G/SGmrn7U5Q39VA8WeM2UfVL5Y/6HX8WAP9uJLNk4= -github.com/substrait-io/substrait-go v0.6.0/go.mod h1:cl8Wsc7aBPDfcHp9+OrUqGpjkgrYlhcDsH/lMP6KUZA= +github.com/substrait-io/substrait-go v0.7.0 h1:53yi73t4wW383+RD1YuhXhbjhP1KzF9GCxPC7SsRlqc= +github.com/substrait-io/substrait-go v0.7.0/go.mod h1:7mjSvIaxk94bOF+YZn/vBOpHK4DWTpBv7nC/btjXCmc= github.com/tidwall/gjson v1.14.2 h1:6BBkirS0rAHjumnjHF6qgy5d2YAJ1TLIaFE2lzfOLqo= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= From a49493d96bc3021af1a126ce33f859bfb7a2ec80 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 27 Aug 2024 11:44:19 +0900 Subject: [PATCH 076/157] MINOR: [Java] Downgrade gRPC to 1.65 (#43839) ### Rationale for this change Newer versions don't run in all CI pipelines due to protoc using a newer glibc. ### What changes are included in this PR? This reverts commit 4af1e491df7ac22217656668b65c3e8d55f5b5ab. ### Are these changes tested? N/A ### Are there any user-facing changes? No Authored-by: David Li Signed-off-by: David Li --- java/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/pom.xml b/java/pom.xml index 77feed12f3f..f78d02c0c65 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -97,7 +97,7 @@ under the License. 2.0.16 33.2.1-jre 4.1.112.Final - 1.66.0 + 1.65.0 3.25.4 2.17.2 3.4.0 From 23fe1ce3361b9a6825fea77deb20d0bd7f247fe2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 11:56:45 +0900 Subject: [PATCH 077/157] MINOR: [Java] Bump org.apache.commons:commons-compress from 1.27.0 to 1.27.1 in /java (#43826) Bumps org.apache.commons:commons-compress from 1.27.0 to 1.27.1. [![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=org.apache.commons:commons-compress&package-manager=maven&previous-version=1.27.0&new-version=1.27.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/compression/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/compression/pom.xml b/java/compression/pom.xml index a1f2bc861da..46ed8796423 100644 --- a/java/compression/pom.xml +++ b/java/compression/pom.xml @@ -50,7 +50,7 @@ under the License. org.apache.commons commons-compress - 1.27.0 + 1.27.1 com.github.luben From fa5d158282b316819e4e23e0903b696467a61d38 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 26 Aug 2024 21:01:45 -0700 Subject: [PATCH 078/157] MINOR: [C#] Bump Microsoft.NET.Test.Sdk from 17.10.0 to 17.11.0 in /csharp (#43822) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [Microsoft.NET.Test.Sdk](https://github.com/microsoft/vstest) from 17.10.0 to 17.11.0.
Release notes

Sourced from Microsoft.NET.Test.Sdk's releases.

v17.11.0

What's Changed

New Contributors

Full Changelog: https://github.com/microsoft/vstest/compare/v17.10.0...v17.11.0-release-24352-06

v17.11.0-release-24373-02

What's Changed

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=Microsoft.NET.Test.Sdk&package-manager=nuget&previous-version=17.10.0&new-version=17.11.0)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@ dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@ dependabot rebase` will rebase this PR - `@ dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@ dependabot merge` will merge this PR after your CI passes on it - `@ dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@ dependabot cancel merge` will cancel a previously requested merge and block automerging - `@ dependabot reopen` will reopen this PR if it is closed - `@ dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@ dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@ dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@ dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: Curt Hagenlocher --- .../Apache.Arrow.Compression.Tests.csproj | 2 +- .../Apache.Arrow.Flight.Sql.Tests.csproj | 2 +- .../Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj | 2 +- csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csharp/test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj b/csharp/test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj index 047cdb94b96..4ea02e0ed21 100644 --- a/csharp/test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj +++ b/csharp/test/Apache.Arrow.Compression.Tests/Apache.Arrow.Compression.Tests.csproj @@ -7,7 +7,7 @@ - + diff --git a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj index dc95f9edf9f..fd8274230ec 100644 --- a/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Sql.Tests/Apache.Arrow.Flight.Sql.Tests.csproj @@ -6,7 +6,7 @@ - + diff --git a/csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj b/csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj index e68a97670cc..eae9ab746f2 100644 --- a/csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj +++ b/csharp/test/Apache.Arrow.Flight.Tests/Apache.Arrow.Flight.Tests.csproj @@ -6,7 +6,7 @@ - + diff --git a/csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj b/csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj index f0533831306..ee71b203218 100644 --- a/csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj +++ b/csharp/test/Apache.Arrow.Tests/Apache.Arrow.Tests.csproj @@ -16,7 +16,7 @@ - + all From c30bb6a84536d66bc1179e2a051915d5c34b2616 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Tue, 27 Aug 2024 14:49:45 +0900 Subject: [PATCH 079/157] GH-41056: [GLib][FlightRPC] Add gaflight_client_do_put() and related APIs (#43813) ### Rationale for this change DoPut is needed to upload data. ### What changes are included in this PR? * Add `gaflight_client_do_put()` * Add `GAFlightStreamWriter` * Add `GAFlightMetadataReader` * Add `GAFlightDoPutResult` * Fix `GAFlightRecordBatchWriter` API ### Are these changes tested? No. They aren't tested yet. We will add tests when we implement server side DoPut. ### Are there any user-facing changes? Yes. * GitHub Issue: #41056 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- c_glib/arrow-flight-glib/client.cpp | 337 +++++++++++++++++++++++++++- c_glib/arrow-flight-glib/client.h | 46 ++++ c_glib/arrow-flight-glib/client.hpp | 16 ++ c_glib/arrow-flight-glib/common.cpp | 102 ++------- c_glib/arrow-flight-glib/common.h | 8 +- c_glib/arrow-glib/writer.hpp | 4 + 6 files changed, 421 insertions(+), 92 deletions(-) diff --git a/c_glib/arrow-flight-glib/client.cpp b/c_glib/arrow-flight-glib/client.cpp index 80c47e336f8..23f59c9da69 100644 --- a/c_glib/arrow-flight-glib/client.cpp +++ b/c_glib/arrow-flight-glib/client.cpp @@ -33,10 +33,19 @@ G_BEGIN_DECLS * #GAFlightStreamReader is a class for reading record batches from a * server. * + * #GAFlightStreamWriter is a class for writing record batches to a + * server. + * + * #GAFlightMetadataReader is a class for reading metadata from a + * server. + * * #GAFlightCallOptions is a class for options of each call. * * #GAFlightClientOptions is a class for options of each client. * + * #GAFlightDoPutResult is a class that has gaflight_client_do_put() + * result. + * * #GAFlightClient is a class for Apache Arrow Flight client. * * Since: 5.0.0 @@ -56,6 +65,128 @@ gaflight_stream_reader_class_init(GAFlightStreamReaderClass *klass) { } +G_DEFINE_TYPE(GAFlightStreamWriter, + gaflight_stream_writer, + GAFLIGHT_TYPE_RECORD_BATCH_WRITER) + +static void +gaflight_stream_writer_init(GAFlightStreamWriter *object) +{ +} + +static void +gaflight_stream_writer_class_init(GAFlightStreamWriterClass *klass) +{ +} + +/** + * gaflight_stream_writer_done_writing: + * @writer: A #GAFlightStreamWriter. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: %TRUE on success, %FALSE on error. + * + * Since: 18.0.0 + */ +gboolean +gaflight_stream_writer_done_writing(GAFlightStreamWriter *writer, GError **error) +{ + auto flight_writer = std::static_pointer_cast( + garrow_record_batch_writer_get_raw(GARROW_RECORD_BATCH_WRITER(writer))); + return garrow::check(error, + flight_writer->DoneWriting(), + "[flight-stream-writer][done-writing]"); +} + +struct GAFlightMetadataReaderPrivate +{ + arrow::flight::FlightMetadataReader *reader; +}; + +enum { + PROP_METADATA_READER_READER = 1, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GAFlightMetadataReader, + gaflight_metadata_reader, + G_TYPE_OBJECT) + +#define GAFLIGHT_METADATA_READER_GET_PRIVATE(object) \ + static_cast( \ + gaflight_metadata_reader_get_instance_private(GAFLIGHT_METADATA_READER(object))) + +static void +gaflight_metadata_reader_finalize(GObject *object) +{ + auto priv = GAFLIGHT_METADATA_READER_GET_PRIVATE(object); + delete priv->reader; + G_OBJECT_CLASS(gaflight_metadata_reader_parent_class)->finalize(object); +} + +static void +gaflight_metadata_reader_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GAFLIGHT_METADATA_READER_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_METADATA_READER_READER: + priv->reader = + static_cast(g_value_get_pointer(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gaflight_metadata_reader_init(GAFlightMetadataReader *object) +{ +} + +static void +gaflight_metadata_reader_class_init(GAFlightMetadataReaderClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->finalize = gaflight_metadata_reader_finalize; + gobject_class->set_property = gaflight_metadata_reader_set_property; + + GParamSpec *spec; + spec = g_param_spec_pointer( + "reader", + nullptr, + nullptr, + static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_METADATA_READER_READER, spec); +} + +/** + * gaflight_metadata_reader_read: + * @reader: A #GAFlightMetadataReader. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: (transfer full): The metadata on success, %NULL on error. + * + * Since: 18.0.0 + */ +GArrowBuffer * +gaflight_metadata_reader_read(GAFlightMetadataReader *reader, GError **error) +{ + auto flight_reader = gaflight_metadata_reader_get_raw(reader); + std::shared_ptr metadata; + if (garrow::check(error, + flight_reader->ReadMetadata(&metadata), + "[flight-metadata-reader][read]")) { + return garrow_buffer_new_raw(&metadata); + } else { + return nullptr; + } +} + typedef struct GAFlightCallOptionsPrivate_ { arrow::flight::FlightCallOptions options; @@ -385,6 +516,137 @@ gaflight_client_options_new(void) g_object_new(GAFLIGHT_TYPE_CLIENT_OPTIONS, NULL)); } +struct GAFlightDoPutResultPrivate +{ + GAFlightStreamWriter *writer; + GAFlightMetadataReader *reader; +}; + +enum { + PROP_DO_PUT_RESULT_RESULT = 1, + PROP_DO_PUT_RESULT_WRITER, + PROP_DO_PUT_RESULT_READER, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GAFlightDoPutResult, gaflight_do_put_result, G_TYPE_OBJECT) + +#define GAFLIGHT_DO_PUT_RESULT_GET_PRIVATE(object) \ + static_cast( \ + gaflight_do_put_result_get_instance_private(GAFLIGHT_DO_PUT_RESULT(object))) + +static void +gaflight_do_put_result_dispose(GObject *object) +{ + auto priv = GAFLIGHT_DO_PUT_RESULT_GET_PRIVATE(object); + + if (priv->writer) { + g_object_unref(priv->writer); + priv->writer = nullptr; + } + + if (priv->reader) { + g_object_unref(priv->reader); + priv->reader = nullptr; + } + + G_OBJECT_CLASS(gaflight_do_put_result_parent_class)->dispose(object); +} + +static void +gaflight_do_put_result_init(GAFlightDoPutResult *object) +{ +} + +static void +gaflight_do_put_result_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GAFLIGHT_DO_PUT_RESULT_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_DO_PUT_RESULT_RESULT: + { + auto result = static_cast( + g_value_get_pointer(value)); + priv->writer = gaflight_stream_writer_new_raw(result->writer.release()); + priv->reader = gaflight_metadata_reader_new_raw(result->reader.release()); + break; + } + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gaflight_do_put_result_get_property(GObject *object, + guint prop_id, + GValue *value, + GParamSpec *pspec) +{ + auto priv = GAFLIGHT_DO_PUT_RESULT_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_DO_PUT_RESULT_WRITER: + g_value_set_object(value, priv->writer); + break; + case PROP_DO_PUT_RESULT_READER: + g_value_set_object(value, priv->reader); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} + +static void +gaflight_do_put_result_class_init(GAFlightDoPutResultClass *klass) +{ + auto gobject_class = G_OBJECT_CLASS(klass); + + gobject_class->dispose = gaflight_do_put_result_dispose; + gobject_class->set_property = gaflight_do_put_result_set_property; + gobject_class->get_property = gaflight_do_put_result_get_property; + + GParamSpec *spec; + spec = g_param_spec_pointer( + "result", + nullptr, + nullptr, + static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_DO_PUT_RESULT_RESULT, spec); + + /** + * GAFlightDoPutResult:writer: + * + * A writer to write record batches to. + * + * Since: 18.0.0 + */ + spec = g_param_spec_object("writer", + nullptr, + nullptr, + GAFLIGHT_TYPE_STREAM_WRITER, + static_cast(G_PARAM_READABLE)); + g_object_class_install_property(gobject_class, PROP_DO_PUT_RESULT_WRITER, spec); + + /** + * GAFlightDoPutResult:reader: + * + * A reader for application metadata from the server. + * + * Since: 18.0.0 + */ + spec = g_param_spec_object("reader", + nullptr, + nullptr, + GAFLIGHT_TYPE_METADATA_READER, + static_cast(G_PARAM_READABLE)); + g_object_class_install_property(gobject_class, PROP_DO_PUT_RESULT_READER, spec); +} + struct GAFlightClientPrivate { std::shared_ptr client; @@ -661,6 +923,51 @@ gaflight_client_do_get(GAFlightClient *client, return gaflight_stream_reader_new_raw(flight_reader.release(), TRUE); } +/** + * gaflight_client_do_put: + * @client: A #GAFlightClient. + * @descriptor: A #GAFlightDescriptor. + * @schema: A #GArrowSchema. + * @options: (nullable): A #GAFlightCallOptions. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Upload data to a Flight described by the given descriptor. The + * caller must call garrow_record_batch_writer_close() on the + * returned stream once they are done writing. + * + * The reader and writer are linked; closing the writer will also + * close the reader. Use garrow_flight_stream_writer_done_writing() to + * only close the write side of the channel. + * + * Returns: (nullable) (transfer full): + * The #GAFlighDoPutResult holding a reader and a writer on success, + * %NULL on error. + * + * Since: 18.0.0 + */ +GAFlightDoPutResult * +gaflight_client_do_put(GAFlightClient *client, + GAFlightDescriptor *descriptor, + GArrowSchema *schema, + GAFlightCallOptions *options, + GError **error) +{ + auto flight_client = gaflight_client_get_raw(client); + auto flight_descriptor = gaflight_descriptor_get_raw(descriptor); + auto arrow_schema = garrow_schema_get_raw(schema); + arrow::flight::FlightCallOptions flight_default_options; + auto flight_options = &flight_default_options; + if (options) { + flight_options = gaflight_call_options_get_raw(options); + } + auto result = flight_client->DoPut(*flight_options, *flight_descriptor, arrow_schema); + if (!garrow::check(error, result, "[flight-client][do-put]")) { + return nullptr; + } + auto flight_result = std::move(*result); + return gaflight_do_put_result_new_raw(&flight_result); +} + G_END_DECLS GAFlightStreamReader * @@ -672,7 +979,28 @@ gaflight_stream_reader_new_raw(arrow::flight::FlightStreamReader *flight_reader, flight_reader, "is-owner", is_owner, - NULL)); + nullptr)); +} + +GAFlightStreamWriter * +gaflight_stream_writer_new_raw(arrow::flight::FlightStreamWriter *flight_writer) +{ + return GAFLIGHT_STREAM_WRITER( + g_object_new(GAFLIGHT_TYPE_STREAM_WRITER, "writer", flight_writer, nullptr)); +} + +GAFlightMetadataReader * +gaflight_metadata_reader_new_raw(arrow::flight::FlightMetadataReader *flight_reader) +{ + return GAFLIGHT_METADATA_READER( + g_object_new(GAFLIGHT_TYPE_METADATA_READER, "reader", flight_reader, nullptr)); +} + +arrow::flight::FlightMetadataReader * +gaflight_metadata_reader_get_raw(GAFlightMetadataReader *reader) +{ + auto priv = GAFLIGHT_METADATA_READER_GET_PRIVATE(reader); + return priv->reader; } arrow::flight::FlightCallOptions * @@ -689,6 +1017,13 @@ gaflight_client_options_get_raw(GAFlightClientOptions *options) return &(priv->options); } +GAFlightDoPutResult * +gaflight_do_put_result_new_raw(arrow::flight::FlightClient::DoPutResult *flight_result) +{ + return GAFLIGHT_DO_PUT_RESULT( + g_object_new(GAFLIGHT_TYPE_DO_PUT_RESULT, "result", flight_result, nullptr)); +} + std::shared_ptr gaflight_client_get_raw(GAFlightClient *client) { diff --git a/c_glib/arrow-flight-glib/client.h b/c_glib/arrow-flight-glib/client.h index a91bbe55e3c..12c5a06b810 100644 --- a/c_glib/arrow-flight-glib/client.h +++ b/c_glib/arrow-flight-glib/client.h @@ -35,6 +35,35 @@ struct _GAFlightStreamReaderClass GAFlightRecordBatchReaderClass parent_class; }; +#define GAFLIGHT_TYPE_STREAM_WRITER (gaflight_stream_writer_get_type()) +GAFLIGHT_AVAILABLE_IN_18_0 +G_DECLARE_DERIVABLE_TYPE(GAFlightStreamWriter, + gaflight_stream_writer, + GAFLIGHT, + STREAM_WRITER, + GAFlightRecordBatchWriter) +struct _GAFlightStreamWriterClass +{ + GAFlightRecordBatchWriterClass parent_class; +}; + +GAFLIGHT_AVAILABLE_IN_18_0 +gboolean +gaflight_stream_writer_done_writing(GAFlightStreamWriter *writer, GError **error); + +#define GAFLIGHT_TYPE_METADATA_READER (gaflight_metadata_reader_get_type()) +GAFLIGHT_AVAILABLE_IN_18_0 +G_DECLARE_DERIVABLE_TYPE( + GAFlightMetadataReader, gaflight_metadata_reader, GAFLIGHT, METADATA_READER, GObject) +struct _GAFlightMetadataReaderClass +{ + GObjectClass parent_class; +}; + +GAFLIGHT_AVAILABLE_IN_18_0 +GArrowBuffer * +gaflight_metadata_reader_read(GAFlightMetadataReader *reader, GError **error); + #define GAFLIGHT_TYPE_CALL_OPTIONS (gaflight_call_options_get_type()) GAFLIGHT_AVAILABLE_IN_5_0 G_DECLARE_DERIVABLE_TYPE( @@ -75,6 +104,15 @@ GAFLIGHT_AVAILABLE_IN_5_0 GAFlightClientOptions * gaflight_client_options_new(void); +#define GAFLIGHT_TYPE_DO_PUT_RESULT (gaflight_do_put_result_get_type()) +GAFLIGHT_AVAILABLE_IN_18_0 +G_DECLARE_DERIVABLE_TYPE( + GAFlightDoPutResult, gaflight_do_put_result, GAFLIGHT, DO_PUT_RESULT, GObject) +struct _GAFlightDoPutResultClass +{ + GObjectClass parent_class; +}; + #define GAFLIGHT_TYPE_CLIENT (gaflight_client_get_type()) GAFLIGHT_AVAILABLE_IN_5_0 G_DECLARE_DERIVABLE_TYPE(GAFlightClient, gaflight_client, GAFLIGHT, CLIENT, GObject) @@ -124,4 +162,12 @@ gaflight_client_do_get(GAFlightClient *client, GAFlightCallOptions *options, GError **error); +GAFLIGHT_AVAILABLE_IN_18_0 +GAFlightDoPutResult * +gaflight_client_do_put(GAFlightClient *client, + GAFlightDescriptor *descriptor, + GArrowSchema *schema, + GAFlightCallOptions *options, + GError **error); + G_END_DECLS diff --git a/c_glib/arrow-flight-glib/client.hpp b/c_glib/arrow-flight-glib/client.hpp index 185a28e6dc4..888f87ecb57 100644 --- a/c_glib/arrow-flight-glib/client.hpp +++ b/c_glib/arrow-flight-glib/client.hpp @@ -28,6 +28,18 @@ GAFlightStreamReader * gaflight_stream_reader_new_raw(arrow::flight::FlightStreamReader *flight_reader, gboolean is_owner); +GAFLIGHT_EXTERN +GAFlightStreamWriter * +gaflight_stream_writer_new_raw(arrow::flight::FlightStreamWriter *flight_writer); + +GAFLIGHT_EXTERN +GAFlightMetadataReader * +gaflight_metadata_reader_new_raw(arrow::flight::FlightMetadataReader *flight_reader); + +GAFLIGHT_EXTERN +arrow::flight::FlightMetadataReader * +gaflight_metadata_reader_get_raw(GAFlightMetadataReader *reader); + GAFLIGHT_EXTERN arrow::flight::FlightCallOptions * gaflight_call_options_get_raw(GAFlightCallOptions *options); @@ -36,6 +48,10 @@ GAFLIGHT_EXTERN arrow::flight::FlightClientOptions * gaflight_client_options_get_raw(GAFlightClientOptions *options); +GAFLIGHT_EXTERN +GAFlightDoPutResult * +gaflight_do_put_result_new_raw(arrow::flight::FlightClient::DoPutResult *flight_result); + GAFLIGHT_EXTERN std::shared_ptr gaflight_client_get_raw(GAFlightClient *client); diff --git a/c_glib/arrow-flight-glib/common.cpp b/c_glib/arrow-flight-glib/common.cpp index f7eea08c264..3deaf67cc14 100644 --- a/c_glib/arrow-flight-glib/common.cpp +++ b/c_glib/arrow-flight-glib/common.cpp @@ -1196,7 +1196,7 @@ gaflight_record_batch_reader_finalize(GObject *object) if (priv->is_owner) { delete priv->reader; } - G_OBJECT_CLASS(gaflight_info_parent_class)->finalize(object); + G_OBJECT_CLASS(gaflight_record_batch_reader_parent_class)->finalize(object); } static void @@ -1300,57 +1300,9 @@ gaflight_record_batch_reader_read_all(GAFlightRecordBatchReader *reader, GError } } -typedef struct GAFlightRecordBatchWriterPrivate_ -{ - arrow::flight::MetadataRecordBatchWriter *writer; - bool is_owner; -} GAFlightRecordBatchWriterPrivate; - -enum { - PROP_RECORD_BATCH_WRITER_WRITER = 1, - PROP_RECORD_BATCH_WRITER_IS_OWNER, -}; - -G_DEFINE_ABSTRACT_TYPE_WITH_PRIVATE(GAFlightRecordBatchWriter, - gaflight_record_batch_writer, - GARROW_TYPE_RECORD_BATCH_WRITER) - -#define GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object) \ - static_cast( \ - gaflight_record_batch_writer_get_instance_private( \ - GAFLIGHT_RECORD_BATCH_WRITER(object))) - -static void -gaflight_record_batch_writer_finalize(GObject *object) -{ - auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object); - if (priv->is_owner) { - delete priv->writer; - } - G_OBJECT_CLASS(gaflight_info_parent_class)->finalize(object); -} - -static void -gaflight_record_batch_writer_set_property(GObject *object, - guint prop_id, - const GValue *value, - GParamSpec *pspec) -{ - auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(object); - - switch (prop_id) { - case PROP_RECORD_BATCH_WRITER_WRITER: - priv->writer = - static_cast(g_value_get_pointer(value)); - break; - case PROP_RECORD_BATCH_WRITER_IS_OWNER: - priv->is_owner = g_value_get_boolean(value); - break; - default: - G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); - break; - } -} +G_DEFINE_ABSTRACT_TYPE(GAFlightRecordBatchWriter, + gaflight_record_batch_writer, + GARROW_TYPE_RECORD_BATCH_WRITER) static void gaflight_record_batch_writer_init(GAFlightRecordBatchWriter *object) @@ -1360,26 +1312,6 @@ gaflight_record_batch_writer_init(GAFlightRecordBatchWriter *object) static void gaflight_record_batch_writer_class_init(GAFlightRecordBatchWriterClass *klass) { - auto gobject_class = G_OBJECT_CLASS(klass); - - gobject_class->finalize = gaflight_record_batch_writer_finalize; - gobject_class->set_property = gaflight_record_batch_writer_set_property; - - GParamSpec *spec; - spec = g_param_spec_pointer( - "writer", - nullptr, - nullptr, - static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_WRITER_WRITER, spec); - - spec = g_param_spec_boolean( - "is-owner", - nullptr, - nullptr, - TRUE, - static_cast(G_PARAM_WRITABLE | G_PARAM_CONSTRUCT_ONLY)); - g_object_class_install_property(gobject_class, PROP_RECORD_BATCH_WRITER_IS_OWNER, spec); } /** @@ -1402,7 +1334,8 @@ gaflight_record_batch_writer_begin(GAFlightRecordBatchWriter *writer, GArrowWriteOptions *options, GError **error) { - auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto flight_writer = std::static_pointer_cast( + garrow_record_batch_writer_get_raw(GARROW_RECORD_BATCH_WRITER(writer))); auto arrow_schema = garrow_schema_get_raw(schema); arrow::ipc::IpcWriteOptions arrow_write_options; if (options) { @@ -1432,7 +1365,8 @@ gaflight_record_batch_writer_write_metadata(GAFlightRecordBatchWriter *writer, GArrowBuffer *metadata, GError **error) { - auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto flight_writer = std::static_pointer_cast( + garrow_record_batch_writer_get_raw(GARROW_RECORD_BATCH_WRITER(writer))); auto arrow_metadata = garrow_buffer_get_raw(metadata); return garrow::check(error, flight_writer->WriteMetadata(arrow_metadata), @@ -1440,7 +1374,7 @@ gaflight_record_batch_writer_write_metadata(GAFlightRecordBatchWriter *writer, } /** - * gaflight_record_batch_writer_write: + * gaflight_record_batch_writer_write_record_batch: * @writer: A #GAFlightRecordBatchWriter. * @record_batch: A #GArrowRecordBatch. * @metadata: (nullable): A #GArrowBuffer. @@ -1453,12 +1387,13 @@ gaflight_record_batch_writer_write_metadata(GAFlightRecordBatchWriter *writer, * Since: 18.0.0 */ gboolean -gaflight_record_batch_writer_write(GAFlightRecordBatchWriter *writer, - GArrowRecordBatch *record_batch, - GArrowBuffer *metadata, - GError **error) +gaflight_record_batch_writer_write_record_batch(GAFlightRecordBatchWriter *writer, + GArrowRecordBatch *record_batch, + GArrowBuffer *metadata, + GError **error) { - auto flight_writer = gaflight_record_batch_writer_get_raw(writer); + auto flight_writer = std::static_pointer_cast( + garrow_record_batch_writer_get_raw(GARROW_RECORD_BATCH_WRITER(writer))); auto arrow_record_batch = garrow_record_batch_get_raw(record_batch); auto arrow_metadata = garrow_buffer_get_raw(metadata); return garrow::check( @@ -1599,10 +1534,3 @@ gaflight_record_batch_reader_get_raw(GAFlightRecordBatchReader *reader) auto priv = GAFLIGHT_RECORD_BATCH_READER_GET_PRIVATE(reader); return priv->reader; } - -arrow::flight::MetadataRecordBatchWriter * -gaflight_record_batch_writer_get_raw(GAFlightRecordBatchWriter *writer) -{ - auto priv = GAFLIGHT_RECORD_BATCH_WRITER_GET_PRIVATE(writer); - return priv->writer; -} diff --git a/c_glib/arrow-flight-glib/common.h b/c_glib/arrow-flight-glib/common.h index 91c828caabb..726132fe492 100644 --- a/c_glib/arrow-flight-glib/common.h +++ b/c_glib/arrow-flight-glib/common.h @@ -259,9 +259,9 @@ gaflight_record_batch_writer_write_metadata(GAFlightRecordBatchWriter *writer, GAFLIGHT_AVAILABLE_IN_18_0 gboolean -gaflight_record_batch_writer_write(GAFlightRecordBatchWriter *writer, - GArrowRecordBatch *record_batch, - GArrowBuffer *metadata, - GError **error); +gaflight_record_batch_writer_write_record_batch(GAFlightRecordBatchWriter *writer, + GArrowRecordBatch *record_batch, + GArrowBuffer *metadata, + GError **error); G_END_DECLS diff --git a/c_glib/arrow-glib/writer.hpp b/c_glib/arrow-glib/writer.hpp index aa87ffe77d7..1d85ac52f88 100644 --- a/c_glib/arrow-glib/writer.hpp +++ b/c_glib/arrow-glib/writer.hpp @@ -25,16 +25,20 @@ #include +GARROW_AVAILABLE_IN_ALL GArrowRecordBatchWriter * garrow_record_batch_writer_new_raw( std::shared_ptr *arrow_writer); +GARROW_AVAILABLE_IN_ALL std::shared_ptr garrow_record_batch_writer_get_raw(GArrowRecordBatchWriter *writer); +GARROW_AVAILABLE_IN_ALL GArrowRecordBatchStreamWriter * garrow_record_batch_stream_writer_new_raw( std::shared_ptr *arrow_writer); +GARROW_AVAILABLE_IN_ALL GArrowRecordBatchFileWriter * garrow_record_batch_file_writer_new_raw( std::shared_ptr *arrow_writer); From b83666234c05d34c23993708160033c259b9ec26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cumplido?= Date: Tue, 27 Aug 2024 10:30:23 +0200 Subject: [PATCH 080/157] GH-43815: [CI][Packaging][Python] Avoid uploading wheel to gemfury if version already exists (#43816) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What changes are included in this PR? Check whether version exists on gemfury before trying upload ### Are these changes tested? Will be tested via archery ### Are there any user-facing changes? No * GitHub Issue: #43815 Lead-authored-by: Raúl Cumplido Co-authored-by: Sutou Kouhei Signed-off-by: Raúl Cumplido --- dev/tasks/macros.jinja | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/dev/tasks/macros.jinja b/dev/tasks/macros.jinja index 6423ca0e9ef..df55f32222e 100644 --- a/dev/tasks/macros.jinja +++ b/dev/tasks/macros.jinja @@ -169,10 +169,14 @@ env: - name: Upload package to Gemfury shell: bash run: | - fury push \ - --api-token=${CROSSBOW_GEMFURY_TOKEN} \ - --as=${CROSSBOW_GEMFURY_ORG} \ - {{ pattern }} + if $(fury versions --as=${CROSSBOW_GEMFURY_ORG} --api-token=${CROSSBOW_GEMFURY_TOKEN} pyarrow | grep --fixed-strings -q "{{ arrow.no_rc_version }}"); then + echo "Version {{ arrow.no_rc_version }} already exists. Avoid pushing version." + else + fury push \ + --api-token=${CROSSBOW_GEMFURY_TOKEN} \ + --as=${CROSSBOW_GEMFURY_ORG} \ + {{ pattern }} + fi env: CROSSBOW_GEMFURY_TOKEN: {{ '${{ secrets.CROSSBOW_GEMFURY_TOKEN }}' }} CROSSBOW_GEMFURY_ORG: {{ '${{ secrets.CROSSBOW_GEMFURY_ORG }}' }} From 6502f0e3ad046d361aba44385ab3379ed7af5b7f Mon Sep 17 00:00:00 2001 From: Joel Lubinitsky <33523178+joellubi@users.noreply.github.com> Date: Tue, 27 Aug 2024 13:17:39 -0400 Subject: [PATCH 081/157] GH-43790: [Go][Parquet] Add support for LZ4_RAW compression codec (#43835) ### Rationale for this change Fixes: #43790 The LZ4 compression codec for Parquet is no longer ambiguous, as it has been superceded by the [LZ4_RAW](https://github.com/apache/parquet-format/blob/master/Compression.md#lz4_raw) spec. ### What changes are included in this PR? - Add `LZ4Raw` compression codec - Split out `StreamingCodec` methods from core `Codec` interface - Various conformance/roundtrip tests - Set of benchmarks for reading/writing an Arrow table to/from Parquet, using each compression codec ### Are these changes tested? Yes ### Are there any user-facing changes? - New codec `LZ4Raw` is available - `Codec` interface no long provides the following methods, which are now part of `StreamingCodec`: - `NewReader` - `NewWriter` - `NewWriterLevel` * GitHub Issue: #43790 Authored-by: Joel Lubinitsky Signed-off-by: Joel Lubinitsky --- go/parquet/compress/compress.go | 22 ++-- go/parquet/compress/compress_test.go | 8 +- go/parquet/compress/lz4_raw.go | 66 ++++++++++++ go/parquet/file/file_reader_test.go | 127 +++++++++++++++++++++++ go/parquet/file/file_writer_test.go | 58 ++++++++++- go/parquet/pqarrow/reader_writer_test.go | 111 ++++++++++++++++++++ 6 files changed, 380 insertions(+), 12 deletions(-) create mode 100644 go/parquet/compress/lz4_raw.go diff --git a/go/parquet/compress/compress.go b/go/parquet/compress/compress.go index b6a1349133e..92f2ae99bb1 100644 --- a/go/parquet/compress/compress.go +++ b/go/parquet/compress/compress.go @@ -49,8 +49,9 @@ var Codecs = struct { Brotli Compression // LZ4 unsupported in this library due to problematic issues between the Hadoop LZ4 spec vs regular lz4 // see: http://mail-archives.apache.org/mod_mbox/arrow-dev/202007.mbox/%3CCAAri41v24xuA8MGHLDvgSnE+7AAgOhiEukemW_oPNHMvfMmrWw@mail.gmail.com%3E - Lz4 Compression - Zstd Compression + Lz4 Compression + Zstd Compression + Lz4Raw Compression }{ Uncompressed: Compression(parquet.CompressionCodec_UNCOMPRESSED), Snappy: Compression(parquet.CompressionCodec_SNAPPY), @@ -59,17 +60,12 @@ var Codecs = struct { Brotli: Compression(parquet.CompressionCodec_BROTLI), Lz4: Compression(parquet.CompressionCodec_LZ4), Zstd: Compression(parquet.CompressionCodec_ZSTD), + Lz4Raw: Compression(parquet.CompressionCodec_LZ4_RAW), } // Codec is an interface which is implemented for each compression type in order to make the interactions easy to // implement. Most consumers won't be calling GetCodec directly. type Codec interface { - // NewReader provides a reader that wraps a stream with compressed data to stream the uncompressed data - NewReader(io.Reader) io.ReadCloser - // NewWriter provides a wrapper around a write stream to compress data before writing it. - NewWriter(io.Writer) io.WriteCloser - // NewWriterLevel is like NewWriter but allows specifying the compression level - NewWriterLevel(io.Writer, int) (io.WriteCloser, error) // Encode encodes a block of data given by src and returns the compressed block. dst should be either nil // or sized large enough to fit the compressed block (use CompressBound to allocate). dst and src should not // overlap since some of the compression types don't allow it. @@ -90,6 +86,16 @@ type Codec interface { Decode(dst, src []byte) []byte } +// StreamingCodec is an interface that may be implemented for compression codecs that expose a streaming API. +type StreamingCodec interface { + // NewReader provides a reader that wraps a stream with compressed data to stream the uncompressed data + NewReader(io.Reader) io.ReadCloser + // NewWriter provides a wrapper around a write stream to compress data before writing it. + NewWriter(io.Writer) io.WriteCloser + // NewWriterLevel is like NewWriter but allows specifying the compression level + NewWriterLevel(io.Writer, int) (io.WriteCloser, error) +} + var codecs = map[Compression]Codec{} // RegisterCodec adds or overrides a codec implementation for a given compression algorithm. diff --git a/go/parquet/compress/compress_test.go b/go/parquet/compress/compress_test.go index 843062c0d02..5aac74759e1 100644 --- a/go/parquet/compress/compress_test.go +++ b/go/parquet/compress/compress_test.go @@ -66,8 +66,8 @@ func TestCompressDataOneShot(t *testing.T) { {compress.Codecs.Gzip}, {compress.Codecs.Brotli}, {compress.Codecs.Zstd}, + {compress.Codecs.Lz4Raw}, // {compress.Codecs.Lzo}, - // {compress.Codecs.Lz4}, } for _, tt := range tests { @@ -107,9 +107,11 @@ func TestCompressReaderWriter(t *testing.T) { var buf bytes.Buffer codec, err := compress.GetCodec(tt.c) assert.NoError(t, err) + streamingCodec, ok := codec.(compress.StreamingCodec) + assert.True(t, ok) data := makeRandomData(RandomDataSize) - wr := codec.NewWriter(&buf) + wr := streamingCodec.NewWriter(&buf) const chunkSize = 1111 input := data @@ -129,7 +131,7 @@ func TestCompressReaderWriter(t *testing.T) { } wr.Close() - rdr := codec.NewReader(&buf) + rdr := streamingCodec.NewReader(&buf) out, err := io.ReadAll(rdr) assert.NoError(t, err) assert.Exactly(t, data, out) diff --git a/go/parquet/compress/lz4_raw.go b/go/parquet/compress/lz4_raw.go new file mode 100644 index 00000000000..788d9520a66 --- /dev/null +++ b/go/parquet/compress/lz4_raw.go @@ -0,0 +1,66 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package compress + +import ( + "sync" + + "github.com/pierrec/lz4/v4" +) + +// lz4.Compressor is not goroutine-safe, so we use a pool to amortize the cost +// of allocating a new one for each call to Encode(). +var compressorPool = sync.Pool{New: func() interface{} { return new(lz4.Compressor) }} + +func compressBlock(src, dst []byte) (int, error) { + c := compressorPool.Get().(*lz4.Compressor) + defer compressorPool.Put(c) + return c.CompressBlock(src, dst) +} + +type lz4RawCodec struct{} + +func (c lz4RawCodec) Encode(dst, src []byte) []byte { + n, err := compressBlock(src, dst[:cap(dst)]) + if err != nil { + panic(err) + } + + return dst[:n] +} + +func (c lz4RawCodec) EncodeLevel(dst, src []byte, _ int) []byte { + // the lz4 block implementation does not allow level to be set + return c.Encode(dst, src) +} + +func (lz4RawCodec) Decode(dst, src []byte) []byte { + n, err := lz4.UncompressBlock(src, dst) + if err != nil { + panic(err) + } + + return dst[:n] +} + +func (c lz4RawCodec) CompressBound(len int64) int64 { + return int64(lz4.CompressBlockBound(int(len))) +} + +func init() { + RegisterCodec(Codecs.Lz4Raw, lz4RawCodec{}) +} diff --git a/go/parquet/file/file_reader_test.go b/go/parquet/file/file_reader_test.go index 547ec475c27..35f4da4e866 100644 --- a/go/parquet/file/file_reader_test.go +++ b/go/parquet/file/file_reader_test.go @@ -644,3 +644,130 @@ func TestDeltaBinaryPackedMultipleBatches(t *testing.T) { require.Equalf(t, size, totalRows, "Expected %d rows, but got %d rows", size, totalRows) } + +// Test read file lz4_raw_compressed.parquet +// Contents documented at https://github.com/apache/parquet-testing/commit/ddd898958803cb89b7156c6350584d1cda0fe8de +func TestLZ4RawFileRead(t *testing.T) { + dir := os.Getenv("PARQUET_TEST_DATA") + if dir == "" { + t.Skip("no path supplied with PARQUET_TEST_DATA") + } + require.DirExists(t, dir) + + props := parquet.NewReaderProperties(memory.DefaultAllocator) + fileReader, err := file.OpenParquetFile(path.Join(dir, "lz4_raw_compressed.parquet"), + false, file.WithReadProps(props)) + require.NoError(t, err) + defer fileReader.Close() + + nRows := 4 + nCols := 3 + require.Equal(t, 1, fileReader.NumRowGroups()) + rgr := fileReader.RowGroup(0) + require.EqualValues(t, nRows, rgr.NumRows()) + require.EqualValues(t, nCols, rgr.NumColumns()) + + rdr, err := rgr.Column(0) + require.NoError(t, err) + + rowsInt64, ok := rdr.(*file.Int64ColumnChunkReader) + require.True(t, ok) + + valsInt64 := make([]int64, nRows) + total, read, err := rowsInt64.ReadBatch(int64(nRows), valsInt64, nil, nil) + require.NoError(t, err) + require.Equal(t, int64(nRows), total) + require.Equal(t, nRows, read) + + expectedValsInt64 := []int64{ + 1593604800, + 1593604800, + 1593604801, + 1593604801, + } + require.Equal(t, expectedValsInt64, valsInt64) + + rdr, err = rgr.Column(1) + require.NoError(t, err) + + rowsByteArray, ok := rdr.(*file.ByteArrayColumnChunkReader) + require.True(t, ok) + + valsByteArray := make([]parquet.ByteArray, nRows) + total, read, err = rowsByteArray.ReadBatch(int64(nRows), valsByteArray, nil, nil) + require.NoError(t, err) + require.Equal(t, int64(nRows), total) + require.Equal(t, nRows, read) + + expectedValsByteArray := []parquet.ByteArray{ + []byte("abc"), + []byte("def"), + []byte("abc"), + []byte("def"), + } + require.Equal(t, expectedValsByteArray, valsByteArray) + + rdr, err = rgr.Column(2) + require.NoError(t, err) + + rowsFloat64, ok := rdr.(*file.Float64ColumnChunkReader) + require.True(t, ok) + + valsFloat64 := make([]float64, nRows) + total, read, err = rowsFloat64.ReadBatch(int64(nRows), valsFloat64, nil, nil) + require.NoError(t, err) + require.Equal(t, int64(nRows), total) + require.Equal(t, nRows, read) + + expectedValsFloat64 := []float64{ + 42.0, + 7.7, + 42.125, + 7.7, + } + require.Equal(t, expectedValsFloat64, valsFloat64) +} + +// Test read file lz4_raw_compressed_larger.parquet +// Contents documented at https://github.com/apache/parquet-testing/commit/ddd898958803cb89b7156c6350584d1cda0fe8de +func TestLZ4RawLargerFileRead(t *testing.T) { + dir := os.Getenv("PARQUET_TEST_DATA") + if dir == "" { + t.Skip("no path supplied with PARQUET_TEST_DATA") + } + require.DirExists(t, dir) + + props := parquet.NewReaderProperties(memory.DefaultAllocator) + fileReader, err := file.OpenParquetFile(path.Join(dir, "lz4_raw_compressed_larger.parquet"), + false, file.WithReadProps(props)) + require.NoError(t, err) + defer fileReader.Close() + + nRows := 10000 + nCols := 1 + require.Equal(t, 1, fileReader.NumRowGroups()) + rgr := fileReader.RowGroup(0) + require.EqualValues(t, nRows, rgr.NumRows()) + require.EqualValues(t, nCols, rgr.NumColumns()) + + rdr, err := rgr.Column(0) + require.NoError(t, err) + + rows, ok := rdr.(*file.ByteArrayColumnChunkReader) + require.True(t, ok) + + vals := make([]parquet.ByteArray, nRows) + total, read, err := rows.ReadBatch(int64(nRows), vals, nil, nil) + require.NoError(t, err) + require.Equal(t, int64(nRows), total) + require.Equal(t, nRows, read) + + expectedValsHead := []parquet.ByteArray{ + []byte("c7ce6bef-d5b0-4863-b199-8ea8c7fb117b"), + []byte("e8fb9197-cb9f-4118-b67f-fbfa65f61843"), + []byte("885136e1-0aa1-4fdb-8847-63d87b07c205"), + []byte("ce7b2019-8ebe-4906-a74d-0afa2409e5df"), + []byte("a9ee2527-821b-4b71-a926-03f73c3fc8b7"), + } + require.Equal(t, expectedValsHead, vals[:len(expectedValsHead)]) +} diff --git a/go/parquet/file/file_writer_test.go b/go/parquet/file/file_writer_test.go index 0faf3f7233b..12ac93d1ef4 100644 --- a/go/parquet/file/file_writer_test.go +++ b/go/parquet/file/file_writer_test.go @@ -260,7 +260,7 @@ func (t *SerializeTestSuite) TestSmallFile() { compress.Codecs.Brotli, compress.Codecs.Gzip, compress.Codecs.Zstd, - // compress.Codecs.Lz4, + compress.Codecs.Lz4Raw, // compress.Codecs.Lzo, } for _, c := range codecs { @@ -540,3 +540,59 @@ func TestBatchedByteStreamSplitFileRoundtrip(t *testing.T) { require.NoError(t, rdr.Close()) } + +func TestLZ4RawFileRoundtrip(t *testing.T) { + input := []int64{ + -1, 0, 1, 2, 3, 4, 5, 123456789, -123456789, + } + + size := len(input) + + field, err := schema.NewPrimitiveNodeLogical("int64", parquet.Repetitions.Required, nil, parquet.Types.Int64, 0, 1) + require.NoError(t, err) + + schema, err := schema.NewGroupNode("test", parquet.Repetitions.Required, schema.FieldList{field}, 0) + require.NoError(t, err) + + sink := encoding.NewBufferWriter(0, memory.DefaultAllocator) + writer := file.NewParquetWriter(sink, schema, file.WithWriterProps(parquet.NewWriterProperties(parquet.WithCompression(compress.Codecs.Lz4Raw)))) + + rgw := writer.AppendRowGroup() + cw, err := rgw.NextColumn() + require.NoError(t, err) + + i64ColumnWriter, ok := cw.(*file.Int64ColumnChunkWriter) + require.True(t, ok) + + nVals, err := i64ColumnWriter.WriteBatch(input, nil, nil) + require.NoError(t, err) + require.EqualValues(t, size, nVals) + + require.NoError(t, cw.Close()) + require.NoError(t, rgw.Close()) + require.NoError(t, writer.Close()) + + rdr, err := file.NewParquetReader(bytes.NewReader(sink.Bytes())) + require.NoError(t, err) + + require.Equal(t, 1, rdr.NumRowGroups()) + require.EqualValues(t, size, rdr.NumRows()) + + rgr := rdr.RowGroup(0) + cr, err := rgr.Column(0) + require.NoError(t, err) + + i64ColumnReader, ok := cr.(*file.Int64ColumnChunkReader) + require.True(t, ok) + + output := make([]int64, size) + + total, valuesRead, err := i64ColumnReader.ReadBatch(int64(size), output, nil, nil) + require.NoError(t, err) + require.EqualValues(t, size, total) + require.EqualValues(t, size, valuesRead) + + require.Equal(t, input, output) + + require.NoError(t, rdr.Close()) +} diff --git a/go/parquet/pqarrow/reader_writer_test.go b/go/parquet/pqarrow/reader_writer_test.go index 31bd0eba843..e020c7d9457 100644 --- a/go/parquet/pqarrow/reader_writer_test.go +++ b/go/parquet/pqarrow/reader_writer_test.go @@ -19,6 +19,8 @@ package pqarrow_test import ( "bytes" "context" + "fmt" + "math" "testing" "unsafe" @@ -26,8 +28,10 @@ import ( "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/memory" "github.com/apache/arrow/go/v18/parquet" + "github.com/apache/arrow/go/v18/parquet/compress" "github.com/apache/arrow/go/v18/parquet/file" "github.com/apache/arrow/go/v18/parquet/pqarrow" + "github.com/stretchr/testify/require" "golang.org/x/exp/rand" "gonum.org/v1/gonum/stat/distuv" ) @@ -275,3 +279,110 @@ func BenchmarkReadColumnFloat64(b *testing.B) { benchReadTable(b, tt.name, tbl, int64(arrow.Int32Traits.BytesRequired(SIZELEN))) } } + +var compressTestCases = []struct { + c compress.Compression +}{ + {compress.Codecs.Uncompressed}, + {compress.Codecs.Snappy}, + {compress.Codecs.Gzip}, + {compress.Codecs.Brotli}, + {compress.Codecs.Zstd}, + {compress.Codecs.Lz4Raw}, + // {compress.Codecs.Lzo}, +} + +func buildTableForTest(mem memory.Allocator) arrow.Table { + schema := arrow.NewSchema( + []arrow.Field{ + {Name: "int64s", Type: arrow.PrimitiveTypes.Int64}, + {Name: "strings", Type: arrow.BinaryTypes.String}, + {Name: "bools", Type: arrow.FixedWidthTypes.Boolean}, + {Name: "repeated_int64s", Type: arrow.PrimitiveTypes.Int64}, + {Name: "repeated_strings", Type: arrow.BinaryTypes.String}, + {Name: "repeated_bools", Type: arrow.FixedWidthTypes.Boolean}, + }, + nil, + ) + bldr := array.NewRecordBuilder(mem, schema) + defer bldr.Release() + + for i := 0; i < SIZELEN; i++ { + bldr.Field(0).(*array.Int64Builder).Append(int64(i)) + bldr.Field(1).(*array.StringBuilder).Append(fmt.Sprint(i)) + bldr.Field(2).(*array.BooleanBuilder).Append(i%2 == 0) + bldr.Field(3).(*array.Int64Builder).Append(0) + bldr.Field(4).(*array.StringBuilder).Append("the string is the same") + bldr.Field(5).(*array.BooleanBuilder).Append(true) + } + + rec := bldr.NewRecord() + return array.NewTableFromRecords(schema, []arrow.Record{rec}) +} + +func BenchmarkWriteTableCompressed(b *testing.B) { + mem := memory.DefaultAllocator + table := buildTableForTest(mem) + defer table.Release() + + var uncompressedSize uint64 + for idxCol := 0; int64(idxCol) < table.NumCols(); idxCol++ { + column := table.Column(idxCol) + for _, chunk := range column.Data().Chunks() { + uncompressedSize += chunk.Data().SizeInBytes() + } + } + + var buf bytes.Buffer + buf.Grow(int(uncompressedSize)) + for _, tc := range compressTestCases { + b.Run(fmt.Sprintf("codec=%s", tc.c), func(b *testing.B) { + buf.Reset() + b.ResetTimer() + b.SetBytes(int64(uncompressedSize)) + for n := 0; n < b.N; n++ { + require.NoError(b, + pqarrow.WriteTable( + table, + &buf, + math.MaxInt64, + parquet.NewWriterProperties(parquet.WithAllocator(mem), parquet.WithCompression(tc.c)), + pqarrow.DefaultWriterProps(), + ), + ) + } + }) + } +} + +func BenchmarkReadTableCompressed(b *testing.B) { + ctx := context.Background() + mem := memory.DefaultAllocator + table := buildTableForTest(mem) + defer table.Release() + + for _, tc := range compressTestCases { + b.Run(fmt.Sprintf("codec=%s", tc.c), func(b *testing.B) { + var buf bytes.Buffer + err := pqarrow.WriteTable( + table, + &buf, + math.MaxInt64, + parquet.NewWriterProperties(parquet.WithAllocator(mem), parquet.WithCompression(tc.c)), + pqarrow.DefaultWriterProps(), + ) + require.NoError(b, err) + + compressedBytes := buf.Len() + rdr := bytes.NewReader(buf.Bytes()) + + b.ResetTimer() + b.SetBytes(int64(compressedBytes)) + for n := 0; n < b.N; n++ { + tab, err := pqarrow.ReadTable(ctx, rdr, nil, pqarrow.ArrowReadProperties{}, mem) + require.NoError(b, err) + defer tab.Release() + } + }) + } +} From ce1e724d7ea292746ede6a538519658f1ecab849 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 27 Aug 2024 19:17:55 +0200 Subject: [PATCH 082/157] MINOR: [CI] Use `docker compose` on self-hosted ARM builds (#43844) ### Rationale for this change The Docker client version on the ARM64 self-hosted runners is now recent enough, so we don't need to use `docker-compose` there anymore. Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- .github/workflows/cpp.yml | 5 +---- .github/workflows/go.yml | 5 ----- dev/tasks/java-jars/github.yml | 2 -- dev/tasks/linux-packages/github.linux.yml | 1 - dev/tasks/python-wheels/github.linux.yml | 1 - 5 files changed, 1 insertion(+), 13 deletions(-) diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index a82e1eb7666..c5482f73082 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -99,7 +99,6 @@ jobs: cat <> "$GITHUB_OUTPUT" { "arch": "arm64v8", - "archery-use-legacy-docker-compose": "1", "clang-tools": "10", "image": "ubuntu-cpp", "llvm": "10", @@ -124,9 +123,6 @@ jobs: include: ${{ fromJson(needs.docker-targets.outputs.targets) }} env: ARCH: ${{ matrix.arch }} - # By default, use `docker compose` because docker-compose v1 is obsolete, - # except where the Docker client version is too old. - ARCHERY_USE_LEGACY_DOCKER_COMPOSE: ${{ matrix.archery-use-legacy-docker-compose || '0' }} ARROW_SIMD_LEVEL: ${{ matrix.simd-level }} CLANG_TOOLS: ${{ matrix.clang-tools }} LLVM: ${{ matrix.llvm }} @@ -147,6 +143,7 @@ jobs: run: | sudo apt update sudo apt install -y --no-install-recommends python3 python3-dev python3-pip + python3 -m pip install -U pip - name: Setup Archery run: python3 -m pip install -e dev/archery[docker] - name: Execute Docker Build diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 20c78d86cb2..ffd543691d5 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -78,14 +78,12 @@ jobs: { "arch-label": "ARM64", "arch": "arm64v8", - "archery-use-legacy-docker-compose": "1", "go": "1.21", "runs-on": ["self-hosted", "arm", "linux"] }, { "arch-label": "ARM64", "arch": "arm64v8", - "archery-use-legacy-docker-compose": "1", "go": "1.22", "runs-on": ["self-hosted", "arm", "linux"] } @@ -106,9 +104,6 @@ jobs: include: ${{ fromJson(needs.docker-targets.outputs.targets) }} env: ARCH: ${{ matrix.arch }} - # By default, use Docker CLI because docker-compose v1 is obsolete, - # except where the Docker client version is too old. - ARCHERY_USE_LEGACY_DOCKER_COMPOSE: ${{ matrix.archery-use-legacy-docker-compose || '0' }} GO: ${{ matrix.go }} steps: - name: Checkout Arrow diff --git a/dev/tasks/java-jars/github.yml b/dev/tasks/java-jars/github.yml index 7cbd5f05dab..bdbed1bd678 100644 --- a/dev/tasks/java-jars/github.yml +++ b/dev/tasks/java-jars/github.yml @@ -30,7 +30,6 @@ jobs: ARCH: {{ '${{ matrix.platform.archery_arch }}' }} ARCH_ALIAS: {{ '${{ matrix.platform.archery_arch_alias }}' }} ARCH_SHORT: {{ '${{ matrix.platform.archery_arch_short }}' }} - ARCHERY_USE_LEGACY_DOCKER_COMPOSE: {{ "${{matrix.platform.archery_use_legacy_docker_compose || '0'}}" }} strategy: fail-fast: false matrix: @@ -45,7 +44,6 @@ jobs: archery_arch: "arm64v8" archery_arch_alias: "aarch64" archery_arch_short: "arm64" - archery_use_legacy_docker_compose: "1" steps: {{ macros.github_checkout_arrow()|indent }} {{ macros.github_free_space()|indent }} diff --git a/dev/tasks/linux-packages/github.linux.yml b/dev/tasks/linux-packages/github.linux.yml index 4bf2295ef3e..cce976cd60e 100644 --- a/dev/tasks/linux-packages/github.linux.yml +++ b/dev/tasks/linux-packages/github.linux.yml @@ -29,7 +29,6 @@ jobs: {% endif %} env: ARCHITECTURE: {{ architecture }} - ARCHERY_USE_LEGACY_DOCKER_COMPOSE: {{ '1' if architecture == 'arm64' else '0' }} steps: {{ macros.github_checkout_arrow()|indent }} {{ macros.github_login_dockerhub()|indent }} diff --git a/dev/tasks/python-wheels/github.linux.yml b/dev/tasks/python-wheels/github.linux.yml index 2854d4349fb..97746ba3f9b 100644 --- a/dev/tasks/python-wheels/github.linux.yml +++ b/dev/tasks/python-wheels/github.linux.yml @@ -33,7 +33,6 @@ jobs: ARCH: amd64 {% else %} ARCH: arm64v8 - ARCHERY_USE_LEGACY_DOCKER_COMPOSE: 1 {% endif %} PYTHON: "{{ python_version }}" {% if python_version == "3.13" %} From 75ca5b3631144f58ea3edbe6b4933a686c0e0fd9 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Wed, 28 Aug 2024 05:47:43 +0900 Subject: [PATCH 083/157] GH-43805: [C++] Enable filesystem automatically when one of ARROW_{AZURE,GCS,HDFS,S3}=ON is specified (#43806) ### Rationale for this change `ARROW_{AZURE,GCS,HDFS,S3}=ON` are meaningful only when filesystem is enabled. If the user specified one of them, we can assume that the user wants to enable filesystem. ### What changes are included in this PR? Enable `ARROW_FILESYSTEM` when one of `ARROW_{AZURE,GCS,HDFS,S3}=ON` are specified. ### Are these changes tested? Yes. ### Are there any user-facing changes? Yes. `ARROW_FILESYSTEM` is enabled automatically with one of `ARROW_{AZURE,GCS,HDFS,S3}=ON`. * GitHub Issue: #43805 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- cpp/cmake_modules/DefineOptions.cmake | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 41466a1c224..755887314d1 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -303,7 +303,10 @@ takes precedence over ccache if a storage backend is configured" ON) ARROW_IPC) define_option(ARROW_AZURE - "Build Arrow with Azure support (requires the Azure SDK for C++)" OFF) + "Build Arrow with Azure support (requires the Azure SDK for C++)" + OFF + DEPENDS + ARROW_FILESYSTEM) define_option(ARROW_BUILD_UTILITIES "Build Arrow commandline utilities" OFF) @@ -346,9 +349,16 @@ takes precedence over ccache if a storage backend is configured" ON) ARROW_WITH_UTF8PROC) define_option(ARROW_GCS - "Build Arrow with GCS support (requires the GCloud SDK for C++)" OFF) + "Build Arrow with GCS support (requires the GCloud SDK for C++)" + OFF + DEPENDS + ARROW_FILESYSTEM) - define_option(ARROW_HDFS "Build the Arrow HDFS bridge" OFF) + define_option(ARROW_HDFS + "Build the Arrow HDFS bridge" + OFF + DEPENDS + ARROW_FILESYSTEM) define_option(ARROW_IPC "Build the Arrow IPC extensions" ON) @@ -398,7 +408,11 @@ takes precedence over ccache if a storage backend is configured" ON) ARROW_HDFS ARROW_JSON) - define_option(ARROW_S3 "Build Arrow with S3 support (requires the AWS SDK for C++)" OFF) + define_option(ARROW_S3 + "Build Arrow with S3 support (requires the AWS SDK for C++)" + OFF + DEPENDS + ARROW_FILESYSTEM) define_option(ARROW_SKYHOOK "Build the Skyhook libraries" From 09bb24a5cdf5b6e73334e9a8b521f0188d940c73 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Wed, 28 Aug 2024 06:13:31 +0530 Subject: [PATCH 084/157] MINOR: [Java] Logback dependency upgrade (#43842) ### Rationale for this change Fusing https://github.com/apache/arrow/pull/43752 and https://github.com/apache/arrow/pull/43827 dependabot PRs into a single PR. ### What changes are included in this PR? Keeping a single version for both `logback-classic` and `logback-core`. ### Are these changes tested? N/A ### Are there any user-facing changes? No Authored-by: Vibhatha Lakmal Abeykoon Signed-off-by: David Li --- java/memory/memory-netty/pom.xml | 1 - java/pom.xml | 13 ++++++++++++- java/tools/pom.xml | 1 - 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/java/memory/memory-netty/pom.xml b/java/memory/memory-netty/pom.xml index f2d4d2d0fe3..6cf573dd4d3 100644 --- a/java/memory/memory-netty/pom.xml +++ b/java/memory/memory-netty/pom.xml @@ -56,7 +56,6 @@ under the License. ch.qos.logback logback-core - 1.3.14 test diff --git a/java/pom.xml b/java/pom.xml index f78d02c0c65..577f23e6a71 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -111,6 +111,7 @@ under the License. 5.11.0 5.2.0 3.46.0 + 1.5.7 none -Xdoclint:none @@ -221,6 +222,16 @@ under the License. pom import + + ch.qos.logback + logback-classic + ${logback.version} + + + ch.qos.logback + logback-core + ${logback.version} + @@ -274,7 +285,7 @@ under the License. ch.qos.logback logback-classic - 1.4.14 + ${logback.version} test diff --git a/java/tools/pom.xml b/java/tools/pom.xml index 94566495dff..082f06860c6 100644 --- a/java/tools/pom.xml +++ b/java/tools/pom.xml @@ -59,7 +59,6 @@ under the License. ch.qos.logback logback-classic - 1.4.14 test