Skip to content

Commit

Permalink
Add support for COPY IN protocol (#72)
Browse files Browse the repository at this point in the history
* Add constants for Describe message types

and add a bit more debugging

* Remove unused function parameter

* handleMessageSizeExceeded doesn't need to be a method of srv

* Add failing test and begin work on CopyIn support

* Extend NewDataWriter to receive a CopyData function

* Extract command loop processing

* Rough working version

* Rewrite CopyIn() to return an io.Reader, rather than accept an io.Writer

* Extend CopyIn to allow specifying the format for each column

* Remove extraneous debugging

* Move command loop into consumeCommands, so inner function can be stepped through one command at a time

* Simplify CopyIn to use io.Reader rather than custom function signature

* Remove unused CopyDataFn type

* Further simplify handleCopyInCommand

* feat: introducing the binary column reader and improved copy reader implementation

---------

Co-authored-by: Jeroen Rinzema <[email protected]>
  • Loading branch information
flimzy and jeroenrinzema authored Nov 11, 2024
1 parent 3c2d029 commit ebab8df
Show file tree
Hide file tree
Showing 11 changed files with 526 additions and 48 deletions.
4 changes: 2 additions & 2 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Portal,
return portal, nil
}

func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer *buffer.Writer) (err error) {
func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, reader *buffer.Reader, writer *buffer.Writer) (err error) {
defer func() {
r := recover()
if r != nil {
Expand All @@ -121,5 +121,5 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, write
return nil
}

return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, writer), portal.parameters)
return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, reader, writer), portal.parameters)
}
93 changes: 53 additions & 40 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/lib/pq/oid"
)

// NewErrUnimplementedMessageType is called whenever a unimplemented message
// NewErrUnimplementedMessageType is called whenever an unimplemented message
// type is sent. This error indicates to the client that the sent message cannot
// be processed at this moment in time.
func NewErrUnimplementedMessageType(t types.ClientMessage) error {
Expand Down Expand Up @@ -45,62 +45,73 @@ func NewErrMultipleCommandsStatements() error {
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Syntax), psqlerr.LevelError)
}

// newErrClientCopyFailed is returned whenever the client aborts a copy operation.
func newErrClientCopyFailed(desc string) error {
err := fmt.Errorf("client aborted copy: %s", desc)
// TODO: What error code should this really be?
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Uncategorized), psqlerr.LevelError)
}

// consumeCommands consumes incoming commands sent over the Postgres wire connection.
// Commands consumed from the connection are returned through a go channel.
// Responses for the given message type are written back to the client.
// This method keeps consuming messages until the client issues a close message
// or the connection is terminated.
func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) (err error) {
func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
srv.logger.Debug("ready for query... starting to consume commands")

// TODO: Include a value to identify unique connections
//
// include a identification value inside the context that
// could be used to identify connections at a later stage.

err = readyForQuery(writer, types.ServerIdle)
err := readyForQuery(writer, types.ServerIdle)
if err != nil {
return err
}

for {
t, length, err := reader.ReadTypedMsg()
if err == io.EOF {
return nil
if err = srv.consumeSingleCommand(ctx, reader, writer, conn); err != nil {
return err
}
}
}

// NOTE: we could recover from this scenario
if errors.Is(err, buffer.ErrMessageSizeExceeded) {
err = srv.handleMessageSizeExceeded(reader, writer, err)
if err != nil {
return err
}

continue
}
func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
t, length, err := reader.ReadTypedMsg()
if err == io.EOF {
return nil
}

// NOTE: we could recover from this scenario
if errors.Is(err, buffer.ErrMessageSizeExceeded) {
err = handleMessageSizeExceeded(reader, writer, err)
if err != nil {
return err
}

if srv.closing.Load() {
return nil
}
return nil
}

// NOTE: we increase the wait group by one in order to make sure that idle
// connections are not blocking a close.
srv.wg.Add(1)
srv.logger.Debug("<- incoming command", slog.Int("length", length), slog.String("type", t.String()))
err = srv.handleCommand(ctx, conn, t, reader, writer)
srv.wg.Done()
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return err
}

if err != nil {
return err
}
if srv.closing.Load() {
return nil
}

// NOTE: we increase the wait group by one in order to make sure that idle
// connections are not blocking a close.
srv.wg.Add(1)
srv.logger.Debug("<- incoming command", slog.Int("length", length), slog.String("type", t.String()))
err = srv.handleCommand(ctx, conn, t, reader, writer)
srv.wg.Done()
if errors.Is(err, io.EOF) {
return nil
}

return err
}

// handleMessageSizeExceeded attempts to unwrap the given error message as
Expand All @@ -112,7 +123,7 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b
// type. A fatal error is returned when an unexpected error is returned while
// consuming the expected message size or when attempting to write the error
// message back to the client.
func (srv *Server) handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exceeded error) (err error) {
func handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exceeded error) (err error) {
unwrapped, has := buffer.UnwrapMessageSizeExceeded(exceeded)
if !has {
return exceeded
Expand All @@ -130,7 +141,7 @@ func (srv *Server) handleMessageSizeExceeded(reader *buffer.Reader, writer *buff
// message type and reader buffer containing the actual message. The type
// indecates a action executed by the client.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) (err error) {
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down Expand Up @@ -209,7 +220,7 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli
writer.End() //nolint:errcheck
return nil
case types.ClientTerminate:
err = srv.handleConnTerminate(ctx)
err := srv.handleConnTerminate(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -267,7 +278,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
return ErrorCode(writer, err)
}

err = statements[index].fn(ctx, NewDataWriter(ctx, statements[index].columns, nil, writer), nil)
err = statements[index].fn(ctx, NewDataWriter(ctx, statements[index].columns, nil, reader, writer), nil)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down Expand Up @@ -337,8 +348,10 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr
return err
}

switch d[0] {
case 'S':
srv.logger.Debug("incoming describe request", slog.String("type", types.DescribeMessage(d[0]).String()), slog.String("name", name))

switch types.DescribeMessage(d[0]) {
case types.DescribeStatement:
statement, err := srv.Statements.Get(ctx, name)
if err != nil {
return err
Expand All @@ -355,7 +368,7 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr

// NOTE: the format codes are not yet known at this point in time.
return srv.writeColumnDescription(ctx, writer, nil, statement.columns)
case 'P':
case types.DescribePortal:
portal, err := srv.Portals.Get(ctx, name)
if err != nil {
return err
Expand Down Expand Up @@ -412,7 +425,7 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
return err
}

formats, err := srv.readColumnTypes(ctx, reader)
formats, err := srv.readColumnTypes(reader)
if err != nil {
return err
}
Expand Down Expand Up @@ -503,7 +516,7 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([
return parameters, nil
}

func (srv *Server) readColumnTypes(ctx context.Context, reader *buffer.Reader) ([]FormatCode, error) {
func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) {
length, err := reader.GetUint16()
if err != nil {
return nil, err
Expand Down Expand Up @@ -544,7 +557,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri
}

srv.logger.Debug("executing", slog.String("name", name), slog.Uint64("limit", uint64(limit)))
err = srv.Portals.Execute(ctx, name, writer)
err = srv.Portals.Execute(ctx, name, reader, writer)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down
Loading

0 comments on commit ebab8df

Please sign in to comment.