From 22e217ac69edd9dce468af5b268a81c8b5ffc8ae Mon Sep 17 00:00:00 2001 From: Ivan Kozlovic Date: Sat, 18 Oct 2025 11:26:08 -0600 Subject: [PATCH] [FIXED] Message Tracing: Hop header set properly per gateway Setting the "hop" header for each gateway could cause header corruption. This is now fixed. A test dealing with gateway has been improved to include more than one gateway, which would have demonstrated the issue. The test now passes and ensures that the hop is different per gateway. Related to #7442 Signed-off-by: Ivan Kozlovic --- server/gateway.go | 27 ++++++++-- server/msgtrace_test.go | 110 +++++++++++++++++++++++++++------------- 2 files changed, 97 insertions(+), 40 deletions(-) diff --git a/server/gateway.go b/server/gateway.go index 085f0e18f54..f4982cc2ce6 100644 --- a/server/gateway.go +++ b/server/gateway.go @@ -2561,11 +2561,18 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr return false } + // Copy off original pa in case it changes. + pa := c.pa + mt, _ := c.isMsgTraceEnabled() if mt != nil { - pa := c.pa + // We are going to replace "pa" with our copy of c.pa, but to restore + // to the original copy of c.pa, we need to save it again. + cpa := c.pa msg = mt.setOriginAccountHeaderIfNeeded(c, acc, msg) - defer func() { c.pa = pa }() + defer func() { c.pa = cpa }() + // Update pa with our current c.pa state. + pa = c.pa } var ( @@ -2579,6 +2586,7 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr didDeliver bool prodIsMQTT = c.isMqtt() dlvMsgs int64 + dlvExtraSz int64 ) // Get a subscription from the pool @@ -2676,8 +2684,11 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr } } + // Assume original message + dmsg := msg if mt != nil { - msg = mt.setHopHeader(c, msg) + // If trace is enabled, we need to set the hop header per gateway. + dmsg = mt.setHopHeader(c, dmsg) } // Setup the message header. @@ -2727,16 +2738,22 @@ func (c *client) sendMsgToGateways(acc *Account, msg, subject, reply []byte, qgr sub.nm, sub.max = 0, 0 sub.client = gwc sub.subject = subject - if c.deliverMsg(prodIsMQTT, sub, acc, subject, mreply, mh, msg, false) { + if c.deliverMsg(prodIsMQTT, sub, acc, subject, mreply, mh, dmsg, false) { // We don't count internal deliveries so count only if sub.icb is nil if sub.icb == nil { dlvMsgs++ + dlvExtraSz += int64(len(dmsg) - len(msg)) } didDeliver = true } + + // If we set the header reset the origin pub args. + if mt != nil { + c.pa = pa + } } if dlvMsgs > 0 { - totalBytes := dlvMsgs * int64(len(msg)) + totalBytes := dlvMsgs*int64(len(msg)) + dlvExtraSz // For non MQTT producers, remove the CR_LF * number of messages if !prodIsMQTT { totalBytes -= dlvMsgs * int64(LEN_CR_LF) diff --git a/server/msgtrace_test.go b/server/msgtrace_test.go index 57b72186f94..769a5e2cd4c 100644 --- a/server/msgtrace_test.go +++ b/server/msgtrace_test.go @@ -1496,17 +1496,33 @@ func TestMsgTraceWithGateways(t *testing.T) { s1 := runGatewayServer(o1) defer s1.Shutdown() - waitForOutboundGateways(t, s1, 1, time.Second) - waitForInboundGateways(t, s2, 1, time.Second) - waitForOutboundGateways(t, s2, 1, time.Second) + o3 := testGatewayOptionsFromToWithServers(t, "C", "B", s2) + o3.NoSystemAccount = false + s3 := runGatewayServer(o3) + defer s3.Shutdown() + + waitForOutboundGateways(t, s1, 2, time.Second) + waitForInboundGateways(t, s1, 2, time.Second) + waitForInboundGateways(t, s2, 2, time.Second) + waitForOutboundGateways(t, s2, 2, time.Second) + waitForInboundGateways(t, s3, 2, time.Second) + waitForOutboundGateways(t, s3, 2, time.Second) nc2 := natsConnect(t, s2.ClientURL(), nats.Name("sub2")) defer nc2.Close() - sub2 := natsQueueSubSync(t, nc2, "foo.*", "my_queue") + sub2 := natsQueueSubSync(t, nc2, "foo.*", "my_queue_2") + + nc22 := natsConnect(t, s2.ClientURL(), nats.Name("sub22")) + defer nc22.Close() + sub22 := natsQueueSubSync(t, nc22, "*.*", "my_queue_22") - nc3 := natsConnect(t, s2.ClientURL(), nats.Name("sub3")) + nc3 := natsConnect(t, s3.ClientURL(), nats.Name("sub3")) defer nc3.Close() - sub3 := natsQueueSubSync(t, nc3, "*.*", "my_queue_2") + sub3 := natsQueueSubSync(t, nc3, "foo.*", "my_queue_3") + + nc32 := natsConnect(t, s3.ClientURL(), nats.Name("sub32")) + defer nc32.Close() + sub32 := natsQueueSubSync(t, nc32, "*.*", "my_queue_32") nc1 := natsConnect(t, s1.ClientURL(), nats.Name("sub1")) defer nc1.Close() @@ -1540,17 +1556,18 @@ func TestMsgTraceWithGateways(t *testing.T) { checkAppMsg := func(sub *nats.Subscription, expected bool) { if expected { appMsg := natsNexMsg(t, sub, time.Second) - require_Equal[string](t, string(appMsg.Data), "hello!") + require_Equal(t, string(appMsg.Data), "hello!") } // Check that no (more) messages are received. if msg, err := sub.NextMsg(100 * time.Millisecond); err != nats.ErrTimeout { t.Fatalf("Did not expect application message, got %s", msg.Data) } } - for _, sub := range []*nats.Subscription{sub1, sub2, sub3} { + for _, sub := range []*nats.Subscription{sub1, sub2, sub22, sub3, sub32} { checkAppMsg(sub, test.deliverMsg) } + var previousHop string check := func() { traceMsg := natsNexMsg(t, traceSub, time.Second) var e MsgTraceEvent @@ -1560,58 +1577,81 @@ func TestMsgTraceWithGateways(t *testing.T) { require_True(t, ingress != nil) switch ingress.Kind { case CLIENT: - require_Equal[string](t, e.Server.Name, s1.Name()) - require_Equal[string](t, ingress.Account, globalAccountName) - require_Equal[string](t, ingress.Subject, "foo.bar") + require_Equal(t, e.Server.Name, s1.Name()) + require_Equal(t, ingress.Account, globalAccountName) + require_Equal(t, ingress.Subject, "foo.bar") egress := e.Egresses() - require_Equal[int](t, len(egress), 2) + require_Equal(t, len(egress), 3) for _, eg := range egress { switch eg.Kind { case CLIENT: - require_Equal[string](t, eg.Name, "sub1") - require_Equal[string](t, eg.Subscription, "*.bar") - require_Equal[string](t, eg.Queue, _EMPTY_) + require_Equal(t, eg.Name, "sub1") + require_Equal(t, eg.Subscription, "*.bar") + require_Equal(t, eg.Queue, _EMPTY_) case GATEWAY: - require_Equal[string](t, eg.Name, s2.Name()) - require_Equal[string](t, eg.Error, _EMPTY_) - require_Equal[string](t, eg.Subscription, _EMPTY_) - require_Equal[string](t, eg.Queue, _EMPTY_) + if eg.Name != s2.Name() && eg.Name != s3.Name() { + t.Fatalf("Expected name to be %q or %q, got %q", s2.Name(), s3.Name(), eg.Name) + } + require_Equal(t, eg.Error, _EMPTY_) + require_Equal(t, eg.Subscription, _EMPTY_) + require_Equal(t, eg.Queue, _EMPTY_) default: t.Fatalf("Unexpected egress: %+v", eg) } } case GATEWAY: - require_Equal[string](t, e.Server.Name, s2.Name()) - require_Equal[string](t, ingress.Account, globalAccountName) - require_Equal[string](t, ingress.Subject, "foo.bar") + require_True(t, e.Request.Header != nil) + require_Len(t, len(e.Request.Header[MsgTraceHop]), 1) + hop := e.Request.Header[MsgTraceHop][0] + require_True(t, hop == "1" || hop == "2") + if previousHop == _EMPTY_ { + previousHop = hop + } else if hop == previousHop { + t.Fatalf("Expected different hop value, got the same %q", hop) + } + var sub2Name, queue2Name, sub3Name, queue3Name string + switch e.Server.Name { + case s2.Name(): + require_Equal(t, e.Server.Cluster, "B") + sub2Name, sub3Name = "sub2", "sub22" + queue2Name, queue3Name = "my_queue_2", "my_queue_22" + case s3.Name(): + require_Equal(t, e.Server.Cluster, "C") + sub2Name, sub3Name = "sub3", "sub32" + queue2Name, queue3Name = "my_queue_3", "my_queue_32" + default: + t.Fatalf("Unexpected server name %q", e.Server.Name) + } + require_Equal(t, ingress.Account, globalAccountName) + require_Equal(t, ingress.Subject, "foo.bar") egress := e.Egresses() - require_Equal[int](t, len(egress), 2) + require_Equal(t, len(egress), 2) var gotSub2, gotSub3 int for _, eg := range egress { require_True(t, eg.Kind == CLIENT) switch eg.Name { - case "sub2": - require_Equal[string](t, eg.Subscription, "foo.*") - require_Equal[string](t, eg.Queue, "my_queue") + case sub2Name: + require_Equal(t, eg.Subscription, "foo.*") + require_Equal(t, eg.Queue, queue2Name) gotSub2++ - case "sub3": - require_Equal[string](t, eg.Subscription, "*.*") - require_Equal[string](t, eg.Queue, "my_queue_2") + case sub3Name: + require_Equal(t, eg.Subscription, "*.*") + require_Equal(t, eg.Queue, queue3Name) gotSub3++ default: t.Fatalf("Unexpected egress name: %+v", eg) } } - require_Equal[int](t, gotSub2, 1) - require_Equal[int](t, gotSub3, 1) - + require_Equal(t, gotSub2, 1) + require_Equal(t, gotSub3, 1) default: t.Fatalf("Unexpected ingress: %+v", ingress) } } - // We should get 2 events - check() - check() + // We should get 3 events + for range 3 { + check() + } // Make sure we are not receiving more traces if tm, err := traceSub.NextMsg(250 * time.Millisecond); err == nil { t.Fatalf("Should not have received trace message: %s", tm.Data)