From 636872b2af7737e8b58b8fa6abe9c8b6eb7a736d Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Mon, 31 Jul 2023 12:30:52 -0700 Subject: [PATCH 1/6] Validate and encode query parameters If the endpoint passed to runtime.NewRequest contains query parameters they must be validated and encoded. The typical case is when a paged operation's nextLink value contains query parameters. --- sdk/azcore/CHANGELOG.md | 1 + sdk/azcore/internal/exported/request.go | 6 ++++++ sdk/azcore/internal/exported/request_test.go | 15 +++++++++++++++ sdk/azcore/runtime/request.go | 3 ++- 4 files changed, 24 insertions(+), 1 deletion(-) diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 7012e03e46cd..555b92786865 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -28,6 +28,7 @@ ### Bugs Fixed * Propagate any query parameters when constructing a fake poller and/or injecting next links. +* Calling `runtime.NewRequest` will encode any query parameters in the `endpoint` parameter. ## 1.7.1 (2023-08-14) diff --git a/sdk/azcore/internal/exported/request.go b/sdk/azcore/internal/exported/request.go index 659f2a7d2ead..8abd2c72c4b5 100644 --- a/sdk/azcore/internal/exported/request.go +++ b/sdk/azcore/internal/exported/request.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "net/http" + "net/url" "reflect" "strconv" @@ -80,6 +81,11 @@ func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Reque if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") { return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme) } + qp, err := url.ParseQuery(req.URL.RawQuery) + if err != nil { + return nil, err + } + req.URL.RawQuery = qp.Encode() return &Request{req: req}, nil } diff --git a/sdk/azcore/internal/exported/request_test.go b/sdk/azcore/internal/exported/request_test.go index 3acc8e7a76ae..a8b16de7019f 100644 --- a/sdk/azcore/internal/exported/request_test.go +++ b/sdk/azcore/internal/exported/request_test.go @@ -211,3 +211,18 @@ func TestRequestWithContext(t *testing.T) { req2.Raw().Header.Add("added-req2", "value") require.EqualValues(t, "value", req1.Raw().Header.Get("added-req2")) } + +func TestNewRequestWithEncoding(t *testing.T) { + req, err := NewRequest(context.Background(), http.MethodGet, testURL+"query?$skip=5&$filter='foo eq bar'") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", req.Raw().URL.String()) + req, err = NewRequest(context.Background(), http.MethodGet, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", req.Raw().URL.String()) + req, err = NewRequest(context.Background(), http.MethodGet, testURL+"query?foo=bar&one=two") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?foo=bar&one=two", req.Raw().URL.String()) + req, err = NewRequest(context.Background(), http.MethodGet, testURL+"query?invalid=;semicolon") + require.Error(t, err) + require.Nil(t, req) +} diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go index 7938b5fb7d26..b1a4ccaf7be8 100644 --- a/sdk/azcore/runtime/request.go +++ b/sdk/azcore/runtime/request.go @@ -38,7 +38,8 @@ const ( ) // NewRequest creates a new policy.Request with the specified input. -// The endpoint MUST be properly encoded before calling this function. +// It's assumed that the URL's path segments have been properly encoded. +// Any query parameters will be validated and encoded during construction. func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*policy.Request, error) { return exported.NewRequest(ctx, httpMethod, endpoint) } From 45c8f053f880ef9c224672531ef0fc9322a93248 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Mon, 31 Jul 2023 13:27:23 -0700 Subject: [PATCH 2/6] add test with no query params --- sdk/azcore/internal/exported/request_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sdk/azcore/internal/exported/request_test.go b/sdk/azcore/internal/exported/request_test.go index a8b16de7019f..55e47502fa5f 100644 --- a/sdk/azcore/internal/exported/request_test.go +++ b/sdk/azcore/internal/exported/request_test.go @@ -222,6 +222,9 @@ func TestNewRequestWithEncoding(t *testing.T) { req, err = NewRequest(context.Background(), http.MethodGet, testURL+"query?foo=bar&one=two") require.NoError(t, err) require.EqualValues(t, testURL+"query?foo=bar&one=two", req.Raw().URL.String()) + req, err = NewRequest(context.Background(), http.MethodGet, testURL) + require.NoError(t, err) + require.EqualValues(t, testURL, req.Raw().URL.String()) req, err = NewRequest(context.Background(), http.MethodGet, testURL+"query?invalid=;semicolon") require.Error(t, err) require.Nil(t, req) From c71012c1e8da984cddd38eb8d7f1cc7dd4fab62f Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Tue, 1 Aug 2023 10:03:26 -0700 Subject: [PATCH 3/6] consolidate creating PagingHandler[T].Fetchers --- sdk/azcore/CHANGELOG.md | 4 +- sdk/azcore/internal/exported/request.go | 6 -- sdk/azcore/internal/exported/request_test.go | 18 ----- sdk/azcore/runtime/pager.go | 39 ++++++++++ sdk/azcore/runtime/pager_test.go | 76 ++++++++++++++++++++ sdk/azcore/runtime/request.go | 3 +- 6 files changed, 118 insertions(+), 28 deletions(-) diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 555b92786865..c5522581a01f 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +* Added function `FetcherHelper` to the `runtime` package to centralize creation of `Pager[T].Fetcher` values. + ### Breaking Changes ### Bugs Fixed @@ -15,7 +17,6 @@ ### Features Added * Added function `SanitizePagerPollerPath` to the `server` package to centralize sanitization and formalize the contract. -* Added `TokenRequestOptions.EnableCAE` to indicate whether to request a CAE token. ### Breaking Changes @@ -28,7 +29,6 @@ ### Bugs Fixed * Propagate any query parameters when constructing a fake poller and/or injecting next links. -* Calling `runtime.NewRequest` will encode any query parameters in the `endpoint` parameter. ## 1.7.1 (2023-08-14) diff --git a/sdk/azcore/internal/exported/request.go b/sdk/azcore/internal/exported/request.go index 8abd2c72c4b5..659f2a7d2ead 100644 --- a/sdk/azcore/internal/exported/request.go +++ b/sdk/azcore/internal/exported/request.go @@ -13,7 +13,6 @@ import ( "fmt" "io" "net/http" - "net/url" "reflect" "strconv" @@ -81,11 +80,6 @@ func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*Reque if !(req.URL.Scheme == "http" || req.URL.Scheme == "https") { return nil, fmt.Errorf("unsupported protocol scheme %s", req.URL.Scheme) } - qp, err := url.ParseQuery(req.URL.RawQuery) - if err != nil { - return nil, err - } - req.URL.RawQuery = qp.Encode() return &Request{req: req}, nil } diff --git a/sdk/azcore/internal/exported/request_test.go b/sdk/azcore/internal/exported/request_test.go index 55e47502fa5f..3acc8e7a76ae 100644 --- a/sdk/azcore/internal/exported/request_test.go +++ b/sdk/azcore/internal/exported/request_test.go @@ -211,21 +211,3 @@ func TestRequestWithContext(t *testing.T) { req2.Raw().Header.Add("added-req2", "value") require.EqualValues(t, "value", req1.Raw().Header.Get("added-req2")) } - -func TestNewRequestWithEncoding(t *testing.T) { - req, err := NewRequest(context.Background(), http.MethodGet, testURL+"query?$skip=5&$filter='foo eq bar'") - require.NoError(t, err) - require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", req.Raw().URL.String()) - req, err = NewRequest(context.Background(), http.MethodGet, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5") - require.NoError(t, err) - require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", req.Raw().URL.String()) - req, err = NewRequest(context.Background(), http.MethodGet, testURL+"query?foo=bar&one=two") - require.NoError(t, err) - require.EqualValues(t, testURL+"query?foo=bar&one=two", req.Raw().URL.String()) - req, err = NewRequest(context.Background(), http.MethodGet, testURL) - require.NoError(t, err) - require.EqualValues(t, testURL, req.Raw().URL.String()) - req, err = NewRequest(context.Background(), http.MethodGet, testURL+"query?invalid=;semicolon") - require.Error(t, err) - require.Nil(t, req) -} diff --git a/sdk/azcore/runtime/pager.go b/sdk/azcore/runtime/pager.go index b7e59527a3d8..c59992bc6a3d 100644 --- a/sdk/azcore/runtime/pager.go +++ b/sdk/azcore/runtime/pager.go @@ -11,8 +11,12 @@ import ( "encoding/json" "errors" "fmt" + "net/http" + "net/url" "reflect" + "strings" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" ) @@ -88,3 +92,38 @@ func (p *Pager[T]) NextPage(ctx context.Context) (T, error) { func (p *Pager[T]) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, &p.current) } + +// FetcherHelper is a helper containing boilerplate code to simplify creating a PagingHandler[T].Fetcher. +func FetcherHelper(ctx context.Context, pl Pipeline, nextLink string, createReq func(context.Context) (*policy.Request, error)) (*http.Response, error) { + var req *policy.Request + var err error + if nextLink == "" { + req, err = createReq(ctx) + } else if nextLink, err = encodeNextLink(nextLink); err == nil { + req, err = NewRequest(ctx, http.MethodGet, nextLink) + } + if err != nil { + return nil, err + } + resp, err := pl.Do(req) + if err != nil { + return nil, err + } + if !HasStatusCode(resp, http.StatusOK) { + return nil, NewResponseError(resp) + } + return resp, nil +} + +// encode any query parameters in the nextLink +func encodeNextLink(nextLink string) (string, error) { + before, after, found := strings.Cut(nextLink, "?") + if !found { + return nextLink, nil + } + qp, err := url.ParseQuery(after) + if err != nil { + return "", err + } + return before + "?" + qp.Encode(), nil +} diff --git a/sdk/azcore/runtime/pager_test.go b/sdk/azcore/runtime/pager_test.go index 3d33330caa6e..a9686fe13f8f 100644 --- a/sdk/azcore/runtime/pager_test.go +++ b/sdk/azcore/runtime/pager_test.go @@ -14,6 +14,7 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/stretchr/testify/require" ) @@ -256,3 +257,78 @@ func TestPagerResponderError(t *testing.T) { require.Error(t, err) require.Empty(t, page) } + +func TestFetcherHelper(t *testing.T) { + srv, close := mock.NewServer() + defer close() + pl := exported.NewPipeline(srv) + + srv.AppendResponse() + createReqCalled := false + resp, err := FetcherHelper(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + createReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }) + require.NoError(t, err) + require.True(t, createReqCalled) + require.NotNil(t, resp) + require.EqualValues(t, http.StatusOK, resp.StatusCode) + + srv.AppendResponse() + createReqCalled = false + resp, err = FetcherHelper(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { + createReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }) + require.NoError(t, err) + require.False(t, createReqCalled) + require.NotNil(t, resp) + require.EqualValues(t, http.StatusOK, resp.StatusCode) + + resp, err = FetcherHelper(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + return nil, errors.New("failed") + }) + require.Error(t, err) + require.Nil(t, resp) + + srv.AppendError(errors.New("failed")) + resp, err = FetcherHelper(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + createReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }) + require.Error(t, err) + require.True(t, createReqCalled) + require.Nil(t, resp) + + srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest), mock.WithBody([]byte(`{ "error": { "code": "InvalidResource", "message": "doesn't exist" } }`))) + createReqCalled = false + resp, err = FetcherHelper(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { + createReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }) + require.Error(t, err) + var respErr *exported.ResponseError + require.ErrorAs(t, err, &respErr) + require.EqualValues(t, "InvalidResource", respErr.ErrorCode) + require.False(t, createReqCalled) + require.Nil(t, resp) +} + +func TestEncodeNextLink(t *testing.T) { + const testURL = "https://contoso.com/" + nextLink, err := encodeNextLink(testURL + "query?$skip=5&$filter='foo eq bar'") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink) + nextLink, err = encodeNextLink(testURL + "query?%24filter=%27foo+eq+bar%27&%24skip=5") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink) + nextLink, err = encodeNextLink(testURL + "query?foo=bar&one=two") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?foo=bar&one=two", nextLink) + nextLink, err = encodeNextLink(testURL) + require.NoError(t, err) + require.EqualValues(t, testURL, nextLink) + nextLink, err = encodeNextLink(testURL + "query?invalid=;semicolon") + require.Error(t, err) + require.Empty(t, nextLink) +} diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go index b1a4ccaf7be8..7938b5fb7d26 100644 --- a/sdk/azcore/runtime/request.go +++ b/sdk/azcore/runtime/request.go @@ -38,8 +38,7 @@ const ( ) // NewRequest creates a new policy.Request with the specified input. -// It's assumed that the URL's path segments have been properly encoded. -// Any query parameters will be validated and encoded during construction. +// The endpoint MUST be properly encoded before calling this function. func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*policy.Request, error) { return exported.NewRequest(ctx, httpMethod, endpoint) } From ca68b79464d4aa16e4d886bb04cc8143f0491fa3 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Thu, 17 Aug 2023 08:48:15 -0700 Subject: [PATCH 4/6] make EncodeQueryParams its own func --- sdk/azcore/runtime/pager.go | 17 +---------------- sdk/azcore/runtime/pager_test.go | 19 ------------------- sdk/azcore/runtime/request.go | 14 ++++++++++++++ sdk/azcore/runtime/request_test.go | 19 +++++++++++++++++++ 4 files changed, 34 insertions(+), 35 deletions(-) diff --git a/sdk/azcore/runtime/pager.go b/sdk/azcore/runtime/pager.go index c59992bc6a3d..37168957ff9b 100644 --- a/sdk/azcore/runtime/pager.go +++ b/sdk/azcore/runtime/pager.go @@ -12,9 +12,7 @@ import ( "errors" "fmt" "net/http" - "net/url" "reflect" - "strings" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" @@ -99,7 +97,7 @@ func FetcherHelper(ctx context.Context, pl Pipeline, nextLink string, createReq var err error if nextLink == "" { req, err = createReq(ctx) - } else if nextLink, err = encodeNextLink(nextLink); err == nil { + } else if nextLink, err = EncodeQueryParams(nextLink); err == nil { req, err = NewRequest(ctx, http.MethodGet, nextLink) } if err != nil { @@ -114,16 +112,3 @@ func FetcherHelper(ctx context.Context, pl Pipeline, nextLink string, createReq } return resp, nil } - -// encode any query parameters in the nextLink -func encodeNextLink(nextLink string) (string, error) { - before, after, found := strings.Cut(nextLink, "?") - if !found { - return nextLink, nil - } - qp, err := url.ParseQuery(after) - if err != nil { - return "", err - } - return before + "?" + qp.Encode(), nil -} diff --git a/sdk/azcore/runtime/pager_test.go b/sdk/azcore/runtime/pager_test.go index a9686fe13f8f..ef8bc87c56ef 100644 --- a/sdk/azcore/runtime/pager_test.go +++ b/sdk/azcore/runtime/pager_test.go @@ -313,22 +313,3 @@ func TestFetcherHelper(t *testing.T) { require.False(t, createReqCalled) require.Nil(t, resp) } - -func TestEncodeNextLink(t *testing.T) { - const testURL = "https://contoso.com/" - nextLink, err := encodeNextLink(testURL + "query?$skip=5&$filter='foo eq bar'") - require.NoError(t, err) - require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink) - nextLink, err = encodeNextLink(testURL + "query?%24filter=%27foo+eq+bar%27&%24skip=5") - require.NoError(t, err) - require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink) - nextLink, err = encodeNextLink(testURL + "query?foo=bar&one=two") - require.NoError(t, err) - require.EqualValues(t, testURL+"query?foo=bar&one=two", nextLink) - nextLink, err = encodeNextLink(testURL) - require.NoError(t, err) - require.EqualValues(t, testURL, nextLink) - nextLink, err = encodeNextLink(testURL + "query?invalid=;semicolon") - require.Error(t, err) - require.Empty(t, nextLink) -} diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go index 7938b5fb7d26..6f3b8966998b 100644 --- a/sdk/azcore/runtime/request.go +++ b/sdk/azcore/runtime/request.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "mime/multipart" + "net/url" "os" "path" "reflect" @@ -43,6 +44,19 @@ func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*polic return exported.NewRequest(ctx, httpMethod, endpoint) } +// EncodeQueryParams will parse and encode any query parameters in the specified URL. +func EncodeQueryParams(u string) (string, error) { + before, after, found := strings.Cut(u, "?") + if !found { + return u, nil + } + qp, err := url.ParseQuery(after) + if err != nil { + return "", err + } + return before + "?" + qp.Encode(), nil +} + // JoinPaths concatenates multiple URL path segments into one path, // inserting path separation characters as required. JoinPaths will preserve // query parameters in the root path diff --git a/sdk/azcore/runtime/request_test.go b/sdk/azcore/runtime/request_test.go index aa7c1b3ecb16..0a8da39cf021 100644 --- a/sdk/azcore/runtime/request_test.go +++ b/sdk/azcore/runtime/request_test.go @@ -740,3 +740,22 @@ func TestSetMultipartFormData(t *testing.T) { require.Equal(t, "second part", string(second)) require.Equal(t, "third part", string(third)) } + +func TestEncodeQueryParams(t *testing.T) { + const testURL = "https://contoso.com/" + nextLink, err := EncodeQueryParams(testURL + "query?$skip=5&$filter='foo eq bar'") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink) + nextLink, err = EncodeQueryParams(testURL + "query?%24filter=%27foo+eq+bar%27&%24skip=5") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink) + nextLink, err = EncodeQueryParams(testURL + "query?foo=bar&one=two") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?foo=bar&one=two", nextLink) + nextLink, err = EncodeQueryParams(testURL) + require.NoError(t, err) + require.EqualValues(t, testURL, nextLink) + nextLink, err = EncodeQueryParams(testURL + "query?invalid=;semicolon") + require.Error(t, err) + require.Empty(t, nextLink) +} From fe30e16b517614a62600763c277373535250f925 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Thu, 17 Aug 2023 10:03:56 -0700 Subject: [PATCH 5/6] fix bag merge --- sdk/azcore/CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index c5522581a01f..c3822cff07cf 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -17,6 +17,7 @@ ### Features Added * Added function `SanitizePagerPollerPath` to the `server` package to centralize sanitization and formalize the contract. +* Added `TokenRequestOptions.EnableCAE` to indicate whether to request a CAE token. ### Breaking Changes From f0e6690cbe227d3622ef6e628b5298926cd7d827 Mon Sep 17 00:00:00 2001 From: Joel Hendrix Date: Mon, 21 Aug 2023 10:46:02 -0700 Subject: [PATCH 6/6] rename --- sdk/azcore/CHANGELOG.md | 2 +- sdk/azcore/runtime/pager.go | 4 ++-- sdk/azcore/runtime/pager_test.go | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index c3822cff07cf..6cbd90e0cfa1 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -4,7 +4,7 @@ ### Features Added -* Added function `FetcherHelper` to the `runtime` package to centralize creation of `Pager[T].Fetcher` values. +* Added function `FetcherForNextLink` to the `runtime` package to centralize creation of `Pager[T].Fetcher` from a next link URL. ### Breaking Changes diff --git a/sdk/azcore/runtime/pager.go b/sdk/azcore/runtime/pager.go index 37168957ff9b..f1daac50d32a 100644 --- a/sdk/azcore/runtime/pager.go +++ b/sdk/azcore/runtime/pager.go @@ -91,8 +91,8 @@ func (p *Pager[T]) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, &p.current) } -// FetcherHelper is a helper containing boilerplate code to simplify creating a PagingHandler[T].Fetcher. -func FetcherHelper(ctx context.Context, pl Pipeline, nextLink string, createReq func(context.Context) (*policy.Request, error)) (*http.Response, error) { +// FetcherForNextLink is a helper containing boilerplate code to simplify creating a PagingHandler[T].Fetcher from a next link URL. +func FetcherForNextLink(ctx context.Context, pl Pipeline, nextLink string, createReq func(context.Context) (*policy.Request, error)) (*http.Response, error) { var req *policy.Request var err error if nextLink == "" { diff --git a/sdk/azcore/runtime/pager_test.go b/sdk/azcore/runtime/pager_test.go index ef8bc87c56ef..9d205b297d60 100644 --- a/sdk/azcore/runtime/pager_test.go +++ b/sdk/azcore/runtime/pager_test.go @@ -258,14 +258,14 @@ func TestPagerResponderError(t *testing.T) { require.Empty(t, page) } -func TestFetcherHelper(t *testing.T) { +func TestFetcherForNextLink(t *testing.T) { srv, close := mock.NewServer() defer close() pl := exported.NewPipeline(srv) srv.AppendResponse() createReqCalled := false - resp, err := FetcherHelper(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + resp, err := FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { createReqCalled = true return NewRequest(ctx, http.MethodGet, srv.URL()) }) @@ -276,7 +276,7 @@ func TestFetcherHelper(t *testing.T) { srv.AppendResponse() createReqCalled = false - resp, err = FetcherHelper(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { + resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { createReqCalled = true return NewRequest(ctx, http.MethodGet, srv.URL()) }) @@ -285,14 +285,14 @@ func TestFetcherHelper(t *testing.T) { require.NotNil(t, resp) require.EqualValues(t, http.StatusOK, resp.StatusCode) - resp, err = FetcherHelper(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { return nil, errors.New("failed") }) require.Error(t, err) require.Nil(t, resp) srv.AppendError(errors.New("failed")) - resp, err = FetcherHelper(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { createReqCalled = true return NewRequest(ctx, http.MethodGet, srv.URL()) }) @@ -302,7 +302,7 @@ func TestFetcherHelper(t *testing.T) { srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest), mock.WithBody([]byte(`{ "error": { "code": "InvalidResource", "message": "doesn't exist" } }`))) createReqCalled = false - resp, err = FetcherHelper(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { + resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { createReqCalled = true return NewRequest(ctx, http.MethodGet, srv.URL()) })