diff --git a/pkg/kgo/broker.go b/pkg/kgo/broker.go index 58f8fa67..ffb5aef3 100644 --- a/pkg/kgo/broker.go +++ b/pkg/kgo/broker.go @@ -26,6 +26,7 @@ type promisedReq struct { } type promisedResp struct { + ctx context.Context corrID int32 readTimeout time.Duration @@ -286,10 +287,10 @@ func (b *broker) handleReqs() { } // Juuuust before we issue the request, we check if it was - // canceled. If it is not, we do not cancel hereafter. - // We only check the promised req's ctx, not our clients. - // The client ctx is closed on shutdown, which kills the - // cxn anyway. + // canceled. We could have previously tried this request, which + // then failed and retried due to the error being ErrConnDead. + // Checking the context was canceled here ensures we do not + // loop. We could be more precise with error tracking, though. select { case <-pr.ctx.Done(): pr.promise(nil, pr.ctx.Err()) @@ -297,7 +298,7 @@ func (b *broker) handleReqs() { default: } - corrID, err := cxn.writeRequest(pr.ctx, time.Since(pr.enqueue), req) + corrID, err := cxn.writeRequest(pr.ctx, pr.enqueue, req) if err != nil { pr.promise(nil, err) @@ -308,6 +309,7 @@ func (b *broker) handleReqs() { rt, _ := cxn.cl.connTimeoutFn(req) cxn.waitResp(promisedResp{ + pr.ctx, corrID, rt, req.IsFlexible() && req.Key() != 18, // response header not flexible if ApiVersions; see promisedResp doc @@ -353,8 +355,9 @@ func (b *broker) loadConnection(ctx context.Context, reqKey int16) (*brokerCxn, cl: b.cl, b: b, - addr: b.addr, - conn: conn, + addr: b.addr, + conn: conn, + deadCh: make(chan struct{}), } if err = cxn.init(); err != nil { b.cl.cfg.logger.Log(LogLevelDebug, "connection initialization failed", "addr", b.addr, "id", b.meta.NodeID, "err", err) @@ -414,6 +417,8 @@ type brokerCxn struct { resps chan promisedResp // dead is an atomic so that a backed up resps cannot block cxn death. dead int32 + // closed in cloneConn; allows throttle waiting to quit + deadCh chan struct{} } func (cxn *brokerCxn) init() error { @@ -447,13 +452,13 @@ start: ClientSoftwareVersion: cxn.cl.cfg.softwareVersion, } cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing api versions request", "version", maxVersion) - corrID, err := cxn.writeRequest(nil, 0, req) + corrID, err := cxn.writeRequest(nil, time.Now(), req) if err != nil { return err } rt, _ := cxn.cl.connTimeoutFn(req) - rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, false) // api versions does *not* use flexible response headers; see comment in promisedResp + rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, false) // api versions does *not* use flexible response headers; see comment in promisedResp if err != nil { return err } @@ -515,13 +520,13 @@ start: req.Mechanism = mechanism.Name() req.Version = cxn.versions[req.Key()] cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing SASLHandshakeRequest") - corrID, err := cxn.writeRequest(nil, 0, req) + corrID, err := cxn.writeRequest(nil, time.Now(), req) if err != nil { return err } rt, _ := cxn.cl.connTimeoutFn(req) - rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, req.IsFlexible()) + rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, req.IsFlexible()) if err != nil { return err } @@ -582,14 +587,8 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error { binary.BigEndian.PutUint32(buf, uint32(len(clientWrite))) buf = append(buf, clientWrite...) - if wt > 0 { - cxn.conn.SetWriteDeadline(time.Now().Add(wt)) - } cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing raw sasl authenticate", "step", step) - _, err = cxn.conn.Write(buf) - if wt > 0 { - cxn.conn.SetWriteDeadline(time.Time{}) - } + _, err, _, _ = cxn.writeConn(context.Background(), buf, wt, time.Now()) cxn.cl.bufPool.put(buf) @@ -597,7 +596,7 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error { return ErrConnDead } if !done { - if challenge, err = readConn(cxn.conn, cxn.b.cl.cfg.maxBrokerReadBytes, rt); err != nil { + if _, challenge, err, _, _ = cxn.readConn(context.Background(), rt, time.Now()); err != nil { return err } } @@ -609,12 +608,12 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error { req.Version = cxn.versions[req.Key()] cxn.cl.cfg.logger.Log(LogLevelDebug, "issuing SASLAuthenticate", "version", req.Version, "step", step) - corrID, err := cxn.writeRequest(nil, 0, req) + corrID, err := cxn.writeRequest(nil, time.Now(), req) if err != nil { return err } if !done { - rawResp, err := cxn.readResponse(0, req.Key(), corrID, rt, req.IsFlexible()) + rawResp, err := cxn.readResponse(nil, rt, time.Now(), req.Key(), corrID, req.IsFlexible()) if err != nil { return err } @@ -659,7 +658,7 @@ func (cxn *brokerCxn) doSasl(authenticate bool) error { // writeRequest writes a message request to the broker connection, bumping the // connection's correlation ID as appropriate for the next write. -func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, req kmsg.Request) (int32, error) { +func (cxn *brokerCxn) writeRequest(ctx context.Context, enqueuedForWritingAt time.Time, req kmsg.Request) (int32, error) { // A nil ctx means we cannot be throttled. if ctx != nil { throttleUntil := time.Unix(0, atomic.LoadInt64(&cxn.throttleUntil)) @@ -669,12 +668,17 @@ func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, case <-after.C: case <-ctx.Done(): after.Stop() + return 0, ctx.Err() + case <-cxn.cl.ctx.Done(): + after.Stop() + return 0, ctx.Err() + case <-cxn.deadCh: + after.Stop() + return 0, ErrConnDead } } } - // TODO: write in a goroutine, use ctx to allow for early cancel. - buf := cxn.cl.bufPool.get() defer cxn.cl.bufPool.put(buf) buf = cxn.cl.reqFormatter.AppendRequest( @@ -684,22 +688,15 @@ func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, ) _, wt := cxn.cl.connTimeoutFn(req) - if wt > 0 { - cxn.conn.SetWriteDeadline(time.Now().Add(wt)) - defer cxn.conn.SetWriteDeadline(time.Time{}) - } - - writeStart := time.Now() - _, err := cxn.conn.Write(buf) - timeToWrite := time.Since(writeStart) + bytesWritten, writeErr, writeWait, timeToWrite := cxn.writeConn(ctx, buf, wt, enqueuedForWritingAt) cxn.cl.cfg.hooks.each(func(h Hook) { if h, ok := h.(BrokerWriteHook); ok { - h.OnWrite(cxn.b.meta, req.Key(), len(buf), writeWait, timeToWrite, err) + h.OnWrite(cxn.b.meta, req.Key(), bytesWritten, writeWait, timeToWrite, writeErr) } }) - if err != nil { + if writeErr != nil { return 0, ErrConnDead } id := cxn.corrID @@ -707,41 +704,92 @@ func (cxn *brokerCxn) writeRequest(ctx context.Context, writeWait time.Duration, return id, nil } -func readConn(conn net.Conn, maxSize int32, timeout time.Duration) ([]byte, error) { - sizeBuf := make([]byte, 4) - if timeout > 0 { - conn.SetReadDeadline(time.Now().Add(timeout)) - defer conn.SetReadDeadline(time.Time{}) - } - if _, err := io.ReadFull(conn, sizeBuf); err != nil { - return nil, ErrConnDead +func (cxn *brokerCxn) writeConn(ctx context.Context, buf []byte, timeout time.Duration, enqueuedForWritingAt time.Time) (bytesWritten int, writeErr error, writeWait, timeToWrite time.Duration) { + if ctx == nil { + ctx = context.Background() } - size := int32(binary.BigEndian.Uint32(sizeBuf)) - if size < 0 { - return nil, ErrInvalidRespSize - } - if size > maxSize { - return nil, &ErrLargeRespSize{Size: size, Limit: maxSize} + if timeout > 0 { + cxn.conn.SetWriteDeadline(time.Now().Add(timeout)) } + defer cxn.conn.SetWriteDeadline(time.Time{}) + writeDone := make(chan struct{}) + go func() { + defer close(writeDone) + writeStart := time.Now() + bytesWritten, writeErr = cxn.conn.Write(buf) + timeToWrite = time.Since(writeStart) + writeWait = writeStart.Sub(enqueuedForWritingAt) + }() + select { + case <-writeDone: + case <-cxn.cl.ctx.Done(): + cxn.conn.SetWriteDeadline(time.Now()) + <-writeDone + case <-ctx.Done(): + cxn.conn.SetWriteDeadline(time.Now()) + <-writeDone + } + return +} - buf := make([]byte, size) - if _, err := io.ReadFull(conn, buf); err != nil { - return nil, ErrConnDead +func (cxn *brokerCxn) readConn(ctx context.Context, timeout time.Duration, enqueuedForReadingAt time.Time) (nread int, read []byte, err error, readWait, timeToRead time.Duration) { + if ctx == nil { + ctx = context.Background() } - return buf, nil + if timeout > 0 { + cxn.conn.SetReadDeadline(time.Now().Add(timeout)) + } + defer cxn.conn.SetReadDeadline(time.Time{}) + readDone := make(chan struct{}) + go func() { + defer close(readDone) + sizeBuf := make([]byte, 4) + readStart := time.Now() + defer func() { + timeToRead = time.Since(readStart) + readWait = readStart.Sub(enqueuedForReadingAt) + }() + if nread, err = io.ReadFull(cxn.conn, sizeBuf); err != nil { + err = ErrConnDead + return + } + size := int32(binary.BigEndian.Uint32(sizeBuf)) + if size < 0 { + err = ErrInvalidRespSize + return + } + if maxSize := cxn.b.cl.cfg.maxBrokerReadBytes; size > maxSize { + err = &ErrLargeRespSize{Size: size, Limit: maxSize} + return + } + nread2, buf := 0, make([]byte, size) + nread2, err = io.ReadFull(cxn.conn, buf) + nread += nread2 + if err != nil { + err = ErrConnDead + return + } + }() + select { + case <-readDone: + case <-cxn.cl.ctx.Done(): + cxn.conn.SetReadDeadline(time.Now()) + <-readDone + case <-ctx.Done(): + cxn.conn.SetReadDeadline(time.Now()) + <-readDone + } + return } // readResponse reads a response from conn, ensures the correlation ID is // correct, and returns a newly allocated slice on success. -func (cxn *brokerCxn) readResponse(readWait time.Duration, key int16, corrID int32, timeout time.Duration, flexibleHeader bool) ([]byte, error) { - readStart := time.Now() - buf, err := readConn(cxn.conn, cxn.b.cl.cfg.maxBrokerReadBytes, timeout) - timeToRead := time.Since(readStart) +func (cxn *brokerCxn) readResponse(ctx context.Context, timeout time.Duration, enqueuedForReadingAt time.Time, key int16, corrID int32, flexibleHeader bool) ([]byte, error) { + nread, buf, err, readWait, timeToRead := cxn.readConn(ctx, timeout, enqueuedForReadingAt) cxn.cl.cfg.hooks.each(func(h Hook) { if h, ok := h.(BrokerReadHook); ok { - // readConn reads four size bytes in addition to the buf. - h.OnRead(cxn.b.meta, key, 4+len(buf), readWait, timeToRead, err) + h.OnRead(cxn.b.meta, key, nread, readWait, timeToRead, err) } }) @@ -775,6 +823,7 @@ func (cxn *brokerCxn) closeConn() { } }) cxn.conn.Close() + close(cxn.deadCh) } // die kills a broker connection (which could be dead already) and replies to @@ -825,7 +874,7 @@ func (cxn *brokerCxn) handleResps() { var successes uint64 for pr := range cxn.resps { - raw, err := cxn.readResponse(time.Since(pr.enqueue), pr.resp.Key(), pr.corrID, pr.readTimeout, pr.flexibleHeader) + raw, err := cxn.readResponse(pr.ctx, pr.readTimeout, pr.enqueue, pr.resp.Key(), pr.corrID, pr.flexibleHeader) if err != nil { if successes > 0 || len(cxn.b.cl.cfg.sasls) > 0 { cxn.b.cl.cfg.logger.Log(LogLevelDebug, "read from broker errored, killing connection", "addr", cxn.b.addr, "id", cxn.b.meta.NodeID, "successful_reads", successes, "err", err) @@ -845,10 +894,12 @@ func (cxn *brokerCxn) handleResps() { if readErr == nil { if throttleResponse, ok := pr.resp.(kmsg.ThrottleResponse); ok { millis, throttlesAfterResp := throttleResponse.Throttle() - if throttlesAfterResp && millis > 0 { - throttleUntil := time.Now().Add(time.Millisecond * time.Duration(millis)).UnixNano() - if throttleUntil > cxn.throttleUntil { - atomic.StoreInt64(&cxn.throttleUntil, throttleUntil) + if millis > 0 { + if throttlesAfterResp { + throttleUntil := time.Now().Add(time.Millisecond * time.Duration(millis)).UnixNano() + if throttleUntil > cxn.throttleUntil { + atomic.StoreInt64(&cxn.throttleUntil, throttleUntil) + } } cxn.cl.cfg.hooks.each(func(h Hook) { if h, ok := h.(BrokerThrottleHook); ok { diff --git a/pkg/kgo/client.go b/pkg/kgo/client.go index 960c9e48..de7f8b2a 100644 --- a/pkg/kgo/client.go +++ b/pkg/kgo/client.go @@ -480,9 +480,8 @@ func (cl *Client) Close() { // of request is being issued. // // The passed context can be used to cancel a request and return early. Note -// that if the request is not canceled before it is written to Kafka, you may -// just end up canceling and not receiving the response to what Kafka -// inevitably does. +// that if the request was written to Kafka but the context canceled before a +// response is received, Kafka may still operate on the received request. func (cl *Client) Request(ctx context.Context, req kmsg.Request) (kmsg.Response, error) { resps, merge := cl.shardedRequest(ctx, req) // If there is no merge function, only one request was issued directly diff --git a/pkg/kgo/hooks.go b/pkg/kgo/hooks.go index 863e965b..fc1f9a1f 100644 --- a/pkg/kgo/hooks.go +++ b/pkg/kgo/hooks.go @@ -48,9 +48,10 @@ type BrokerDisconnectHook interface { // key is used (even though sasl authenticate requests are not being issued). type BrokerWriteHook interface { // OnWrite is passed the broker metadata, the key for the request that - // was written, the number of bytes written, how long the request - // waited before being written, how long it took to write the request, - // and any error. + // was written, the number of bytes that were written (may not be the + // whole request if there was an error), how long the request waited + // before being written (including throttling waiting), how long it + // took to write the request, and any error. // // The bytes written does not count any tls overhead. OnWrite(meta BrokerMetadata, key int16, bytesWritten int, writeWait, timeToWrite time.Duration, err error) @@ -63,9 +64,9 @@ type BrokerWriteHook interface { // key is used (even though sasl authenticate requests are not being issued). type BrokerReadHook interface { // OnRead is passed the broker metadata, the key for the response that - // was read, the number of bytes read, how long the client waited - // before reading the response, how long it took to read the response, - // and any error. + // was read, the number of bytes read (may not be the whole read if + // there was an error), how long the client waited before reading the + // response, how long it took to read the response, and any error. // // The bytes read does not count any tls overhead. OnRead(meta BrokerMetadata, key int16, bytesRead int, readWait, timeToRead time.Duration, err error)