diff --git a/api/utils/clientutils/resources.go b/api/utils/clientutils/resources.go index 5700f3d15d576..ee5d2cdff9de6 100644 --- a/api/utils/clientutils/resources.go +++ b/api/utils/clientutils/resources.go @@ -77,7 +77,7 @@ func rangeInternal[T any](ctx context.Context, params rangeParams[T]) iter.Seq2[ isLookingForEnd := params.end != "" && params.keyFunc != nil for { - page, nextToken, err := Page(ctx, pageSize, pageToken, params.pageFunc) + page, nextToken, lastPageSize, err := Page(ctx, pageSize, pageToken, params.pageFunc) if err != nil { yield(*new(T), trace.Wrap(err)) return @@ -97,6 +97,11 @@ func rangeInternal[T any](ctx context.Context, params rangeParams[T]) iter.Seq2[ if nextToken == "" { return } + + // Note that the server may return a smaller page at its own discretion, + // we use the last successful requested page size here to allow the server + // to temporarily lower the size if needed. + pageSize = lastPageSize } } @@ -170,8 +175,17 @@ func CollectWithFallback[T any](ctx context.Context, } // Page is a client side utility which implements auto page size adjustment. -func Page[T any](ctx context.Context, pageSize int, pageToken string, - pageFunc func(context.Context, int, string) ([]T, string, error)) ([]T, string, error) { +func Page[T any]( + ctx context.Context, + pageSize int, + pageToken string, + pageFunc func(context.Context, int, string) ([]T, string, error), +) ( + _ []T, + nextPageToken string, + lastPageSize int, + _ error, +) { for { page, nextToken, err := pageFunc(ctx, pageSize, pageToken) if err != nil { @@ -180,15 +194,15 @@ func Page[T any](ctx context.Context, pageSize int, pageToken string, pageSize /= 2 // This is an extremely unlikely scenario, but better to cover it anyways. if pageSize == 0 { - return nil, "", trace.Wrap(err, "resource is too large to retrieve, token: %q", pageToken) + return nil, "", 0, trace.Wrap(err, "resource is too large to retrieve, token: %q", pageToken) } continue } - return nil, "", trace.Wrap(err) + return nil, "", pageSize, trace.Wrap(err) } - return page, nextToken, nil + return page, nextToken, pageSize, nil } } diff --git a/api/utils/clientutils/resources_test.go b/api/utils/clientutils/resources_test.go index e3359c4a916d2..9398cdd8743e4 100644 --- a/api/utils/clientutils/resources_test.go +++ b/api/utils/clientutils/resources_test.go @@ -34,9 +34,9 @@ import ( const totalItems = defaults.DefaultChunkSize*2 + 5 type mockPaginator struct { - accessDenied bool - limitExceeded bool - pageCalls int + accessDenied bool + maxSupportedPageSize int + pageCalls int } func generatePage(start, count int) []int { @@ -78,7 +78,7 @@ func (m *mockPaginator) List(_ context.Context, pageSize int, token string) ([]i return nil, "", trace.AccessDenied("access denied") } - if m.limitExceeded { + if pageSize > m.maxSupportedPageSize { return nil, "", trace.LimitExceeded("page size %d exceeded the limit", pageSize) } @@ -95,8 +95,8 @@ func (m *mockPaginator) List(_ context.Context, pageSize int, token string) ([]i func TestIterateResources(t *testing.T) { t.Run("success", func(t *testing.T) { var count int - paginator := mockPaginator{} - err := IterateResources(context.Background(), paginator.List, func(int) error { + paginator := mockPaginator{maxSupportedPageSize: defaults.DefaultChunkSize} + err := IterateResources(t.Context(), paginator.List, func(int) error { count++ return nil }) @@ -104,15 +104,15 @@ func TestIterateResources(t *testing.T) { assert.Equal(t, totalItems, count) }) t.Run("paginator error", func(t *testing.T) { - paginator := mockPaginator{accessDenied: true} - err := IterateResources(context.Background(), paginator.List, func(int) error { + paginator := mockPaginator{accessDenied: true, maxSupportedPageSize: defaults.DefaultChunkSize} + err := IterateResources(t.Context(), paginator.List, func(int) error { return nil }) assert.Error(t, err) }) t.Run("callback error", func(t *testing.T) { - paginator := mockPaginator{} - err := IterateResources(context.Background(), paginator.List, func(int) error { + paginator := mockPaginator{maxSupportedPageSize: defaults.DefaultChunkSize} + err := IterateResources(t.Context(), paginator.List, func(int) error { return trace.BadParameter("error") }) assert.Error(t, err) @@ -121,9 +121,9 @@ func TestIterateResources(t *testing.T) { func TestResources(t *testing.T) { t.Run("success", func(t *testing.T) { - paginator := mockPaginator{} + paginator := mockPaginator{maxSupportedPageSize: defaults.DefaultChunkSize} var count int - for _, err := range Resources(context.Background(), paginator.List) { + for _, err := range Resources(t.Context(), paginator.List) { count++ require.NoError(t, err) } @@ -132,9 +132,9 @@ func TestResources(t *testing.T) { assert.Equal(t, 3, paginator.pageCalls) }) t.Run("paginator error", func(t *testing.T) { - paginator := mockPaginator{accessDenied: true} + paginator := mockPaginator{accessDenied: true, maxSupportedPageSize: defaults.DefaultChunkSize} var count int - for _, err := range Resources(context.Background(), paginator.List) { + for _, err := range Resources(t.Context(), paginator.List) { count++ require.Error(t, err) } @@ -143,9 +143,9 @@ func TestResources(t *testing.T) { }) t.Run("limit exceeded", func(t *testing.T) { - paginator := mockPaginator{limitExceeded: true} + paginator := mockPaginator{maxSupportedPageSize: 0} var count int - for _, err := range Resources(context.Background(), paginator.List) { + for _, err := range Resources(t.Context(), paginator.List) { count++ require.Error(t, err) } @@ -155,9 +155,9 @@ func TestResources(t *testing.T) { } func TestResourcesWithCustomPageSize(t *testing.T) { - paginator := mockPaginator{} + paginator := mockPaginator{maxSupportedPageSize: defaults.DefaultChunkSize} var count int - for _, err := range ResourcesWithPageSize(context.Background(), paginator.List, 10) { + for _, err := range ResourcesWithPageSize(t.Context(), paginator.List, 10) { count++ require.NoError(t, err) } @@ -172,70 +172,87 @@ func TestRangeResources(t *testing.T) { } tests := []struct { - name string - start string - end string - expectedItemCount int - expectedListCalls int - accessDenied bool - limitExceeded bool + name string + start string + end string + expectedItemCount int + expectedListCalls int + accessDenied bool + maxSupportedPageSize int + errFn func(require.TestingT, error, ...any) }{ { - name: "RangeAllItems", - expectedItemCount: totalItems, - expectedListCalls: 3, + name: "RangeAllItems", + expectedItemCount: totalItems, + expectedListCalls: 3, + maxSupportedPageSize: defaults.DefaultChunkSize, + errFn: require.NoError, }, { - name: "RangeAccessDenied", - expectedItemCount: 0, - expectedListCalls: 1, - accessDenied: true, + name: "RangeAccessDenied", + expectedItemCount: 0, + expectedListCalls: 1, + accessDenied: true, + maxSupportedPageSize: defaults.DefaultChunkSize, + errFn: require.Error, }, { - name: "RangeWithEnd", - expectedItemCount: 20, - expectedListCalls: 1, - end: keyFunc(20), + name: "RangeWithEnd", + expectedItemCount: 20, + expectedListCalls: 1, + end: keyFunc(20), + maxSupportedPageSize: defaults.DefaultChunkSize, + errFn: require.NoError, }, { - name: "RangeWithStart", - expectedItemCount: totalItems - 1337, - expectedListCalls: 1, - start: keyFunc(1337), + name: "RangeWithStart", + expectedItemCount: totalItems - 1337, + expectedListCalls: 1, + start: keyFunc(1337), + maxSupportedPageSize: defaults.DefaultChunkSize, + errFn: require.NoError, }, { name: "RangeSpan", expectedItemCount: 1500 - 500, // The end marker is not inclusive and the number of items falls on the pagesize, in this case 2 calls will be made. - expectedListCalls: 2, - start: keyFunc(500), - end: keyFunc(1500), + expectedListCalls: 2, + start: keyFunc(500), + end: keyFunc(1500), + maxSupportedPageSize: defaults.DefaultChunkSize, + errFn: require.NoError, }, { - name: "RangeLimitExceeded", - expectedItemCount: 0, - expectedListCalls: 10, - start: keyFunc(500), - end: keyFunc(1500), - limitExceeded: true, + name: "RangeLimitExceeded", + expectedItemCount: 0, + expectedListCalls: 10, + start: keyFunc(500), + end: keyFunc(1500), + maxSupportedPageSize: -1, + errFn: require.Error, + }, + { + name: "RangeLimitExceededWithRecovery", + expectedItemCount: 1000, + expectedListCalls: 4, + start: keyFunc(500), + end: keyFunc(1500), + maxSupportedPageSize: defaults.DefaultChunkSize / 2, + errFn: require.NoError, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - paginator := mockPaginator{accessDenied: tc.accessDenied, limitExceeded: tc.limitExceeded} + paginator := mockPaginator{accessDenied: tc.accessDenied, maxSupportedPageSize: tc.maxSupportedPageSize} var count int - for _, err := range RangeResources(context.Background(), tc.start, tc.end, paginator.List, keyFunc) { + for _, err := range RangeResources(t.Context(), tc.start, tc.end, paginator.List, keyFunc) { if err == nil { count++ } - if tc.accessDenied || tc.limitExceeded { - require.Error(t, err) - } else { - require.NoError(t, err) - } + tc.errFn(t, err) } assert.Equal(t, tc.expectedItemCount, count) diff --git a/lib/web/users.go b/lib/web/users.go index 6d800b34586cd..ee93cf31b9d80 100644 --- a/lib/web/users.go +++ b/lib/web/users.go @@ -80,7 +80,7 @@ func (h *Handler) listUsersHandle(w http.ResponseWriter, r *http.Request, params return nil, trace.Wrap(err) } - users, nextToken, err := clientutils.Page( + users, nextToken, _, err := clientutils.Page( r.Context(), int(limit), values.Get("startKey"),