From 3db13f2382469460bc0769fa9882e15d4693145a Mon Sep 17 00:00:00 2001 From: Marek Smolinski Date: Thu, 9 Jun 2022 16:45:27 +0200 Subject: [PATCH] SQLServer add suport for SSMS client --- lib/srv/db/sqlserver/connect.go | 9 +-- lib/srv/db/sqlserver/engine.go | 74 +++++++++++++++---------- lib/srv/db/sqlserver/protocol/login7.go | 6 ++ lib/srv/db/sqlserver/protocol/packet.go | 9 ++- 4 files changed, 64 insertions(+), 34 deletions(-) diff --git a/lib/srv/db/sqlserver/connect.go b/lib/srv/db/sqlserver/connect.go index 0f15c28f4046d..4393cd81c48fb 100644 --- a/lib/srv/db/sqlserver/connect.go +++ b/lib/srv/db/sqlserver/connect.go @@ -22,12 +22,12 @@ import ( "net" "strconv" - "github.com/gravitational/teleport/lib/srv/db/common" - "github.com/gravitational/teleport/lib/srv/db/sqlserver/protocol" - "github.com/gravitational/trace" - mssql "github.com/denisenkom/go-mssqldb" "github.com/denisenkom/go-mssqldb/msdsn" + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/sqlserver/protocol" ) // Connector defines an interface for connecting to a SQL Server so it can be @@ -76,6 +76,7 @@ func (c *connector) Connect(ctx context.Context, sessionCtx *common.Session, log LoginOptions: options, Encryption: msdsn.EncryptionRequired, TLSConfig: tlsConfig, + PacketSize: loginPacket.PacketSize(), }, auth) conn, err := connector.Connect(ctx) diff --git a/lib/srv/db/sqlserver/engine.go b/lib/srv/db/sqlserver/engine.go index 554cc7228e8ff..82269107105e3 100644 --- a/lib/srv/db/sqlserver/engine.go +++ b/lib/srv/db/sqlserver/engine.go @@ -70,7 +70,7 @@ func (e *Engine) SendError(err error) { } // HandleConnection authorizes the incoming client connection, connects to the -// target SQL Server server and starts proxying messages between client/server. +// target SQLServer server and starts proxying messages between client/server. func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error { // Pre-Login packet was handled on the Proxy. Now we expect the client to // send us a Login7 packet that contains username/database information and @@ -99,10 +99,18 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio return trace.Wrap(err) } - // Start proxying packets between client and server. - err = e.proxy(ctx, serverConn) - if err != nil { - return trace.Wrap(err) + clientErrCh := make(chan error, 1) + serverErrCh := make(chan error, 1) + go e.receiveFromClient(e.clientConn, serverConn, clientErrCh) + go e.receiveFromServer(serverConn, e.clientConn, serverErrCh) + + select { + case err := <-clientErrCh: + e.Log.WithError(err).Debug("Client done.") + case err := <-serverErrCh: + e.Log.WithError(err).Debug("Server done.") + case <-ctx.Done(): + e.Log.Debug("Context canceled.") } return nil @@ -149,33 +157,41 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er return nil } -// proxy proxies all traffic between the client and server connections. -func (e *Engine) proxy(ctx context.Context, serverConn io.ReadWriteCloser) error { - errCh := make(chan error, 2) - - go func() { - defer serverConn.Close() - _, err := io.Copy(serverConn, e.clientConn) - errCh <- err +// receiveFromClient relays protocol messages received from SQL Server client +// to SQL Server database. +func (e *Engine) receiveFromClient(clientConn, serverConn io.ReadWriteCloser, clientErrCh chan<- error) { + defer func() { + serverConn.Close() + e.Log.Debug("Stop receiving from client.") + close(clientErrCh) }() - - go func() { - defer serverConn.Close() - _, err := io.Copy(e.clientConn, serverConn) - errCh <- err - }() - - var errs []error - for i := 0; i < 2; i++ { - select { - case err := <-errCh: - if err != nil && !utils.IsOKNetworkError(err) { - errs = append(errs, err) + for { + p, err := protocol.ReadPacket(clientConn) + if err != nil { + if utils.IsOKNetworkError(err) { + e.Log.Debug("Client connection closed.") + return } - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) + e.Log.WithError(err).Error("Failed to read client packet.") + clientErrCh <- err + return + } + + _, err = serverConn.Write(p.Bytes()) + if err != nil { + e.Log.WithError(err).Error("Failed to write server packet.") + clientErrCh <- err + return } } +} - return trace.NewAggregate(errs...) +// receiveFromServer relays protocol messages received from SQLServer database +// to SQLServer client. +func (e *Engine) receiveFromServer(serverConn, clientConn io.ReadWriteCloser, serverErrCh chan<- error) { + defer clientConn.Close() + _, err := io.Copy(clientConn, serverConn) + if err != nil && !utils.IsOKNetworkError(err) { + serverErrCh <- trace.Wrap(err) + } } diff --git a/lib/srv/db/sqlserver/protocol/login7.go b/lib/srv/db/sqlserver/protocol/login7.go index a413fb559ed46..c60b1e83efaa0 100644 --- a/lib/srv/db/sqlserver/protocol/login7.go +++ b/lib/srv/db/sqlserver/protocol/login7.go @@ -61,6 +61,12 @@ func (p *Login7Packet) TypeFlags() uint8 { return p.header.TypeFlags } +// PacketSize return the packet size from the Login7 packet. +// Packet size is used by a server to negation the size of max packet length. +func (p *Login7Packet) PacketSize() uint16 { + return uint16(p.header.PacketSize) +} + // Login7Header contains options and offset/length pairs parsed from the Login7 // packet sent by client. // diff --git a/lib/srv/db/sqlserver/protocol/packet.go b/lib/srv/db/sqlserver/protocol/packet.go index accc54ef91d57..f0ba74654957a 100644 --- a/lib/srv/db/sqlserver/protocol/packet.go +++ b/lib/srv/db/sqlserver/protocol/packet.go @@ -54,7 +54,13 @@ type Packet struct { PacketHeader // Data is the packet data bytes without header. - Data []byte + Data []byte + headerBytes [8]byte +} + +func (p Packet) Bytes() []byte { + return append(p.headerBytes[:], p.Data...) + } // ReadPacket reads a single full packet from the reader. @@ -78,6 +84,7 @@ func ReadPacket(r io.Reader) (*Packet, error) { } return &Packet{ + headerBytes: headerBytes, PacketHeader: header, Data: dataBytes, }, nil