From 8da5ca7a4059c717c94eb8f42b7502098475fdf0 Mon Sep 17 00:00:00 2001 From: Zarina Sayfullina Date: Mon, 13 May 2024 19:51:48 +0300 Subject: [PATCH] Add transport response callback --- conn.go | 7 +++++++ ctx.go | 28 ++++++++++++++++++++++++++++ ctx_test.go | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+) create mode 100644 ctx.go create mode 100644 ctx_test.go diff --git a/conn.go b/conn.go index 6358e64..ca6d772 100644 --- a/conn.go +++ b/conn.go @@ -277,6 +277,12 @@ func (c *conn) doRequest(ctx context.Context, req *http.Request) (io.ReadCloser, c.cancel = nil return nil, fmt.Errorf("doRequest: transport failed to send a request to ClickHouse: %w", err) } + + if err = callCtxTransportCallback(ctx, req, resp); err != nil { + c.cancel = nil + return nil, fmt.Errorf("doRequest: transport callback: %w", err) + } + if resp.StatusCode != 200 { msg, err := readResponse(resp) c.cancel = nil @@ -287,6 +293,7 @@ func (c *conn) doRequest(ctx context.Context, req *http.Request) (io.ReadCloser, // response return nil, newError(string(msg)) } + return resp.Body, nil } diff --git a/ctx.go b/ctx.go new file mode 100644 index 0000000..638d84b --- /dev/null +++ b/ctx.go @@ -0,0 +1,28 @@ +package clickhouse + +import ( + "context" + "net/http" +) + +type ctxKey uint8 + +const ( + ctxTransportCallbackKey ctxKey = iota + 1 +) + +// TransportCallback is a transport response callback. Called before processing the http response. +type TransportCallback func(*http.Request, *http.Response) error + +// CtxAddTransportCallback adds callback to work with transport response. +func CtxAddTransportCallback(ctx context.Context, f TransportCallback) context.Context { + return context.WithValue(ctx, ctxTransportCallbackKey, f) +} + +func callCtxTransportCallback(ctx context.Context, req *http.Request, resp *http.Response) error { + if f, ok := ctx.Value(ctxTransportCallbackKey).(TransportCallback); ok && f != nil { + return f(req, resp) + } + + return nil +} diff --git a/ctx_test.go b/ctx_test.go new file mode 100644 index 0000000..9f79bab --- /dev/null +++ b/ctx_test.go @@ -0,0 +1,41 @@ +package clickhouse + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_CtxAddTransportCallback(t *testing.T) { + var flag bool + ctx := context.Background() + + ctx = CtxAddTransportCallback(ctx, func(_ *http.Request, _ *http.Response) error { + flag = true + return nil + }) + + assert.NoError(t, callCtxTransportCallback(ctx, + httptest.NewRequest(http.MethodGet, "http://localhost", nil), httptest.NewRecorder().Result(), + )) + assert.True(t, flag) +} + +func Test_CtxAddTransportCallback_err(t *testing.T) { + var flag bool + ctx := context.Background() + + ctx = CtxAddTransportCallback(ctx, func(_ *http.Request, _ *http.Response) error { + flag = true + return errors.New("some error") + }) + + assert.EqualError(t, callCtxTransportCallback(ctx, + httptest.NewRequest(http.MethodGet, "http://localhost", nil), httptest.NewRecorder().Result(), + ), "some error") + assert.True(t, flag) +}