diff --git a/pubsub/gochannel/pubsub.go b/pubsub/gochannel/pubsub.go index fdd22044d..9cd0b5f9e 100644 --- a/pubsub/gochannel/pubsub.go +++ b/pubsub/gochannel/pubsub.go @@ -356,12 +356,10 @@ func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields wat s.sending.Lock() defer s.sending.Unlock() - var ctx context.Context + ctx := msg.Context() //This is getting the context from the message, not the subscriber - if s.preserveContext { - ctx = msg.Context() - } else { + if !s.preserveContext { var cancelCtx context.CancelFunc ctx, cancelCtx = context.WithCancel(s.ctx) defer cancelCtx() diff --git a/pubsub/tests/test_asserts.go b/pubsub/tests/test_asserts.go index 4660ed931..5942b2838 100644 --- a/pubsub/tests/test_asserts.go +++ b/pubsub/tests/test_asserts.go @@ -95,17 +95,11 @@ func AssertMessagesMetadata(t *testing.T, key string, expectedValues map[string] } // AssertAllMessagesHaveSameContext checks if context of all received messages is the same as in expectedValues, if PreserveContext is enabled. -func AssertAllMessagesHaveSameContext(t *testing.T, contextKeyString string, expectedValues map[string]context.Context, received []*message.Message) bool { +func AssertAllMessagesHaveSameContext(t *testing.T, contextKeyString string, expectedValues map[string]context.Context, received []*message.Message) { assert.Len(t, received, len(expectedValues)) - - ok := true for _, msg := range received { expectedValue := expectedValues[msg.UUID].Value(contextKey(contextKeyString)).(string) actualValue := msg.Context().Value(contextKeyString) - if !assert.Equal(t, expectedValue, actualValue) { - ok = false - } + assert.Equal(t, expectedValue, actualValue) } - - return ok }