Skip to content

Commit

Permalink
Fix ListSnapshots paging
Browse files Browse the repository at this point in the history
This changes provides the following fixes and improvements to
ListSnapshots:

- Use paging to collect snapshots beyond the first page. Previously, we
  would only return snapshots from the first page.
- Handle StartingToken and MaxEntries such that we use paging
  efficiently and skip initial, unneeded snapshots.
- Extend fake snapshot driver to support paging.
- Add tests.

Note that Kubernetes / the csi-snapshotter sidecar currently do not
invoke ListSnapshots without the snapshot ID parameter, which means that
the fixed code is not executed in production. However, it is used by
csi-test / the sanity package, and other COs (Container Orchestrators)
may potentially use it as well as Kubernetes going forward.
  • Loading branch information
Timo Reimann committed Apr 21, 2020
1 parent 0d8f2b5 commit e2bd611
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 46 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## unreleased

* Fix ListSnapshots paging
[[GH-300]](https://github.com/digitalocean/csi-digitalocean/pull/300)
* Support filtering snapshots by ID
[[GH-299]](https://github.com/digitalocean/csi-digitalocean/pull/299)
* Return minimum disk size field from snapshot response
Expand Down
129 changes: 92 additions & 37 deletions driver/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -776,12 +776,14 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques
log := d.log.WithFields(logrus.Fields{
"snapshot_id": req.SnapshotId,
"source_volume_id": req.SourceVolumeId,
"max_entries": req.MaxEntries,
"req_starting_token": req.StartingToken,
"method": "list_snapshots",
})
log.Info("list snapshots is called")

if req.SnapshotId != "" {
// Fetch snapshot directly by ID.
snapshot, resp, err := d.snapshots.Get(ctx, req.SnapshotId)
if err != nil {
if resp == nil || resp.StatusCode != http.StatusNotFound {
Expand All @@ -802,44 +804,97 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques
}
}
} else {
// Pagination in the CSI world works different than at DO. CSI sends the
// `req.MaxEntries` to indicate how much snapshots it wants. The
// req.StartingToken is returned by us, if we somehow need to indicate that
// we couldn't fetch and need to fetch again. But it's NOT the page number.
// I.e: suppose CSI wants us to fetch 50 entries, we only fetch 30, we need to
// return NextToken as 31 (so req.StartingToken will be set to 31 when CSI
// calls us again), to indicate that we want to continue returning from the
// index 31 up to 50.

var nextToken int
var err error
// Paginate through snapshots and return results.

// Pagination is controlled by two request parameters:
// MaxEntries indicates how many entries should be returned at most. If
// more results are available, we must return a NextToken value
// indicating the index for the next snapshot to request.
// StartingToken defines the index of the first snapshot to return.
// The CSI request parameters are defined in terms of number of
// snapshots, not pages. It is up to the driver to translate the
// parameters into paged requests accordingly.

var (
startingToken int32
originalStartingToken int32
)
if req.StartingToken != "" {
nextToken, err = strconv.Atoi(req.StartingToken)
parsedToken, err := strconv.ParseInt(req.StartingToken, 10, 32)
if err != nil {
return nil, status.Errorf(codes.Aborted, "ListSnapshots starting token %s is not valid : %s",
req.StartingToken, err.Error())
return nil, status.Errorf(codes.Aborted, "ListSnapshots starting token %q is not valid: %s", req.StartingToken, err)
}
startingToken = int32(parsedToken)
originalStartingToken = startingToken
}

if nextToken != 0 && req.MaxEntries != 0 {
return nil, status.Errorf(codes.Aborted,
"ListSnapshots invalid arguments starting token: %d and max entries: %d can't be non null at the same time", nextToken, req.MaxEntries)
}

// fetch all entries
// Fetch snapshots until we have either collected req.MaxEntries (if
// positive) or all available ones, whichever comes first.
listOpts := &godo.ListOptions{
Page: 1,
PerPage: int(req.MaxEntries),
}
var snapshots []godo.Snapshot
if req.MaxEntries > 0 {
// MaxEntries also defines the page size so that we can skip over
// snapshots before the StartingToken and minimize the number of
// paged requests we need.
listOpts.Page = int(startingToken/req.MaxEntries) + 1
// Offset StartingToken to skip snapshots we do not want. This is
// needed when MaxEntries does not divide StartingToken without
// remainder.
startingToken = startingToken % req.MaxEntries
}

log = log.WithFields(logrus.Fields{
"page": listOpts.Page,
"computed_starting_token": startingToken,
})

var (
// remainingEntries keeps track of how much room is left to return
// as many as MaxEntries snapshots.
remainingEntries int = int(req.MaxEntries)
// hasMore indicates if NextToken must be set.
hasMore bool
snapshots []godo.Snapshot
)
for {
hasMore = false
snaps, resp, err := d.snapshots.ListVolume(ctx, listOpts)
if err != nil {
return nil, status.Errorf(codes.Aborted, "ListSnapshots listing volume snapshots has failed: %s", err.Error())
return nil, status.Errorf(codes.Internal, "ListSnapshots listing volume snapshots has failed: %s", err)
}

// Skip pre-StartingToken snapshots. This is required on the first
// page at most.
if startingToken > 0 {
if startingToken > int32(len(snaps)) {
startingToken = int32(len(snaps))
} else {
startingToken--
}
snaps = snaps[startingToken:]
}
startingToken = 0

// Do not return more than MaxEntries across pages.
if req.MaxEntries > 0 && len(snaps) > remainingEntries {
snaps = snaps[:remainingEntries]
hasMore = true
}

snapshots = append(snapshots, snaps...)
remainingEntries -= len(snaps)

isLastPage := resp.Links == nil || resp.Links.IsLastPage()
hasMore = hasMore || !isLastPage

// Stop paging if we have used up all of MaxEntries.
if req.MaxEntries > 0 && remainingEntries == 0 {
break
}

if resp.Links == nil || resp.Links.IsLastPage() {
if isLastPage {
break
}

Expand All @@ -849,20 +904,18 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques
}

listOpts.Page = page + 1
listOpts.PerPage = len(snaps)
}

if nextToken > len(snapshots) {
return nil, status.Error(codes.Aborted, "ListSnapshots starting token is greater than total number of snapshots")
}

if nextToken != 0 {
snapshots = snapshots[nextToken:]
}

if req.MaxEntries != 0 {
nextToken = len(snapshots) - int(req.MaxEntries) - 1
snapshots = snapshots[:req.MaxEntries]
var nextToken int32
if hasMore {
// Compute NextToken, which is at least StartingToken plus
// MaxEntries. If StartingToken was zero, we need to add one because
// StartingToken defines the n-th snapshot we want but is not
// zero-based.
nextToken = originalStartingToken + req.MaxEntries
if originalStartingToken == 0 {
nextToken++
}
}

entries := make([]*csi.ListSnapshotsResponse_Entry, 0, len(snapshots))
Expand All @@ -878,8 +931,10 @@ func (d *Driver) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsReques
})
}
listResp = &csi.ListSnapshotsResponse{
Entries: entries,
NextToken: strconv.Itoa(nextToken),
Entries: entries,
}
if nextToken > 0 {
listResp.NextToken = strconv.FormatInt(int64(nextToken), 10)
}
}

Expand Down
145 changes: 145 additions & 0 deletions driver/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package driver
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
Expand Down Expand Up @@ -476,3 +477,147 @@ func TestWaitAction(t *testing.T) {
})
}
}

func TestListSnapshot(t *testing.T) {
createID := func(id int) string {
return fmt.Sprintf("%03d", id)
}

tests := []struct {
name string
inNumSnapshots int
maxEntries int32
startingToken int
wantNumSnapshots int
wantNextToken int
}{
{
name: "no constraints",
inNumSnapshots: 10,
wantNumSnapshots: 10,
},
{
name: "max entries set",
inNumSnapshots: 10,
maxEntries: 5,
wantNumSnapshots: 5,
wantNextToken: 6,
},
{
name: "starting token lower than number of snapshots",
inNumSnapshots: 10,
startingToken: 8,
wantNumSnapshots: 3,
},
{
name: "starting token larger than number of snapshots",
inNumSnapshots: 10,
startingToken: 50,
wantNumSnapshots: 0,
},
{
name: "starting token and max entries set with extra snapshots available",
inNumSnapshots: 10,
maxEntries: 5,
startingToken: 4,
wantNumSnapshots: 5,
wantNextToken: 9,
},
{
name: "starting token and max entries set with no extra snapshots available",
inNumSnapshots: 10,
maxEntries: 15,
startingToken: 8,
wantNumSnapshots: 3,
},
{
name: "single paging with extra snapshots available",
inNumSnapshots: 50,
maxEntries: 12,
startingToken: 30,
wantNumSnapshots: 12,
wantNextToken: 42,
},
{
name: "single paging with no extra snapshots available",
inNumSnapshots: 32,
maxEntries: 12,
startingToken: 30,
wantNumSnapshots: 3,
},
{
name: "multi-paging with extra snapshots available",
inNumSnapshots: 50,
maxEntries: 30,
startingToken: 12,
wantNumSnapshots: 30,
wantNextToken: 42,
},
{
name: "multi-paging with exact fit",
inNumSnapshots: 42,
maxEntries: 30,
startingToken: 13,
wantNumSnapshots: 30,
},
{
name: "maxEntries exceeding maximum page size limit",
inNumSnapshots: 300,
maxEntries: 250,
wantNumSnapshots: 250,
wantNextToken: 251,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
snapshots := map[string]*godo.Snapshot{}
for i := 1; i <= test.inNumSnapshots; i++ {
id := createID(i)
snap := createGodoSnapshot(id, fmt.Sprintf("snapshot-%d", i), "")
snapshots[id] = snap
}

d := Driver{
snapshots: &fakeSnapshotsDriver{
snapshots: snapshots,
},
log: logrus.New().WithField("test_enabed", true),
}

resp, err := d.ListSnapshots(context.Background(), &csi.ListSnapshotsRequest{
MaxEntries: test.maxEntries,
StartingToken: strconv.Itoa(test.startingToken),
})
if err != nil {
t.Fatalf("got error: %s", err)
}

if len(resp.Entries) != test.wantNumSnapshots {
t.Errorf("got %d snapshot(s), want %d", len(resp.Entries), test.wantNumSnapshots)
} else {
runningID := test.startingToken
if runningID == 0 {
runningID = 1
}
for i, entry := range resp.Entries {
wantID := createID(runningID)
gotID := entry.Snapshot.GetSnapshotId()
if gotID != wantID {
t.Errorf("got snapshot ID %q at position %d, want %q", gotID, i, wantID)
}
runningID++
}
}

if test.wantNextToken > 0 {
wantNextTokenStr := strconv.Itoa(test.wantNextToken)
if resp.NextToken != wantNextTokenStr {
t.Errorf("got next token %q, want %q", resp.NextToken, wantNextTokenStr)
}
} else if resp.NextToken != "" {
t.Errorf("got non-empty next token %q", resp.NextToken)
}
})
}
}
Loading

0 comments on commit e2bd611

Please sign in to comment.