diff --git a/lib/events/gcssessions/gcshandler_test.go b/lib/events/gcssessions/gcshandler_test.go index f732ec77c48b5..cb6402fb6cb56 100644 --- a/lib/events/gcssessions/gcshandler_test.go +++ b/lib/events/gcssessions/gcshandler_test.go @@ -19,16 +19,25 @@ package gcssessions import ( + "bytes" "context" "fmt" + "net/http" + "net/http/httptest" "os" "testing" + "time" + "cloud.google.com/go/storage" "github.com/fsouza/fake-gcs-server/fakestorage" "github.com/google/uuid" + "github.com/googleapis/gax-go/v2" "github.com/stretchr/testify/require" + "google.golang.org/api/option" + "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/test" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" ) @@ -60,3 +69,40 @@ func TestFakeStreams(t *testing.T) { test.DownloadNotFound(t, handler) }) } + +func TestRetryOnRateLimit(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + t.Cleanup(cancel) + attempts := 0 + rateLimitedServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + attempts++ + if attempts == 5 { + cancel() + } + })) + + client, err := storage.NewClient( + ctx, + option.WithoutAuthentication(), + option.WithEndpoint(rateLimitedServer.URL), + ) + require.NoError(t, err) + // Shorten backoff to shorten the test. + client.SetRetry(storage.WithBackoff(gax.Backoff{Initial: time.Millisecond})) + handler, err := NewHandler(ctx, cancel, Config{ + Endpoint: rateLimitedServer.URL, + Bucket: fmt.Sprintf("teleport-test-%v", uuid.New().String()), + }, client) + require.NoError(t, err) + defer handler.Close() + + // Send a request that can trigger rate limiting. The client should retry + // automatically until the context is canceled. + _, err = handler.UploadPart(ctx, events.StreamUpload{ + ID: uuid.NewString(), + SessionID: session.ID(uuid.NewString()), + }, 0, bytes.NewReader([]byte("foo"))) + require.ErrorIs(t, err, context.Canceled) + require.Equal(t, 5, attempts) +} diff --git a/lib/events/gcssessions/gcsstream.go b/lib/events/gcssessions/gcsstream.go index 8e1d47bbbeb5f..1301ce0232270 100644 --- a/lib/events/gcssessions/gcsstream.go +++ b/lib/events/gcssessions/gcsstream.go @@ -55,7 +55,8 @@ func (h *Handler) CreateUpload(ctx context.Context, sessionID session.ID) (*even h.logger.DebugContext(ctx, "Creating upload", "path", uploadPath) // Make sure we don't overwrite an existing upload - _, err := h.gcsClient.Bucket(h.Config.Bucket).Object(uploadPath).Attrs(ctx) + uploadObject := h.gcsClient.Bucket(h.Config.Bucket).Object(uploadPath) + _, err := uploadObject.Attrs(ctx) if !errors.Is(err, storage.ErrObjectNotExist) { if err != nil { return nil, convertGCSError(err) @@ -63,7 +64,9 @@ func (h *Handler) CreateUpload(ctx context.Context, sessionID session.ID) (*even return nil, trace.AlreadyExists("upload %v for session %q already exists in GCS", upload.ID, sessionID) } - writer := h.gcsClient.Bucket(h.Config.Bucket).Object(uploadPath).NewWriter(ctx) + // Perform a conditional write in order to make the request idempotent. + // Idempotent requests will be automatically retried when. rate-limited. + writer := uploadObject.If(storage.Conditions{DoesNotExist: true}).NewWriter(ctx) start := time.Now() _, err = io.Copy(writer, strings.NewReader(string(sessionID))) // Always close the writer, even if upload failed. @@ -86,7 +89,7 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa } partPath := h.partPath(upload, partNumber) - writer := h.gcsClient.Bucket(h.Config.Bucket).Object(partPath).NewWriter(ctx) + writer := h.gcsClient.Bucket(h.Config.Bucket).Object(partPath).If(storage.Conditions{DoesNotExist: true}).NewWriter(ctx) start := time.Now() _, err := io.Copy(writer, partBody) // Always close the writer, even if upload failed. @@ -109,8 +112,9 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload } // If the session has been already created, move to cleanup - sessionPath := h.path(upload.SessionID) - _, err := h.gcsClient.Bucket(h.Config.Bucket).Object(sessionPath).Attrs(ctx) + bucket := h.gcsClient.Bucket(h.Config.Bucket) + sessionObject := bucket.Object(h.path(upload.SessionID)) + _, err := sessionObject.Attrs(ctx) if !errors.Is(err, storage.ErrObjectNotExist) { if err != nil { return convertGCSError(err) @@ -120,9 +124,7 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload // Makes sure that upload has been properly initiated, // checks the .upload file - uploadPath := h.uploadPath(upload) - bucket := h.gcsClient.Bucket(h.Config.Bucket) - _, err = bucket.Object(uploadPath).Attrs(ctx) + _, err = bucket.Object(h.uploadPath(upload)).Attrs(ctx) if err != nil { return convertGCSError(err) } @@ -139,14 +141,14 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload mergeID := hashOfNames(objectsToMerge) mergePath := h.mergePath(upload, mergeID) mergeObject := bucket.Object(mergePath) - composer := mergeObject.ComposerFrom(objectsToMerge...) + composer := mergeObject.If(storage.Conditions{DoesNotExist: true}).ComposerFrom(objectsToMerge...) _, err = h.OnComposerRun(ctx, composer) if err != nil { return convertGCSError(err) } objects = append([]*storage.ObjectHandle{mergeObject}, objects[maxParts:]...) } - composer := bucket.Object(sessionPath).ComposerFrom(objects...) + composer := sessionObject.If(storage.Conditions{DoesNotExist: true}).ComposerFrom(objects...) _, err = h.OnComposerRun(ctx, composer) if err != nil { return convertGCSError(err) @@ -176,7 +178,7 @@ func (h *Handler) cleanupUpload(ctx context.Context, upload events.StreamUpload) if err != nil { return convertGCSError(err) } - objects = append(objects, bucket.Object(attrs.Name)) + objects = append(objects, bucket.Object(attrs.Name).Generation(0)) } }