From d07538d00638c22af18df97bffae52b4b917a8ca Mon Sep 17 00:00:00 2001 From: Travis Bischel Date: Wed, 13 Jan 2021 02:17:26 -0700 Subject: [PATCH] kgo: allow clients to cancel requests Previously, if the client got to the point of writing a request, it would not be possible to cancel the request. This now writes and reads in goroutines, and allows for killing requests (with immediate deadlines) if a context closes early. Importantly, this also ensures that client shutdown or broker migration cancels requests as well. Previously, a client shutting down did not actually kill an active write / read. We used to just close the request channel and trust that things would eventually close; the new logic kills writes if the client's context closes. Also even more importantly, this new logic allows for a dying broker connection, as well as the client closing, to quit the throttle waiting. Previously, a throttle was only quit if the request's context quit. This updates the read and write hooks to track the exact amount of bytes written, and to more correctly count bytes read. The throttle logic has been fixed to call the throttle hooks for pre-2.0.0 (i.e, !throttlesAfterResp). I accidentally broke the logic when introducing throttlesAfterResp (thanks @akesle for noticing this). --- pkg/kgo/broker.go | 177 +++++++++++++++++++++++++++++----------------- pkg/kgo/client.go | 5 +- pkg/kgo/hooks.go | 13 ++-- 3 files changed, 123 insertions(+), 72 deletions(-) diff --git a/pkg/kgo/broker.go b/pkg/kgo/broker.go index 58f8fa67..ffb5aef3 100644 --- a/pkg/kgo/broker.go +++ b/pkg/kgo/broker.go @@ -26,6 +26,7 @@ type promisedReq struct { } type promisedResp struct { + ctx context.Context corrID int32 readTimeout time.Duration @@ -286,10 +287,10 @@ func (b *broker) handleReqs() { } // Juuuust before we issue the request, we check if it was - // canceled. If it is not, we do not cancel hereafter. - // We only check the promised req's ctx, not our clients. - // The client ctx is closed on shutdown, which kills the - // cxn anyway. + // canceled. We could have previously tried this request, which + // then failed and retried due to the error being ErrConnDead. + // Checking the context was canceled here ensures we do not + // loop. We could be more precise with error tracking, though. select { case <-pr.ctx.Done(): pr.promise(nil, pr.ctx.Err()) @@ -297,7 +298,7 @@ func (b *broker) handleReqs() { default: } - corrID, err := cxn.writeRequest(pr.ctx, time.Since(pr.enqueue), req) + corrID, err := cxn.writeRequest(pr.ctx, pr.enqueue, req) if err != nil { pr.promise(nil, err) @@ -308,6 +309,7 @@ func (b *broker) handleReqs() { rt, _ := cxn.cl.connTimeoutFn(req) cxn.waitResp(promisedResp{ + pr.ctx, corrID, rt, req.IsFlexible() && req.Key() != 18, // response header not flexible if ApiVersions; see promisedResp doc @@ -353,8 +355,9 @@ func (b *broker) loadConnection(ctx context.Context, reqKey int16) (*brokerCxn, cl: b.cl, b: b, - addr: b.addr, - conn: conn, + addr: b.addr, + conn: conn, + deadCh: make(chan struct{}), } if err = cxn.init(); err != nil { b.cl.cfg.logger.Log(LogLevelDebug, "connection initialization failed", "addr", b.addr, "id", b.meta.NodeID, "err", err) @@ -414,6 +417,8 @@ type brokerCxn struct { resps chan promisedResp // dead is an atomic so that a backed up resps cannot block cxn death. dead int32 + // closed in cloneConn; allows throttle waiting to quit + deadCh chan struct{} } func (cxn *brokerCxn) init() error { @@ -447,13 +452,13 @@ start: ClientSoftwareVersion: cxn.cl.cfg.softwareVersion, } cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing api versions request", "version", maxVersion) - corrID, err := cxn.writeRequest(nil, 0, req) + corrID, err := cxn.writeRequest(nil, time.Now(), req) if err != nil { return err } rt, _ := cxn.cl.connTimeoutFn(req) - rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, false) // api versions does *not* use flexible response headers; see comment in promisedResp + rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, false) // api versions does *not* use flexible response headers; see comment in promisedResp if err != nil { return err } @@ -515,13 +520,13 @@ start: req.Mechanism = mechanism.Name() req.Version = cxn.versions[req.Key()] cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing SASLHandshakeRequest") - corrID, err := cxn.writeRequest(nil, 0, req) + corrID, err := cxn.writeRequest(nil, time.Now(), req) if err != nil { return err } rt, _ := cxn.cl.connTimeoutFn(req) - rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, req.IsFlexible()) + rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, req.IsFlexible()) if err != nil { return err } @@ -582,14 +587,8 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error { binary.BigEndian.PutUint32(buf, uint32(len(clientWrite))) buf = append(buf, clientWrite...) - if wt > 0 { - cxn.conn.SetWriteDeadline(time.Now().Add(wt)) - } cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing raw sasl authenticate", "step", step) - _, err = cxn.conn.Write(buf) - if wt > 0 { - cxn.conn.SetWriteDeadline(time.Time{}) - } + _, err, _, _ = cxn.writeConn(context.Background(), buf, wt, time.Now()) cxn.cl.bufPool.put(buf) @@ -597,7 +596,7 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error { return ErrConnDead } if !done { - if challenge, err = readConn(cxn.conn, cxn.b.cl.cfg.maxBrokerReadBytes, rt); err != nil { + if _, challenge, err, _, _ = cxn.readConn(context.Background(), rt, time.Now()); err != nil { return err } } @@ -609,12 +608,12 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error { req.Version = cxn.versions[req.Key()] cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing SASLAuthenticate", "version", req.Version, "step", step) - corrID, err := cxn.writeRequest(nil, 0, req) + corrID, err := cxn.writeRequest(nil, time.Now(), req) if err != nil { return err } if !done { - rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, req.IsFlexible()) + rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, req.IsFlexible()) if err != nil { return err } @@ -659,7 +658,7 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error { // writeRequest writes a message request to the broker connection, bumping the // connection's correlation ID as appropriate for the next write. -func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, req kmsg.Request) (int32, error) { +func (cxn *brokerCxn) writeRequest(ctx context.Context, enqueuedForWritingAt time.Time, req kmsg.Request) (int32, error) { // A nil ctx means we cannot be throttled. if ctx != nil { throttleUntil := time.Unix(0, atomic.LoadInt64(&cxn.throttleUntil)) @@ -669,12 +668,17 @@ func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, case <-after.C: case <-ctx.Done(): after.Stop() + return 0, ctx.Err() + case <-cxn.cl.ctx.Done(): + after.Stop() + return 0, ctx.Err() + case <-cxn.deadCh: + after.Stop() + return 0, ErrConnDead } } } - // TODO: write in a goroutine, use ctx to allow for early cancel. - buf := cxn.cl.bufPool.get() defer cxn.cl.bufPool.put(buf) buf = cxn.cl.reqFormatter.AppendRequest( @@ -684,22 +688,15 @@ func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, ) _, wt := cxn.cl.connTimeoutFn(req) - if wt > 0 { - cxn.conn.SetWriteDeadline(time.Now().Add(wt)) - defer cxn.conn.SetWriteDeadline(time.Time{}) - } - - writeStart := time.Now() - _, err := cxn.conn.Write(buf) - timeToWrite := time.Since(writeStart) + bytesWritten, writeErr, writeWait, timeToWrite := cxn.writeConn(ctx, buf, wt, enqueuedForWritingAt) cxn.cl.cfg.hooks.each(func(h Hook) { if h, ok := h.(BrokerWriteHook); ok { - h.OnWrite(cxn.b.meta, req.Key(), len(buf), writeWait, timeToWrite, err) + h.OnWrite(cxn.b.meta, req.Key(), bytesWritten, writeWait, timeToWrite, writeErr) } }) - if err != nil { + if writeErr != nil { return 0, ErrConnDead } id := cxn.corrID @@ -707,41 +704,92 @@ func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, return id, nil } -func readConn(conn net.Conn, maxSize int32, timeout time.Duration) ([]byte, error) { - sizeBuf := make([]byte, 4) - if timeout > 0 { - conn.SetReadDeadline(time.Now().Add(timeout)) - defer conn.SetReadDeadline(time.Time{}) - } - if _, err := io.ReadFull(conn, sizeBuf); err != nil { - return nil, ErrConnDead +func (cxn *brokerCxn) writeConn(ctx context.Context, buf []byte, timeout time.Duration, enqueuedForWritingAt time.Time) (bytesWritten int, writeErr error, writeWait, timeToWrite time.Duration) { + if ctx == nil { + ctx = context.Background() } - size := int32(binary.BigEndian.Uint32(sizeBuf)) - if size < 0 { - return nil, ErrInvalidRespSize - } - if size > maxSize { - return nil, &ErrLargeRespSize{Size: size, Limit: maxSize} + if timeout > 0 { + cxn.conn.SetWriteDeadline(time.Now().Add(timeout)) } + defer cxn.conn.SetWriteDeadline(time.Time{}) + writeDone := make(chan struct{}) + go func() { + defer close(writeDone) + writeStart := time.Now() + bytesWritten, writeErr = cxn.conn.Write(buf) + timeToWrite = time.Since(writeStart) + writeWait = writeStart.Sub(enqueuedForWritingAt) + }() + select { + case <-writeDone: + case <-cxn.cl.ctx.Done(): + cxn.conn.SetWriteDeadline(time.Now()) + <-writeDone + case <-ctx.Done(): + cxn.conn.SetWriteDeadline(time.Now()) + <-writeDone + } + return +} - buf := make([]byte, size) - if _, err := io.ReadFull(conn, buf); err != nil { - return nil, ErrConnDead +func (cxn *brokerCxn) readConn(ctx context.Context, timeout time.Duration, enqueuedForReadingAt time.Time) (nread int, read []byte, err error, readWait, timeToRead time.Duration) { + if ctx == nil { + ctx = context.Background() } - return buf, nil + if timeout > 0 { + cxn.conn.SetReadDeadline(time.Now().Add(timeout)) + } + defer cxn.conn.SetReadDeadline(time.Time{}) + readDone := make(chan struct{}) + go func() { + defer close(readDone) + sizeBuf := make([]byte, 4) + readStart := time.Now() + defer func() { + timeToRead = time.Since(readStart) + readWait = readStart.Sub(enqueuedForReadingAt) + }() + if nread, err = io.ReadFull(cxn.conn, sizeBuf); err != nil { + err = ErrConnDead + return + } + size := int32(binary.BigEndian.Uint32(sizeBuf)) + if size < 0 { + err = ErrInvalidRespSize + return + } + if maxSize := cxn.b.cl.cfg.maxBrokerReadBytes; size > maxSize { + err = &ErrLargeRespSize{Size: size, Limit: maxSize} + return + } + nread2, buf := 0, make([]byte, size) + nread2, err = io.ReadFull(cxn.conn, buf) + nread += nread2 + if err != nil { + err = ErrConnDead + return + } + }() + select { + case <-readDone: + case <-cxn.cl.ctx.Done(): + cxn.conn.SetReadDeadline(time.Now()) + <-readDone + case <-ctx.Done(): + cxn.conn.SetReadDeadline(time.Now()) + <-readDone + } + return } // readResponse reads a response from conn, ensures the correlation ID is // correct, and returns a newly allocated slice on success. -func (cxn *brokerCxn) readResponse(readWait time.Duration, key int16, corrID int32, timeout time.Duration, flexibleHeader bool) ([]byte, error) { - readStart := time.Now() - buf, err := readConn(cxn.conn, cxn.b.cl.cfg.maxBrokerReadBytes, timeout) - timeToRead := time.Since(readStart) +func (cxn *brokerCxn) readResponse(ctx context.Context, timeout time.Duration, enqueuedForReadingAt time.Time, key int16, corrID int32, flexibleHeader bool) ([]byte, error) { + nread, buf, err, readWait, timeToRead := cxn.readConn(ctx, timeout, enqueuedForReadingAt) cxn.cl.cfg.hooks.each(func(h Hook) { if h, ok := h.(BrokerReadHook); ok { - // readConn reads four size bytes in addition to the buf. - h.OnRead(cxn.b.meta, key, 4+len(buf), readWait, timeToRead, err) + h.OnRead(cxn.b.meta, key, nread, readWait, timeToRead, err) } }) @@ -775,6 +823,7 @@ func (cxn *brokerCxn) closeConn() { } }) cxn.conn.Close() + close(cxn.deadCh) } // die kills a broker connection (which could be dead already) and replies to @@ -825,7 +874,7 @@ func (cxn *brokerCxn) handleResps() { var successes uint64 for pr := range cxn.resps { - raw, err := cxn.readResponse(time.Since(pr.enqueue), pr.resp.Key(), pr.corrID, pr.readTimeout, pr.flexibleHeader) + raw, err := cxn.readResponse(pr.ctx, pr.readTimeout, pr.enqueue, pr.resp.Key(), pr.corrID, pr.flexibleHeader) if err != nil { if successes > 0 || len(cxn.b.cl.cfg.sasls) > 0 { cxn.b.cl.cfg.logger.Log(LogLevelDebug, "read from broker errored, killing connection", "addr", cxn.b.addr, "id", cxn.b.meta.NodeID, "successful_reads", successes, "err", err) @@ -845,10 +894,12 @@ func (cxn *brokerCxn) handleResps() { if readErr == nil { if throttleResponse, ok := pr.resp.(kmsg.ThrottleResponse); ok { millis, throttlesAfterResp := throttleResponse.Throttle() - if throttlesAfterResp && millis > 0 { - throttleUntil := time.Now().Add(time.Millisecond * time.Duration(millis)).UnixNano() - if throttleUntil > cxn.throttleUntil { - atomic.StoreInt64(&cxn.throttleUntil, throttleUntil) + if millis > 0 { + if throttlesAfterResp { + throttleUntil := time.Now().Add(time.Millisecond * time.Duration(millis)).UnixNano() + if throttleUntil > cxn.throttleUntil { + atomic.StoreInt64(&cxn.throttleUntil, throttleUntil) + } } cxn.cl.cfg.hooks.each(func(h Hook) { if h, ok := h.(BrokerThrottleHook); ok { diff --git a/pkg/kgo/client.go b/pkg/kgo/client.go index 960c9e48..de7f8b2a 100644 --- a/pkg/kgo/client.go +++ b/pkg/kgo/client.go @@ -480,9 +480,8 @@ func (cl *Client) Close() { // of request is being issued. // // The passed context can be used to cancel a request and return early. Note -// that if the request is not canceled before it is written to Kafka, you may -// just end up canceling and not receiving the response to what Kafka -// inevitably does. +// that if the request was written to Kafka but the context canceled before a +// response is received, Kafka may still operate on the received request. func (cl *Client) Request(ctx context.Context, req kmsg.Request) (kmsg.Response, error) { resps, merge := cl.shardedRequest(ctx, req) // If there is no merge function, only one request was issued directly diff --git a/pkg/kgo/hooks.go b/pkg/kgo/hooks.go index 863e965b..fc1f9a1f 100644 --- a/pkg/kgo/hooks.go +++ b/pkg/kgo/hooks.go @@ -48,9 +48,10 @@ type BrokerDisconnectHook interface { // key is used (even though sasl authenticate requests are not being issued). type BrokerWriteHook interface { // OnWrite is passed the broker metadata, the key for the request that - // was written, the number of bytes written, how long the request - // waited before being written, how long it took to write the request, - // and any error. + // was written, the number of bytes that were written (may not be the + // whole request if there was an error), how long the request waited + // before being written (including throttling waiting), how long it + // took to write the request, and any error. // // The bytes written does not count any tls overhead. OnWrite(meta BrokerMetadata, key int16, bytesWritten int, writeWait, timeToWrite time.Duration, err error) @@ -63,9 +64,9 @@ type BrokerWriteHook interface { // key is used (even though sasl authenticate requests are not being issued). type BrokerReadHook interface { // OnRead is passed the broker metadata, the key for the response that - // was read, the number of bytes read, how long the client waited - // before reading the response, how long it took to read the response, - // and any error. + // was read, the number of bytes read (may not be the whole read if + // there was an error), how long the client waited before reading the + // response, how long it took to read the response, and any error. // // The bytes read does not count any tls overhead. OnRead(meta BrokerMetadata, key int16, bytesRead int, readWait, timeToRead time.Duration, err error)