diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 7012e03e46cd..6cbd90e0cfa1 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -4,6 +4,8 @@ ### Features Added +* Added function `FetcherForNextLink` to the `runtime` package to centralize creation of `Pager[T].Fetcher` from a next link URL. + ### Breaking Changes ### Bugs Fixed diff --git a/sdk/azcore/runtime/pager.go b/sdk/azcore/runtime/pager.go index b7e59527a3d8..f1daac50d32a 100644 --- a/sdk/azcore/runtime/pager.go +++ b/sdk/azcore/runtime/pager.go @@ -11,8 +11,10 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "reflect" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing" ) @@ -88,3 +90,25 @@ func (p *Pager[T]) NextPage(ctx context.Context) (T, error) { func (p *Pager[T]) UnmarshalJSON(data []byte) error { return json.Unmarshal(data, &p.current) } + +// FetcherForNextLink is a helper containing boilerplate code to simplify creating a PagingHandler[T].Fetcher from a next link URL. +func FetcherForNextLink(ctx context.Context, pl Pipeline, nextLink string, createReq func(context.Context) (*policy.Request, error)) (*http.Response, error) { + var req *policy.Request + var err error + if nextLink == "" { + req, err = createReq(ctx) + } else if nextLink, err = EncodeQueryParams(nextLink); err == nil { + req, err = NewRequest(ctx, http.MethodGet, nextLink) + } + if err != nil { + return nil, err + } + resp, err := pl.Do(req) + if err != nil { + return nil, err + } + if !HasStatusCode(resp, http.StatusOK) { + return nil, NewResponseError(resp) + } + return resp, nil +} diff --git a/sdk/azcore/runtime/pager_test.go b/sdk/azcore/runtime/pager_test.go index 3d33330caa6e..9d205b297d60 100644 --- a/sdk/azcore/runtime/pager_test.go +++ b/sdk/azcore/runtime/pager_test.go @@ -14,6 +14,7 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/internal/mock" "github.com/stretchr/testify/require" ) @@ -256,3 +257,59 @@ func TestPagerResponderError(t *testing.T) { require.Error(t, err) require.Empty(t, page) } + +func TestFetcherForNextLink(t *testing.T) { + srv, close := mock.NewServer() + defer close() + pl := exported.NewPipeline(srv) + + srv.AppendResponse() + createReqCalled := false + resp, err := FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + createReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }) + require.NoError(t, err) + require.True(t, createReqCalled) + require.NotNil(t, resp) + require.EqualValues(t, http.StatusOK, resp.StatusCode) + + srv.AppendResponse() + createReqCalled = false + resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { + createReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }) + require.NoError(t, err) + require.False(t, createReqCalled) + require.NotNil(t, resp) + require.EqualValues(t, http.StatusOK, resp.StatusCode) + + resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + return nil, errors.New("failed") + }) + require.Error(t, err) + require.Nil(t, resp) + + srv.AppendError(errors.New("failed")) + resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) { + createReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }) + require.Error(t, err) + require.True(t, createReqCalled) + require.Nil(t, resp) + + srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest), mock.WithBody([]byte(`{ "error": { "code": "InvalidResource", "message": "doesn't exist" } }`))) + createReqCalled = false + resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) { + createReqCalled = true + return NewRequest(ctx, http.MethodGet, srv.URL()) + }) + require.Error(t, err) + var respErr *exported.ResponseError + require.ErrorAs(t, err, &respErr) + require.EqualValues(t, "InvalidResource", respErr.ErrorCode) + require.False(t, createReqCalled) + require.Nil(t, resp) +} diff --git a/sdk/azcore/runtime/request.go b/sdk/azcore/runtime/request.go index 7938b5fb7d26..6f3b8966998b 100644 --- a/sdk/azcore/runtime/request.go +++ b/sdk/azcore/runtime/request.go @@ -14,6 +14,7 @@ import ( "fmt" "io" "mime/multipart" + "net/url" "os" "path" "reflect" @@ -43,6 +44,19 @@ func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*polic return exported.NewRequest(ctx, httpMethod, endpoint) } +// EncodeQueryParams will parse and encode any query parameters in the specified URL. +func EncodeQueryParams(u string) (string, error) { + before, after, found := strings.Cut(u, "?") + if !found { + return u, nil + } + qp, err := url.ParseQuery(after) + if err != nil { + return "", err + } + return before + "?" + qp.Encode(), nil +} + // JoinPaths concatenates multiple URL path segments into one path, // inserting path separation characters as required. JoinPaths will preserve // query parameters in the root path diff --git a/sdk/azcore/runtime/request_test.go b/sdk/azcore/runtime/request_test.go index aa7c1b3ecb16..0a8da39cf021 100644 --- a/sdk/azcore/runtime/request_test.go +++ b/sdk/azcore/runtime/request_test.go @@ -740,3 +740,22 @@ func TestSetMultipartFormData(t *testing.T) { require.Equal(t, "second part", string(second)) require.Equal(t, "third part", string(third)) } + +func TestEncodeQueryParams(t *testing.T) { + const testURL = "https://contoso.com/" + nextLink, err := EncodeQueryParams(testURL + "query?$skip=5&$filter='foo eq bar'") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink) + nextLink, err = EncodeQueryParams(testURL + "query?%24filter=%27foo+eq+bar%27&%24skip=5") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink) + nextLink, err = EncodeQueryParams(testURL + "query?foo=bar&one=two") + require.NoError(t, err) + require.EqualValues(t, testURL+"query?foo=bar&one=two", nextLink) + nextLink, err = EncodeQueryParams(testURL) + require.NoError(t, err) + require.EqualValues(t, testURL, nextLink) + nextLink, err = EncodeQueryParams(testURL + "query?invalid=;semicolon") + require.Error(t, err) + require.Empty(t, nextLink) +}