Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for COPY IN protocol #72

Merged
merged 16 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func (cache *DefaultPortalCache) Get(ctx context.Context, name string) (*Portal,
return portal, nil
}

func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, writer *buffer.Writer) (err error) {
func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, reader *buffer.Reader, writer *buffer.Writer) (err error) {
defer func() {
r := recover()
if r != nil {
Expand All @@ -121,5 +121,5 @@ func (cache *DefaultPortalCache) Execute(ctx context.Context, name string, write
return nil
}

return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, writer), portal.parameters)
return portal.statement.fn(ctx, NewDataWriter(ctx, portal.statement.columns, portal.formats, reader, writer), portal.parameters)
}
93 changes: 53 additions & 40 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/lib/pq/oid"
)

// NewErrUnimplementedMessageType is called whenever a unimplemented message
// NewErrUnimplementedMessageType is called whenever an unimplemented message
// type is sent. This error indicates to the client that the sent message cannot
// be processed at this moment in time.
func NewErrUnimplementedMessageType(t types.ClientMessage) error {
Expand Down Expand Up @@ -45,62 +45,73 @@ func NewErrMultipleCommandsStatements() error {
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Syntax), psqlerr.LevelError)
}

// newErrClientCopyFailed is returned whenever the client aborts a copy operation.
func newErrClientCopyFailed(desc string) error {
err := fmt.Errorf("client aborted copy: %s", desc)
// TODO: What error code should this really be?
return psqlerr.WithSeverity(psqlerr.WithCode(err, codes.Uncategorized), psqlerr.LevelError)
}

// consumeCommands consumes incoming commands sent over the Postgres wire connection.
// Commands consumed from the connection are returned through a go channel.
// Responses for the given message type are written back to the client.
// This method keeps consuming messages until the client issues a close message
// or the connection is terminated.
func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) (err error) {
func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *buffer.Reader, writer *buffer.Writer) error {
srv.logger.Debug("ready for query... starting to consume commands")

// TODO: Include a value to identify unique connections
//
// include a identification value inside the context that
// could be used to identify connections at a later stage.

err = readyForQuery(writer, types.ServerIdle)
err := readyForQuery(writer, types.ServerIdle)
if err != nil {
return err
}

for {
t, length, err := reader.ReadTypedMsg()
if err == io.EOF {
return nil
if err = srv.consumeSingleCommand(ctx, reader, writer, conn); err != nil {
return err
}
}
}

// NOTE: we could recover from this scenario
if errors.Is(err, buffer.ErrMessageSizeExceeded) {
err = srv.handleMessageSizeExceeded(reader, writer, err)
if err != nil {
return err
}

continue
}
func (srv *Server) consumeSingleCommand(ctx context.Context, reader *buffer.Reader, writer *buffer.Writer, conn net.Conn) error {
t, length, err := reader.ReadTypedMsg()
if err == io.EOF {
return nil
}

// NOTE: we could recover from this scenario
if errors.Is(err, buffer.ErrMessageSizeExceeded) {
err = handleMessageSizeExceeded(reader, writer, err)
if err != nil {
return err
}

if srv.closing.Load() {
return nil
}
return nil
}

// NOTE: we increase the wait group by one in order to make sure that idle
// connections are not blocking a close.
srv.wg.Add(1)
srv.logger.Debug("<- incoming command", slog.Int("length", length), slog.String("type", t.String()))
err = srv.handleCommand(ctx, conn, t, reader, writer)
srv.wg.Done()
if errors.Is(err, io.EOF) {
return nil
}
if err != nil {
return err
}

if err != nil {
return err
}
if srv.closing.Load() {
return nil
}

// NOTE: we increase the wait group by one in order to make sure that idle
// connections are not blocking a close.
srv.wg.Add(1)
srv.logger.Debug("<- incoming command", slog.Int("length", length), slog.String("type", t.String()))
err = srv.handleCommand(ctx, conn, t, reader, writer)
srv.wg.Done()
if errors.Is(err, io.EOF) {
return nil
}

return err
}

// handleMessageSizeExceeded attempts to unwrap the given error message as
Expand All @@ -112,7 +123,7 @@ func (srv *Server) consumeCommands(ctx context.Context, conn net.Conn, reader *b
// type. A fatal error is returned when an unexpected error is returned while
// consuming the expected message size or when attempting to write the error
// message back to the client.
func (srv *Server) handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exceeded error) (err error) {
func handleMessageSizeExceeded(reader *buffer.Reader, writer *buffer.Writer, exceeded error) (err error) {
unwrapped, has := buffer.UnwrapMessageSizeExceeded(exceeded)
if !has {
return exceeded
Expand All @@ -130,7 +141,7 @@ func (srv *Server) handleMessageSizeExceeded(reader *buffer.Reader, writer *buff
// message type and reader buffer containing the actual message. The type
// indecates a action executed by the client.
// https://www.postgresql.org/docs/14/protocol-message-formats.html
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) (err error) {
func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.ClientMessage, reader *buffer.Reader, writer *buffer.Writer) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down Expand Up @@ -209,7 +220,7 @@ func (srv *Server) handleCommand(ctx context.Context, conn net.Conn, t types.Cli
writer.End() //nolint:errcheck
return nil
case types.ClientTerminate:
err = srv.handleConnTerminate(ctx)
err := srv.handleConnTerminate(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -267,7 +278,7 @@ func (srv *Server) handleSimpleQuery(ctx context.Context, reader *buffer.Reader,
return ErrorCode(writer, err)
}

err = statements[index].fn(ctx, NewDataWriter(ctx, statements[index].columns, nil, writer), nil)
err = statements[index].fn(ctx, NewDataWriter(ctx, statements[index].columns, nil, reader, writer), nil)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down Expand Up @@ -337,8 +348,10 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr
return err
}

switch d[0] {
case 'S':
srv.logger.Debug("incoming describe request", slog.String("type", types.DescribeMessage(d[0]).String()), slog.String("name", name))

switch types.DescribeMessage(d[0]) {
case types.DescribeStatement:
statement, err := srv.Statements.Get(ctx, name)
if err != nil {
return err
Expand All @@ -355,7 +368,7 @@ func (srv *Server) handleDescribe(ctx context.Context, reader *buffer.Reader, wr

// NOTE: the format codes are not yet known at this point in time.
return srv.writeColumnDescription(ctx, writer, nil, statement.columns)
case 'P':
case types.DescribePortal:
portal, err := srv.Portals.Get(ctx, name)
if err != nil {
return err
Expand Down Expand Up @@ -412,7 +425,7 @@ func (srv *Server) handleBind(ctx context.Context, reader *buffer.Reader, writer
return err
}

formats, err := srv.readColumnTypes(ctx, reader)
formats, err := srv.readColumnTypes(reader)
if err != nil {
return err
}
Expand Down Expand Up @@ -503,7 +516,7 @@ func (srv *Server) readParameters(ctx context.Context, reader *buffer.Reader) ([
return parameters, nil
}

func (srv *Server) readColumnTypes(ctx context.Context, reader *buffer.Reader) ([]FormatCode, error) {
func (srv *Server) readColumnTypes(reader *buffer.Reader) ([]FormatCode, error) {
length, err := reader.GetUint16()
if err != nil {
return nil, err
Expand Down Expand Up @@ -544,7 +557,7 @@ func (srv *Server) handleExecute(ctx context.Context, reader *buffer.Reader, wri
}

srv.logger.Debug("executing", slog.String("name", name), slog.Uint64("limit", uint64(limit)))
err = srv.Portals.Execute(ctx, name, writer)
err = srv.Portals.Execute(ctx, name, reader, writer)
if err != nil {
return ErrorCode(writer, err)
}
Expand Down
Loading
Loading