diff --git a/arrow/cdata/cdata.go b/arrow/cdata/cdata.go index 4085ed3d..63419469 100644 --- a/arrow/cdata/cdata.go +++ b/arrow/cdata/cdata.go @@ -407,7 +407,9 @@ func (imp *cimporter) doImportChildren() error { st := imp.dt.(*arrow.StructType) for i, c := range children { imp.children[i].dt = st.Field(i).Type - imp.children[i].importChild(imp, c) + if err := imp.children[i].importChild(imp, c); err != nil { + return err + } } case arrow.RUN_END_ENCODED: // import run-ends and values st := imp.dt.(*arrow.RunEndEncodedType) @@ -428,13 +430,17 @@ func (imp *cimporter) doImportChildren() error { dt := imp.dt.(*arrow.DenseUnionType) for i, c := range children { imp.children[i].dt = dt.Fields()[i].Type - imp.children[i].importChild(imp, c) + if err := imp.children[i].importChild(imp, c); err != nil { + return err + } } case arrow.SPARSE_UNION: dt := imp.dt.(*arrow.SparseUnionType) for i, c := range children { imp.children[i].dt = dt.Fields()[i].Type - imp.children[i].importChild(imp, c) + if err := imp.children[i].importChild(imp, c); err != nil { + return err + } } } @@ -461,33 +467,28 @@ func (imp *cimporter) doImportArr(src *CArrowArray) error { // and only null columns, then we can release the CArrowArray // struct immediately after import, since we have no imported // memory that we have to track the lifetime of. + // On error, we always release regardless of buffer count to avoid leaks. + var importErr error defer func() { - if imp.alloc.bufCount.Load() == 0 { - C.ArrowArrayRelease(imp.arr) - C.free(unsafe.Pointer(imp.arr)) + if importErr != nil || imp.alloc.bufCount.Load() == 0 { + imp.alloc.forceRelease() } }() - return imp.doImport() + importErr = imp.doImport() + return importErr } // import is called recursively as needed for importing an array and its children // in order to generate array.Data objects func (imp *cimporter) doImport() error { - // move the array from the src object passed in to the one referenced by - // this importer. That way we can set up a finalizer on the created - // arrow.ArrayData object so we clean up our Array's memory when garbage collected. - defer func(arr *CArrowArray) { - // this should only occur in the case of an error happening - // during import, at which point we need to clean up the - // ArrowArray struct we allocated. - if imp.data == nil { - C.free(unsafe.Pointer(arr)) - } - }(imp.arr) - // import any children if err := imp.doImportChildren(); err != nil { + for _, c := range imp.children { + if c.data != nil { + c.data.Release() + } + } return err } diff --git a/arrow/cdata/cdata_test.go b/arrow/cdata/cdata_test.go index 170a5151..8fa690f2 100644 --- a/arrow/cdata/cdata_test.go +++ b/arrow/cdata/cdata_test.go @@ -669,8 +669,8 @@ func createTestDenseUnion() arrow.Array { func createTestUnionArr(mode arrow.UnionMode) arrow.Array { fields := []arrow.Field{ - arrow.Field{Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, - arrow.Field{Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true}, + {Name: "u0", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "u1", Type: arrow.PrimitiveTypes.Uint8, Nullable: true}, } typeCodes := []arrow.UnionTypeCode{5, 10} bld := array.NewBuilder(memory.DefaultAllocator, arrow.UnionOf(mode, fields, typeCodes)).(array.UnionBuilder) @@ -785,6 +785,104 @@ func TestRecordBatch(t *testing.T) { assert.True(t, array.RecordEqual(rb, rec)) } +func TestImportStructWithInvalidSchema(t *testing.T) { + mem := mallocator.NewMallocator() + defer mem.AssertSize(t, 0) + + arr := createTestStructArr() + defer arr.Release() + + carr := createCArr(arr, mem) + defer freeTestMallocatorArr(carr, mem) + + sc := testStruct([]string{"+s", "c", "l"}, []string{"", "a", "b"}, []int64{0, flagIsNullable, flagIsNullable}) + defer freeMallocedSchemas(sc) + + top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0] + _, err := ImportCRecordBatch(carr, top) + assert.Error(t, err) +} + +func TestImportDenseUnionWithInvalidSchema(t *testing.T) { + mem := mallocator.NewMallocator() + defer mem.AssertSize(t, 0) + + unionArr := createTestDenseUnion() + defer unionArr.Release() + + structBld := array.NewStructBuilder(memory.DefaultAllocator, arrow.StructOf( + arrow.Field{Name: "union_field", Type: unionArr.DataType(), Nullable: false}, + )) + defer structBld.Release() + + unionBld := structBld.FieldBuilder(0).(*array.DenseUnionBuilder) + structBld.Append(true) + du := unionArr.(*array.DenseUnion) + for i := 0; i < du.Len(); i++ { + unionBld.Append(du.TypeCode(i)) + if du.TypeCode(i) == 5 { + unionBld.Child(0).(*array.Int32Builder).Append(du.Field(0).(*array.Int32).Value(int(du.ValueOffset(i)))) + } else { + unionBld.Child(1).(*array.Uint8Builder).Append(du.Field(1).(*array.Uint8).Value(int(du.ValueOffset(i)))) + } + } + + structArr := structBld.NewArray() + defer structArr.Release() + + carr := createCArr(structArr, mem) + defer freeTestMallocatorArr(carr, mem) + + // Create an invalid schema: wrong type for union field (using "i" instead of proper union schema) + sc := testStruct([]string{"+s", "i"}, []string{"", "union_field"}, []int64{0, flagIsNullable}) + defer freeMallocedSchemas(sc) + + top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0] + _, err := ImportCRecordBatch(carr, top) + assert.Error(t, err) +} + +func TestImportSPARSEUnionWithInvalidSchema(t *testing.T) { + mem := mallocator.NewMallocator() + defer mem.AssertSize(t, 0) + + unionArr := createTestSparseUnion() + defer unionArr.Release() + + structBld := array.NewStructBuilder(memory.DefaultAllocator, arrow.StructOf( + arrow.Field{Name: "union_field", Type: unionArr.DataType(), Nullable: false}, + )) + defer structBld.Release() + + unionBld := structBld.FieldBuilder(0).(*array.SparseUnionBuilder) + structBld.Append(true) + su := unionArr.(*array.SparseUnion) + for i := 0; i < su.Len(); i++ { + unionBld.Append(su.TypeCode(i)) + if su.TypeCode(i) == 5 { + unionBld.Child(0).(*array.Int32Builder).Append(su.Field(0).(*array.Int32).Value(i)) + unionBld.Child(1).(*array.Uint8Builder).AppendNull() + } else { + unionBld.Child(0).(*array.Int32Builder).AppendNull() + unionBld.Child(1).(*array.Uint8Builder).Append(su.Field(1).(*array.Uint8).Value(i)) + } + } + + structArr := structBld.NewArray() + defer structArr.Release() + + carr := createCArr(structArr, mem) + defer freeTestMallocatorArr(carr, mem) + + // Create an invalid schema: wrong type for union field (using "u" instead of proper union schema) + sc := testStruct([]string{"+s", "u"}, []string{"", "union_field"}, []int64{0, flagIsNullable}) + defer freeMallocedSchemas(sc) + + top := (*[1]*CArrowSchema)(unsafe.Pointer(sc))[0] + _, err := ImportCRecordBatch(carr, top) + assert.Error(t, err) +} + func TestRecordReaderStream(t *testing.T) { stream := arrayStreamTest() defer releaseStreamTest(stream) @@ -1006,17 +1104,21 @@ func (r *failingReader) Schema() *arrow.Schema { } return arrdata.Records["primitives"][0].Schema() } + func (r *failingReader) Next() bool { r.opCount -= 1 return r.opCount > 0 } + func (r *failingReader) RecordBatch() arrow.RecordBatch { arrdata.Records["primitives"][0].Retain() return arrdata.Records["primitives"][0] } + func (r *failingReader) Record() arrow.Record { return r.RecordBatch() } + func (r *failingReader) Err() error { if r.opCount == 0 { return fmt.Errorf("Expected error message") diff --git a/arrow/cdata/import_allocator.go b/arrow/cdata/import_allocator.go index d2cc44b7..2dea1336 100644 --- a/arrow/cdata/import_allocator.go +++ b/arrow/cdata/import_allocator.go @@ -29,6 +29,7 @@ import "C" type importAllocator struct { bufCount atomic.Int64 + released atomic.Bool arr *CArrowArray } @@ -49,6 +50,12 @@ func (i *importAllocator) Free([]byte) { debug.Assert(i.bufCount.Load() > 0, "too many releases") if i.bufCount.Add(-1) == 0 { + i.forceRelease() + } +} + +func (i *importAllocator) forceRelease() { + if i.released.CompareAndSwap(false, true) { defer C.free(unsafe.Pointer(i.arr)) C.ArrowArrayRelease(i.arr) if C.ArrowArrayIsReleased(i.arr) != 1 {