diff --git a/arrow-testing b/arrow-testing index d2a13712..6a7b02fa 160000 --- a/arrow-testing +++ b/arrow-testing @@ -1 +1 @@ -Subproject commit d2a13712303498963395318a4eb42872e66aead7 +Subproject commit 6a7b02fac93d8addbcdbb213264e58bfdc3068e4 diff --git a/arrow/extensions/variant.go b/arrow/extensions/variant.go index fe97f247..659f571c 100644 --- a/arrow/extensions/variant.go +++ b/arrow/extensions/variant.go @@ -18,6 +18,7 @@ package extensions import ( "bytes" + "errors" "fmt" "math" "reflect" @@ -171,21 +172,23 @@ func NewVariantType(storage arrow.DataType) (*VariantType, error) { return nil, fmt.Errorf("%w: missing non-nullable field 'metadata' in variant storage type %s", arrow.ErrInvalid, storage) } - if valueFieldIdx, ok = s.FieldIdx("value"); !ok { - return nil, fmt.Errorf("%w: missing non-nullable field 'value' in variant storage type %s", arrow.ErrInvalid, storage) + var valueOk, typedValueOk bool + valueFieldIdx, valueOk = s.FieldIdx("value") + typedValueFieldIdx, typedValueOk = s.FieldIdx("typed_value") + + if !valueOk && !typedValueOk { + return nil, fmt.Errorf("%w: there must be at least one of 'value' or 'typed_value' fields in variant storage type %s", arrow.ErrInvalid, storage) } - if s.NumFields() > 3 { - return nil, fmt.Errorf("%w: too many fields in variant storage type %s, expected 2 or 3", arrow.ErrInvalid, storage) + if s.NumFields() == 3 && (!valueOk || !typedValueOk) { + return nil, fmt.Errorf("%w: has 3 fields, but missing one of 'value' or 'typed_value' fields, %s", arrow.ErrInvalid, storage) } - if s.NumFields() == 3 { - if typedValueFieldIdx, ok = s.FieldIdx("typed_value"); !ok { - return nil, fmt.Errorf("%w: has 3 fields, but missing 'typed_value' field, %s", arrow.ErrInvalid, storage) - } + if s.NumFields() > 3 { + return nil, fmt.Errorf("%w: too many fields in variant storage type %s, expected 2 or 3", arrow.ErrInvalid, storage) } - mdField, valField := s.Field(metadataFieldIdx), s.Field(valueFieldIdx) + mdField := s.Field(metadataFieldIdx) if mdField.Nullable { return nil, fmt.Errorf("%w: metadata field must be non-nullable binary type, got %s", arrow.ErrInvalid, mdField.Type) } @@ -196,11 +199,14 @@ func NewVariantType(storage arrow.DataType) (*VariantType, error) { } } - if !isBinary(valField.Type) || (valField.Nullable && typedValueFieldIdx == -1) { - return nil, fmt.Errorf("%w: value field must be non-nullable binary type, got %s", arrow.ErrInvalid, valField.Type) + if valueOk { + valField := s.Field(valueFieldIdx) + if !isBinary(valField.Type) { + return nil, fmt.Errorf("%w: value field must be binary type, got %s", arrow.ErrInvalid, valField.Type) + } } - if typedValueFieldIdx == -1 { + if !typedValueOk { return &VariantType{ ExtensionBase: arrow.ExtensionBase{Storage: storage}, metadataFieldIdx: metadataFieldIdx, @@ -209,17 +215,17 @@ func NewVariantType(storage arrow.DataType) (*VariantType, error) { }, nil } - valueField := s.Field(valueFieldIdx) - if !valueField.Nullable { - return nil, fmt.Errorf("%w: value field must be nullable if typed_value is present, got %s", arrow.ErrInvalid, valueField.Type) - } - typedValueField := s.Field(typedValueFieldIdx) if !typedValueField.Nullable { return nil, fmt.Errorf("%w: typed_value field must be nullable, got %s", arrow.ErrInvalid, typedValueField.Type) } - if nt, ok := typedValueField.Type.(arrow.NestedType); ok { + dt := typedValueField.Type + if dt.ID() == arrow.EXTENSION { + dt = dt.(arrow.ExtensionType).StorageType() + } + + if nt, ok := dt.(arrow.NestedType); ok { if !validNestedType(nt) { return nil, fmt.Errorf("%w: typed_value field must be a valid nested type, got %s", arrow.ErrInvalid, typedValueField.Type) } @@ -242,6 +248,9 @@ func (v *VariantType) Metadata() arrow.Field { } func (v *VariantType) Value() arrow.Field { + if v.valueFieldIdx == -1 { + return arrow.Field{} + } return v.StorageType().(*arrow.StructType).Field(v.valueFieldIdx) } @@ -286,7 +295,7 @@ func validStruct(s *arrow.StructType) bool { switch s.NumFields() { case 1: f := s.Field(0) - return f.Name == "value" && !f.Nullable && isBinary(f.Type) + return (f.Name == "value" && isBinary(f.Type)) || f.Name == "typed_value" case 2: valField, ok := s.FieldByName("value") if !ok || !valField.Nullable || !isBinary(valField.Type) { @@ -365,8 +374,6 @@ func (v *VariantArray) initReader() { vt := v.ExtensionType().(*VariantType) st := v.Storage().(*array.Struct) metaField := st.Field(vt.metadataFieldIdx) - valueField := st.Field(vt.valueFieldIdx) - metadata, ok := metaField.(arrow.TypedArray[[]byte]) if !ok { // we already validated that if the metadata field isn't a binary @@ -374,24 +381,30 @@ func (v *VariantArray) initReader() { metadata, _ = array.NewDictWrapper[[]byte](metaField.(*array.Dictionary)) } - if vt.typedValueFieldIdx == -1 { + var value arrow.TypedArray[[]byte] + if vt.valueFieldIdx != -1 { + valueField := st.Field(vt.valueFieldIdx) + value = valueField.(arrow.TypedArray[[]byte]) + } + + var ivreader typedValReader + var err error + if vt.typedValueFieldIdx != -1 { + ivreader, err = getReader(st.Field(vt.typedValueFieldIdx)) + if err != nil { + v.rdrErr = err + return + } + v.rdr = &shreddedVariantReader{ + metadata: metadata, + value: value, + typedValue: ivreader, + } + } else { v.rdr = &basicVariantReader{ metadata: metadata, - value: valueField.(arrow.TypedArray[[]byte]), + value: value, } - return - } - - ivreader, err := getReader(st.Field(vt.typedValueFieldIdx)) - if err != nil { - v.rdrErr = err - return - } - - v.rdr = &shreddedVariantReader{ - metadata: metadata, - value: valueField.(arrow.TypedArray[[]byte]), - typedValue: ivreader, } }) } @@ -419,6 +432,9 @@ func (v *VariantArray) Metadata() arrow.TypedArray[[]byte] { // value of null). func (v *VariantArray) UntypedValues() arrow.TypedArray[[]byte] { vt := v.ExtensionType().(*VariantType) + if vt.valueFieldIdx == -1 { + return nil + } return v.Storage().(*array.Struct).Field(vt.valueFieldIdx).(arrow.TypedArray[[]byte]) } @@ -451,7 +467,6 @@ func (v *VariantArray) IsNull(i int) bool { } vt := v.ExtensionType().(*VariantType) - valArr := v.Storage().(*array.Struct).Field(vt.valueFieldIdx) if vt.typedValueFieldIdx != -1 { typedArr := v.Storage().(*array.Struct).Field(vt.typedValueFieldIdx) if !typedArr.IsNull(i) { @@ -459,6 +474,7 @@ func (v *VariantArray) IsNull(i int) bool { } } + valArr := v.Storage().(*array.Struct).Field(vt.valueFieldIdx) b := valArr.(arrow.TypedArray[[]byte]).Value(i) return len(b) == 1 && b[0] == 0 // variant null } @@ -747,9 +763,20 @@ func getReader(typedArr arrow.Array) (typedValReader, error) { childType := child.DataType().(*arrow.StructType) valueIdx, _ := childType.FieldIdx("value") - valueArr := child.Field(valueIdx).(arrow.TypedArray[[]byte]) + var valueArr arrow.TypedArray[[]byte] + if valueIdx != -1 { + valueArr = child.Field(valueIdx).(arrow.TypedArray[[]byte]) + } + + typedValueIdx, exists := childType.FieldIdx("typed_value") + if !exists { + fieldReaders[fieldList[i].Name] = fieldReaderPair{ + values: valueArr, + typedVal: nil, + } + continue + } - typedValueIdx, _ := childType.FieldIdx("typed_value") typedRdr, err := getReader(child.Field(typedValueIdx)) if err != nil { return nil, fmt.Errorf("error getting typed value reader for field %s: %w", fieldList[i].Name, err) @@ -768,13 +795,22 @@ func getReader(typedArr arrow.Array) (typedValReader, error) { case array.ListLike: listValues := arr.ListValues().(*array.Struct) elemType := listValues.DataType().(*arrow.StructType) + + var valueArr arrow.TypedArray[[]byte] + var typedRdr typedValReader + valueIdx, _ := elemType.FieldIdx("value") - valueArr := listValues.Field(valueIdx).(arrow.TypedArray[[]byte]) + if valueIdx != -1 { + valueArr = listValues.Field(valueIdx).(arrow.TypedArray[[]byte]) + } typedValueIdx, _ := elemType.FieldIdx("typed_value") - typedRdr, err := getReader(listValues.Field(typedValueIdx)) - if err != nil { - return nil, fmt.Errorf("error getting typed value reader: %w", err) + if typedValueIdx != -1 { + var err error + typedRdr, err = getReader(listValues.Field(typedValueIdx)) + if err != nil { + return nil, fmt.Errorf("error getting typed value reader: %w", err) + } } return &typedListReader{ @@ -796,6 +832,7 @@ func constructVariant(b *variant.Builder, meta variant.Metadata, value []byte, t switch v := typedVal.(type) { case nil: if len(value) == 0 { + b.AppendNull() return nil } @@ -846,6 +883,9 @@ func constructVariant(b *variant.Builder, meta variant.Metadata, value []byte, t return b.FinishArray(arrstart, elems) case []byte: + if len(value) > 0 { + return errors.New("invalid variant, conflicting value and typed_value") + } return b.UnsafeAppendEncoded(v) default: return fmt.Errorf("%w: unsupported typed value type %T for variant", arrow.ErrInvalid, v) @@ -876,14 +916,24 @@ func (v *typedObjReader) Value(meta variant.Metadata, i int) (any, error) { return nil, nil } + var err error result := make(map[string]typedPair) for name, rdr := range v.fieldRdrs { - typedValue, err := rdr.typedVal.Value(meta, i) - if err != nil { - return nil, fmt.Errorf("error reading typed value for field %s at index %d: %w", name, i, err) + var typedValue any + if rdr.typedVal != nil { + typedValue, err = rdr.typedVal.Value(meta, i) + if err != nil { + return nil, fmt.Errorf("error reading typed value for field %s at index %d: %w", name, i, err) + } } + + var val []byte + if rdr.values != nil { + val = rdr.values.Value(i) + } + result[name] = typedPair{ - Value: rdr.values.Value(i), + Value: val, TypedValue: typedValue, } } @@ -913,7 +963,11 @@ func (v *typedListReader) Value(meta variant.Metadata, i int) (any, error) { result := make([]typedPair, 0, end-start) for j := start; j < end; j++ { - val := v.valueArr.Value(int(j)) + var val []byte + if v.valueArr != nil { + val = v.valueArr.Value(int(j)) + } + typedValue, err := v.typedVal.Value(meta, int(j)) if err != nil { return nil, fmt.Errorf("error reading typed value at index %d: %w", j, err) @@ -956,12 +1010,17 @@ func (v *shreddedVariantReader) Value(i int) (variant.Value, error) { } b := variant.NewBuilderFromMeta(meta) + b.SetAllowDuplicates(true) typed, err := v.typedValue.Value(meta, i) if err != nil { return variant.NullValue, fmt.Errorf("error reading typed value at index %d: %w", i, err) } - if err := constructVariant(b, meta, v.value.Value(i), typed); err != nil { + var value []byte + if v.value != nil { + value = v.value.Value(i) + } + if err := constructVariant(b, meta, value, typed); err != nil { return variant.NullValue, fmt.Errorf("error constructing variant at index %d: %w", i, err) } return b.Build() diff --git a/arrow/extensions/variant_test.go b/arrow/extensions/variant_test.go index 6e539ee5..925d0621 100644 --- a/arrow/extensions/variant_test.go +++ b/arrow/extensions/variant_test.go @@ -61,21 +61,18 @@ func TestVariantExtensionType(t *testing.T) { expectedErr string }{ {arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary}), - "missing non-nullable field 'value'"}, + "there must be at least one of 'value' or 'typed_value' fields in variant storage type"}, {arrow.StructOf(arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary}), "missing non-nullable field 'metadata'"}, {arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary}, arrow.Field{Name: "value", Type: arrow.PrimitiveTypes.Int32}), - "value field must be non-nullable binary type, got int32"}, + "value field must be binary type, got int32"}, {arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary}, arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary}, arrow.Field{Name: "extra", Type: arrow.BinaryTypes.Binary}), - "has 3 fields, but missing 'typed_value' field"}, + "has 3 fields, but missing one of 'value' or 'typed_value' field"}, {arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: true}, arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false}), "metadata field must be non-nullable binary type"}, - {arrow.StructOf(arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false}, - arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true}), - "value field must be non-nullable binary type"}, {arrow.FixedWidthTypes.Boolean, "bad storage type bool for variant type"}, {arrow.StructOf( arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false}, @@ -86,16 +83,6 @@ func TestVariantExtensionType(t *testing.T) { arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.String, Nullable: false}, arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false}), "metadata field must be non-nullable binary type, got utf8"}, - {arrow.StructOf( - arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false}, - arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: false}, - arrow.Field{Name: "typed_value", Type: arrow.BinaryTypes.String, Nullable: true}), - "value field must be nullable if typed_value is present"}, - {arrow.StructOf( - arrow.Field{Name: "metadata", Type: arrow.BinaryTypes.Binary, Nullable: false}, - arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true}, - arrow.Field{Name: "typed_value", Type: arrow.BinaryTypes.String, Nullable: false}), - "typed_value field must be nullable"}, } for _, tt := range tests { @@ -126,11 +113,6 @@ func TestVariantExtensionBadNestedTypes(t *testing.T) { ), Nullable: false})}, {"empty struct elem", arrow.StructOf( arrow.Field{Name: "foobar", Type: arrow.StructOf(), Nullable: false})}, - {"nullable value struct elem", - arrow.StructOf( - arrow.Field{Name: "foobar", Type: arrow.StructOf( - arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true}, - ), Nullable: false})}, {"non-nullable two elem struct", arrow.StructOf( arrow.Field{Name: "foobar", Type: arrow.StructOf( arrow.Field{Name: "value", Type: arrow.BinaryTypes.Binary, Nullable: true}, diff --git a/parquet-testing b/parquet-testing index 2dc8bf14..a3d96a65 160000 --- a/parquet-testing +++ b/parquet-testing @@ -1 +1 @@ -Subproject commit 2dc8bf140ed6e28652fc347211c7d661714c7f95 +Subproject commit a3d96a65e11e2bbca7d22a894e8313ede90a33a3 diff --git a/parquet/pqarrow/encode_arrow.go b/parquet/pqarrow/encode_arrow.go index cdaba241..5724e9f8 100644 --- a/parquet/pqarrow/encode_arrow.go +++ b/parquet/pqarrow/encode_arrow.go @@ -333,7 +333,7 @@ func writeDenseArrow(ctx *arrowWriteContext, cw file.ColumnChunkWriter, leafArr case arrow.DECIMAL128: for idx, val := range leafArr.(*array.Decimal128).Values() { debug.Assert(val.HighBits() == 0 || val.HighBits() == -1, "casting Decimal128 greater than the value range; high bits must be 0 or -1") - debug.Assert(val.LowBits() <= math.MaxUint32, "casting Decimal128 to int32 when value > MaxUint32") + debug.Assert(int64(val.LowBits()) <= math.MaxUint32, "casting Decimal128 to int32 when value > MaxUint32") data[idx] = int32(val.LowBits()) } case arrow.DECIMAL256: diff --git a/parquet/pqarrow/schema.go b/parquet/pqarrow/schema.go index 2c0e70b5..7c56e333 100644 --- a/parquet/pqarrow/schema.go +++ b/parquet/pqarrow/schema.go @@ -242,7 +242,7 @@ func repFromNullable(isnullable bool) parquet.Repetition { } func variantToNode(t *extensions.VariantType, field arrow.Field, props *parquet.WriterProperties, arrProps ArrowWriterProperties) (schema.Node, error) { - fields := make(schema.FieldList, 2, 3) + fields := make(schema.FieldList, 1, 3) var err error fields[0], err = fieldToNode("metadata", t.Metadata(), props, arrProps) @@ -250,9 +250,12 @@ func variantToNode(t *extensions.VariantType, field arrow.Field, props *parquet. return nil, err } - fields[1], err = fieldToNode("value", t.Value(), props, arrProps) - if err != nil { - return nil, err + if value := t.Value(); value.Type != nil { + valueField, err := fieldToNode("value", value, props, arrProps) + if err != nil { + return nil, err + } + fields = append(fields, valueField) } if typed := t.TypedValue(); typed.Type != nil { @@ -594,8 +597,9 @@ func getParquetType(typ arrow.DataType, props *parquet.WriterProperties, arrprop precision := int(dectype.GetPrecision()) scale := int(dectype.GetScale()) + logicalType := schema.NewDecimalLogicalType(int32(precision), int32(scale)) if !props.StoreDecimalAsInteger() || precision > 18 { - return parquet.Types.FixedLenByteArray, schema.NewDecimalLogicalType(int32(precision), int32(scale)), int(DecimalSize(int32(precision))), nil + return parquet.Types.FixedLenByteArray, logicalType, int(DecimalSize(int32(precision))), nil } pqType := parquet.Types.Int32 @@ -603,7 +607,7 @@ func getParquetType(typ arrow.DataType, props *parquet.WriterProperties, arrprop pqType = parquet.Types.Int64 } - return pqType, schema.NoLogicalType{}, -1, nil + return pqType, logicalType, -1, nil case arrow.DATE32: return parquet.Types.Int32, schema.DateLogicalType{}, -1, nil case arrow.DATE64: @@ -612,14 +616,14 @@ func getParquetType(typ arrow.DataType, props *parquet.WriterProperties, arrprop pqType, logicalType, err := getTimestampMeta(typ.(*arrow.TimestampType), props, arrprops) return pqType, logicalType, -1, err case arrow.TIME32: - return parquet.Types.Int32, schema.NewTimeLogicalType(true, schema.TimeUnitMillis), -1, nil + return parquet.Types.Int32, schema.NewTimeLogicalType(false, schema.TimeUnitMillis), -1, nil case arrow.TIME64: pqTimeUnit := schema.TimeUnitMicros if typ.(*arrow.Time64Type).Unit == arrow.Nanosecond { pqTimeUnit = schema.TimeUnitNanos } - return parquet.Types.Int64, schema.NewTimeLogicalType(true, pqTimeUnit), -1, nil + return parquet.Types.Int64, schema.NewTimeLogicalType(false, pqTimeUnit), -1, nil case arrow.FLOAT16: return parquet.Types.FixedLenByteArray, schema.Float16LogicalType{}, arrow.Float16SizeBytes, nil case arrow.EXTENSION: diff --git a/parquet/pqarrow/schema_test.go b/parquet/pqarrow/schema_test.go index 6f3da880..6f5d14c7 100644 --- a/parquet/pqarrow/schema_test.go +++ b/parquet/pqarrow/schema_test.go @@ -184,11 +184,11 @@ func TestConvertArrowFlatPrimitives(t *testing.T) { arrowFields = append(arrowFields, arrow.Field{Name: "date64", Type: arrow.FixedWidthTypes.Date64, Nullable: false}) parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("time32", parquet.Repetitions.Required, - schema.NewTimeLogicalType(true, schema.TimeUnitMillis), parquet.Types.Int32, 0, -1))) + schema.NewTimeLogicalType(false, schema.TimeUnitMillis), parquet.Types.Int32, 0, -1))) arrowFields = append(arrowFields, arrow.Field{Name: "time32", Type: arrow.FixedWidthTypes.Time32ms, Nullable: false}) parquetFields = append(parquetFields, schema.Must(schema.NewPrimitiveNodeLogical("time64", parquet.Repetitions.Required, - schema.NewTimeLogicalType(true, schema.TimeUnitMicros), parquet.Types.Int64, 0, -1))) + schema.NewTimeLogicalType(false, schema.TimeUnitMicros), parquet.Types.Int64, 0, -1))) arrowFields = append(arrowFields, arrow.Field{Name: "time64", Type: arrow.FixedWidthTypes.Time64us, Nullable: false}) parquetFields = append(parquetFields, schema.NewInt96Node("timestamp96", parquet.Repetitions.Required, -1)) diff --git a/parquet/pqarrow/variant_test.go b/parquet/pqarrow/variant_test.go new file mode 100644 index 00000000..81fa246b --- /dev/null +++ b/parquet/pqarrow/variant_test.go @@ -0,0 +1,326 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pqarrow_test + +import ( + "context" + "fmt" + "io" + "iter" + "os" + "path/filepath" + "slices" + "strings" + "testing" + "unsafe" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/endian" + "github.com/apache/arrow-go/v18/arrow/extensions" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/apache/arrow-go/v18/internal/json" + "github.com/apache/arrow-go/v18/parquet" + "github.com/apache/arrow-go/v18/parquet/pqarrow" + "github.com/apache/arrow-go/v18/parquet/variant" + "github.com/stretchr/testify/suite" +) + +type ShreddedVariantTestSuite struct { + suite.Suite + + generate bool + + dirPrefix string + outDir string + cases []Case + + errorCases []Case + singleVariant []Case + multiVariant []Case +} + +func (s *ShreddedVariantTestSuite) SetupSuite() { + dir := os.Getenv("PARQUET_TEST_DATA") + if dir == "" { + s.T().Skip("PARQUET_TEST_DATA environment variable not set") + } + + s.dirPrefix = filepath.Join(dir, "..", "shredded_variant") + s.outDir = filepath.Join(dir, "..", "go_variant") + if s.generate { + s.Require().NoError(os.MkdirAll(s.outDir, 0o755), "Failed to create output directory: %s", s.outDir) + } + + cases, err := os.Open(filepath.Join(s.dirPrefix, "cases.json")) + s.Require().NoError(err, "Failed to open cases.json") + defer cases.Close() + + s.Require().NoError(json.NewDecoder(cases).Decode(&s.cases)) + + s.errorCases = slices.DeleteFunc(slices.Clone(s.cases), func(c Case) bool { + return c.ErrorMessage == "" + }) + + s.singleVariant = slices.DeleteFunc(slices.Clone(s.cases), func(c Case) bool { + return c.ErrorMessage != "" || c.VariantFile == "" || len(c.VariantFiles) > 0 + }) + + s.multiVariant = slices.DeleteFunc(slices.Clone(s.cases), func(c Case) bool { + return c.ErrorMessage != "" || c.VariantFile != "" || len(c.VariantFiles) == 0 + }) + + if s.generate { + cases.Seek(0, io.SeekStart) + outCases, err := os.Create(filepath.Join(s.outDir, "cases.json")) + s.Require().NoError(err, "Failed to create cases.json") + defer outCases.Close() + + io.Copy(outCases, cases) + outCases.Sync() + } +} + +type Case struct { + Number int `json:"case_number"` + Title string `json:"test"` + Notes string `json:"notes,omitempty"` + ParquetFile string `json:"parquet_file"` + VariantFile string `json:"variant_file,omitempty"` + VariantFiles []*string `json:"variant_files,omitempty"` + VariantData string `json:"variant,omitempty"` + Variants string `json:"variants,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` +} + +func readUnsigned(b []byte) (result uint32) { + v := (*[4]byte)(unsafe.Pointer(&result)) + copy(v[:], b) + return endian.FromLE(result) +} + +func (s *ShreddedVariantTestSuite) readVariant(filename string) variant.Value { + data, err := os.ReadFile(filename) + s.Require().NoError(err, "Failed to read variant file: %s", filename) + + hdr := data[0] + offsetSize := int(1 + ((hdr & 0b11000000) >> 6)) + dictSize := int(readUnsigned(data[1 : 1+offsetSize])) + offsetListOffset := 1 + offsetSize + dataOffset := offsetListOffset + ((1 + dictSize) * offsetSize) + + idx := offsetListOffset + (offsetSize * dictSize) + endOffset := dataOffset + int(readUnsigned(data[idx:idx+offsetSize])) + val, err := variant.New(data[:endOffset], data[endOffset:]) + s.Require().NoError(err, "Failed to create variant from data: %s", filename) + return val +} + +func (s *ShreddedVariantTestSuite) readParquet(filename string) arrow.Table { + file, err := os.Open(filepath.Join(s.dirPrefix, filename)) + s.Require().NoError(err, "Failed to open Parquet file: %s", filename) + defer file.Close() + + tbl, err := pqarrow.ReadTable(context.Background(), file, nil, pqarrow.ArrowReadProperties{}, memory.DefaultAllocator) + s.Require().NoError(err, "Failed to read Parquet file: %s", filename) + return tbl +} + +func (s *ShreddedVariantTestSuite) writeVariantFile(filename string, val variant.Value) { + out, err := os.Create(filepath.Join(s.outDir, filename)) + s.Require().NoError(err) + defer out.Close() + + _, err = out.Write(val.Metadata().Bytes()) + s.Require().NoError(err) + _, err = out.Write(val.Bytes()) + s.Require().NoError(err) +} + +func (s *ShreddedVariantTestSuite) writeParquetFile(filename string, tbl arrow.Table) { + out, err := os.Create(filepath.Join(s.outDir, filename)) + s.Require().NoError(err) + defer out.Close() + + s.Require().NoError(pqarrow.WriteTable(tbl, out, max(1, tbl.NumRows()), parquet.NewWriterProperties( + parquet.WithDictionaryDefault(false), parquet.WithStats(false), + parquet.WithStoreDecimalAsInteger(true), + ), pqarrow.DefaultWriterProps())) +} + +func zip[T, U any](a iter.Seq[T], b iter.Seq[U]) iter.Seq2[T, U] { + return func(yield func(T, U) bool) { + nexta, stopa := iter.Pull(a) + nextb, stopb := iter.Pull(b) + defer stopa() + defer stopb() + + for { + a, ok := nexta() + if !ok { + return + } + b, ok := nextb() + if !ok { + return + } + if !yield(a, b) { + return + } + } + } +} + +func (s *ShreddedVariantTestSuite) assertVariantEqual(expected, actual variant.Value) { + switch expected.BasicType() { + case variant.BasicObject: + exp := expected.Value().(variant.ObjectValue) + act := actual.Value().(variant.ObjectValue) + + s.Equal(exp.NumElements(), act.NumElements(), "Expected %d elements in object, got %d", exp.NumElements(), act.NumElements()) + for i := range exp.NumElements() { + expectedField, err := exp.FieldAt(i) + s.Require().NoError(err, "Failed to get expected field at index %d", i) + actualField, err := act.FieldAt(i) + s.Require().NoError(err, "Failed to get actual field at index %d", i) + + s.Equal(expectedField.Key, actualField.Key, "Expected field key %s, got %s", expectedField.Key, actualField.Key) + s.assertVariantEqual(expectedField.Value, actualField.Value) + } + case variant.BasicArray: + exp := expected.Value().(variant.ArrayValue) + act := actual.Value().(variant.ArrayValue) + + s.Equal(exp.Len(), act.Len(), "Expected array length %d, got %d", exp.Len(), act.Len()) + for e, a := range zip(exp.Values(), act.Values()) { + s.assertVariantEqual(e, a) + } + default: + switch expected.Type() { + case variant.Decimal4, variant.Decimal8, variant.Decimal16: + e, err := json.Marshal(expected.Value()) + s.Require().NoError(err, "Failed to marshal expected value") + a, err := json.Marshal(actual.Value()) + s.Require().NoError(err, "Failed to marshal actual value") + s.JSONEq(string(e), string(a), "Expected variant value %s, got %s", e, a) + default: + s.EqualValues(expected.Value(), actual.Value(), "Expected variant value %v, got %v", expected.Value(), actual.Value()) + } + } +} + +func (s *ShreddedVariantTestSuite) TestSingleVariantCases() { + for _, c := range s.singleVariant { + s.Run(c.Title, func() { + s.Run(fmt.Sprint(c.Number), func() { + if strings.Contains(c.ParquetFile, "-INVALID") { + s.T().Skip(c.Notes) + } + + expected := s.readVariant(filepath.Join(s.dirPrefix, c.VariantFile)) + if s.generate { + s.writeVariantFile(c.VariantFile, expected) + } + + tbl := s.readParquet(c.ParquetFile) + defer tbl.Release() + + if s.generate { + s.writeParquetFile(c.ParquetFile, tbl) + } + + col := tbl.Column(1).Data().Chunk(0) + s.Require().IsType(&extensions.VariantArray{}, col) + + variantArray := col.(*extensions.VariantArray) + s.Require().Equal(1, variantArray.Len(), "Expected single variant value") + + val, err := variantArray.Value(0) + s.Require().NoError(err, "Failed to get variant value from array") + s.assertVariantEqual(expected, val) + }) + }) + } +} + +func (s *ShreddedVariantTestSuite) TestMultiVariantCases() { + for _, c := range s.multiVariant { + s.Run(c.Title, func() { + s.Run(fmt.Sprint(c.Number), func() { + tbl := s.readParquet(c.ParquetFile) + defer tbl.Release() + + if s.generate { + s.writeParquetFile(c.ParquetFile, tbl) + } + + s.Require().EqualValues(len(c.VariantFiles), tbl.NumRows(), "Expected number of rows to match number of variant files") + col := tbl.Column(1).Data().Chunk(0) + s.Require().IsType(&extensions.VariantArray{}, col) + + variantArray := col.(*extensions.VariantArray) + for i, variantFile := range c.VariantFiles { + if variantFile == nil { + s.True(variantArray.IsNull(i), "Expected null value at index %d", i) + continue + } + + expected := s.readVariant(filepath.Join(s.dirPrefix, *variantFile)) + if s.generate { + s.writeVariantFile(*variantFile, expected) + } + + actual, err := variantArray.Value(i) + s.Require().NoError(err, "Failed to get variant value at index %d", i) + s.assertVariantEqual(expected, actual) + } + }) + }) + } +} + +func (s *ShreddedVariantTestSuite) TestErrorCases() { + for _, c := range s.errorCases { + s.Run(c.Title, func() { + s.Run(fmt.Sprint(c.Number), func() { + switch c.Number { + case 127: + s.T().Skip("Skipping case 127: test says uint32 should error, we just upcast to int64") + case 137: + s.T().Skip("Skipping case 137: test says flba(4) should error, we just treat it as a binary variant") + } + + tbl := s.readParquet(c.ParquetFile) + defer tbl.Release() + + if s.generate { + s.writeParquetFile(c.ParquetFile, tbl) + } + + col := tbl.Column(1).Data().Chunk(0) + s.Require().IsType(&extensions.VariantArray{}, col) + + variantArray := col.(*extensions.VariantArray) + _, err := variantArray.Value(0) + s.Error(err, "Expected error for case %d: %s", c.Number, c.ErrorMessage) + }) + }) + } +} + +func TestShreddedVariantExamples(t *testing.T) { + suite.Run(t, &ShreddedVariantTestSuite{generate: false}) +} diff --git a/parquet/schema/logical_types.go b/parquet/schema/logical_types.go index 0c0ce559..e7f1c29f 100644 --- a/parquet/schema/logical_types.go +++ b/parquet/schema/logical_types.go @@ -24,6 +24,7 @@ import ( "github.com/apache/arrow-go/v18/parquet" "github.com/apache/arrow-go/v18/parquet/internal/debug" format "github.com/apache/arrow-go/v18/parquet/internal/gen-go/parquet" + "github.com/apache/thrift/lib/go/thrift" ) // DecimalMetadata is a struct for managing scale and precision information between @@ -1139,7 +1140,7 @@ func (VariantLogicalType) IsCompatible(ct ConvertedType, _ DecimalMetadata) bool func (VariantLogicalType) IsApplicable(parquet.Type, int32) bool { return false } func (VariantLogicalType) toThrift() *format.LogicalType { - return &format.LogicalType{VARIANT: format.NewVariantType()} + return &format.LogicalType{VARIANT: &format.VariantType{SpecificationVersion: thrift.Int8Ptr(1)}} } func (VariantLogicalType) Equals(rhs LogicalType) bool { diff --git a/parquet/variant/builder.go b/parquet/variant/builder.go index 194814c6..68fc178d 100644 --- a/parquet/variant/builder.go +++ b/parquet/variant/builder.go @@ -887,7 +887,7 @@ func (b *Builder) Build() (Value, error) { type variantPrimitiveType interface { constraints.Integer | constraints.Float | string | []byte | arrow.Date32 | arrow.Time64 | arrow.Timestamp | bool | - uuid.UUID | DecimalValue[decimal.Decimal32] | + uuid.UUID | DecimalValue[decimal.Decimal32] | time.Time | DecimalValue[decimal.Decimal64] | DecimalValue[decimal.Decimal128] } @@ -895,17 +895,25 @@ type variantPrimitiveType interface { // variant value. At the moment this is just delegating to the [Builder.Append] method, // but in the future it will be optimized to avoid the extra overhead and reduce allocations. func Encode[T variantPrimitiveType](v T, opt ...AppendOpt) ([]byte, error) { + out, err := Of(v, opt...) + if err != nil { + return nil, fmt.Errorf("failed to encode variant value: %w", err) + } + return out.value, nil +} + +func Of[T variantPrimitiveType](v T, opt ...AppendOpt) (Value, error) { var b Builder if err := b.Append(v, opt...); err != nil { - return nil, fmt.Errorf("failed to append value: %w", err) + return Value{}, fmt.Errorf("failed to append value: %w", err) } val, err := b.Build() if err != nil { - return nil, fmt.Errorf("failed to build variant value: %w", err) + return Value{}, fmt.Errorf("failed to build variant value: %w", err) } - return val.value, nil + return val, nil } func ParseJSON(data string, allowDuplicateKeys bool) (Value, error) { diff --git a/parquet/variant/builder_test.go b/parquet/variant/builder_test.go index 09fa80eb..982fa4e9 100644 --- a/parquet/variant/builder_test.go +++ b/parquet/variant/builder_test.go @@ -57,9 +57,7 @@ func TestBuildPrimitive(t *testing.T) { {"primitive_int8", func(b *variant.Builder) error { return b.AppendInt(42) }}, {"primitive_int16", func(b *variant.Builder) error { return b.AppendInt(1234) }}, {"primitive_int32", func(b *variant.Builder) error { return b.AppendInt(123456) }}, - // FIXME: https://github.com/apache/parquet-testing/issues/82 - // primitive_int64 is an int32 value, but the metadata is int64 - {"primitive_int64", func(b *variant.Builder) error { return b.AppendInt(12345678) }}, + {"primitive_int64", func(b *variant.Builder) error { return b.AppendInt(1234567890123456789) }}, {"primitive_float", func(b *variant.Builder) error { return b.AppendFloat32(1234568000) }}, {"primitive_double", func(b *variant.Builder) error { return b.AppendFloat64(1234567890.1234) }}, {"primitive_string", func(b *variant.Builder) error { diff --git a/parquet/variant/variant.go b/parquet/variant/variant.go index 800b7eb2..254bc3c3 100644 --- a/parquet/variant/variant.go +++ b/parquet/variant/variant.go @@ -650,7 +650,10 @@ func (v Value) Value() any { } case BasicShortString: sz := int(v.value[0] >> 2) - return unsafe.String(&v.value[1], sz) + if sz > 0 { + return unsafe.String(&v.value[1], sz) + } + return "" case BasicObject: valueHdr := (v.value[0] >> basicTypeBits) fieldOffsetSz := (valueHdr & 0b11) + 1 diff --git a/parquet/variant/variant_test.go b/parquet/variant/variant_test.go index 2ef4da38..c623f646 100644 --- a/parquet/variant/variant_test.go +++ b/parquet/variant/variant_test.go @@ -152,9 +152,7 @@ func TestPrimitiveVariants(t *testing.T) { {"primitive_int8", int8(42), variant.Int8, "42"}, {"primitive_int16", int16(1234), variant.Int16, "1234"}, {"primitive_int32", int32(123456), variant.Int32, "123456"}, - // FIXME: https://github.com/apache/parquet-testing/issues/82 - // primitive_int64 is an int32 value, but the metadata is int64 - {"primitive_int64", int32(12345678), variant.Int32, "12345678"}, + {"primitive_int64", int64(1234567890123456789), variant.Int64, "1234567890123456789"}, {"primitive_float", float32(1234567940.0), variant.Float, "1234568000"}, {"primitive_double", float64(1234567890.1234), variant.Double, "1234567890.1234"}, {"primitive_string",