Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

### Features Added

* Added function `FetcherForNextLink` to the `runtime` package to centralize creation of `Pager[T].Fetcher` from a next link URL.

### Breaking Changes

### Bugs Fixed
Expand Down
24 changes: 24 additions & 0 deletions sdk/azcore/runtime/pager.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import (
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/tracing"
)

Expand Down Expand Up @@ -88,3 +90,25 @@ func (p *Pager[T]) NextPage(ctx context.Context) (T, error) {
func (p *Pager[T]) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, &p.current)
}

// FetcherForNextLink is a helper containing boilerplate code to simplify creating a PagingHandler[T].Fetcher from a next link URL.
func FetcherForNextLink(ctx context.Context, pl Pipeline, nextLink string, createReq func(context.Context) (*policy.Request, error)) (*http.Response, error) {
var req *policy.Request
var err error
if nextLink == "" {
req, err = createReq(ctx)
} else if nextLink, err = EncodeQueryParams(nextLink); err == nil {
Comment thread
jhendrixMSFT marked this conversation as resolved.
req, err = NewRequest(ctx, http.MethodGet, nextLink)
}
if err != nil {
return nil, err
}
resp, err := pl.Do(req)
if err != nil {
return nil, err
}
if !HasStatusCode(resp, http.StatusOK) {
return nil, NewResponseError(resp)
}
return resp, nil
}
57 changes: 57 additions & 0 deletions sdk/azcore/runtime/pager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/internal/exported"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/internal/mock"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -256,3 +257,59 @@ func TestPagerResponderError(t *testing.T) {
require.Error(t, err)
require.Empty(t, page)
}

func TestFetcherForNextLink(t *testing.T) {
srv, close := mock.NewServer()
defer close()
pl := exported.NewPipeline(srv)

srv.AppendResponse()
createReqCalled := false
resp, err := FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) {
createReqCalled = true
return NewRequest(ctx, http.MethodGet, srv.URL())
})
require.NoError(t, err)
require.True(t, createReqCalled)
require.NotNil(t, resp)
require.EqualValues(t, http.StatusOK, resp.StatusCode)

srv.AppendResponse()
createReqCalled = false
resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) {
createReqCalled = true
return NewRequest(ctx, http.MethodGet, srv.URL())
})
require.NoError(t, err)
require.False(t, createReqCalled)
require.NotNil(t, resp)
require.EqualValues(t, http.StatusOK, resp.StatusCode)

resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) {
return nil, errors.New("failed")
})
require.Error(t, err)
require.Nil(t, resp)

srv.AppendError(errors.New("failed"))
resp, err = FetcherForNextLink(context.Background(), pl, "", func(ctx context.Context) (*policy.Request, error) {
createReqCalled = true
return NewRequest(ctx, http.MethodGet, srv.URL())
})
require.Error(t, err)
require.True(t, createReqCalled)
require.Nil(t, resp)

srv.AppendResponse(mock.WithStatusCode(http.StatusBadRequest), mock.WithBody([]byte(`{ "error": { "code": "InvalidResource", "message": "doesn't exist" } }`)))
createReqCalled = false
resp, err = FetcherForNextLink(context.Background(), pl, srv.URL(), func(ctx context.Context) (*policy.Request, error) {
createReqCalled = true
return NewRequest(ctx, http.MethodGet, srv.URL())
})
require.Error(t, err)
var respErr *exported.ResponseError
require.ErrorAs(t, err, &respErr)
require.EqualValues(t, "InvalidResource", respErr.ErrorCode)
require.False(t, createReqCalled)
require.Nil(t, resp)
}
14 changes: 14 additions & 0 deletions sdk/azcore/runtime/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"fmt"
"io"
"mime/multipart"
"net/url"
"os"
"path"
"reflect"
Expand Down Expand Up @@ -43,6 +44,19 @@ func NewRequest(ctx context.Context, httpMethod string, endpoint string) (*polic
return exported.NewRequest(ctx, httpMethod, endpoint)
}

// EncodeQueryParams will parse and encode any query parameters in the specified URL.
func EncodeQueryParams(u string) (string, error) {
before, after, found := strings.Cut(u, "?")
if !found {
return u, nil
}
qp, err := url.ParseQuery(after)
if err != nil {
return "", err
}
return before + "?" + qp.Encode(), nil
}

// JoinPaths concatenates multiple URL path segments into one path,
// inserting path separation characters as required. JoinPaths will preserve
// query parameters in the root path
Expand Down
19 changes: 19 additions & 0 deletions sdk/azcore/runtime/request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -740,3 +740,22 @@ func TestSetMultipartFormData(t *testing.T) {
require.Equal(t, "second part", string(second))
require.Equal(t, "third part", string(third))
}

func TestEncodeQueryParams(t *testing.T) {
const testURL = "https://contoso.com/"
nextLink, err := EncodeQueryParams(testURL + "query?$skip=5&$filter='foo eq bar'")
require.NoError(t, err)
require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink)
nextLink, err = EncodeQueryParams(testURL + "query?%24filter=%27foo+eq+bar%27&%24skip=5")
require.NoError(t, err)
require.EqualValues(t, testURL+"query?%24filter=%27foo+eq+bar%27&%24skip=5", nextLink)
nextLink, err = EncodeQueryParams(testURL + "query?foo=bar&one=two")
require.NoError(t, err)
require.EqualValues(t, testURL+"query?foo=bar&one=two", nextLink)
nextLink, err = EncodeQueryParams(testURL)
require.NoError(t, err)
require.EqualValues(t, testURL, nextLink)
nextLink, err = EncodeQueryParams(testURL + "query?invalid=;semicolon")
require.Error(t, err)
require.Empty(t, nextLink)
}