Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 108 additions & 49 deletions arrow/extensions/variant.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package extensions

import (
"bytes"
"errors"
"fmt"
"math"
"reflect"
Expand Down Expand Up @@ -171,21 +172,23 @@ func NewVariantType(storage arrow.DataType) (*VariantType, error) {
return nil, fmt.Errorf("%w: missing non-nullable field 'metadata' in variant storage type %s", arrow.ErrInvalid, storage)
}

if valueFieldIdx, ok = s.FieldIdx("value"); !ok {
return nil, fmt.Errorf("%w: missing non-nullable field 'value' in variant storage type %s", arrow.ErrInvalid, storage)
var valueOk, typedValueOk bool
valueFieldIdx, valueOk = s.FieldIdx("value")
typedValueFieldIdx, typedValueOk = s.FieldIdx("typed_value")

if !valueOk && !typedValueOk {
return nil, fmt.Errorf("%w: there must be at least one of 'value' or 'typed_value' fields in variant storage type %s", arrow.ErrInvalid, storage)
}

if s.NumFields() > 3 {
return nil, fmt.Errorf("%w: too many fields in variant storage type %s, expected 2 or 3", arrow.ErrInvalid, storage)
if s.NumFields() == 3 && (!valueOk || !typedValueOk) {
return nil, fmt.Errorf("%w: has 3 fields, but missing one of 'value' or 'typed_value' fields, %s", arrow.ErrInvalid, storage)
}

if s.NumFields() == 3 {
if typedValueFieldIdx, ok = s.FieldIdx("typed_value"); !ok {
return nil, fmt.Errorf("%w: has 3 fields, but missing 'typed_value' field, %s", arrow.ErrInvalid, storage)
}
if s.NumFields() > 3 {
return nil, fmt.Errorf("%w: too many fields in variant storage type %s, expected 2 or 3", arrow.ErrInvalid, storage)
}

mdField, valField := s.Field(metadataFieldIdx), s.Field(valueFieldIdx)
mdField := s.Field(metadataFieldIdx)
if mdField.Nullable {
return nil, fmt.Errorf("%w: metadata field must be non-nullable binary type, got %s", arrow.ErrInvalid, mdField.Type)
}
Expand All @@ -196,11 +199,14 @@ func NewVariantType(storage arrow.DataType) (*VariantType, error) {
}
}

if !isBinary(valField.Type) || (valField.Nullable && typedValueFieldIdx == -1) {
return nil, fmt.Errorf("%w: value field must be non-nullable binary type, got %s", arrow.ErrInvalid, valField.Type)
if valueOk {
valField := s.Field(valueFieldIdx)
if !isBinary(valField.Type) {
return nil, fmt.Errorf("%w: value field must be binary type, got %s", arrow.ErrInvalid, valField.Type)
}
}

if typedValueFieldIdx == -1 {
if !typedValueOk {
return &VariantType{
ExtensionBase: arrow.ExtensionBase{Storage: storage},
metadataFieldIdx: metadataFieldIdx,
Expand All @@ -209,17 +215,17 @@ func NewVariantType(storage arrow.DataType) (*VariantType, error) {
}, nil
}

valueField := s.Field(valueFieldIdx)
if !valueField.Nullable {
return nil, fmt.Errorf("%w: value field must be nullable if typed_value is present, got %s", arrow.ErrInvalid, valueField.Type)
}

typedValueField := s.Field(typedValueFieldIdx)
if !typedValueField.Nullable {
return nil, fmt.Errorf("%w: typed_value field must be nullable, got %s", arrow.ErrInvalid, typedValueField.Type)
}

if nt, ok := typedValueField.Type.(arrow.NestedType); ok {
dt := typedValueField.Type
if dt.ID() == arrow.EXTENSION {
dt = dt.(arrow.ExtensionType).StorageType()
}

if nt, ok := dt.(arrow.NestedType); ok {
if !validNestedType(nt) {
return nil, fmt.Errorf("%w: typed_value field must be a valid nested type, got %s", arrow.ErrInvalid, typedValueField.Type)
}
Expand All @@ -242,6 +248,9 @@ func (v *VariantType) Metadata() arrow.Field {
}

func (v *VariantType) Value() arrow.Field {
if v.valueFieldIdx == -1 {
return arrow.Field{}
}
return v.StorageType().(*arrow.StructType).Field(v.valueFieldIdx)
}

Expand Down Expand Up @@ -286,7 +295,7 @@ func validStruct(s *arrow.StructType) bool {
switch s.NumFields() {
case 1:
f := s.Field(0)
return f.Name == "value" && !f.Nullable && isBinary(f.Type)
return (f.Name == "value" && isBinary(f.Type)) || f.Name == "typed_value"
case 2:
valField, ok := s.FieldByName("value")
if !ok || !valField.Nullable || !isBinary(valField.Type) {
Expand Down Expand Up @@ -365,33 +374,37 @@ func (v *VariantArray) initReader() {
vt := v.ExtensionType().(*VariantType)
st := v.Storage().(*array.Struct)
metaField := st.Field(vt.metadataFieldIdx)
valueField := st.Field(vt.valueFieldIdx)

metadata, ok := metaField.(arrow.TypedArray[[]byte])
if !ok {
// we already validated that if the metadata field isn't a binary
// type directly, it must be a dictionary with a binary value type.
metadata, _ = array.NewDictWrapper[[]byte](metaField.(*array.Dictionary))
}

if vt.typedValueFieldIdx == -1 {
var value arrow.TypedArray[[]byte]
if vt.valueFieldIdx != -1 {
valueField := st.Field(vt.valueFieldIdx)
value = valueField.(arrow.TypedArray[[]byte])
}

var ivreader typedValReader
var err error
if vt.typedValueFieldIdx != -1 {
ivreader, err = getReader(st.Field(vt.typedValueFieldIdx))
if err != nil {
v.rdrErr = err
return
}
v.rdr = &shreddedVariantReader{
metadata: metadata,
value: value,
typedValue: ivreader,
}
} else {
v.rdr = &basicVariantReader{
metadata: metadata,
value: valueField.(arrow.TypedArray[[]byte]),
value: value,
}
return
}

ivreader, err := getReader(st.Field(vt.typedValueFieldIdx))
if err != nil {
v.rdrErr = err
return
}

v.rdr = &shreddedVariantReader{
metadata: metadata,
value: valueField.(arrow.TypedArray[[]byte]),
typedValue: ivreader,
}
})
}
Expand Down Expand Up @@ -419,6 +432,9 @@ func (v *VariantArray) Metadata() arrow.TypedArray[[]byte] {
// value of null).
func (v *VariantArray) UntypedValues() arrow.TypedArray[[]byte] {
vt := v.ExtensionType().(*VariantType)
if vt.valueFieldIdx == -1 {
return nil
}
return v.Storage().(*array.Struct).Field(vt.valueFieldIdx).(arrow.TypedArray[[]byte])
}

Expand Down Expand Up @@ -451,14 +467,14 @@ func (v *VariantArray) IsNull(i int) bool {
}

vt := v.ExtensionType().(*VariantType)
valArr := v.Storage().(*array.Struct).Field(vt.valueFieldIdx)
if vt.typedValueFieldIdx != -1 {
typedArr := v.Storage().(*array.Struct).Field(vt.typedValueFieldIdx)
if !typedArr.IsNull(i) {
return false
}
}

valArr := v.Storage().(*array.Struct).Field(vt.valueFieldIdx)
b := valArr.(arrow.TypedArray[[]byte]).Value(i)
return len(b) == 1 && b[0] == 0 // variant null
}
Expand Down Expand Up @@ -747,9 +763,20 @@ func getReader(typedArr arrow.Array) (typedValReader, error) {
childType := child.DataType().(*arrow.StructType)

valueIdx, _ := childType.FieldIdx("value")
valueArr := child.Field(valueIdx).(arrow.TypedArray[[]byte])
var valueArr arrow.TypedArray[[]byte]
if valueIdx != -1 {
valueArr = child.Field(valueIdx).(arrow.TypedArray[[]byte])
}

typedValueIdx, exists := childType.FieldIdx("typed_value")
if !exists {
fieldReaders[fieldList[i].Name] = fieldReaderPair{
values: valueArr,
typedVal: nil,
}
continue
}

typedValueIdx, _ := childType.FieldIdx("typed_value")
typedRdr, err := getReader(child.Field(typedValueIdx))
if err != nil {
return nil, fmt.Errorf("error getting typed value reader for field %s: %w", fieldList[i].Name, err)
Expand All @@ -768,13 +795,22 @@ func getReader(typedArr arrow.Array) (typedValReader, error) {
case array.ListLike:
listValues := arr.ListValues().(*array.Struct)
elemType := listValues.DataType().(*arrow.StructType)

var valueArr arrow.TypedArray[[]byte]
var typedRdr typedValReader

valueIdx, _ := elemType.FieldIdx("value")
valueArr := listValues.Field(valueIdx).(arrow.TypedArray[[]byte])
if valueIdx != -1 {
valueArr = listValues.Field(valueIdx).(arrow.TypedArray[[]byte])
}

typedValueIdx, _ := elemType.FieldIdx("typed_value")
typedRdr, err := getReader(listValues.Field(typedValueIdx))
if err != nil {
return nil, fmt.Errorf("error getting typed value reader: %w", err)
if typedValueIdx != -1 {
var err error
typedRdr, err = getReader(listValues.Field(typedValueIdx))
if err != nil {
return nil, fmt.Errorf("error getting typed value reader: %w", err)
}
}

return &typedListReader{
Expand All @@ -796,6 +832,7 @@ func constructVariant(b *variant.Builder, meta variant.Metadata, value []byte, t
switch v := typedVal.(type) {
case nil:
if len(value) == 0 {
b.AppendNull()
return nil
}

Expand Down Expand Up @@ -846,6 +883,9 @@ func constructVariant(b *variant.Builder, meta variant.Metadata, value []byte, t

return b.FinishArray(arrstart, elems)
case []byte:
if len(value) > 0 {
return errors.New("invalid variant, conflicting value and typed_value")
}
return b.UnsafeAppendEncoded(v)
default:
return fmt.Errorf("%w: unsupported typed value type %T for variant", arrow.ErrInvalid, v)
Expand Down Expand Up @@ -876,14 +916,24 @@ func (v *typedObjReader) Value(meta variant.Metadata, i int) (any, error) {
return nil, nil
}

var err error
result := make(map[string]typedPair)
for name, rdr := range v.fieldRdrs {
typedValue, err := rdr.typedVal.Value(meta, i)
if err != nil {
return nil, fmt.Errorf("error reading typed value for field %s at index %d: %w", name, i, err)
var typedValue any
if rdr.typedVal != nil {
typedValue, err = rdr.typedVal.Value(meta, i)
if err != nil {
return nil, fmt.Errorf("error reading typed value for field %s at index %d: %w", name, i, err)
}
}

var val []byte
if rdr.values != nil {
val = rdr.values.Value(i)
}

result[name] = typedPair{
Value: rdr.values.Value(i),
Value: val,
TypedValue: typedValue,
}
}
Expand Down Expand Up @@ -913,7 +963,11 @@ func (v *typedListReader) Value(meta variant.Metadata, i int) (any, error) {

result := make([]typedPair, 0, end-start)
for j := start; j < end; j++ {
val := v.valueArr.Value(int(j))
var val []byte
if v.valueArr != nil {
val = v.valueArr.Value(int(j))
}

typedValue, err := v.typedVal.Value(meta, int(j))
if err != nil {
return nil, fmt.Errorf("error reading typed value at index %d: %w", j, err)
Expand Down Expand Up @@ -956,12 +1010,17 @@ func (v *shreddedVariantReader) Value(i int) (variant.Value, error) {
}

b := variant.NewBuilderFromMeta(meta)
b.SetAllowDuplicates(true)
typed, err := v.typedValue.Value(meta, i)
if err != nil {
return variant.NullValue, fmt.Errorf("error reading typed value at index %d: %w", i, err)
}

if err := constructVariant(b, meta, v.value.Value(i), typed); err != nil {
var value []byte
if v.value != nil {
value = v.value.Value(i)
}
if err := constructVariant(b, meta, value, typed); err != nil {
return variant.NullValue, fmt.Errorf("error constructing variant at index %d: %w", i, err)
}
return b.Build()
Expand Down
24 changes: 3 additions & 21 deletions arrow/extensions/variant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,18 @@ func TestVariantExtensionType(t *testing.T) {
expectedErr string
}{
{arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary}),
"missing non-nullable field 'value'"},
"there must be at least one of 'value' or 'typed_value' fields in variant storage type"},
{arrow.StructOf(arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary}), "missing non-nullable field 'metadata'"},
{arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary},
arrow.Field{Name: "value", Type: arrow.PrimitiveTypes.Int32}),
"value field must be non-nullable binary type, got int32"},
"value field must be binary type, got int32"},
{arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary},
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary},
arrow.Field{Name: "extra", Type: arrow.BinaryTypes.Binary}),
"has 3 fields, but missing 'typed_value' field"},
"has 3 fields, but missing one of 'value' or 'typed_value' field"},
{arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: true},
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false}),
"metadata field must be non-nullable binary type"},
{arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false},
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true}),
"value field must be non-nullable binary type"},
{arrow.FixedWidthTypes.Boolean, "bad storage type bool for variant type"},
{arrow.StructOf(
arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false},
Expand All @@ -86,16 +83,6 @@ func TestVariantExtensionType(t *testing.T) {
arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.String, Nullable: false},
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false}),
"metadata field must be non-nullable binary type, got utf8"},
{arrow.StructOf(
arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false},
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false},
arrow.Field{Name: "typed_value", Type: arrow.BinaryTypes.String, Nullable: true}),
"value field must be nullable if typed_value is present"},
{arrow.StructOf(
arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false},
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true},
arrow.Field{Name: "typed_value", Type: arrow.BinaryTypes.String, Nullable: false}),
"typed_value field must be nullable"},
}

for _, tt := range tests {
Expand Down Expand Up @@ -126,11 +113,6 @@ func TestVariantExtensionBadNestedTypes(t *testing.T) {
), Nullable: false})},
{"empty struct elem", arrow.StructOf(
arrow.Field{Name: "foobar", Type: arrow.StructOf(), Nullable: false})},
{"nullable value struct elem",
arrow.StructOf(
arrow.Field{Name: "foobar", Type: arrow.StructOf(
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true},
), Nullable: false})},
{"non-nullable two elem struct", arrow.StructOf(
arrow.Field{Name: "foobar", Type: arrow.StructOf(
arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true},
Expand Down
2 changes: 1 addition & 1 deletion parquet-testing
Submodule parquet-testing updated 292 files
2 changes: 1 addition & 1 deletion parquet/pqarrow/encode_arrow.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func writeDenseArrow(ctx *arrowWriteContext, cw file.ColumnChunkWriter, leafArr
case arrow.DECIMAL128:
for idx, val := range leafArr.(*array.Decimal128).Values() {
debug.Assert(val.HighBits() == 0 || val.HighBits() == -1, "casting Decimal128 greater than the value range; high bits must be 0 or -1")
debug.Assert(val.LowBits() <= math.MaxUint32, "casting Decimal128 to int32 when value > MaxUint32")
debug.Assert(int64(val.LowBits()) <= math.MaxUint32, "casting Decimal128 to int32 when value > MaxUint32")
data[idx] = int32(val.LowBits())
}
case arrow.DECIMAL256:
Expand Down
Loading
Loading