From 383299e626d461959336db03541289fb4fd3e942 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Tue, 8 Jun 2021 17:43:29 -0700 Subject: [PATCH] update inFlight cache to avoid race condition on volume operation --- pkg/driver/controller.go | 180 ++++++++++++++++++++++---------- pkg/driver/controller_test.go | 111 +++++++++++++++++--- pkg/driver/internal/inflight.go | 4 + 3 files changed, 227 insertions(+), 68 deletions(-) diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 03a550deff..21f37eb97f 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -99,27 +99,21 @@ func newControllerService(driverOptions *DriverOptions) controllerService { func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVolumeRequest) (*csi.CreateVolumeResponse, error) { klog.V(4).Infof("CreateVolume: called with args %+v", *req) - volName := req.GetName() - if len(volName) == 0 { - return nil, status.Error(codes.InvalidArgument, "Volume name not provided") + if err := validateCreateVolumeRequest(req); err != nil { + return nil, err } - volSizeBytes, err := getVolSizeBytes(req) if err != nil { return nil, err } + volName := req.GetName() - volCaps := req.GetVolumeCapabilities() - if len(volCaps) == 0 { - return nil, status.Error(codes.InvalidArgument, "Volume capabilities not provided") - } - - if !isValidVolumeCapabilities(volCaps) { - modes := util.GetAccessModes(volCaps) - stringModes := strings.Join(*modes, ", ") - errString := "Volume capabilities " + stringModes + " not supported. Only AccessModes[ReadWriteOnce] supported." - return nil, status.Error(codes.InvalidArgument, errString) + // check if a request is already in-flight + if ok := d.inFlight.Insert(volName); !ok { + msg := fmt.Sprintf("Create volume request for %s is already in progress", volName) + return nil, status.Error(codes.Aborted, msg) } + defer d.inFlight.Delete(volName) disk, err := d.cloud.GetDiskByName(ctx, volName, volSizeBytes) if err != nil { @@ -217,13 +211,6 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol return newCreateVolumeResponse(disk), nil } - // check if a request is already in-flight because the CreateVolume API is not idempotent - if ok := d.inFlight.Insert(req.String()); !ok { - msg := fmt.Sprintf("Create volume request for %s is already in progress", volName) - return nil, status.Error(codes.Aborted, msg) - } - defer d.inFlight.Delete(req.String()) - // create a new volume zone := pickAvailabilityZone(req.GetAccessibilityRequirements()) outpostArn := getOutpostArn(req.GetAccessibilityRequirements()) @@ -264,12 +251,40 @@ func (d *controllerService) CreateVolume(ctx context.Context, req *csi.CreateVol return newCreateVolumeResponse(disk), nil } +func validateCreateVolumeRequest(req *csi.CreateVolumeRequest) error { + volName := req.GetName() + if len(volName) == 0 { + return status.Error(codes.InvalidArgument, "Volume name not provided") + } + + volCaps := req.GetVolumeCapabilities() + if len(volCaps) == 0 { + return status.Error(codes.InvalidArgument, "Volume capabilities not provided") + } + + if !isValidVolumeCapabilities(volCaps) { + modes := util.GetAccessModes(volCaps) + stringModes := strings.Join(*modes, ", ") + errString := "Volume capabilities " + stringModes + " not supported. Only AccessModes[ReadWriteOnce] supported." + return status.Error(codes.InvalidArgument, errString) + } + return nil +} + func (d *controllerService) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { klog.V(4).Infof("DeleteVolume: called with args: %+v", *req) + if err := validateDeleteVolumeRequest(req); err != nil { + return nil, err + } + volumeID := req.GetVolumeId() - if len(volumeID) == 0 { - return nil, status.Error(codes.InvalidArgument, "Volume ID not provided") + + // check if a request is already in-flight + if ok := d.inFlight.Insert(volumeID); !ok { + msg := fmt.Sprintf(internal.VolumeOperationAlreadyExistsErrorMsg, volumeID) + return nil, status.Error(codes.Aborted, msg) } + defer d.inFlight.Delete(volumeID) if _, err := d.cloud.DeleteDisk(ctx, volumeID); err != nil { if err == cloud.ErrNotFound { @@ -282,30 +297,21 @@ func (d *controllerService) DeleteVolume(ctx context.Context, req *csi.DeleteVol return &csi.DeleteVolumeResponse{}, nil } +func validateDeleteVolumeRequest(req *csi.DeleteVolumeRequest) error { + if len(req.GetVolumeId()) == 0 { + return status.Error(codes.InvalidArgument, "Volume ID not provided") + } + return nil +} + func (d *controllerService) ControllerPublishVolume(ctx context.Context, req *csi.ControllerPublishVolumeRequest) (*csi.ControllerPublishVolumeResponse, error) { klog.V(4).Infof("ControllerPublishVolume: called with args %+v", *req) - volumeID := req.GetVolumeId() - if len(volumeID) == 0 { - return nil, status.Error(codes.InvalidArgument, "Volume ID not provided") + if err := validateControllerPublishVolumeRequest(req); err != nil { + return nil, err } + volumeID := req.GetVolumeId() nodeID := req.GetNodeId() - if len(nodeID) == 0 { - return nil, status.Error(codes.InvalidArgument, "Node ID not provided") - } - - volCap := req.GetVolumeCapability() - if volCap == nil { - return nil, status.Error(codes.InvalidArgument, "Volume capability not provided") - } - - caps := []*csi.VolumeCapability{volCap} - if !isValidVolumeCapabilities(caps) { - modes := util.GetAccessModes(caps) - stringModes := strings.Join(*modes, ", ") - errString := "Volume capabilities " + stringModes + " not supported. Only AccessModes[ReadWriteOnce] supported." - return nil, status.Error(codes.InvalidArgument, errString) - } if !d.cloud.IsExistInstance(ctx, nodeID) { return nil, status.Errorf(codes.NotFound, "Instance %q not found", nodeID) @@ -333,17 +339,38 @@ func (d *controllerService) ControllerPublishVolume(ctx context.Context, req *cs return &csi.ControllerPublishVolumeResponse{PublishContext: pvInfo}, nil } +func validateControllerPublishVolumeRequest(req *csi.ControllerPublishVolumeRequest) error { + if len(req.GetVolumeId()) == 0 { + return status.Error(codes.InvalidArgument, "Volume ID not provided") + } + + if len(req.GetNodeId()) == 0 { + return status.Error(codes.InvalidArgument, "Node ID not provided") + } + + volCap := req.GetVolumeCapability() + if volCap == nil { + return status.Error(codes.InvalidArgument, "Volume capability not provided") + } + + caps := []*csi.VolumeCapability{volCap} + if !isValidVolumeCapabilities(caps) { + modes := util.GetAccessModes(caps) + stringModes := strings.Join(*modes, ", ") + errString := "Volume capabilities " + stringModes + " not supported. Only AccessModes[ReadWriteOnce] supported." + return status.Error(codes.InvalidArgument, errString) + } + return nil +} + func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req *csi.ControllerUnpublishVolumeRequest) (*csi.ControllerUnpublishVolumeResponse, error) { klog.V(4).Infof("ControllerUnpublishVolume: called with args %+v", *req) - volumeID := req.GetVolumeId() - if len(volumeID) == 0 { - return nil, status.Error(codes.InvalidArgument, "Volume ID not provided") + if err := validateControllerUnpublishVolumeRequest(req); err != nil { + return nil, err } + volumeID := req.GetVolumeId() nodeID := req.GetNodeId() - if len(nodeID) == 0 { - return nil, status.Error(codes.InvalidArgument, "Node ID not provided") - } if err := d.cloud.DetachDisk(ctx, volumeID, nodeID); err != nil { if err == cloud.ErrNotFound { @@ -356,6 +383,18 @@ func (d *controllerService) ControllerUnpublishVolume(ctx context.Context, req * return &csi.ControllerUnpublishVolumeResponse{}, nil } +func validateControllerUnpublishVolumeRequest(req *csi.ControllerUnpublishVolumeRequest) error { + if len(req.GetVolumeId()) == 0 { + return status.Error(codes.InvalidArgument, "Volume ID not provided") + } + + if len(req.GetNodeId()) == 0 { + return status.Error(codes.InvalidArgument, "Node ID not provided") + } + + return nil +} + func (d *controllerService) ControllerGetCapabilities(ctx context.Context, req *csi.ControllerGetCapabilitiesRequest) (*csi.ControllerGetCapabilitiesResponse, error) { klog.V(4).Infof("ControllerGetCapabilities: called with args %+v", *req) var caps []*csi.ControllerServiceCapability @@ -489,15 +528,20 @@ func isValidVolumeContext(volContext map[string]string) bool { func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateSnapshotRequest) (*csi.CreateSnapshotResponse, error) { klog.V(4).Infof("CreateSnapshot: called with args %+v", req) - snapshotName := req.GetName() - if len(snapshotName) == 0 { - return nil, status.Error(codes.InvalidArgument, "Snapshot name not provided") + if err := validateCreateSnapshotRequest(req); err != nil { + return nil, err } + snapshotName := req.GetName() volumeID := req.GetSourceVolumeId() - if len(volumeID) == 0 { - return nil, status.Error(codes.InvalidArgument, "Snapshot volume source ID not provided") + + // check if a request is already in-flight + if ok := d.inFlight.Insert(snapshotName); !ok { + msg := fmt.Sprintf(internal.VolumeOperationAlreadyExistsErrorMsg, snapshotName) + return nil, status.Error(codes.Aborted, msg) } + defer d.inFlight.Delete(snapshotName) + snapshot, err := d.cloud.GetSnapshotByName(ctx, snapshotName) if err != nil && err != cloud.ErrNotFound { klog.Errorf("Error looking for the snapshot %s: %v", snapshotName, err) @@ -535,12 +579,31 @@ func (d *controllerService) CreateSnapshot(ctx context.Context, req *csi.CreateS return newCreateSnapshotResponse(snapshot) } +func validateCreateSnapshotRequest(req *csi.CreateSnapshotRequest) error { + if len(req.GetName()) == 0 { + return status.Error(codes.InvalidArgument, "Snapshot name not provided") + } + + if len(req.GetSourceVolumeId()) == 0 { + return status.Error(codes.InvalidArgument, "Snapshot volume source ID not provided") + } + return nil +} + func (d *controllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteSnapshotRequest) (*csi.DeleteSnapshotResponse, error) { klog.V(4).Infof("DeleteSnapshot: called with args %+v", req) + if err := validateDeleteSnapshotRequest(req); err != nil { + return nil, err + } + snapshotID := req.GetSnapshotId() - if len(snapshotID) == 0 { - return nil, status.Error(codes.InvalidArgument, "Snapshot ID not provided") + + // check if a request is already in-flight + if ok := d.inFlight.Insert(snapshotID); !ok { + msg := fmt.Sprintf("DeleteSnapshot for Snapshot %s is already in progress", snapshotID) + return nil, status.Error(codes.Aborted, msg) } + defer d.inFlight.Delete(snapshotID) if _, err := d.cloud.DeleteSnapshot(ctx, snapshotID); err != nil { if err == cloud.ErrNotFound { @@ -553,6 +616,13 @@ func (d *controllerService) DeleteSnapshot(ctx context.Context, req *csi.DeleteS return &csi.DeleteSnapshotResponse{}, nil } +func validateDeleteSnapshotRequest(req *csi.DeleteSnapshotRequest) error { + if len(req.GetSnapshotId()) == 0 { + return status.Error(codes.InvalidArgument, "Snapshot ID not provided") + } + return nil +} + func (d *controllerService) ListSnapshots(ctx context.Context, req *csi.ListSnapshotsRequest) (*csi.ListSnapshotsResponse, error) { klog.V(4).Infof("ListSnapshots: called with args %+v", req) var snapshots []*cloud.Snapshot diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index 84d78db1b0..ca2f96115a 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -1536,11 +1536,10 @@ func TestCreateVolume(t *testing.T) { defer mockCtl.Finish() mockCloud := mocks.NewMockCloud(mockCtl) - mockCloud.EXPECT().GetDiskByName(gomock.Eq(ctx), gomock.Eq(req.Name), gomock.Eq(stdVolSize)).Return(nil, cloud.ErrNotFound) inFlight := internal.NewInFlight() - inFlight.Insert(req.String()) - defer inFlight.Delete(req.String()) + inFlight.Insert(req.GetName()) + defer inFlight.Delete(req.GetName()) awsDriver := controllerService{ cloud: mockCloud, @@ -1549,17 +1548,8 @@ func TestCreateVolume(t *testing.T) { } _, err := awsDriver.CreateVolume(ctx, req) - if err == nil { - t.Fatalf("Expected CreateVolume to fail but got no error") - } - srvErr, ok := status.FromError(err) - if !ok { - t.Fatalf("Could not get error status code from error: %v", srvErr) - } - if srvErr.Code() != codes.Aborted { - t.Fatalf("Expected Aborted but got: %s", srvErr.Code()) - } + checkAbortedErrorCode(t, err) }, }, { @@ -1714,6 +1704,31 @@ func TestDeleteVolume(t *testing.T) { } }, }, + { + name: "fail another request already in-flight", + testFunc: func(t *testing.T) { + req := &csi.DeleteVolumeRequest{ + VolumeId: "vol-test", + } + + ctx := context.Background() + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + inFlight := internal.NewInFlight() + inFlight.Insert(req.GetVolumeId()) + defer inFlight.Delete(req.GetVolumeId()) + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: inFlight, + driverOptions: &DriverOptions{}, + } + _, err := awsDriver.DeleteVolume(ctx, req) + + checkAbortedErrorCode(t, err) + }, + }, } for _, tc := range testCases { @@ -2259,6 +2274,34 @@ func TestCreateSnapshot(t *testing.T) { } }, }, + { + name: "fail with another request in-flight", + testFunc: func(t *testing.T) { + req := &csi.CreateSnapshotRequest{ + Name: "test-snapshot", + Parameters: nil, + SourceVolumeId: "vol-test", + } + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + inFlight := internal.NewInFlight() + inFlight.Insert(req.GetName()) + defer inFlight.Delete(req.GetName()) + + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: inFlight, + driverOptions: &DriverOptions{}, + } + _, err := awsDriver.CreateSnapshot(context.Background(), req) + + checkAbortedErrorCode(t, err) + }, + }, } for _, tc := range testCases { @@ -2321,6 +2364,34 @@ func TestDeleteSnapshot(t *testing.T) { } }, }, + { + name: "fail with another request in-flight", + testFunc: func(t *testing.T) { + ctx := context.Background() + + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + + mockCloud := mocks.NewMockCloud(mockCtl) + + req := &csi.DeleteSnapshotRequest{ + SnapshotId: "test-snapshotID", + } + inFlight := internal.NewInFlight() + inFlight.Insert(req.GetSnapshotId()) + defer inFlight.Delete(req.GetSnapshotId()) + + awsDriver := controllerService{ + cloud: mockCloud, + inFlight: inFlight, + driverOptions: &DriverOptions{}, + } + + _, err := awsDriver.DeleteSnapshot(ctx, req) + + checkAbortedErrorCode(t, err) + }, + }, } for _, tc := range testCases { @@ -3082,3 +3153,17 @@ func TestControllerExpandVolume(t *testing.T) { }) } } + +func checkAbortedErrorCode(t *testing.T, err error) { + if err == nil { + t.Fatalf("Expected operation to fail but got no error") + } + + srvErr, ok := status.FromError(err) + if !ok { + t.Fatalf("Could not get error status code from error: %v", srvErr) + } + if srvErr.Code() != codes.Aborted { + t.Fatalf("Expected Aborted but got: %s", srvErr.Code()) + } +} diff --git a/pkg/driver/internal/inflight.go b/pkg/driver/internal/inflight.go index 5f0d2a9ad7..9b45680fbc 100644 --- a/pkg/driver/internal/inflight.go +++ b/pkg/driver/internal/inflight.go @@ -30,6 +30,10 @@ type Idempotent interface { String() string } +const ( + VolumeOperationAlreadyExistsErrorMsg = "An operation with the given Volume %s already exists" +) + // InFlight is a struct used to manage in flight requests per volumeId. type InFlight struct { mux *sync.Mutex