diff --git a/go/arrow/internal/arrdata/ioutil.go b/go/arrow/internal/arrdata/ioutil.go index 7065f64b503..33aab24bb3c 100644 --- a/go/arrow/internal/arrdata/ioutil.go +++ b/go/arrow/internal/arrdata/ioutil.go @@ -17,8 +17,10 @@ package arrdata // import "github.com/apache/arrow/go/arrow/internal/arrdata" import ( + "fmt" "io" "os" + "sync" "testing" "github.com/apache/arrow/go/arrow" @@ -59,6 +61,54 @@ func CheckArrowFile(t *testing.T, f *os.File, mem memory.Allocator, schema *arro } +func CheckArrowConcurrentFile(t *testing.T, f *os.File, mem memory.Allocator, schema *arrow.Schema, recs []array.Record) { + t.Helper() + + _, err := f.Seek(0, io.SeekStart) + if err != nil { + t.Fatal(err) + } + + r, err := ipc.NewFileReader(f, ipc.WithSchema(schema), ipc.WithAllocator(mem)) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + var g sync.WaitGroup + errs := make(chan error, r.NumRecords()) + checkRecord := func(i int) { + defer g.Done() + rec, err := r.RecordAt(i) + if err != nil { + errs <- fmt.Errorf("could not read record %d: %v", i, err) + return + } + if !array.RecordEqual(rec, recs[i]) { + errs <- fmt.Errorf("records[%d] differ", i) + } + } + + for i := 0; i < r.NumRecords(); i++ { + g.Add(1) + go checkRecord(i) + } + + g.Wait() + close(errs) + + for err := range errs { + if err != nil { + t.Fatal(err) + } + } + + err = r.Close() + if err != nil { + t.Fatal(err) + } +} + // CheckArrowStream checks whether a given ARROW stream contains the expected list of records. func CheckArrowStream(t *testing.T, f *os.File, mem memory.Allocator, schema *arrow.Schema, recs []array.Record) { t.Helper() diff --git a/go/arrow/ipc/file_reader.go b/go/arrow/ipc/file_reader.go index 961803b33ef..cf324482018 100644 --- a/go/arrow/ipc/file_reader.go +++ b/go/arrow/ipc/file_reader.go @@ -244,6 +244,23 @@ func (f *FileReader) Close() error { // The returned value is valid until the next call to Record. // Users need to call Retain on that Record to keep it valid for longer. func (f *FileReader) Record(i int) (array.Record, error) { + record, err := f.RecordAt(i) + if err != nil { + return nil, err + } + + if f.record != nil { + f.record.Release() + } + + f.record = record + return record, nil +} + +// Record returns the i-th record from the file. Ownership is transferred to the +// caller and must call Release() to free the memory. This method is safe to +// call concurrently. +func (f *FileReader) RecordAt(i int) (array.Record, error) { if i < 0 || i > f.NumRecords() { panic("arrow/ipc: record index out of bounds") } @@ -271,12 +288,7 @@ func (f *FileReader) Record(i int) (array.Record, error) { return nil, xerrors.Errorf("arrow/ipc: message %d is not a Record", i) } - if f.record != nil { - f.record.Release() - } - - f.record = newRecord(f.schema, msg.meta, bytes.NewReader(msg.body.Bytes())) - return f.record, nil + return newRecord(f.schema, msg.meta, bytes.NewReader(msg.body.Bytes())), nil } // Read reads the current record from the underlying stream and an error, if any. diff --git a/go/arrow/ipc/file_test.go b/go/arrow/ipc/file_test.go index 8c5d515ba5e..d0ef9605e61 100644 --- a/go/arrow/ipc/file_test.go +++ b/go/arrow/ipc/file_test.go @@ -45,6 +45,7 @@ func TestFile(t *testing.T) { arrdata.WriteFile(t, f, mem, recs[0].Schema(), recs) arrdata.CheckArrowFile(t, f, mem, recs[0].Schema(), recs) + arrdata.CheckArrowConcurrentFile(t, f, mem, recs[0].Schema(), recs) }) } }