diff --git a/message/message.go b/message/message.go index 410f5e3a7..a064ef823 100644 --- a/message/message.go +++ b/message/message.go @@ -193,3 +193,11 @@ func (m *Message) Copy() *Message { } return msg } + +// CopyWithContext copies all message without Acks/Nacks. +// The context is also propagated to the copy. +func (m *Message) CopyWithContext() *Message { + msg := m.Copy() + msg.ctx = m.ctx + return msg +} diff --git a/message/message_test.go b/message/message_test.go index 435fa5aca..e89cfdff9 100644 --- a/message/message_test.go +++ b/message/message_test.go @@ -1,6 +1,7 @@ package message_test import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -113,6 +114,37 @@ func TestMessage_Copy(t *testing.T) { assert.True(t, msg.Equals(msgCopy)) } +func TestMessage_CopyWithContext(t *testing.T) { + msg := message.NewMessage("1", []byte("foo")) + testCtx := context.Background() + testCtx = context.WithValue(testCtx, "foo", "bar") + msg.SetContext(testCtx) + + msgCopy := msg.CopyWithContext() + testCtx = context.WithValue(testCtx, "foo", "baz") + testCtx = context.WithValue(testCtx, "abc", "def") + + copyMsgCtx := msgCopy.Context() + assert.True(t, copyMsgCtx.Value("foo") == "bar", "expected context not being copied") + assert.False(t, copyMsgCtx.Value("abc") == "def", "non-expected context being copied") + assert.True(t, msg.Equals(msgCopy)) +} + +func TestMessage_CopyWithContextAndMetadata(t *testing.T) { + msg := message.NewMessage("1", []byte("foo")) + testCtx := context.Background() + testCtx = context.WithValue(testCtx, "foo", "bar") + msg.SetContext(testCtx) + msg.Metadata.Set("foo", "bar") + msgCopy := msg.CopyWithContext() + + msg.Metadata.Set("foo", "baz") + + copyMsgCtx := msgCopy.Context() + assert.True(t, copyMsgCtx.Value("foo") == "bar", "expected context not being copied") + assert.Equal(t, msgCopy.Metadata.Get("foo"), "bar", "did not expect changing source message's metadata to alter copy's metadata") +} + func TestMessage_CopyMetadata(t *testing.T) { msg := message.NewMessage("1", []byte("foo")) msg.Metadata.Set("foo", "bar") diff --git a/message/pubsub.go b/message/pubsub.go index 258e1767e..bc2441ee5 100644 --- a/message/pubsub.go +++ b/message/pubsub.go @@ -31,7 +31,7 @@ type Subscriber interface { // If message processing fails and the message should be redelivered `Nack()` should be called instead. // // When the provided ctx is canceled, the subscriber closes the subscription and the output channel. - // The provided ctx is passed to all produced messages. + // The provided ctx is passed to all produced messages (this is configurable for the local Pub/Sub implementation). // When Nack or Ack is called on the message, the context of the message is canceled. Subscribe(ctx context.Context, topic string) (<-chan *Message, error) // Close closes all subscriptions with their output channels and flushes offsets etc. when needed. diff --git a/pubsub/gochannel/pubsub.go b/pubsub/gochannel/pubsub.go index 0d4fc61c7..9cd0b5f9e 100644 --- a/pubsub/gochannel/pubsub.go +++ b/pubsub/gochannel/pubsub.go @@ -26,6 +26,11 @@ type Config struct { // When true, Publish will block until subscriber Ack's the message. // If there are no subscribers, Publish will not block (also when Persistent is true). BlockPublishUntilSubscriberAck bool + + // PreserveContext is a flag that determines if the context should be preserved when sending messages to subscribers. + // This behavior is different from other implementations of Publishers where data travels over the network, + // hence context can't be preserved in those cases + PreserveContext bool } // GoChannel is the simplest Pub/Sub implementation. @@ -87,7 +92,11 @@ func (g *GoChannel) Publish(topic string, messages ...*message.Message) error { messagesToPublish := make(message.Messages, len(messages)) for i, msg := range messages { - messagesToPublish[i] = msg.Copy() + if g.config.PreserveContext { + messagesToPublish[i] = msg.CopyWithContext() + } else { + messagesToPublish[i] = msg.Copy() + } } g.subscribersLock.RLock() @@ -187,11 +196,12 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag subLock.(*sync.Mutex).Lock() s := &subscriber{ - ctx: ctx, - uuid: watermill.NewUUID(), - outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer), - logger: g.logger, - closing: make(chan struct{}), + ctx: ctx, + uuid: watermill.NewUUID(), + outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer), + logger: g.logger, + closing: make(chan struct{}), + preserveContext: g.config.PreserveContext, } go func(s *subscriber, g *GoChannel) { @@ -320,6 +330,8 @@ type subscriber struct { logger watermill.LoggerAdapter closed bool closing chan struct{} + + preserveContext bool } func (s *subscriber) Close() { @@ -344,8 +356,14 @@ func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields wat s.sending.Lock() defer s.sending.Unlock() - ctx, cancelCtx := context.WithCancel(s.ctx) - defer cancelCtx() + ctx := msg.Context() + + //This is getting the context from the message, not the subscriber + if !s.preserveContext { + var cancelCtx context.CancelFunc + ctx, cancelCtx = context.WithCancel(s.ctx) + defer cancelCtx() + } SendToSubscriber: for { diff --git a/pubsub/gochannel/pubsub_test.go b/pubsub/gochannel/pubsub_test.go index 012064091..2478d080f 100644 --- a/pubsub/gochannel/pubsub_test.go +++ b/pubsub/gochannel/pubsub_test.go @@ -29,6 +29,18 @@ func createPersistentPubSub(t *testing.T) (message.Publisher, message.Subscriber return pubSub, pubSub } +func createPersistentPubSubWithContextPreserved(t *testing.T) (message.Publisher, message.Subscriber) { + pubSub := gochannel.NewGoChannel( + gochannel.Config{ + OutputChannelBuffer: 10000, + Persistent: true, + PreserveContext: true, + }, + watermill.NewStdLogger(true, true), + ) + return pubSub, pubSub +} + func TestPublishSubscribe_persistent(t *testing.T) { tests.TestPubSub( t, @@ -44,6 +56,22 @@ func TestPublishSubscribe_persistent(t *testing.T) { ) } +func TestPublishSubscribe_context_preserved(t *testing.T) { + tests.TestPubSub( + t, + tests.Features{ + ConsumerGroups: false, + ExactlyOnceDelivery: true, + GuaranteedOrder: false, + Persistent: false, + RequireSingleInstance: true, + ContextPreserved: true, + }, + createPersistentPubSubWithContextPreserved, + nil, + ) +} + func TestPublishSubscribe_not_persistent(t *testing.T) { messagesCount := 100 pubSub := gochannel.NewGoChannel( @@ -63,6 +91,31 @@ func TestPublishSubscribe_not_persistent(t *testing.T) { assert.NoError(t, pubSub.Close()) } +func TestPublishSubscribe_not_persistent_with_context(t *testing.T) { + messagesCount := 100 + pubSub := gochannel.NewGoChannel( + gochannel.Config{OutputChannelBuffer: int64(messagesCount), PreserveContext: true}, + watermill.NewStdLogger(true, true), + ) + topicName := "test_topic_" + watermill.NewUUID() + + msgs, err := pubSub.Subscribe(context.Background(), topicName) + require.NoError(t, err) + + const contextKeyString = "foo" + sendMessages := tests.PublishSimpleMessagesWithContext(t, messagesCount, contextKeyString, pubSub, topicName) + receivedMsgs, _ := subscriber.BulkRead(msgs, messagesCount, time.Second) + + expectedContexts := make(map[string]context.Context) + for _, msg := range sendMessages { + expectedContexts[msg.UUID] = msg.Context() + } + tests.AssertAllMessagesReceived(t, sendMessages, receivedMsgs) + tests.AssertAllMessagesHaveSameContext(t, contextKeyString, expectedContexts, receivedMsgs) + + assert.NoError(t, pubSub.Close()) +} + func TestPublishSubscribe_block_until_ack(t *testing.T) { pubSub := gochannel.NewGoChannel( gochannel.Config{BlockPublishUntilSubscriberAck: true}, diff --git a/pubsub/tests/test_asserts.go b/pubsub/tests/test_asserts.go index b0b788245..7f2798b92 100644 --- a/pubsub/tests/test_asserts.go +++ b/pubsub/tests/test_asserts.go @@ -1,6 +1,7 @@ package tests import ( + "context" "sort" "testing" @@ -92,3 +93,13 @@ func AssertMessagesMetadata(t *testing.T, key string, expectedValues map[string] return ok } + +// AssertAllMessagesHaveSameContext checks if context of all received messages is the same as in expectedValues, if PreserveContext is enabled. +func AssertAllMessagesHaveSameContext(t *testing.T, contextKeyString string, expectedValues map[string]context.Context, received []*message.Message) { + assert.Len(t, received, len(expectedValues)) + for _, msg := range received { + expectedValue := expectedValues[msg.UUID].Value(contextKey(contextKeyString)).(string) + actualValue := msg.Context().Value(contextKey(contextKeyString)) + assert.Equal(t, expectedValue, actualValue) + } +} diff --git a/pubsub/tests/test_pubsub.go b/pubsub/tests/test_pubsub.go index 9c0232a78..ac6efbdba 100644 --- a/pubsub/tests/test_pubsub.go +++ b/pubsub/tests/test_pubsub.go @@ -124,6 +124,10 @@ type Features struct { // GenerateTopicFunc overrides standard topic name generation. GenerateTopicFunc func(tctx TestContext) string + + // ContextPreserved should be set to true if the Pub/Sub implementation preserves the context + // of the message when it's published and consumed. + ContextPreserved bool } // RunOnlyFastTests returns true if -short flag was provided -race was not provided. @@ -993,7 +997,14 @@ func TestSubscribeCtx( if subscribeInitializer, ok := sub.(message.SubscribeInitializer); ok { require.NoError(t, subscribeInitializer.SubscribeInitialize(topicName)) } - publishedMessages := PublishSimpleMessages(t, messagesCount, pub, topicName) + + var publishedMessages message.Messages + var contextKeyString = "abc" + if tCtx.Features.ContextPreserved { + publishedMessages = PublishSimpleMessagesWithContext(t, messagesCount, contextKeyString, pub, topicName) + } else { + publishedMessages = PublishSimpleMessages(t, messagesCount, pub, topicName) + } msgsToCancel, err := sub.Subscribe(ctxWithCancel, topicName) require.NoError(t, err) @@ -1017,14 +1028,24 @@ ClosedLoop: } ctx := context.WithValue(context.Background(), contextKey("foo"), "bar") + + // For mocking the output of pub-subs where context is preserved vs not preserved + expectedContexts := make(map[string]context.Context) + for _, msg := range publishedMessages { + if tCtx.Features.ContextPreserved { + expectedContexts[msg.UUID] = msg.Context() + } else { + expectedContexts[msg.UUID] = ctx + } + } + msgs, err := sub.Subscribe(ctx, topicName) require.NoError(t, err) receivedMessages, _ := bulkRead(tCtx, msgs, messagesCount, defaultTimeout) AssertAllMessagesReceived(t, publishedMessages, receivedMessages) - - for _, msg := range receivedMessages { - assert.EqualValues(t, "bar", msg.Context().Value(contextKey("foo"))) + if tCtx.Features.ContextPreserved { + AssertAllMessagesHaveSameContext(t, contextKeyString, expectedContexts, receivedMessages) } } @@ -1271,6 +1292,24 @@ func PublishSimpleMessages(t *testing.T, messagesCount int, publisher message.Pu return messagesToPublish } +// PublishSimpleMessagesWithContext publishes provided number of simple messages without a payload, but custom context +func PublishSimpleMessagesWithContext(t *testing.T, messagesCount int, contextKeyString string, publisher message.Publisher, topicName string) message.Messages { + var messagesToPublish []*message.Message + + for i := 0; i < messagesCount; i++ { + id := watermill.NewUUID() + + msg := message.NewMessage(id, nil) + msg.SetContext(context.WithValue(context.Background(), contextKey(contextKeyString), "bar"+strconv.Itoa(i))) + messagesToPublish = append(messagesToPublish, msg) + + err := publishWithRetry(publisher, topicName, msg) + require.NoError(t, err, "cannot publish messages") + } + + return messagesToPublish +} + // AddSimpleMessagesParallel publishes provided number of simple messages without a payload // using the provided number of publishers (goroutines). func AddSimpleMessagesParallel(t *testing.T, messagesCount int, publisher message.Publisher, topicName string, publishers int) message.Messages {