Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions packetbeat/protos/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import (
"github.com/elastic/beats/v7/packetbeat/procs"
"github.com/elastic/beats/v7/packetbeat/protos"
"github.com/elastic/beats/v7/packetbeat/protos/tcp"

"go.mongodb.org/mongo-driver/bson/primitive"
)

var debugf = logp.MakeDebug("mongodb")
Expand All @@ -54,7 +56,7 @@ type mongodbPlugin struct {

type transactionKey struct {
tcp common.HashableTCPTuple
id int
id int32
}

var unmatchedRequests = monitoring.NewInt(nil, "mongodb.unmatched_requests")
Expand Down Expand Up @@ -232,7 +234,7 @@ func (mongodb *mongodbPlugin) handleMongodb(

func (mongodb *mongodbPlugin) onRequest(conn *mongodbConnectionData, msg *mongodbMessage) {
// publish request only transaction
if !awaitsReply(msg.opCode) {
if !awaitsReply(msg) {
mongodb.onTransComplete(msg, nil)
return
}
Expand Down Expand Up @@ -273,7 +275,6 @@ func (mongodb *mongodbPlugin) onResponse(conn *mongodbConnectionData, msg *mongo
func (mongodb *mongodbPlugin) onTransComplete(requ, resp *mongodbMessage) {
trans := newTransaction(requ, resp)
debugf("Mongodb transaction completed: %s", trans.mongodb)

mongodb.publishTransaction(trans)
}

Expand All @@ -294,8 +295,9 @@ func newTransaction(requ, resp *mongodbMessage) *transaction {
}
trans.params = requ.params
trans.resource = requ.resource
trans.bytesIn = requ.messageLength
trans.bytesIn = int(requ.messageLength)
trans.documents = requ.documents
trans.requestDocuments = requ.documents // preserving request documents that contains mongodb query for the new OP_MSG based protocol
}

// fill response
Expand All @@ -308,7 +310,7 @@ func newTransaction(requ, resp *mongodbMessage) *transaction {
trans.documents = resp.documents

trans.endTime = resp.ts
trans.bytesOut = resp.messageLength
trans.bytesOut = int(resp.messageLength)

}

Expand All @@ -325,10 +327,17 @@ func (mongodb *mongodbPlugin) ReceivedFin(tcptuple *common.TCPTuple, dir uint8,
return private
}

func copyMapWithoutKey(d map[string]interface{}, key string) map[string]interface{} {
func copyMapWithoutKey(d map[string]interface{}, keys ...string) map[string]interface{} {
res := map[string]interface{}{}
for k, v := range d {
if k != key {
found := false
for _, excludeKey := range keys {
if k == excludeKey {
found = true
break
}
}
if !found {
res[k] = v
}
}
Expand All @@ -337,29 +346,40 @@ func copyMapWithoutKey(d map[string]interface{}, key string) map[string]interfac

func reconstructQuery(t *transaction, full bool) (query string) {
query = t.resource + "." + t.method + "("
var doc interface{}

if len(t.params) > 0 {
var err error
var params string
if !full {
// remove the actual data.
// TODO: review if we need to add other commands here
switch t.method {
case "insert":
params, err = doc2str(copyMapWithoutKey(t.params, "documents"))
doc = copyMapWithoutKey(t.params, "documents")
case "update":
params, err = doc2str(copyMapWithoutKey(t.params, "updates"))
doc = copyMapWithoutKey(t.params, "updates")
case "findandmodify":
params, err = doc2str(copyMapWithoutKey(t.params, "update"))
doc = copyMapWithoutKey(t.params, "update")
}
} else {
params, err = doc2str(t.params)
doc = t.params
}
if err != nil {
debugf("Error marshaling params: %v", err)
} else {
query += params
} else if len(t.requestDocuments) > 0 { // This recovers the query document from OP_MSG
if m, ok := t.requestDocuments[0].(primitive.M); ok {
excludeKeys := []string{"lsid"}
if !full {
excludeKeys = append(excludeKeys, "documents")
}
doc = copyMapWithoutKey(m, excludeKeys...)
}
}

queryString, err := doc2str(doc)
if err != nil {
debugf("Error marshaling query document: %v", err)
} else {
query += queryString
}

query += ")"
skip, _ := t.event["numberToSkip"].(int)
if skip > 0 {
Expand All @@ -370,7 +390,7 @@ func reconstructQuery(t *transaction, full bool) (query string) {
if limit > 0 && limit < 0x7fffffff {
query += fmt.Sprintf(".limit(%d)", limit)
}
return
return query
}

func (mongodb *mongodbPlugin) publishTransaction(t *transaction) {
Expand Down
48 changes: 20 additions & 28 deletions packetbeat/protos/mongodb/mongodb_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,13 @@ func mongodbMessageParser(s *stream) (bool, bool) {
return true, false
}

if length > len(s.data) {
if int(length) > len(s.data) {
// Not yet reached the end of message
return true, false
}

// Tell decoder to only consider current message
d.truncate(length)
d.truncate(int(length))

// fill up the header common to all messages
// see http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/#standard-message-header
Expand All @@ -72,8 +72,7 @@ func mongodbMessageParser(s *stream) (bool, bool) {
}

s.message.opCode = opCode
s.message.isResponse = false // default is that the message is a request. If not opReplyParse will set this to false
s.message.expectsResponse = false
s.message.isResponse = false // default is that the message is a request. If not opReplyParse will set this to true
debugf("opCode = %d (%v)", s.message.opCode, s.message.opCode)

// then split depending on operation type
Expand All @@ -93,11 +92,9 @@ func mongodbMessageParser(s *stream) (bool, bool) {
s.message.method = "insert"
return opInsertParse(d, s.message)
case opQuery:
s.message.expectsResponse = true
return opQueryParse(d, s.message)
case opGetMore:
s.message.method = "getMore"
s.message.expectsResponse = true
return opGetMoreParse(d, s.message)
case opDelete:
s.message.method = "delete"
Expand All @@ -107,6 +104,11 @@ func mongodbMessageParser(s *stream) (bool, bool) {
return opKillCursorsParse(d, s.message)
case opMsg:
s.message.method = "msg"
// The assumption is that the message with responseTo == 0 is the request
// TODO: handle the cases where moreToCome flag is set (multiple responses chained by responseTo)
if s.message.responseTo > 0 {
s.message.isResponse = true
}
return opMsgParse(d, s.message)
}

Expand Down Expand Up @@ -141,7 +143,7 @@ func opReplyParse(d *decoder, m *mongodbMessage) (bool, bool) {
debugf("Prepare to read %d document from reply", m.event["numberReturned"])

documents := make([]interface{}, numberReturned)
for i := 0; i < numberReturned; i++ {
for i := int32(0); i < numberReturned; i++ {
var document bson.M
document, err = d.readDocument()
if err != nil {
Expand Down Expand Up @@ -235,19 +237,6 @@ func opInsertParse(d *decoder, m *mongodbMessage) (bool, bool) {
return true, true
}

func extractDocuments(query map[string]interface{}) []interface{} {
docsVi, present := query["documents"]
if !present {
return []interface{}{}
}

docs, ok := docsVi.([]interface{})
if !ok {
return []interface{}{}
}
return docs
}

// Try to guess whether this key:value pair found in
// the query represents a command.
func isDatabaseCommand(key string, val interface{}) bool {
Expand Down Expand Up @@ -387,12 +376,14 @@ func opKillCursorsParse(d *decoder, m *mongodbMessage) (bool, bool) {

func opMsgParse(d *decoder, m *mongodbMessage) (bool, bool) {
// ignore flagbits
_, err := d.readInt32()
flagBits, err := d.readInt32()
if err != nil {
logp.Err("An error occurred while parsing OP_MSG message: %s", err)
return false, false
}

m.SetFlagBits(flagBits)

// read sections
kind, err := d.readByte()
if err != nil {
Expand Down Expand Up @@ -423,7 +414,7 @@ func opMsgParse(d *decoder, m *mongodbMessage) (bool, bool) {
}
m.event["message"] = cstring
var documents []interface{}
for d.i < start+size {
for d.i < start+int(size) {
document, err := d.readDocument()
if err != nil {
logp.Err("An error occurred while parsing OP_MSG message: %s", err)
Expand All @@ -432,7 +423,8 @@ func opMsgParse(d *decoder, m *mongodbMessage) (bool, bool) {
documents = append(documents, document)
}
m.documents = documents

case msgKindInternal:
// Ignore the internal purposes section
default:
logp.Err("Unknown message kind: %v", kind)
return false, false
Expand Down Expand Up @@ -482,25 +474,25 @@ func (d *decoder) readByte() (byte, error) {
return d.in[i], nil
}

func (d *decoder) readInt32() (int, error) {
func (d *decoder) readInt32() (int32, error) {
b, err := d.readBytes(4)
if err != nil {
return 0, err
}

return int((uint32(b[0]) << 0) |
return int32((uint32(b[0]) << 0) |
(uint32(b[1]) << 8) |
(uint32(b[2]) << 16) |
(uint32(b[3]) << 24)), nil
}

func (d *decoder) readInt64() (int, error) {
func (d *decoder) readInt64() (int64, error) {
b, err := d.readBytes(8)
if err != nil {
return 0, err
}

return int((uint64(b[0]) << 0) |
return int64((uint64(b[0]) << 0) |
(uint64(b[1]) << 8) |
(uint64(b[2]) << 16) |
(uint64(b[3]) << 24) |
Expand All @@ -516,7 +508,7 @@ func (d *decoder) readDocument() (bson.M, error) {
if err != nil {
return nil, err
}
d.i = start + documentLength
d.i = start + int(documentLength)
if len(d.in) < d.i {
return nil, errors.New("document out of bounds")
}
Expand Down
69 changes: 36 additions & 33 deletions packetbeat/protos/mongodb/mongodb_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
package mongodb

import (
"encoding/json"
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -77,6 +80,39 @@ func TestMongodbParser_simpleRequest(t *testing.T) {
}
}

func TestMongodbParser_OpMsg(t *testing.T) {
files := []string{
"1req.bin",
"1res.bin",
"2req.bin",
"2req.bin",
"3req.bin",
"3res.bin",
}

for _, fn := range files {
data, err := os.ReadFile(filepath.Join("testdata", fn))
if err != nil {
t.Fatal(err)
}

st := &stream{data: data, message: new(mongodbMessage)}

ok, complete := mongodbMessageParser(st)

if !ok {
t.Errorf("Parsing returned error")
}
if !complete {
t.Errorf("Expecting a complete message")
}
_, err = json.Marshal(st.message.documents)
if err != nil {
t.Fatal(err)
}
}
}

func TestMongodbParser_unknownOpCode(t *testing.T) {
var data []byte
data = addInt32(data, 16) // length = 16
Expand Down Expand Up @@ -107,39 +143,6 @@ func addInt32(in []byte, v int32) []byte {
return append(in, byte(u), byte(u>>8), byte(u>>16), byte(u>>24))
}

func Test_extract_documents(t *testing.T) {
type io struct {
Input map[string]interface{}
Output []interface{}
}
tests := []io{
{
Input: map[string]interface{}{
"a": 1,
"documents": []interface{}{"a", "b", "c"},
},
Output: []interface{}{"a", "b", "c"},
},
{
Input: map[string]interface{}{
"a": 1,
},
Output: []interface{}{},
},
{
Input: map[string]interface{}{
"a": 1,
"documents": 1,
},
Output: []interface{}{},
},
}

for _, test := range tests {
assert.Equal(t, test.Output, extractDocuments(test.Input))
}
}

func Test_isDatabaseCommand(t *testing.T) {
type io struct {
Key string
Expand Down
Loading