Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQLServer add suport for SSMS client #13337

Merged
merged 2 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions lib/srv/db/sqlserver/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ import (
"net"
"strconv"

"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/sqlserver/protocol"
"github.com/gravitational/trace"

mssql "github.com/denisenkom/go-mssqldb"
"github.com/denisenkom/go-mssqldb/msdsn"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/sqlserver/protocol"
)

// Connector defines an interface for connecting to a SQL Server so it can be
Expand Down Expand Up @@ -76,6 +76,7 @@ func (c *connector) Connect(ctx context.Context, sessionCtx *common.Session, log
LoginOptions: options,
Encryption: msdsn.EncryptionRequired,
TLSConfig: tlsConfig,
PacketSize: loginPacket.PacketSize(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow. how did you figure this out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wireshark has TDS packet filter that allows to visualise SQLServer traffic. Since the tsh proxy db has --tunnel is pretty easy to look/debug db wire protocol.

}, auth)

conn, err := connector.Connect(ctx)
Expand Down
74 changes: 45 additions & 29 deletions lib/srv/db/sqlserver/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (e *Engine) SendError(err error) {
}

// HandleConnection authorizes the incoming client connection, connects to the
// target SQL Server server and starts proxying messages between client/server.
// target SQLServer server and starts proxying messages between client/server.
func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error {
// Pre-Login packet was handled on the Proxy. Now we expect the client to
// send us a Login7 packet that contains username/database information and
Expand Down Expand Up @@ -99,10 +99,18 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
return trace.Wrap(err)
}

// Start proxying packets between client and server.
err = e.proxy(ctx, serverConn)
if err != nil {
return trace.Wrap(err)
clientErrCh := make(chan error, 1)
serverErrCh := make(chan error, 1)
go e.receiveFromClient(e.clientConn, serverConn, clientErrCh)
go e.receiveFromServer(serverConn, e.clientConn, serverErrCh)

select {
case err := <-clientErrCh:
e.Log.WithError(err).Debug("Client done.")
case err := <-serverErrCh:
e.Log.WithError(err).Debug("Server done.")
case <-ctx.Done():
e.Log.Debug("Context canceled.")
}

return nil
Expand Down Expand Up @@ -149,33 +157,41 @@ func (e *Engine) checkAccess(ctx context.Context, sessionCtx *common.Session) er
return nil
}

// proxy proxies all traffic between the client and server connections.
func (e *Engine) proxy(ctx context.Context, serverConn io.ReadWriteCloser) error {
errCh := make(chan error, 2)

go func() {
defer serverConn.Close()
_, err := io.Copy(serverConn, e.clientConn)
errCh <- err
// receiveFromClient relays protocol messages received from SQL Server client
// to SQL Server database.
func (e *Engine) receiveFromClient(clientConn, serverConn io.ReadWriteCloser, clientErrCh chan<- error) {
defer func() {
serverConn.Close()
e.Log.Debug("Stop receiving from client.")
close(clientErrCh)
}()

go func() {
defer serverConn.Close()
_, err := io.Copy(e.clientConn, serverConn)
errCh <- err
}()

var errs []error
for i := 0; i < 2; i++ {
select {
case err := <-errCh:
if err != nil && !utils.IsOKNetworkError(err) {
errs = append(errs, err)
for {
p, err := protocol.ReadPacket(clientConn)
if err != nil {
if utils.IsOKNetworkError(err) {
e.Log.Debug("Client connection closed.")
return
}
case <-ctx.Done():
return trace.Wrap(ctx.Err())
e.Log.WithError(err).Error("Failed to read client packet.")
clientErrCh <- err
return
}

_, err = serverConn.Write(p.Bytes())
if err != nil {
e.Log.WithError(err).Error("Failed to write server packet.")
clientErrCh <- err
return
}
}
}

return trace.NewAggregate(errs...)
// receiveFromServer relays protocol messages received from SQLServer database
// to SQLServer client.
func (e *Engine) receiveFromServer(serverConn, clientConn io.ReadWriteCloser, serverErrCh chan<- error) {
defer clientConn.Close()
_, err := io.Copy(clientConn, serverConn)
if err != nil && !utils.IsOKNetworkError(err) {
serverErrCh <- trace.Wrap(err)
}
}
6 changes: 6 additions & 0 deletions lib/srv/db/sqlserver/protocol/login7.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ func (p *Login7Packet) TypeFlags() uint8 {
return p.header.TypeFlags
}

// PacketSize return the packet size from the Login7 packet.
// Packet size is used by a server to negation the size of max packet length.
func (p *Login7Packet) PacketSize() uint16 {
return uint16(p.header.PacketSize)
}

// Login7Header contains options and offset/length pairs parsed from the Login7
// packet sent by client.
//
Expand Down
9 changes: 8 additions & 1 deletion lib/srv/db/sqlserver/protocol/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,13 @@ type Packet struct {
PacketHeader

// Data is the packet data bytes without header.
Data []byte
Data []byte
headerBytes [8]byte
}

func (p Packet) Bytes() []byte {
return append(p.headerBytes[:], p.Data...)

}

// ReadPacket reads a single full packet from the reader.
Expand All @@ -78,6 +84,7 @@ func ReadPacket(r io.Reader) (*Packet, error) {
}

return &Packet{
headerBytes: headerBytes,
PacketHeader: header,
Data: dataBytes,
}, nil
Expand Down