Skip to content

Commit 7b8bcb3

Browse files
author
yash.bansal
committed
Address comments
1 parent 867fb77 commit 7b8bcb3

File tree

4 files changed

+24
-20
lines changed

4 files changed

+24
-20
lines changed

pubsub/gochannel/pubsub.go

+16-12
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func (g *GoChannel) sendMessage(topic string, message *message.Message) (<-chan
163163

164164
wg.Add(1)
165165
go func() {
166-
subscriber.sendMessageToSubscriber(message, logFields, g.config.PreserveContext)
166+
subscriber.sendMessageToSubscriber(message, logFields)
167167
wg.Done()
168168
}()
169169
}
@@ -196,11 +196,12 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag
196196
subLock.(*sync.Mutex).Lock()
197197

198198
s := &subscriber{
199-
ctx: ctx,
200-
uuid: watermill.NewUUID(),
201-
outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer),
202-
logger: g.logger,
203-
closing: make(chan struct{}),
199+
ctx: ctx,
200+
uuid: watermill.NewUUID(),
201+
outputChannel: make(chan *message.Message, g.config.OutputChannelBuffer),
202+
logger: g.logger,
203+
closing: make(chan struct{}),
204+
preserveContext: g.config.PreserveContext,
204205
}
205206

206207
go func(s *subscriber, g *GoChannel) {
@@ -246,7 +247,7 @@ func (g *GoChannel) Subscribe(ctx context.Context, topic string) (<-chan *messag
246247
msg := g.persistedMessages[topic][i]
247248
logFields := watermill.LogFields{"message_uuid": msg.UUID, "topic": topic}
248249

249-
go s.sendMessageToSubscriber(msg, logFields, g.config.PreserveContext)
250+
go s.sendMessageToSubscriber(msg, logFields)
250251
}
251252
}
252253

@@ -329,6 +330,8 @@ type subscriber struct {
329330
logger watermill.LoggerAdapter
330331
closed bool
331332
closing chan struct{}
333+
334+
preserveContext bool
332335
}
333336

334337
func (s *subscriber) Close() {
@@ -349,19 +352,20 @@ func (s *subscriber) Close() {
349352
close(s.outputChannel)
350353
}
351354

352-
func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields, preserveContext bool) {
355+
func (s *subscriber) sendMessageToSubscriber(msg *message.Message, logFields watermill.LogFields) {
353356
s.sending.Lock()
354357
defer s.sending.Unlock()
355358

356359
var ctx context.Context
357-
var cancelCtx context.CancelFunc
358360

359-
if preserveContext {
360-
ctx, cancelCtx = context.WithCancel(msg.Context())
361+
//This is getting the context from the message, not the subscriber
362+
if s.preserveContext {
363+
ctx = msg.Context()
361364
} else {
365+
var cancelCtx context.CancelFunc
362366
ctx, cancelCtx = context.WithCancel(s.ctx)
367+
defer cancelCtx()
363368
}
364-
defer cancelCtx()
365369

366370
SendToSubscriber:
367371
for {

pubsub/gochannel/pubsub_test.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,16 @@ func TestPublishSubscribe_not_persistent_with_context(t *testing.T) {
102102
msgs, err := pubSub.Subscribe(context.Background(), topicName)
103103
require.NoError(t, err)
104104

105-
const contextKey = "foo"
106-
sendMessages := tests.PublishSimpleMessagesWithContext(t, messagesCount, contextKey, pubSub, topicName)
105+
const contextKeyString = "foo"
106+
sendMessages := tests.PublishSimpleMessagesWithContext(t, messagesCount, contextKeyString, pubSub, topicName)
107107
receivedMsgs, _ := subscriber.BulkRead(msgs, messagesCount, time.Second)
108108

109109
expectedContexts := make(map[string]context.Context)
110110
for _, msg := range sendMessages {
111111
expectedContexts[msg.UUID] = msg.Context()
112112
}
113113
tests.AssertAllMessagesReceived(t, sendMessages, receivedMsgs)
114-
tests.AssertAllMessagesHaveSameContext(t, contextKey, expectedContexts, receivedMsgs)
114+
tests.AssertAllMessagesHaveSameContext(t, contextKeyString, expectedContexts, receivedMsgs)
115115

116116
assert.NoError(t, pubSub.Close())
117117
}

pubsub/tests/test_asserts.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,13 @@ func AssertMessagesMetadata(t *testing.T, key string, expectedValues map[string]
9595
}
9696

9797
// AssertAllMessagesHaveSameContext checks if context of all received messages is the same as in expectedValues, if PreserveContext is enabled.
98-
func AssertAllMessagesHaveSameContext(t *testing.T, contextKey string, expectedValues map[string]context.Context, received []*message.Message) bool {
98+
func AssertAllMessagesHaveSameContext(t *testing.T, contextKeyString string, expectedValues map[string]context.Context, received []*message.Message) bool {
9999
assert.Len(t, received, len(expectedValues))
100100

101101
ok := true
102102
for _, msg := range received {
103-
expectedValue := expectedValues[msg.UUID].Value(contextKey)
104-
actualValue := msg.Context().Value(contextKey)
103+
expectedValue := expectedValues[msg.UUID].Value(contextKey(contextKeyString)).(string)
104+
actualValue := msg.Context().Value(contextKeyString)
105105
if !assert.Equal(t, expectedValue, actualValue) {
106106
ok = false
107107
}

pubsub/tests/test_pubsub.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1284,14 +1284,14 @@ func PublishSimpleMessages(t *testing.T, messagesCount int, publisher message.Pu
12841284
}
12851285

12861286
// PublishSimpleMessagesWithContext publishes provided number of simple messages without a payload, but custom context
1287-
func PublishSimpleMessagesWithContext(t *testing.T, messagesCount int, contextKey string, publisher message.Publisher, topicName string) message.Messages {
1287+
func PublishSimpleMessagesWithContext(t *testing.T, messagesCount int, contextKeyString string, publisher message.Publisher, topicName string) message.Messages {
12881288
var messagesToPublish []*message.Message
12891289

12901290
for i := 0; i < messagesCount; i++ {
12911291
id := watermill.NewUUID()
12921292

12931293
msg := message.NewMessage(id, nil)
1294-
msg.SetContext(context.WithValue(context.Background(), contextKey, "bar"+strconv.Itoa(i)))
1294+
msg.SetContext(context.WithValue(context.Background(), contextKey(contextKeyString), "bar"+strconv.Itoa(i)))
12951295
messagesToPublish = append(messagesToPublish, msg)
12961296

12971297
err := publishWithRetry(publisher, topicName, msg)

0 commit comments

Comments
 (0)