Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions server/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -2561,11 +2561,18 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr
return false
}

// Copy off original pa in case it changes.
pa := c.pa

mt, _ := c.isMsgTraceEnabled()
if mt != nil {
pa := c.pa
// We are going to replace "pa" with our copy of c.pa, but to restore
// to the original copy of c.pa, we need to save it again.
cpa := c.pa
msg = mt.setOriginAccountHeaderIfNeeded(c, acc, msg)
defer func() { c.pa = pa }()
defer func() { c.pa = cpa }()
// Update pa with our current c.pa state.
pa = c.pa
}

var (
Expand All @@ -2579,6 +2586,7 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr
didDeliver bool
prodIsMQTT = c.isMqtt()
dlvMsgs int64
dlvExtraSz int64
)

// Get a subscription from the pool
Expand Down Expand Up @@ -2676,8 +2684,11 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr
}
}

// Assume original message
dmsg := msg
if mt != nil {
msg = mt.setHopHeader(c, msg)
// If trace is enabled, we need to set the hop header per gateway.
dmsg = mt.setHopHeader(c, dmsg)
}

// Setup the message header.
Expand Down Expand Up @@ -2727,16 +2738,22 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr
sub.nm, sub.max = 0, 0
sub.client = gwc
sub.subject = subject
if c.deliverMsg(prodIsMQTT, sub, acc, subject, mreply, mh, msg, false) {
if c.deliverMsg(prodIsMQTT, sub, acc, subject, mreply, mh, dmsg, false) {
// We don't count internal deliveries so count only if sub.icb is nil
if sub.icb == nil {
dlvMsgs++
dlvExtraSz += int64(len(dmsg) - len(msg))
}
didDeliver = true
}

// If we set the header reset the origin pub args.
if mt != nil {
c.pa = pa
}
}
if dlvMsgs > 0 {
totalBytes := dlvMsgs * int64(len(msg))
totalBytes := dlvMsgs*int64(len(msg)) + dlvExtraSz
// For non MQTT producers, remove the CR_LF * number of messages
if !prodIsMQTT {
totalBytes -= dlvMsgs * int64(LEN_CR_LF)
Expand Down
110 changes: 75 additions & 35 deletions server/msgtrace_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1496,17 +1496,33 @@ func TestMsgTraceWithGateways(t *testing.T) {
s1 := runGatewayServer(o1)
defer s1.Shutdown()

waitForOutboundGateways(t, s1, 1, time.Second)
waitForInboundGateways(t, s2, 1, time.Second)
waitForOutboundGateways(t, s2, 1, time.Second)
o3 := testGatewayOptionsFromToWithServers(t, "C", "B", s2)
o3.NoSystemAccount = false
s3 := runGatewayServer(o3)
defer s3.Shutdown()

waitForOutboundGateways(t, s1, 2, time.Second)
waitForInboundGateways(t, s1, 2, time.Second)
waitForInboundGateways(t, s2, 2, time.Second)
waitForOutboundGateways(t, s2, 2, time.Second)
waitForInboundGateways(t, s3, 2, time.Second)
waitForOutboundGateways(t, s3, 2, time.Second)

nc2 := natsConnect(t, s2.ClientURL(), nats.Name("sub2"))
defer nc2.Close()
sub2 := natsQueueSubSync(t, nc2, "foo.*", "my_queue")
sub2 := natsQueueSubSync(t, nc2, "foo.*", "my_queue_2")

nc22 := natsConnect(t, s2.ClientURL(), nats.Name("sub22"))
defer nc22.Close()
sub22 := natsQueueSubSync(t, nc22, "*.*", "my_queue_22")

nc3 := natsConnect(t, s2.ClientURL(), nats.Name("sub3"))
nc3 := natsConnect(t, s3.ClientURL(), nats.Name("sub3"))
defer nc3.Close()
sub3 := natsQueueSubSync(t, nc3, "*.*", "my_queue_2")
sub3 := natsQueueSubSync(t, nc3, "foo.*", "my_queue_3")

nc32 := natsConnect(t, s3.ClientURL(), nats.Name("sub32"))
defer nc32.Close()
sub32 := natsQueueSubSync(t, nc32, "*.*", "my_queue_32")

nc1 := natsConnect(t, s1.ClientURL(), nats.Name("sub1"))
defer nc1.Close()
Expand Down Expand Up @@ -1540,17 +1556,18 @@ func TestMsgTraceWithGateways(t *testing.T) {
checkAppMsg := func(sub *nats.Subscription, expected bool) {
if expected {
appMsg := natsNexMsg(t, sub, time.Second)
require_Equal[string](t, string(appMsg.Data), "hello!")
require_Equal(t, string(appMsg.Data), "hello!")
}
// Check that no (more) messages are received.
if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout {
t.Fatalf("Did not expect application message, got %s", msg.Data)
}
}
for _, sub := range []*nats.Subscription{sub1, sub2, sub3} {
for _, sub := range []*nats.Subscription{sub1, sub2, sub22, sub3, sub32} {
checkAppMsg(sub, test.deliverMsg)
}

var previousHop string
check := func() {
traceMsg := natsNexMsg(t, traceSub, time.Second)
var e MsgTraceEvent
Expand All @@ -1560,58 +1577,81 @@ func TestMsgTraceWithGateways(t *testing.T) {
require_True(t, ingress != nil)
switch ingress.Kind {
case CLIENT:
require_Equal[string](t, e.Server.Name, s1.Name())
require_Equal[string](t, ingress.Account, globalAccountName)
require_Equal[string](t, ingress.Subject, "foo.bar")
require_Equal(t, e.Server.Name, s1.Name())
require_Equal(t, ingress.Account, globalAccountName)
require_Equal(t, ingress.Subject, "foo.bar")
egress := e.Egresses()
require_Equal[int](t, len(egress), 2)
require_Equal(t, len(egress), 3)
for _, eg := range egress {
switch eg.Kind {
case CLIENT:
require_Equal[string](t, eg.Name, "sub1")
require_Equal[string](t, eg.Subscription, "*.bar")
require_Equal[string](t, eg.Queue, _EMPTY_)
require_Equal(t, eg.Name, "sub1")
require_Equal(t, eg.Subscription, "*.bar")
require_Equal(t, eg.Queue, _EMPTY_)
case GATEWAY:
require_Equal[string](t, eg.Name, s2.Name())
require_Equal[string](t, eg.Error, _EMPTY_)
require_Equal[string](t, eg.Subscription, _EMPTY_)
require_Equal[string](t, eg.Queue, _EMPTY_)
if eg.Name != s2.Name() && eg.Name != s3.Name() {
t.Fatalf("Expected name to be %q or %q, got %q", s2.Name(), s3.Name(), eg.Name)
}
require_Equal(t, eg.Error, _EMPTY_)
require_Equal(t, eg.Subscription, _EMPTY_)
require_Equal(t, eg.Queue, _EMPTY_)
default:
t.Fatalf("Unexpected egress: %+v", eg)
}
}
case GATEWAY:
require_Equal[string](t, e.Server.Name, s2.Name())
require_Equal[string](t, ingress.Account, globalAccountName)
require_Equal[string](t, ingress.Subject, "foo.bar")
require_True(t, e.Request.Header != nil)
require_Len(t, len(e.Request.Header[MsgTraceHop]), 1)
hop := e.Request.Header[MsgTraceHop][0]
require_True(t, hop == "1" || hop == "2")
if previousHop == _EMPTY_ {
previousHop = hop
} else if hop == previousHop {
t.Fatalf("Expected different hop value, got the same %q", hop)
}
var sub2Name, queue2Name, sub3Name, queue3Name string
switch e.Server.Name {
case s2.Name():
require_Equal(t, e.Server.Cluster, "B")
sub2Name, sub3Name = "sub2", "sub22"
queue2Name, queue3Name = "my_queue_2", "my_queue_22"
case s3.Name():
require_Equal(t, e.Server.Cluster, "C")
sub2Name, sub3Name = "sub3", "sub32"
queue2Name, queue3Name = "my_queue_3", "my_queue_32"
default:
t.Fatalf("Unexpected server name %q", e.Server.Name)
}
require_Equal(t, ingress.Account, globalAccountName)
require_Equal(t, ingress.Subject, "foo.bar")
egress := e.Egresses()
require_Equal[int](t, len(egress), 2)
require_Equal(t, len(egress), 2)
var gotSub2, gotSub3 int
for _, eg := range egress {
require_True(t, eg.Kind == CLIENT)
switch eg.Name {
case "sub2":
require_Equal[string](t, eg.Subscription, "foo.*")
require_Equal[string](t, eg.Queue, "my_queue")
case sub2Name:
require_Equal(t, eg.Subscription, "foo.*")
require_Equal(t, eg.Queue, queue2Name)
gotSub2++
case "sub3":
require_Equal[string](t, eg.Subscription, "*.*")
require_Equal[string](t, eg.Queue, "my_queue_2")
case sub3Name:
require_Equal(t, eg.Subscription, "*.*")
require_Equal(t, eg.Queue, queue3Name)
gotSub3++
default:
t.Fatalf("Unexpected egress name: %+v", eg)
}
}
require_Equal[int](t, gotSub2, 1)
require_Equal[int](t, gotSub3, 1)

require_Equal(t, gotSub2, 1)
require_Equal(t, gotSub3, 1)
default:
t.Fatalf("Unexpected ingress: %+v", ingress)
}
}
// We should get 2 events
check()
check()
// We should get 3 events
for range 3 {
check()
}
// Make sure we are not receiving more traces
if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil {
t.Fatalf("Should not have received trace message: %s", tm.Data)
Expand Down