Skip to content

Commit

Permalink
Merge pull request #192 from t2y/take-context-to-consume
Browse files Browse the repository at this point in the history
Add Channel.ConsumeWithContext to be able to cancel delivering
  • Loading branch information
Zerpet authored Jun 20, 2023
2 parents 5c9eb22 + d02f590 commit 13dde10
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
22 changes: 22 additions & 0 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,17 @@ When the consumer tag is cancelled, all inflight messages will be delivered unti
the returned chan is closed.
*/
func (ch *Channel) Consume(queue, consumer string, autoAck, exclusive, noLocal, noWait bool, args Table) (<-chan Delivery, error) {
return ch.ConsumeWithContext(context.Background(), queue, consumer, autoAck, exclusive, noLocal, noWait, args)
}

/*
ConsumeWithContext immediately starts delivering queued messages.
This is similar to Consume() function but has different semantics.
The caller can cancel via the given context, then call ch.Cancel() and stop
receiving messages.
*/
func (ch *Channel) ConsumeWithContext(ctx context.Context, queue, consumer string, autoAck, exclusive, noLocal, noWait bool, args Table) (<-chan Delivery, error) {
// When we return from ch.call, there may be a delivery already for the
// consumer that hasn't been added to the consumer hash yet. Because of
// this, we never rely on the server picking a consumer tag for us.
Expand Down Expand Up @@ -1123,6 +1134,17 @@ func (ch *Channel) Consume(queue, consumer string, autoAck, exclusive, noLocal,
return nil, err
}

go func() {
select {
case <-ch.consumers.closed:
return
case <-ctx.Done():
if ch != nil {
_ = ch.Cancel(consumer, false)
}
}
}()

return deliveries, nil
}

Expand Down
44 changes: 44 additions & 0 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package amqp091

import (
"bytes"
"context"
devrand "crypto/rand"
"encoding/binary"
"fmt"
Expand Down Expand Up @@ -819,6 +820,49 @@ func TestIntegrationConsumeCancel(t *testing.T) {
}
}

func TestIntegrationConsumeCancelWithContext(t *testing.T) {
queue := "test.integration.consume-cancel-with-context"

c := integrationConnection(t, "pub")

if c != nil {
defer c.Close()

ch, _ := c.Channel()

if _, e := ch.QueueDeclare(queue, false, true, false, false, nil); e != nil {
t.Fatalf("error declaring queue %s: %v", queue, e)
}

defer integrationQueueDelete(t, ch, queue)

ctx, cancel := context.WithCancel(context.Background())
messages, _ := ch.ConsumeWithContext(ctx, queue, "integration-tag-with-context", false, false, false, false, nil)

if e := ch.Publish("", queue, false, false, Publishing{Body: []byte("1")}); e != nil {
t.Fatalf("error publishing: %v", e)
}

assertConsumeBody(t, messages, []byte("1"))

cancel()
<-time.After(100 * time.Millisecond) // wait to call cancel asynchronously

if e := ch.Publish("", queue, false, false, Publishing{Body: []byte("2")}); e != nil {
t.Fatalf("error publishing: %v", e)
}

select {
case <-time.After(100 * time.Millisecond):
t.Fatalf("Timeout on Close")
case _, ok := <-messages:
if ok {
t.Fatalf("Extra message on consumer when consumer should have been closed")
}
}
}
}

func (c *Connection) Generate(_ *rand.Rand, _ int) reflect.Value {
urlStr := amqpURL

Expand Down

0 comments on commit 13dde10

Please sign in to comment.