diff --git a/lib/srv/db/sqlserver/engine.go b/lib/srv/db/sqlserver/engine.go index ef4fa9f83401a..30de9779a6f2f 100644 --- a/lib/srv/db/sqlserver/engine.go +++ b/lib/srv/db/sqlserver/engine.go @@ -17,6 +17,7 @@ limitations under the License. package sqlserver import ( + "bytes" "context" "io" "net" @@ -138,6 +139,10 @@ func (e *Engine) receiveFromClient(clientConn, serverConn io.ReadWriteCloser, cl }() msgFromClient := common.GetMessagesFromClientMetric(sessionCtx.Database) + // initialPacketHeader and chunkData are used to accumulate chunked packets + // to build a single packet with full contents for auditing. + var initialPacketHeader protocol.PacketHeader + var chunkData bytes.Buffer for { p, err := protocol.ReadPacket(clientConn) @@ -150,16 +155,26 @@ func (e *Engine) receiveFromClient(clientConn, serverConn io.ReadWriteCloser, cl clientErrCh <- err return } - msgFromClient.Inc() - sqlPacket, err := protocol.ToSQLPacket(p) - switch { - case err != nil: - e.Log.WithError(err).Errorf("Failed to parse SQLServer packet.") - e.emitMalformedPacket(e.Context, sessionCtx, p) - default: - e.auditPacket(e.Context, sessionCtx, sqlPacket) + // Audit events are going to be emitted only on final messages, this way + // the packet parsing can be complete and provide the query/RPC + // contents. + if protocol.IsFinalPacket(p) { + sqlPacket, err := e.toSQLPacket(initialPacketHeader, p, &chunkData) + switch { + case err != nil: + e.Log.WithError(err).Errorf("Failed to parse SQLServer packet.") + e.emitMalformedPacket(e.Context, sessionCtx, p) + default: + e.auditPacket(e.Context, sessionCtx, sqlPacket) + } + } else { + if chunkData.Len() == 0 { + initialPacketHeader = p.Header() + } + + chunkData.Write(p.Data()) } _, err = serverConn.Write(p.Bytes()) @@ -171,6 +186,27 @@ func (e *Engine) receiveFromClient(clientConn, serverConn io.ReadWriteCloser, cl } } +// toSQLPacket Parses a regular (self-contained) or chunked packet into an SQL +// packet (used for auditing). +func (e *Engine) toSQLPacket(header protocol.PacketHeader, packet *protocol.BasicPacket, chunks *bytes.Buffer) (protocol.Packet, error) { + if chunks.Len() > 0 { + defer chunks.Reset() + chunks.Write(packet.Data()) + // We're safe to "read" chunk using `Bytes()` function because the + // packet processing copies the packet contents. + packetData := chunks.Bytes() + + var err error + // The final chucked packet header must be the first packet header. + packet, err = protocol.NewBasicPacket(header, packetData) + if err != nil { + return nil, trace.Wrap(err) + } + } + + return protocol.ToSQLPacket(packet) +} + // receiveFromServer relays protocol messages received from MySQL database // to MySQL client. func (e *Engine) receiveFromServer(serverConn, clientConn io.ReadWriteCloser, serverErrCh chan<- error) { diff --git a/lib/srv/db/sqlserver/engine_test.go b/lib/srv/db/sqlserver/engine_test.go index 5f4271c31a83e..b484f15d102c9 100644 --- a/lib/srv/db/sqlserver/engine_test.go +++ b/lib/srv/db/sqlserver/engine_test.go @@ -66,13 +66,13 @@ func TestHandleConnectionAuditEvents(t *testing.T) { } tests := []struct { - name string - packet []byte - checks []check + name string + packets [][]byte + checks []check }{ { - name: "rpc request procedure", - packet: fixtures.RPCClientRequest, + name: "rpc request procedure", + packets: [][]byte{fixtures.GenerateCustomRPCCallPacket("foo3")}, checks: []check{ hasNoErr(), hasAuditEventCode(libevents.DatabaseSessionStartCode), @@ -94,8 +94,8 @@ func TestHandleConnectionAuditEvents(t *testing.T) { }, }, { - name: "rpc request param", - packet: fixtures.RPCClientRequestParam, + name: "rpc request param", + packets: [][]byte{fixtures.GenerateExecuteSQLRPCPacket("select @@version")}, checks: []check{ hasNoErr(), hasAuditEventCode(libevents.DatabaseSessionStartCode), @@ -118,8 +118,8 @@ func TestHandleConnectionAuditEvents(t *testing.T) { }, }, { - name: "sql batch", - packet: fixtures.SQLBatch, + name: "sql batch", + packets: [][]byte{fixtures.GenerateBatchQueryPacket("\nselect 'foo' as 'bar'\n ")}, checks: []check{ hasNoErr(), hasAuditEventCode(libevents.DatabaseSessionStartCode), @@ -144,8 +144,8 @@ func TestHandleConnectionAuditEvents(t *testing.T) { }, }, { - name: "malformed packet", - packet: fixtures.MalformedPacketTest, + name: "malformed packet", + packets: [][]byte{fixtures.MalformedPacketTest}, checks: []check{ hasNoErr(), hasAuditEventCode(libevents.DatabaseSessionStartCode), @@ -153,6 +153,129 @@ func TestHandleConnectionAuditEvents(t *testing.T) { hasAuditEventCode(libevents.DatabaseSessionMalformedPacketCode), }, }, + { + name: "sql batch chunked packets", + packets: fixtures.GenerateBatchQueryChunkedPacket(5, "select 'foo' as 'bar'"), + checks: []check{ + hasNoErr(), + hasAuditEventCode(libevents.DatabaseSessionStartCode), + hasAuditEventCode(libevents.DatabaseSessionEndCode), + hasAuditEvent(1, &events.DatabaseSessionQuery{ + DatabaseMetadata: events.DatabaseMetadata{ + DatabaseUser: "sa", + DatabaseType: "self-hosted", + DatabaseService: "dummy", + DatabaseURI: "uri", + DatabaseProtocol: "test", + }, + Metadata: events.Metadata{ + Type: libevents.DatabaseSessionQueryEvent, + Code: libevents.DatabaseSessionQueryCode, + }, + DatabaseQuery: "select 'foo' as 'bar'", + Status: events.Status{ + Success: true, + }, + }), + }, + }, + { + name: "rpc request param chunked", + packets: fixtures.GenerateExecuteSQLRPCChunkedPacket(5, "select @@version"), + checks: []check{ + hasNoErr(), + hasAuditEventCode(libevents.DatabaseSessionStartCode), + hasAuditEventCode(libevents.DatabaseSessionEndCode), + hasAuditEvent(1, &events.SQLServerRPCRequest{ + DatabaseMetadata: events.DatabaseMetadata{ + DatabaseUser: "sa", + DatabaseType: "self-hosted", + DatabaseService: "dummy", + DatabaseURI: "uri", + DatabaseProtocol: "test", + }, + Metadata: events.Metadata{ + Type: libevents.DatabaseSessionSQLServerRPCRequestEvent, + Code: libevents.SQLServerRPCRequestCode, + }, + Parameters: []string{"select @@version"}, + Procname: "Sp_ExecuteSql", + }), + }, + }, + { + name: "intercalated chunked messages", + packets: intercalateChunkedPacketMessages( + fixtures.GenerateExecuteSQLRPCChunkedPacket(5, "select @@version"), + fixtures.GenerateExecuteSQLRPCPacket("select 1"), + 2, + ), + checks: []check{ + hasNoErr(), + hasAuditEventCode(libevents.DatabaseSessionStartCode), + hasAuditEventCode(libevents.DatabaseSessionEndCode), + hasAuditEvent(1, &events.SQLServerRPCRequest{ + DatabaseMetadata: events.DatabaseMetadata{ + DatabaseUser: "sa", + DatabaseType: "self-hosted", + DatabaseService: "dummy", + DatabaseURI: "uri", + DatabaseProtocol: "test", + }, + Metadata: events.Metadata{ + Type: libevents.DatabaseSessionSQLServerRPCRequestEvent, + Code: libevents.SQLServerRPCRequestCode, + }, + Parameters: []string{"select @@version"}, + Procname: "Sp_ExecuteSql", + }), + hasAuditEvent(2, &events.SQLServerRPCRequest{ + DatabaseMetadata: events.DatabaseMetadata{ + DatabaseUser: "sa", + DatabaseType: "self-hosted", + DatabaseService: "dummy", + DatabaseURI: "uri", + DatabaseProtocol: "test", + }, + Metadata: events.Metadata{ + Type: libevents.DatabaseSessionSQLServerRPCRequestEvent, + Code: libevents.SQLServerRPCRequestCode, + }, + Parameters: []string{"select 1"}, + Procname: "Sp_ExecuteSql", + }), + hasAuditEvent(3, &events.SQLServerRPCRequest{ + DatabaseMetadata: events.DatabaseMetadata{ + DatabaseUser: "sa", + DatabaseType: "self-hosted", + DatabaseService: "dummy", + DatabaseURI: "uri", + DatabaseProtocol: "test", + }, + Metadata: events.Metadata{ + Type: libevents.DatabaseSessionSQLServerRPCRequestEvent, + Code: libevents.SQLServerRPCRequestCode, + }, + Parameters: []string{"select @@version"}, + Procname: "Sp_ExecuteSql", + }), + hasAuditEvent(4, &events.SQLServerRPCRequest{ + DatabaseMetadata: events.DatabaseMetadata{ + DatabaseUser: "sa", + DatabaseType: "self-hosted", + DatabaseService: "dummy", + DatabaseURI: "uri", + DatabaseProtocol: "test", + }, + Metadata: events.Metadata{ + Type: libevents.DatabaseSessionSQLServerRPCRequestEvent, + Code: libevents.SQLServerRPCRequestCode, + }, + Parameters: []string{"select 1"}, + Procname: "Sp_ExecuteSql", + }), + }, + }, } for _, tc := range tests { @@ -170,8 +293,11 @@ func TestHandleConnectionAuditEvents(t *testing.T) { }) require.NoError(t, err) - _, err = b.Write(tc.packet) - require.NoError(t, err) + for _, packet := range tc.packets { + _, err = b.Write(packet) + require.NoError(t, err) + } + emitterMock := &eventstest.MockRecorderEmitter{} audit, err := common.NewAudit(common.AuditConfig{ Emitter: emitterMock, @@ -208,6 +334,17 @@ func TestHandleConnectionAuditEvents(t *testing.T) { } } +// intercalateChunkedPacketMessages intercalates a chunked packet with a regular packet a specified number of times. +func intercalateChunkedPacketMessages(chunkedPacket [][]byte, regularPacket []byte, repeat int) [][]byte { + var result [][]byte + for i := 0; i < repeat; i++ { + result = append(result, chunkedPacket...) + result = append(result, regularPacket) + } + + return result +} + type mockConn struct { net.Conn buff bytes.Buffer diff --git a/lib/srv/db/sqlserver/protocol/constants.go b/lib/srv/db/sqlserver/protocol/constants.go index 7cbdd3db464f5..46f9d98f143e0 100644 --- a/lib/srv/db/sqlserver/protocol/constants.go +++ b/lib/srv/db/sqlserver/protocol/constants.go @@ -32,8 +32,9 @@ const ( // packetHeaderSize is the size of the protocol packet header. packetHeaderSize = 8 - // packetStatusLast indicates that the packet is the last in the request. - packetStatusLast uint8 = 0x01 + // PacketStatusLast indicates that the packet is the last in the request. + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/ce398f9a-7d47-4ede-8f36-9dd6fc21ca43 + PacketStatusLast uint8 = 0x01 preLoginOptionVersion = 0x00 preLoginOptionEncryption = 0x01 diff --git a/lib/srv/db/sqlserver/protocol/fixtures/packets.go b/lib/srv/db/sqlserver/protocol/fixtures/packets.go index 9fc711b13685c..286c27b22a98d 100644 --- a/lib/srv/db/sqlserver/protocol/fixtures/packets.go +++ b/lib/srv/db/sqlserver/protocol/fixtures/packets.go @@ -18,6 +18,33 @@ package fixtures import ( "encoding/binary" + "unicode/utf16" +) + +const ( + // packetStatusFinalMessage packet status value used to indicate the message + // does not contain more chunks. + packetStatusFinalMessage = 0x01 + // packetStatusNotFinalMessage packet status value used to indicate the + // message is not the final one. It must not contain the final message flag + // bit. + packetStatusNotFinalMessage = 0x04 + // packetTypeSQLBatch is the packet type for SQL Batch. + packetTypeSQLBatch = 0x01 + // PacketTypeSQLBatch is the packet type for RPC Call. + packetTypeRPCCall = 0x03 + // procIDExecuteSQL is the RPC ID for Sp_ExecuteSQL. + procIDExecuteSQL = 10 + // nvarcharType is the flag that represents the type NVARCHAR. + nvarcharType = 0xef + // ntextType is the flag that represents the type NTEXT. + ntextType = 0x63 + // intnType is the flag that represents the type INTN. + intnType = 0x26 + // intnTinyType is the flag that indicates the integer type tiny int. + intnTinyType = 0x01 + // statusFlags consists 3 flag bits + 5 reserved bits. + statusFlags = 0x00 ) var ( @@ -45,36 +72,6 @@ var ( 0x4C, 0x00, 0x2D, 0x00, 0x33, 0x00, 0x32, 0x00, 0x4F, 0x00, 0x44, 0x00, 0x42, 0x00, 0x43, 0x00, } - // SQLBatch is an example of SQLBatchClientRequest client request packet from the protocol spec: - // - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/b05b006b-3cbf-404b-bcaf-7ec584b54706 - SQLBatch = []byte{ - 0x01, 0x01, 0x00, 0x5c, 0x00, 0x00, 0x01, 0x00, 0x16, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, - 0x73, 0x00, 0x65, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x63, 0x00, 0x74, 0x00, 0x20, 0x00, 0x27, 0x00, - 0x66, 0x00, 0x6f, 0x00, 0x6f, 0x00, 0x27, 0x00, 0x20, 0x00, 0x61, 0x00, 0x73, 0x00, 0x20, 0x00, - 0x27, 0x00, 0x62, 0x00, 0x61, 0x00, 0x72, 0x00, 0x27, 0x00, 0x0a, 0x00, 0x20, 0x00, 0x20, 0x00, - 0x20, 0x00, 0x20, 0x00, 0x20, 0x00, 0x20, 0x00, 0x20, 0x00, 0x20, 0x00, - } - - // RPCClientRequest is an example of RPCClientRequest client request packet from the protocol spec: - // - // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/1469b2fa-6cab-42e9-91f9-044d358b306b - RPCClientRequest = []byte{ - 0x03, 0x01, 0x00, 0x2F, 0x00, 0x00, 0x01, 0x00, 0x16, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, - 0x66, 0x00, 0x6F, 0x00, 0x6F, 0x00, 0x33, 0x00, 0x00, 0x00, 0x00, 0x02, 0x26, 0x02, 0x00, - } - - // RPCClientRequestParam is a custom RPC Request with SQL param. - RPCClientRequestParam = []byte{ - 0x03, 0x01, 0x00, 0x50, 0x00, 0x00, 0x01, 0x00, 0x16, 0x00, 0x00, 0x00, 0x12, 0x00, 0x00, 0x00, - 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xff, 0xff, - 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0xe7, 0x40, 0x1f, 0x09, 0x04, 0xd0, 0x00, 0x34, 0x20, 0x00, - 0x73, 0x00, 0x65, 0x00, 0x6c, 0x00, 0x65, 0x00, 0x63, 0x00, 0x74, 0x00, 0x20, 0x00, 0x40, 0x00, - 0x40, 0x00, 0x76, 0x00, 0x65, 0x00, 0x72, 0x00, 0x73, 0x00, 0x69, 0x00, 0x6f, 0x00, 0x6e, 0x00, - } - // MalformedPacketTest is an RPC Request malformed packet. MalformedPacketTest = []byte{ 0x03, 0x01, 0x00, 0x90, 0x00, 0x00, 0x02, 0x00, 0x72, 0x00, 0x61, 0x00, 0x6d, 0x00, 0x5f, 0x00, @@ -87,50 +84,235 @@ var ( 0x00, 0x70, 0x00, 0x61, 0x00, 0x72, 0x00, 0x61, 0x00, 0x6d, 0x00, 0x5f, 0x00, 0x31, 0x00, 0x00, 0xe7, 0x40, 0x1f, 0x09, 0x04, 0xd0, 0x00, 0x34, 0x06, 0x00, 0x64, 0x00, 0x62, 0x00, 0x6f, 0x00, } + + // AllHeadersSliceWithTransactionDescriptor is a ALL_HEADERS data stream + // header containing the TransactionDescriptor data. It is required for + // SQLBatch and RPC packets. + // + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/e17e54ae-0fac-48b7-b8a8-c267be297923 + AllHeadersSliceWithTransactionDescriptor = []byte{ + 0x16, 0x00, 0x00, 0x00, // Total length + 0x12, 0x00, 0x00, 0x00, // Header length + 0x02, 0x00, // Header type: Transaction descriptor + // BEGIN Transaction description: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/4257dd95-ef6c-4621-b75d-270738487d68 + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // TransactionDescriptor + 0x00, 0x00, 0x00, 0x00, // OutstandingRequestCount + // End transaction description. + } + + // FieldCollation definition for data parameters. Using "raw collation" is + // ok for testing because the server is not processing it. + // + // https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/3d29e8dc-218a-42c6-9ba4-947ebca9fd7e + FieldCollation = []byte{0x00, 0x00, 0x00, 0x00, 0x00} ) // RPCClientVariableLength returns a RPCCLientRequest packet containing a // partially Length-prefixed Bytes request, as described here: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/3f983fde-0509-485a-8c40-a9fa6679a828 -func RPCClientPartiallyLength(length uint64, chunks uint64) []byte { - packet := []byte{ - 0x03, 0x01, - 0x00, 0x00, // Length placeholder - 0x00, 0x00, 0x01, 0x00, 0x16, 0x00, 0x00, 0x00, - 0x12, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, - 0x00, 0x00, 0x04, 0x00, 0x66, 0x00, 0x6F, 0x00, - 0x6F, 0x00, 0x33, 0x00, 0x00, 0x00, 0x00, 0x02, - 0xef, // NVARCHARTYPE - 0xff, 0xff, // NULL length - 0x00, 0x00, 0x00, 0x00, 0x00, // NVARCHARTYPE flags - } - - // NVARCHARTYPE must have even length +func RPCClientPartiallyLength(procName string, length uint64, chunks uint64) []byte { + params := []byte{ + 0x00, // Parameter name (B_VARCHAR) + statusFlags, // Status flags + nvarcharType, // NVARCHARYTYPE + 0xff, 0xff, // NULL length (this indicates it is a PLP parameter) + } + params = append(params, FieldCollation...) + + // Since we're not encoding the string into UC2, here we force it to have a + // valid size. if length%2 != 0 { length += 1 } - packet = binary.LittleEndian.AppendUint64(packet, length) + // PLP_BODY length is ULONGLONGLEN (64-bit). + params = binary.LittleEndian.AppendUint64(params, length) if length > 0 && chunks > 1 { chunkSize := length / chunks rem := length for i := uint64(0); i < chunks-1; i++ { - packet = binary.LittleEndian.AppendUint32(packet, uint32(chunkSize)) + // PLP_CHUNK length size is ULONGLEN (32-bit). + params = binary.LittleEndian.AppendUint32(params, uint32(chunkSize)) data := make([]byte, chunkSize) - packet = append(packet, data...) + params = append(params, data...) rem -= chunkSize } // Last chunk will contain the remaining data. - packet = binary.LittleEndian.AppendUint32(packet, uint32(rem)) + params = binary.LittleEndian.AppendUint32(params, uint32(rem)) data := make([]byte, rem) - packet = append(packet, data...) + params = append(params, data...) } // PLP_TERMINATOR - packet = append(packet, []byte{0x00, 0x00, 0x00, 0x00}...) + params = append(params, []byte{0x00, 0x00, 0x00, 0x00}...) + return generateRPCCallPacket(packetStatusFinalMessage, true, 1, rpcProcName(procName), params) +} + +// GeneratePacketHeader generates a packet header based on the specified parameters. +func GeneratePacketHeader(packetType byte, packetStatus byte, length int, seq int) []byte { + header := []byte{ + packetType, // Type: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/9b4a463c-2634-4a4b-ac35-bebfff2fb0f7 + packetStatus, // Status: https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/ce398f9a-7d47-4ede-8f36-9dd6fc21ca43 + 0x00, 0x00, // Packet length (placeholder). https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/c1cddd03-b448-470a-946a-9b1b908f27a7 + 0x00, 0x00, // Sever process ID (SPID). This is only sent by the server. https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/fcfc00d0-6df1-42c8-8d34-93007b9a80f0 + byte(seq), // PacketID (currently ignored). https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/ec9e8663-191c-4dd1-baa8-48bbfba5ed7e + 0x00, // Window (currently ignored). https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/fbc2f523-92e6-4316-aaed-3a6966e548ad + } + binary.BigEndian.PutUint16(header[2:], uint16(length+len(header))) + return header +} + +// GenerateBatchQueryPacket generates a final SQLBatch with the provided query. +func GenerateBatchQueryPacket(query string) []byte { + return generateBatchQueryPacket(0x01, 1, true, query) +} + +func generateBatchQueryPacket(packetStatus byte, seq int, withAllHeaders bool, query string) []byte { + var packet []byte + if withAllHeaders { + packet = append(packet, AllHeadersSliceWithTransactionDescriptor...) + } + packet = append(packet, encodeString(query)...) + return append(GeneratePacketHeader(packetTypeSQLBatch, packetStatus, len(packet), seq), packet...) +} + +// GenerateBatchQueryChunkedPacket split a batch query into multiple network packets. +func GenerateBatchQueryChunkedPacket(chunks int, query string) [][]byte { + queryLen := len(query) + chunkSize := queryLen / chunks + + packets := [][]byte{generateBatchQueryPacket(0x04, 1, true, query[0:chunkSize])} + for i := 1; i < chunks-1; i++ { + // Sequence packets must not have the all headers information. + packets = append(packets, generateBatchQueryPacket(packetStatusNotFinalMessage, i+2, false, query[chunkSize*i:chunkSize*(i+1)])) + } + + // Last packet must indicate the final message. + return append(packets, generateBatchQueryPacket(packetStatusFinalMessage, chunks, false, query[chunkSize*(chunks-1):])) +} + +// GenerateExecuteSQLRPCChunkedPacket slipt a RPC Call into multiple network +// packets. +func GenerateExecuteSQLRPCChunkedPacket(chunks int, query string) [][]byte { + rpcProcName := rpcProcID(procIDExecuteSQL) + queryLen := len(query) + chunkSize := queryLen / chunks + + packets := [][]byte{generateRPCCallPacket(0x04, true, 1, rpcProcName, generateNVARCHARParam(query[0:chunkSize], len(encodeString(query))))} + for i := 1; i < chunks-1; i++ { + packetData := encodeString(query[chunkSize*i : chunkSize*(i+1)]) + packets = append(packets, append(GeneratePacketHeader(0x03, packetStatusNotFinalMessage, len(packetData), i+1), packetData...)) + } + + // Last packet must indicate the final message. + packetData := encodeString(query[chunkSize*(chunks-1):]) + return append(packets, append(GeneratePacketHeader(packetTypeRPCCall, packetStatusFinalMessage, len(packetData), chunks), packetData...)) +} + +// GenerateCustomRPCCallPacket generates a packet containing a custom RPC call +// with an empty integer parameter. +func GenerateCustomRPCCallPacket(procName string) []byte { + params := []byte{ + 0x00, // Parameter name (B_VARCHAR). Here we're passing a length 0 (BYTELEN) which means empty name. + statusFlags, // Status flags. + intnType, // Parameter type. INTNTYPE + intnTinyType, // Integer type. + 0x00, // Actual data length. Providing length 0 zero we don't need to encode the integer. + } + + return generateRPCCallPacket(packetStatusFinalMessage, true, 1, rpcProcName(procName), params) +} + +// generateRPCCallPacket generates a RPC call packet. +func generateRPCCallPacket(packetStatus byte, withAllHeaders bool, seq int, rpcProcName []byte, params []byte) []byte { + var packet []byte + if withAllHeaders { + packet = append(packet, AllHeadersSliceWithTransactionDescriptor...) + } + // Proc name + packet = append(packet, rpcProcName...) + // Option flags: 3 flag bits + 13 reserved bits. + packet = append(packet, []byte{0x00, 0x00}...) + // Parameters + packet = append(packet, params...) + return append(GeneratePacketHeader(packetTypeRPCCall, packetStatus, len(packet), seq), packet...) +} + +// GenerateExecuteSQLRPCPacket generates a RPC call packet containing a +// single parameter (NVARCHARTYPE). +func GenerateExecuteSQLRPCPacket(query string) []byte { + return generateRPCCallPacket(packetStatusFinalMessage, true, 1, rpcProcID(procIDExecuteSQL), generateNVARCHARParam(query, 0)) +} + +// GenerateExecuteSQLRPCPacketNTEXT generates a RPC call packet containing a +// single parameter (NTEXT). +func GenerateExecuteSQLRPCPacketNTEXT(query string) []byte { + return generateRPCCallPacket(packetStatusFinalMessage, true, 1, rpcProcID(procIDExecuteSQL), generateNTEXTParam(query)) +} + +// generateNVARCHARParam generates a NVARCHARTYPE parameter. +func generateNVARCHARParam(contents string, totalLength int) []byte { + encodedContents := encodeString(contents) + length := len(encodedContents) + if totalLength > 0 { + length = totalLength + } + + // Parameter length (USHORTLEN_TYPE for NVARCHARTYPE). + encodedLength := binary.LittleEndian.AppendUint16([]byte{}, uint16(length)) + packet := []byte{ + 0x00, // Parameter name (B_VARCHAR). Here we're passing a length 0 (BYTELEN) which means empty name. + statusFlags, // Status flags. + nvarcharType, // Parameter type: NVARCHARTYPE + encodedLength[0], encodedLength[1], // Param length + } + // Data collation flags. + packet = append(packet, FieldCollation...) + // Param data also has the parameter length (same encoding). + packet = append(packet, encodedLength...) + return append(packet, encodedContents...) +} + +// generateNTEXTParam generates a NTEXT parameter. +// +// The parameter format is based on the official documentataion and compared +// with requests generated by Azure Data Studio. +func generateNTEXTParam(contents string) []byte { + encodedContents := encodeString(contents) + // Parameter length (LONGLEN_TYPE for NTEXT). + encodedLength := binary.LittleEndian.AppendUint32([]byte{}, uint32(len(encodedContents))) + packet := append([]byte{ + 0x00, // Parameter name (B_VARCHAR). Here we're passing a length 0 (BYTELEN) which means empty name. + statusFlags, // Status flags. + ntextType, // Parameter type: NTEXT + }, encodedLength...) + // Data collation flags. + packet = append(packet, FieldCollation...) + // Param data also has the parameter length (same encoding). + packet = append(packet, encodedLength...) + return append(packet, encodedContents...) +} + +// rpcProcName returns PROC NAME field used on RPC calls. +func rpcProcName(name string) []byte { + var packet []byte + packet = binary.LittleEndian.AppendUint16(packet, uint16(len(name))) + return append(packet, encodeString(name)...) +} + +// rpcProcID returns the PROC ID field used on RPC calls. +func rpcProcID(id uint16) []byte { + packet := []byte{0xff, 0xff} + return binary.LittleEndian.AppendUint16(packet, id) +} + +// encodeString encodes the string into UTF-16 LittleEndian. +func encodeString(s string) []byte { + var encodedString []byte + for _, r := range utf16.Encode([]rune(s)) { + encodedString = binary.LittleEndian.AppendUint16(encodedString, r) + } - binary.BigEndian.PutUint16(packet[2:], uint16(len(packet))) - return packet + return encodedString } diff --git a/lib/srv/db/sqlserver/protocol/fuzz_test.go b/lib/srv/db/sqlserver/protocol/fuzz_test.go index c4da2802bed9b..6265e38480aae 100644 --- a/lib/srv/db/sqlserver/protocol/fuzz_test.go +++ b/lib/srv/db/sqlserver/protocol/fuzz_test.go @@ -45,7 +45,7 @@ func FuzzMSSQLLogin(f *testing.F) { func FuzzRPCClientPartialLength(f *testing.F) { f.Fuzz(func(t *testing.T, length uint64, chunks uint64) { - packet, err := ReadPacket(bytes.NewReader(fixtures.RPCClientPartiallyLength(length, chunks))) + packet, err := ReadPacket(bytes.NewReader(fixtures.RPCClientPartiallyLength("foo3", length, chunks))) require.NoError(t, err) require.Equal(t, packet.Type(), PacketTypeRPCRequest) diff --git a/lib/srv/db/sqlserver/protocol/packet.go b/lib/srv/db/sqlserver/protocol/packet.go index 5b2535d651955..174f395a3371c 100644 --- a/lib/srv/db/sqlserver/protocol/packet.go +++ b/lib/srv/db/sqlserver/protocol/packet.go @@ -118,6 +118,21 @@ func ReadPacket(r io.Reader) (*BasicPacket, error) { return p, nil } +// NewBasicPacket creates a new BasicPacket instance with the specified +// PacketHeader and data. +func NewBasicPacket(header PacketHeader, data []byte) (*BasicPacket, error) { + headerBytes, err := header.Marshal() + if err != nil { + return nil, trace.Wrap(err) + } + raw := bytes.NewBuffer(append(headerBytes, data...)) + return &BasicPacket{ + header: header, + data: data, + raw: *raw, + }, nil +} + // ToSQLPacket tries to convert basicPacket to MSServer SQL packet. func ToSQLPacket(p *BasicPacket) (out Packet, err error) { defer func() { @@ -147,7 +162,7 @@ func ToSQLPacket(p *BasicPacket) (out Packet, err error) { func makePacket(pktType uint8, pktData []byte) ([]byte, error) { header := PacketHeader{ Type: pktType, - Status: packetStatusLast, + Status: PacketStatusLast, Length: uint16(packetHeaderSize + len(pktData)), } @@ -158,3 +173,8 @@ func makePacket(pktType uint8, pktData []byte) ([]byte, error) { return append(headerBytes, pktData...), nil } + +// IsFinalPacket returns true there are no more packets on the message. +func IsFinalPacket(packet Packet) bool { + return packet.Header().Status&PacketStatusLast == 1 +} diff --git a/lib/srv/db/sqlserver/protocol/protocol_test.go b/lib/srv/db/sqlserver/protocol/protocol_test.go index fb3c9f146616c..3ff81acdb5629 100644 --- a/lib/srv/db/sqlserver/protocol/protocol_test.go +++ b/lib/srv/db/sqlserver/protocol/protocol_test.go @@ -62,7 +62,7 @@ func TestErrorResponse(t *testing.T) { // TestSQLBatch verifies SQLPatch packet parsing. func TestSQLBatch(t *testing.T) { - packet, err := ReadPacket(bytes.NewReader(fixtures.SQLBatch)) + packet, err := ReadPacket(bytes.NewReader(fixtures.GenerateBatchQueryPacket("\nselect 'foo' as 'bar'\n "))) require.NoError(t, err) r, err := ToSQLPacket(packet) require.NoError(t, err) @@ -74,7 +74,7 @@ func TestSQLBatch(t *testing.T) { // TestRPCClientRequestParam verifies RPC Request with param packet parsing. func TestRPCClientRequestParam(t *testing.T) { - packet, err := ReadPacket(bytes.NewReader(fixtures.RPCClientRequestParam)) + packet, err := ReadPacket(bytes.NewReader(fixtures.GenerateExecuteSQLRPCPacket("select @@version"))) require.NoError(t, err) require.Equal(t, packet.Type(), PacketTypeRPCRequest) r, err := ToSQLPacket(packet) @@ -86,7 +86,7 @@ func TestRPCClientRequestParam(t *testing.T) { // TestRPCClientRequest verifies rpc request packet parsing. func TestRPCClientRequest(t *testing.T) { - packet, err := ReadPacket(bytes.NewReader(fixtures.RPCClientRequest)) + packet, err := ReadPacket(bytes.NewReader(fixtures.GenerateCustomRPCCallPacket("foo3"))) require.NoError(t, err) require.Equal(t, packet.Type(), PacketTypeRPCRequest) r, err := ToSQLPacket(packet) @@ -97,7 +97,7 @@ func TestRPCClientRequest(t *testing.T) { } func TestRPCClientRequestPartialLength(t *testing.T) { - packet, err := ReadPacket(bytes.NewReader(fixtures.RPCClientPartiallyLength(32, 4))) + packet, err := ReadPacket(bytes.NewReader(fixtures.RPCClientPartiallyLength("foo3", 32, 4))) require.NoError(t, err) require.Equal(t, packet.Type(), PacketTypeRPCRequest) @@ -109,3 +109,21 @@ func TestRPCClientRequestPartialLength(t *testing.T) { require.Equal(t, "foo3", p.ProcName) require.NoError(t, err) } + +func TestRPCClientRequestParamNTEXT(t *testing.T) { + // Currently the ReadTypeInfo is not parsing the NTEXT contents correctly, + // giving a invalid memory access. + // + // TODO(gabrielcorado): validate this use case and ensure the parameter is + // correctly parsed on the driver. + t.Skip() + + packet, err := ReadPacket(bytes.NewReader(fixtures.GenerateExecuteSQLRPCPacketNTEXT("select @@version"))) + require.NoError(t, err) + require.Equal(t, packet.Type(), PacketTypeRPCRequest) + r, err := ToSQLPacket(packet) + require.NoError(t, err) + p, ok := r.(*RPCRequest) + require.True(t, ok) + require.Equal(t, "select @@version", p.Parameters[0]) +}