diff --git a/go/arrow/ipc/compression.go b/go/arrow/ipc/compression.go index 35815a09e10..462b1bad23e 100644 --- a/go/arrow/ipc/compression.go +++ b/go/arrow/ipc/compression.go @@ -21,6 +21,7 @@ import ( "github.com/apache/arrow/go/v11/arrow/internal/debug" "github.com/apache/arrow/go/v11/arrow/internal/flatbuf" + "github.com/apache/arrow/go/v11/arrow/memory" "github.com/klauspost/compress/zstd" "github.com/pierrec/lz4/v4" ) @@ -118,3 +119,17 @@ func getDecompressor(codec flatbuf.CompressionType) decompressor { } return nil } + +type bufferWriter struct { + buf *memory.Buffer + pos int +} + +func (bw *bufferWriter) Write(p []byte) (n int, err error) { + if bw.pos+len(p) >= bw.buf.Cap() { + bw.buf.Reserve(bw.pos + len(p)) + } + n = copy(bw.buf.Buf()[bw.pos:], p) + bw.pos += n + return +} diff --git a/go/arrow/ipc/writer.go b/go/arrow/ipc/writer.go index 088c6419526..b8d502c15f4 100644 --- a/go/arrow/ipc/writer.go +++ b/go/arrow/ipc/writer.go @@ -17,7 +17,6 @@ package ipc import ( - "bytes" "context" "encoding/binary" "errors" @@ -330,19 +329,23 @@ func (w *recordEncoder) compressBodyBuffers(p *Payload) error { if p.body[idx] == nil || p.body[idx].Len() == 0 { return nil } - var buf bytes.Buffer - buf.Grow(codec.MaxCompressedLen(p.body[idx].Len()) + arrow.Int64SizeBytes) - if err := binary.Write(&buf, binary.LittleEndian, uint64(p.body[idx].Len())); err != nil { - return err - } - codec.Reset(&buf) + + buf := memory.NewResizableBuffer(w.mem) + buf.Reserve(codec.MaxCompressedLen(p.body[idx].Len()) + arrow.Int64SizeBytes) + + binary.LittleEndian.PutUint64(buf.Buf(), uint64(p.body[idx].Len())) + bw := &bufferWriter{buf: buf, pos: arrow.Int64SizeBytes} + codec.Reset(bw) if _, err := codec.Write(p.body[idx].Bytes()); err != nil { return err } if err := codec.Close(); err != nil { return err } - p.body[idx] = memory.NewBufferBytes(buf.Bytes()) + + buf.Resize(bw.pos) + p.body[idx].Release() + p.body[idx] = buf return nil } diff --git a/go/arrow/ipc/writer_test.go b/go/arrow/ipc/writer_test.go index 73e1bbf0fe3..9ebdf267353 100644 --- a/go/arrow/ipc/writer_test.go +++ b/go/arrow/ipc/writer_test.go @@ -144,3 +144,25 @@ func TestWriterCatchPanic(t *testing.T) { writer := NewWriter(buf, WithSchema(schema)) assert.EqualError(t, writer.Write(rec), "arrow/ipc: unknown error while writing: runtime error: slice bounds out of range [-1:]") } + +func TestWriterMemCompression(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "s", Type: arrow.BinaryTypes.String}, + }, nil) + + b := array.NewRecordBuilder(mem, schema) + defer b.Release() + + b.Field(0).(*array.StringBuilder).AppendValues([]string{"foo", "bar", "baz"}, nil) + rec := b.NewRecord() + defer rec.Release() + + var buf bytes.Buffer + w := NewWriter(&buf, WithAllocator(mem), WithSchema(schema), WithZstd()) + defer w.Close() + + require.NoError(t, w.Write(rec)) +}