Skip to content

Commit

Permalink
kgo: allow clients to cancel requests
Browse files Browse the repository at this point in the history
Previously, if the client got to the point of writing a request, it
would not be possible to cancel the request.

This now writes and reads in goroutines, and allows for killing requests
(with immediate deadlines) if a context closes early.

Importantly, this also ensures that client shutdown or broker migration
cancels requests as well. Previously, a client shutting down did not
actually kill an active write / read. We used to just close the request
channel and trust that things would eventually close; the new logic
kills writes if the client's context closes.

Also even more importantly, this new logic allows for a dying broker
connection, as well as the client closing, to quit the throttle waiting.
Previously, a throttle was only quit if the request's context quit.

This updates the read and write hooks to track the exact amount of bytes
written, and to more correctly count bytes read.

The throttle logic has been fixed to call the throttle hooks for
pre-2.0.0 (i.e, !throttlesAfterResp). I accidentally broke the logic
when introducing throttlesAfterResp (thanks @akesle for noticing this).
  • Loading branch information
twmb committed Jan 13, 2021
1 parent 2493ae7 commit d07538d
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 72 deletions.
177 changes: 114 additions & 63 deletions pkg/kgo/broker.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type promisedReq struct {
}

type promisedResp struct {
ctx context.Context
corrID int32

readTimeout time.Duration
Expand Down Expand Up @@ -286,18 +287,18 @@ func (b *broker) handleReqs() {
}

// Juuuust before we issue the request, we check if it was
// canceled. If it is not, we do not cancel hereafter.
// We only check the promised req's ctx, not our clients.
// The client ctx is closed on shutdown, which kills the
// cxn anyway.
// canceled. We could have previously tried this request, which
// then failed and retried due to the error being ErrConnDead.
// Checking the context was canceled here ensures we do not
// loop. We could be more precise with error tracking, though.
select {
case <-pr.ctx.Done():
pr.promise(nil, pr.ctx.Err())
continue
default:
}

corrID, err := cxn.writeRequest(pr.ctx, time.Since(pr.enqueue), req)
corrID, err := cxn.writeRequest(pr.ctx, pr.enqueue, req)

if err != nil {
pr.promise(nil, err)
Expand All @@ -308,6 +309,7 @@ func (b *broker) handleReqs() {
rt, _ := cxn.cl.connTimeoutFn(req)

cxn.waitResp(promisedResp{
pr.ctx,
corrID,
rt,
req.IsFlexible() && req.Key() != 18, // response header not flexible if ApiVersions; see promisedResp doc
Expand Down Expand Up @@ -353,8 +355,9 @@ func (b *broker) loadConnection(ctx context.Context, reqKey int16) (*brokerCxn,
cl: b.cl,
b: b,

addr: b.addr,
conn: conn,
addr: b.addr,
conn: conn,
deadCh: make(chan struct{}),
}
if err = cxn.init(); err != nil {
b.cl.cfg.logger.Log(LogLevelDebug, "connection initialization failed", "addr", b.addr, "id", b.meta.NodeID, "err", err)
Expand Down Expand Up @@ -414,6 +417,8 @@ type brokerCxn struct {
resps chan promisedResp
// dead is an atomic so that a backed up resps cannot block cxn death.
dead int32
// closed in cloneConn; allows throttle waiting to quit
deadCh chan struct{}
}

func (cxn *brokerCxn) init() error {
Expand Down Expand Up @@ -447,13 +452,13 @@ start:
ClientSoftwareVersion: cxn.cl.cfg.softwareVersion,
}
cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing api versions request", "version", maxVersion)
corrID, err := cxn.writeRequest(nil, 0, req)
corrID, err := cxn.writeRequest(nil, time.Now(), req)
if err != nil {
return err
}

rt, _ := cxn.cl.connTimeoutFn(req)
rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, false) // api versions does *not* use flexible response headers; see comment in promisedResp
rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, false) // api versions does *not* use flexible response headers; see comment in promisedResp
if err != nil {
return err
}
Expand Down Expand Up @@ -515,13 +520,13 @@ start:
req.Mechanism = mechanism.Name()
req.Version = cxn.versions[req.Key()]
cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing SASLHandshakeRequest")
corrID, err := cxn.writeRequest(nil, 0, req)
corrID, err := cxn.writeRequest(nil, time.Now(), req)
if err != nil {
return err
}

rt, _ := cxn.cl.connTimeoutFn(req)
rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, req.IsFlexible())
rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, req.IsFlexible())
if err != nil {
return err
}
Expand Down Expand Up @@ -582,22 +587,16 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error {
binary.BigEndian.PutUint32(buf, uint32(len(clientWrite)))
buf = append(buf, clientWrite...)

if wt > 0 {
cxn.conn.SetWriteDeadline(time.Now().Add(wt))
}
cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing raw sasl authenticate", "step", step)
_, err = cxn.conn.Write(buf)
if wt > 0 {
cxn.conn.SetWriteDeadline(time.Time{})
}
_, err, _, _ = cxn.writeConn(context.Background(), buf, wt, time.Now())

cxn.cl.bufPool.put(buf)

if err != nil {
return ErrConnDead
}
if !done {
if challenge, err = readConn(cxn.conn, cxn.b.cl.cfg.maxBrokerReadBytes, rt); err != nil {
if _, challenge, err, _, _ = cxn.readConn(context.Background(), rt, time.Now()); err != nil {
return err
}
}
Expand All @@ -609,12 +608,12 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error {
req.Version = cxn.versions[req.Key()]
cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing SASLAuthenticate", "version", req.Version, "step", step)

corrID, err := cxn.writeRequest(nil, 0, req)
corrID, err := cxn.writeRequest(nil, time.Now(), req)
if err != nil {
return err
}
if !done {
rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, req.IsFlexible())
rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, req.IsFlexible())
if err != nil {
return err
}
Expand Down Expand Up @@ -659,7 +658,7 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error {

// writeRequest writes a message request to the broker connection, bumping the
// connection's correlation ID as appropriate for the next write.
func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, req kmsg.Request) (int32, error) {
func (cxn *brokerCxn) writeRequest(ctx context.Context, enqueuedForWritingAt time.Time, req kmsg.Request) (int32, error) {
// A nil ctx means we cannot be throttled.
if ctx != nil {
throttleUntil := time.Unix(0, atomic.LoadInt64(&cxn.throttleUntil))
Expand All @@ -669,12 +668,17 @@ func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration,
case <-after.C:
case <-ctx.Done():
after.Stop()
return 0, ctx.Err()
case <-cxn.cl.ctx.Done():
after.Stop()
return 0, ctx.Err()
case <-cxn.deadCh:
after.Stop()
return 0, ErrConnDead
}
}
}

// TODO: write in a goroutine, use ctx to allow for early cancel.

buf := cxn.cl.bufPool.get()
defer cxn.cl.bufPool.put(buf)
buf = cxn.cl.reqFormatter.AppendRequest(
Expand All @@ -684,64 +688,108 @@ func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration,
)

_, wt := cxn.cl.connTimeoutFn(req)
if wt > 0 {
cxn.conn.SetWriteDeadline(time.Now().Add(wt))
defer cxn.conn.SetWriteDeadline(time.Time{})
}

writeStart := time.Now()
_, err := cxn.conn.Write(buf)
timeToWrite := time.Since(writeStart)
bytesWritten, writeErr, writeWait, timeToWrite := cxn.writeConn(ctx, buf, wt, enqueuedForWritingAt)

cxn.cl.cfg.hooks.each(func(h Hook) {
if h, ok := h.(BrokerWriteHook); ok {
h.OnWrite(cxn.b.meta, req.Key(), len(buf), writeWait, timeToWrite, err)
h.OnWrite(cxn.b.meta, req.Key(), bytesWritten, writeWait, timeToWrite, writeErr)
}
})

if err != nil {
if writeErr != nil {
return 0, ErrConnDead
}
id := cxn.corrID
cxn.corrID++
return id, nil
}

func readConn(conn net.Conn, maxSize int32, timeout time.Duration) ([]byte, error) {
sizeBuf := make([]byte, 4)
if timeout > 0 {
conn.SetReadDeadline(time.Now().Add(timeout))
defer conn.SetReadDeadline(time.Time{})
}
if _, err := io.ReadFull(conn, sizeBuf); err != nil {
return nil, ErrConnDead
func (cxn *brokerCxn) writeConn(ctx context.Context, buf []byte, timeout time.Duration, enqueuedForWritingAt time.Time) (bytesWritten int, writeErr error, writeWait, timeToWrite time.Duration) {
if ctx == nil {
ctx = context.Background()
}
size := int32(binary.BigEndian.Uint32(sizeBuf))
if size < 0 {
return nil, ErrInvalidRespSize
}
if size > maxSize {
return nil, &ErrLargeRespSize{Size: size, Limit: maxSize}
if timeout > 0 {
cxn.conn.SetWriteDeadline(time.Now().Add(timeout))
}
defer cxn.conn.SetWriteDeadline(time.Time{})
writeDone := make(chan struct{})
go func() {
defer close(writeDone)
writeStart := time.Now()
bytesWritten, writeErr = cxn.conn.Write(buf)
timeToWrite = time.Since(writeStart)
writeWait = writeStart.Sub(enqueuedForWritingAt)
}()
select {
case <-writeDone:
case <-cxn.cl.ctx.Done():
cxn.conn.SetWriteDeadline(time.Now())
<-writeDone
case <-ctx.Done():
cxn.conn.SetWriteDeadline(time.Now())
<-writeDone
}
return
}

buf := make([]byte, size)
if _, err := io.ReadFull(conn, buf); err != nil {
return nil, ErrConnDead
func (cxn *brokerCxn) readConn(ctx context.Context, timeout time.Duration, enqueuedForReadingAt time.Time) (nread int, read []byte, err error, readWait, timeToRead time.Duration) {
if ctx == nil {
ctx = context.Background()
}
return buf, nil
if timeout > 0 {
cxn.conn.SetReadDeadline(time.Now().Add(timeout))
}
defer cxn.conn.SetReadDeadline(time.Time{})
readDone := make(chan struct{})
go func() {
defer close(readDone)
sizeBuf := make([]byte, 4)
readStart := time.Now()
defer func() {
timeToRead = time.Since(readStart)
readWait = readStart.Sub(enqueuedForReadingAt)
}()
if nread, err = io.ReadFull(cxn.conn, sizeBuf); err != nil {
err = ErrConnDead
return
}
size := int32(binary.BigEndian.Uint32(sizeBuf))
if size < 0 {
err = ErrInvalidRespSize
return
}
if maxSize := cxn.b.cl.cfg.maxBrokerReadBytes; size > maxSize {
err = &ErrLargeRespSize{Size: size, Limit: maxSize}
return
}
nread2, buf := 0, make([]byte, size)
nread2, err = io.ReadFull(cxn.conn, buf)
nread += nread2
if err != nil {
err = ErrConnDead
return
}
}()
select {
case <-readDone:
case <-cxn.cl.ctx.Done():
cxn.conn.SetReadDeadline(time.Now())
<-readDone
case <-ctx.Done():
cxn.conn.SetReadDeadline(time.Now())
<-readDone
}
return
}

// readResponse reads a response from conn, ensures the correlation ID is
// correct, and returns a newly allocated slice on success.
func (cxn *brokerCxn) readResponse(readWait time.Duration, key int16, corrID int32, timeout time.Duration, flexibleHeader bool) ([]byte, error) {
readStart := time.Now()
buf, err := readConn(cxn.conn, cxn.b.cl.cfg.maxBrokerReadBytes, timeout)
timeToRead := time.Since(readStart)
func (cxn *brokerCxn) readResponse(ctx context.Context, timeout time.Duration, enqueuedForReadingAt time.Time, key int16, corrID int32, flexibleHeader bool) ([]byte, error) {
nread, buf, err, readWait, timeToRead := cxn.readConn(ctx, timeout, enqueuedForReadingAt)

cxn.cl.cfg.hooks.each(func(h Hook) {
if h, ok := h.(BrokerReadHook); ok {
// readConn reads four size bytes in addition to the buf.
h.OnRead(cxn.b.meta, key, 4+len(buf), readWait, timeToRead, err)
h.OnRead(cxn.b.meta, key, nread, readWait, timeToRead, err)
}
})

Expand Down Expand Up @@ -775,6 +823,7 @@ func (cxn *brokerCxn) closeConn() {
}
})
cxn.conn.Close()
close(cxn.deadCh)
}

// die kills a broker connection (which could be dead already) and replies to
Expand Down Expand Up @@ -825,7 +874,7 @@ func (cxn *brokerCxn) handleResps() {

var successes uint64
for pr := range cxn.resps {
raw, err := cxn.readResponse(time.Since(pr.enqueue), pr.resp.Key(), pr.corrID, pr.readTimeout, pr.flexibleHeader)
raw, err := cxn.readResponse(pr.ctx, pr.readTimeout, pr.enqueue, pr.resp.Key(), pr.corrID, pr.flexibleHeader)
if err != nil {
if successes > 0 || len(cxn.b.cl.cfg.sasls) > 0 {
cxn.b.cl.cfg.logger.Log(LogLevelDebug, "read from broker errored, killing connection", "addr", cxn.b.addr, "id", cxn.b.meta.NodeID, "successful_reads", successes, "err", err)
Expand All @@ -845,10 +894,12 @@ func (cxn *brokerCxn) handleResps() {
if readErr == nil {
if throttleResponse, ok := pr.resp.(kmsg.ThrottleResponse); ok {
millis, throttlesAfterResp := throttleResponse.Throttle()
if throttlesAfterResp && millis > 0 {
throttleUntil := time.Now().Add(time.Millisecond * time.Duration(millis)).UnixNano()
if throttleUntil > cxn.throttleUntil {
atomic.StoreInt64(&cxn.throttleUntil, throttleUntil)
if millis > 0 {
if throttlesAfterResp {
throttleUntil := time.Now().Add(time.Millisecond * time.Duration(millis)).UnixNano()
if throttleUntil > cxn.throttleUntil {
atomic.StoreInt64(&cxn.throttleUntil, throttleUntil)
}
}
cxn.cl.cfg.hooks.each(func(h Hook) {
if h, ok := h.(BrokerThrottleHook); ok {
Expand Down
5 changes: 2 additions & 3 deletions pkg/kgo/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,8 @@ func (cl *Client) Close() {
// of request is being issued.
//
// The passed context can be used to cancel a request and return early. Note
// that if the request is not canceled before it is written to Kafka, you may
// just end up canceling and not receiving the response to what Kafka
// inevitably does.
// that if the request was written to Kafka but the context canceled before a
// response is received, Kafka may still operate on the received request.
func (cl *Client) Request(ctx context.Context, req kmsg.Request) (kmsg.Response, error) {
resps, merge := cl.shardedRequest(ctx, req)
// If there is no merge function, only one request was issued directly
Expand Down
13 changes: 7 additions & 6 deletions pkg/kgo/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ type BrokerDisconnectHook interface {
// key is used (even though sasl authenticate requests are not being issued).
type BrokerWriteHook interface {
// OnWrite is passed the broker metadata, the key for the request that
// was written, the number of bytes written, how long the request
// waited before being written, how long it took to write the request,
// and any error.
// was written, the number of bytes that were written (may not be the
// whole request if there was an error), how long the request waited
// before being written (including throttling waiting), how long it
// took to write the request, and any error.
//
// The bytes written does not count any tls overhead.
OnWrite(meta BrokerMetadata, key int16, bytesWritten int, writeWait, timeToWrite time.Duration, err error)
Expand All @@ -63,9 +64,9 @@ type BrokerWriteHook interface {
// key is used (even though sasl authenticate requests are not being issued).
type BrokerReadHook interface {
// OnRead is passed the broker metadata, the key for the response that
// was read, the number of bytes read, how long the client waited
// before reading the response, how long it took to read the response,
// and any error.
// was read, the number of bytes read (may not be the whole read if
// there was an error), how long the client waited before reading the
// response, how long it took to read the response, and any error.
//
// The bytes read does not count any tls overhead.
OnRead(meta BrokerMetadata, key int16, bytesRead int, readWait, timeToRead time.Duration, err error)
Expand Down

0 comments on commit d07538d

Please sign in to comment.