diff --git a/sdk/storage/azfile/assets.json b/sdk/storage/azfile/assets.json new file mode 100644 index 000000000000..6561d915b461 --- /dev/null +++ b/sdk/storage/azfile/assets.json @@ -0,0 +1,6 @@ +{ + "AssetsRepo": "Azure/azure-sdk-assets", + "AssetsRepoPrefixPath": "go", + "TagPrefix": "go/storage/azfile", + "Tag": "go/storage/azfile_33b8efd383" +} diff --git a/sdk/storage/azfile/directory/client.go b/sdk/storage/azfile/directory/client.go index 579d917991af..dc7716ad1337 100644 --- a/sdk/storage/azfile/directory/client.go +++ b/sdk/storage/azfile/directory/client.go @@ -23,14 +23,6 @@ type ClientOptions struct { // Client represents a URL to the Azure Storage directory allowing you to manipulate its directories and files. type Client base.Client[generated.DirectoryClient] -// NewClient creates an instance of Client with the specified values. -// - directoryURL - the URL of the directory e.g. https://.file.core.windows.net/share/directory -// - cred - an Azure AD credential, typically obtained via the azidentity module -// - options - client options; pass nil to accept the default values -func NewClient(directoryURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { - return nil, nil -} - // NewClientWithNoCredential creates an instance of Client with the specified values. // This is used to anonymously access a directory or with a shared access signature (SAS) token. // - directoryURL - the URL of the directory e.g. https://.file.core.windows.net/share/directory? @@ -66,7 +58,7 @@ func (d *Client) sharedKey() *SharedKeyCredential { // URL returns the URL endpoint used by the Client object. func (d *Client) URL() string { - return "s.generated().Endpoint()" + return d.generated().Endpoint() } // NewSubdirectoryClient creates a new Client object by concatenating subDirectoryName to the end of this Client's URL. diff --git a/sdk/storage/azfile/file/client.go b/sdk/storage/azfile/file/client.go index 9c259804f4e8..dcc22ae55ab6 100644 --- a/sdk/storage/azfile/file/client.go +++ b/sdk/storage/azfile/file/client.go @@ -23,14 +23,6 @@ type ClientOptions struct { // Client represents a URL to the Azure Storage file. type Client base.Client[generated.FileClient] -// NewClient creates an instance of Client with the specified values. -// - fileURL - the URL of the file e.g. https://.file.core.windows.net/share/directoryPath/file -// - cred - an Azure AD credential, typically obtained via the azidentity module -// - options - client options; pass nil to accept the default values -func NewClient(fileURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { - return nil, nil -} - // NewClientWithNoCredential creates an instance of Client with the specified values. // This is used to anonymously access a file or with a shared access signature (SAS) token. // - fileURL - the URL of the file e.g. https://.file.core.windows.net/share/directoryPath/file? @@ -67,7 +59,7 @@ func (f *Client) sharedKey() *SharedKeyCredential { // URL returns the URL endpoint used by the Client object. func (f *Client) URL() string { - return "s.generated().Endpoint()" + return f.generated().Endpoint() } // Create operation creates a new file or replaces a file. Note it only initializes the file with no content. @@ -124,12 +116,12 @@ func (f *Client) DownloadStream(ctx context.Context, options *DownloadStreamOpti return DownloadStreamResponse{}, nil } -// DownloadBuffer downloads an Azure blob to a buffer with parallel. +// DownloadBuffer downloads an Azure file to a buffer with parallel. func (f *Client) DownloadBuffer(ctx context.Context, buffer []byte, o *DownloadBufferOptions) (int64, error) { return 0, nil } -// DownloadFile downloads an Azure blob to a local file. +// DownloadFile downloads an Azure file to a local file. // The file would be truncated if the size doesn't match. func (f *Client) DownloadFile(ctx context.Context, file *os.File, o *DownloadFileOptions) (int64, error) { return 0, nil diff --git a/sdk/storage/azfile/fileerror/error_codes.go b/sdk/storage/azfile/fileerror/error_codes.go new file mode 100644 index 000000000000..3f91c984c711 --- /dev/null +++ b/sdk/storage/azfile/fileerror/error_codes.go @@ -0,0 +1,102 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package fileerror + +import ( + "errors" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" +) + +// HasCode returns true if the provided error is an *azcore.ResponseError +// with its ErrorCode field equal to one of the specified Codes. +func HasCode(err error, codes ...Code) bool { + var respErr *azcore.ResponseError + if !errors.As(err, &respErr) { + return false + } + + for _, code := range codes { + if respErr.ErrorCode == string(code) { + return true + } + } + + return false +} + +// Code - Error codes returned by the service +type Code = generated.StorageErrorCode + +const ( + AccountAlreadyExists Code = "AccountAlreadyExists" + AccountBeingCreated Code = "AccountBeingCreated" + AccountIsDisabled Code = "AccountIsDisabled" + AuthenticationFailed Code = "AuthenticationFailed" + AuthorizationFailure Code = "AuthorizationFailure" + AuthorizationPermissionMismatch Code = "AuthorizationPermissionMismatch" + AuthorizationProtocolMismatch Code = "AuthorizationProtocolMismatch" + AuthorizationResourceTypeMismatch Code = "AuthorizationResourceTypeMismatch" + AuthorizationServiceMismatch Code = "AuthorizationServiceMismatch" + AuthorizationSourceIPMismatch Code = "AuthorizationSourceIPMismatch" + CannotDeleteFileOrDirectory Code = "CannotDeleteFileOrDirectory" + ClientCacheFlushDelay Code = "ClientCacheFlushDelay" + ConditionHeadersNotSupported Code = "ConditionHeadersNotSupported" + ConditionNotMet Code = "ConditionNotMet" + DeletePending Code = "DeletePending" + DirectoryNotEmpty Code = "DirectoryNotEmpty" + EmptyMetadataKey Code = "EmptyMetadataKey" + FeatureVersionMismatch Code = "FeatureVersionMismatch" + FileLockConflict Code = "FileLockConflict" + InsufficientAccountPermissions Code = "InsufficientAccountPermissions" + InternalError Code = "InternalError" + InvalidAuthenticationInfo Code = "InvalidAuthenticationInfo" + InvalidFileOrDirectoryPathName Code = "InvalidFileOrDirectoryPathName" + InvalidHTTPVerb Code = "InvalidHttpVerb" + InvalidHeaderValue Code = "InvalidHeaderValue" + InvalidInput Code = "InvalidInput" + InvalidMD5 Code = "InvalidMd5" + InvalidMetadata Code = "InvalidMetadata" + InvalidQueryParameterValue Code = "InvalidQueryParameterValue" + InvalidRange Code = "InvalidRange" + InvalidResourceName Code = "InvalidResourceName" + InvalidURI Code = "InvalidUri" + InvalidXMLDocument Code = "InvalidXmlDocument" + InvalidXMLNodeValue Code = "InvalidXmlNodeValue" + MD5Mismatch Code = "Md5Mismatch" + MetadataTooLarge Code = "MetadataTooLarge" + MissingContentLengthHeader Code = "MissingContentLengthHeader" + MissingRequiredHeader Code = "MissingRequiredHeader" + MissingRequiredQueryParameter Code = "MissingRequiredQueryParameter" + MissingRequiredXMLNode Code = "MissingRequiredXmlNode" + MultipleConditionHeadersNotSupported Code = "MultipleConditionHeadersNotSupported" + OperationTimedOut Code = "OperationTimedOut" + OutOfRangeInput Code = "OutOfRangeInput" + OutOfRangeQueryParameterValue Code = "OutOfRangeQueryParameterValue" + ParentNotFound Code = "ParentNotFound" + ReadOnlyAttribute Code = "ReadOnlyAttribute" + RequestBodyTooLarge Code = "RequestBodyTooLarge" + RequestURLFailedToParse Code = "RequestUrlFailedToParse" + ResourceAlreadyExists Code = "ResourceAlreadyExists" + ResourceNotFound Code = "ResourceNotFound" + ResourceTypeMismatch Code = "ResourceTypeMismatch" + ServerBusy Code = "ServerBusy" + ShareAlreadyExists Code = "ShareAlreadyExists" + ShareBeingDeleted Code = "ShareBeingDeleted" + ShareDisabled Code = "ShareDisabled" + ShareHasSnapshots Code = "ShareHasSnapshots" + ShareNotFound Code = "ShareNotFound" + ShareSnapshotCountExceeded Code = "ShareSnapshotCountExceeded" + ShareSnapshotInProgress Code = "ShareSnapshotInProgress" + ShareSnapshotOperationNotSupported Code = "ShareSnapshotOperationNotSupported" + SharingViolation Code = "SharingViolation" + UnsupportedHTTPVerb Code = "UnsupportedHttpVerb" + UnsupportedHeader Code = "UnsupportedHeader" + UnsupportedQueryParameter Code = "UnsupportedQueryParameter" + UnsupportedXMLNode Code = "UnsupportedXmlNode" +) diff --git a/sdk/storage/azfile/go.mod b/sdk/storage/azfile/go.mod index 1f6c597d5272..fe84eb01a79c 100644 --- a/sdk/storage/azfile/go.mod +++ b/sdk/storage/azfile/go.mod @@ -2,10 +2,18 @@ module github.com/Azure/azure-sdk-for-go/sdk/storage/azfile go 1.18 -require github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.0 +require ( + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.0 + github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.1 + github.com/stretchr/testify v1.7.0 +) require ( - github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.1 // indirect + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/dnaeon/go-vcr v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 // indirect golang.org/x/text v0.3.7 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) diff --git a/sdk/storage/azfile/go.sum b/sdk/storage/azfile/go.sum index c6cbff29d2f1..67ee617668d0 100644 --- a/sdk/storage/azfile/go.sum +++ b/sdk/storage/azfile/go.sum @@ -3,10 +3,23 @@ github.com/Azure/azure-sdk-for-go/sdk/azcore v1.3.0/go.mod h1:tZoQYdDZNOiIjdSn0d github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.1 h1:Oj853U9kG+RLTCQXpjvOnrv0WaZHxgmZz1TlLywgOPY= github.com/Azure/azure-sdk-for-go/sdk/internal v1.1.1/go.mod h1:eWRD7oawr1Mu1sLCawqVc0CUiF43ia3qQMxLscsKQ9w= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dnaeon/go-vcr v1.1.0 h1:ReYa/UBrRyQdant9B4fNHGoCNKw6qh6P0fsdGmZpR7c= +github.com/dnaeon/go-vcr v1.1.0/go.mod h1:M7tiix8f0r6mKKJ3Yq/kqU1OYf3MnfmBWVbPx/yU9ko= +github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4 h1:HVyaeDAYux4pnY+D/SiwmLOR36ewZ4iGQIIrtnuCjFA= golang.org/x/net v0.0.0-20220425223048-2871e0cb64e4/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/sdk/storage/azfile/internal/base/clients.go b/sdk/storage/azfile/internal/base/clients.go index 5661feff2f01..c89ab5aa49f5 100644 --- a/sdk/storage/azfile/internal/base/clients.go +++ b/sdk/storage/azfile/internal/base/clients.go @@ -7,7 +7,9 @@ package base import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/exported" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" ) type Client[T any] struct { @@ -22,3 +24,31 @@ func InnerClient[T any](client *Client[T]) *T { func SharedKey[T any](client *Client[T]) *exported.SharedKeyCredential { return client.sharedKey } + +func NewServiceClient(serviceURL string, pipeline runtime.Pipeline, sharedKey *exported.SharedKeyCredential) *Client[generated.ServiceClient] { + return &Client[generated.ServiceClient]{ + inner: generated.NewServiceClient(serviceURL, pipeline), + sharedKey: sharedKey, + } +} + +func NewShareClient(shareURL string, pipeline runtime.Pipeline, sharedKey *exported.SharedKeyCredential) *Client[generated.ShareClient] { + return &Client[generated.ShareClient]{ + inner: generated.NewShareClient(shareURL, pipeline), + sharedKey: sharedKey, + } +} + +func NewDirectoryClient(directoryURL string, pipeline runtime.Pipeline, sharedKey *exported.SharedKeyCredential) *Client[generated.DirectoryClient] { + return &Client[generated.DirectoryClient]{ + inner: generated.NewDirectoryClient(directoryURL, pipeline), + sharedKey: sharedKey, + } +} + +func NewFileClient(fileURL string, pipeline runtime.Pipeline, sharedKey *exported.SharedKeyCredential) *Client[generated.FileClient] { + return &Client[generated.FileClient]{ + inner: generated.NewFileClient(fileURL, pipeline), + sharedKey: sharedKey, + } +} diff --git a/sdk/storage/azfile/internal/exported/shared_key_credential.go b/sdk/storage/azfile/internal/exported/shared_key_credential.go index 44f2bc7b652b..439617d07ba1 100644 --- a/sdk/storage/azfile/internal/exported/shared_key_credential.go +++ b/sdk/storage/azfile/internal/exported/shared_key_credential.go @@ -7,9 +7,22 @@ package exported import ( + "bytes" + "crypto/hmac" + "crypto/sha256" "encoding/base64" "fmt" + "net/http" + "net/url" + "sort" + "strings" "sync/atomic" + "time" + + azlog "github.com/Azure/azure-sdk-for-go/sdk/azcore/log" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/internal/log" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/shared" ) // SharedKeyCredential contains an account's name and its primary or secondary key. @@ -43,3 +56,163 @@ func (c *SharedKeyCredential) SetAccountKey(accountKey string) error { c.accountKey.Store(_bytes) return nil } + +// ComputeHMACSHA256 generates a hash signature for an HTTP request or for a SAS. +func (c *SharedKeyCredential) computeHMACSHA256(message string) (string, error) { + h := hmac.New(sha256.New, c.accountKey.Load().([]byte)) + _, err := h.Write([]byte(message)) + return base64.StdEncoding.EncodeToString(h.Sum(nil)), err +} + +func (c *SharedKeyCredential) buildStringToSign(req *http.Request) (string, error) { + // https://docs.microsoft.com/en-us/rest/api/storageservices/authentication-for-the-azure-storage-services + headers := req.Header + contentLength := getHeader(shared.HeaderContentLength, headers) + if contentLength == "0" { + contentLength = "" + } + + canonicalizedResource, err := c.buildCanonicalizedResource(req.URL) + if err != nil { + return "", err + } + + stringToSign := strings.Join([]string{ + req.Method, + getHeader(shared.HeaderContentEncoding, headers), + getHeader(shared.HeaderContentLanguage, headers), + contentLength, + getHeader(shared.HeaderContentMD5, headers), + getHeader(shared.HeaderContentType, headers), + "", // Empty date because x-ms-date is expected (as per web page above) + getHeader(shared.HeaderIfModifiedSince, headers), + getHeader(shared.HeaderIfMatch, headers), + getHeader(shared.HeaderIfNoneMatch, headers), + getHeader(shared.HeaderIfUnmodifiedSince, headers), + getHeader(shared.HeaderRange, headers), + c.buildCanonicalizedHeader(headers), + canonicalizedResource, + }, "\n") + return stringToSign, nil +} + +func getHeader(key string, headers map[string][]string) string { + if headers == nil { + return "" + } + if v, ok := headers[key]; ok { + if len(v) > 0 { + return v[0] + } + } + + return "" +} + +func (c *SharedKeyCredential) buildCanonicalizedHeader(headers http.Header) string { + cm := map[string][]string{} + for k, v := range headers { + headerName := strings.TrimSpace(strings.ToLower(k)) + if strings.HasPrefix(headerName, "x-ms-") { + cm[headerName] = v // NOTE: the value must not have any whitespace around it. + } + } + if len(cm) == 0 { + return "" + } + + keys := make([]string, 0, len(cm)) + for key := range cm { + keys = append(keys, key) + } + sort.Strings(keys) + ch := bytes.NewBufferString("") + for i, key := range keys { + if i > 0 { + ch.WriteRune('\n') + } + ch.WriteString(key) + ch.WriteRune(':') + ch.WriteString(strings.Join(cm[key], ",")) + } + return ch.String() +} + +func (c *SharedKeyCredential) buildCanonicalizedResource(u *url.URL) (string, error) { + // https://docs.microsoft.com/en-us/rest/api/storageservices/authentication-for-the-azure-storage-services + cr := bytes.NewBufferString("/") + cr.WriteString(c.accountName) + + if len(u.Path) > 0 { + // Any portion of the CanonicalizedResource string that is derived from + // the resource's URI should be encoded exactly as it is in the URI. + // -- https://msdn.microsoft.com/en-gb/library/azure/dd179428.aspx + cr.WriteString(u.EscapedPath()) + } else { + // a slash is required to indicate the root path + cr.WriteString("/") + } + + // params is a map[string][]string; param name is key; params values is []string + params, err := url.ParseQuery(u.RawQuery) // Returns URL decoded values + if err != nil { + return "", fmt.Errorf("failed to parse query params: %w", err) + } + + if len(params) > 0 { // There is at least 1 query parameter + var paramNames []string // We use this to sort the parameter key names + for paramName := range params { + paramNames = append(paramNames, paramName) // paramNames must be lowercase + } + sort.Strings(paramNames) + + for _, paramName := range paramNames { + paramValues := params[paramName] + sort.Strings(paramValues) + + // Join the sorted key values separated by ',' + // Then prepend "keyName:"; then add this string to the buffer + cr.WriteString("\n" + paramName + ":" + strings.Join(paramValues, ",")) + } + } + return cr.String(), nil +} + +// ComputeHMACSHA256 is a helper for computing the signed string outside of this package. +func ComputeHMACSHA256(cred *SharedKeyCredential, message string) (string, error) { + return cred.computeHMACSHA256(message) +} + +// the following content isn't actually exported but must live +// next to SharedKeyCredential as it uses its unexported methods + +type SharedKeyCredPolicy struct { + cred *SharedKeyCredential +} + +func NewSharedKeyCredPolicy(cred *SharedKeyCredential) *SharedKeyCredPolicy { + return &SharedKeyCredPolicy{cred: cred} +} + +func (s *SharedKeyCredPolicy) Do(req *policy.Request) (*http.Response, error) { + if d := getHeader(shared.HeaderXmsDate, req.Raw().Header); d == "" { + req.Raw().Header.Set(shared.HeaderXmsDate, time.Now().UTC().Format(http.TimeFormat)) + } + stringToSign, err := s.cred.buildStringToSign(req.Raw()) + if err != nil { + return nil, err + } + signature, err := s.cred.computeHMACSHA256(stringToSign) + if err != nil { + return nil, err + } + authHeader := strings.Join([]string{"SharedKey ", s.cred.AccountName(), ":", signature}, "") + req.Raw().Header.Set(shared.HeaderAuthorization, authHeader) + + response, err := req.Next() + if err != nil && response != nil && response.StatusCode == http.StatusForbidden { + // Service failed to authenticate request, log it + log.Write(azlog.EventResponse, "===== HTTP Forbidden status, String-to-Sign:\n"+stringToSign+"\n===============================\n") + } + return response, err +} diff --git a/sdk/storage/azfile/internal/exported/version.go b/sdk/storage/azfile/internal/exported/version.go new file mode 100644 index 000000000000..8e130784dbf2 --- /dev/null +++ b/sdk/storage/azfile/internal/exported/version.go @@ -0,0 +1,12 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package exported + +const ( + ModuleName = "azfile" + ModuleVersion = "v0.1.0" +) diff --git a/sdk/storage/azfile/internal/generated/autorest.md b/sdk/storage/azfile/internal/generated/autorest.md index 3ebe6634b6e8..89e32751eb7f 100644 --- a/sdk/storage/azfile/internal/generated/autorest.md +++ b/sdk/storage/azfile/internal/generated/autorest.md @@ -231,3 +231,16 @@ directive: return $. replace(/xml:"CORS>CORSRule"/g, "xml:\"Cors>CorsRule\""); ``` + +### Remove pager methods and export various generated methods in service client + +``` yaml +directive: + - from: zz_service_client.go + where: $ + transform: >- + return $. + replace(/func \(client \*ServiceClient\) NewListSharesSegmentPager\(.+\/\/ listSharesSegmentCreateRequest creates the ListSharesSegment request/s, `//\n// listSharesSegmentCreateRequest creates the ListSharesSegment request`). + replace(/\(client \*ServiceClient\) listSharesSegmentCreateRequest\(/, `(client *ServiceClient) ListSharesSegmentCreateRequest(`). + replace(/\(client \*ServiceClient\) listSharesSegmentHandleResponse\(/, `(client *ServiceClient) ListSharesSegmentHandleResponse(`); +``` diff --git a/sdk/storage/azfile/internal/generated/directory_client.go b/sdk/storage/azfile/internal/generated/directory_client.go new file mode 100644 index 000000000000..1f07400bee9c --- /dev/null +++ b/sdk/storage/azfile/internal/generated/directory_client.go @@ -0,0 +1,17 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package generated + +import "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + +func (client *DirectoryClient) Endpoint() string { + return client.endpoint +} + +func (client *DirectoryClient) Pipeline() runtime.Pipeline { + return client.pl +} diff --git a/sdk/storage/azfile/internal/generated/file_client.go b/sdk/storage/azfile/internal/generated/file_client.go new file mode 100644 index 000000000000..f4a01a783938 --- /dev/null +++ b/sdk/storage/azfile/internal/generated/file_client.go @@ -0,0 +1,17 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package generated + +import "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + +func (client *FileClient) Endpoint() string { + return client.endpoint +} + +func (client *FileClient) Pipeline() runtime.Pipeline { + return client.pl +} diff --git a/sdk/storage/azfile/internal/generated/service_client.go b/sdk/storage/azfile/internal/generated/service_client.go new file mode 100644 index 000000000000..1f449b955e82 --- /dev/null +++ b/sdk/storage/azfile/internal/generated/service_client.go @@ -0,0 +1,17 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package generated + +import "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + +func (client *ServiceClient) Endpoint() string { + return client.endpoint +} + +func (client *ServiceClient) Pipeline() runtime.Pipeline { + return client.pl +} diff --git a/sdk/storage/azfile/internal/generated/share_client.go b/sdk/storage/azfile/internal/generated/share_client.go new file mode 100644 index 000000000000..040785814606 --- /dev/null +++ b/sdk/storage/azfile/internal/generated/share_client.go @@ -0,0 +1,17 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package generated + +import "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + +func (client *ShareClient) Endpoint() string { + return client.endpoint +} + +func (client *ShareClient) Pipeline() runtime.Pipeline { + return client.pl +} diff --git a/sdk/storage/azfile/internal/generated/zz_service_client.go b/sdk/storage/azfile/internal/generated/zz_service_client.go index 8b388a79c3e3..efd5f4708912 100644 --- a/sdk/storage/azfile/internal/generated/zz_service_client.go +++ b/sdk/storage/azfile/internal/generated/zz_service_client.go @@ -97,36 +97,9 @@ func (client *ServiceClient) getPropertiesHandleResponse(resp *http.Response) (S // Generated from API version 2020-10-02 // - options - ServiceClientListSharesSegmentOptions contains the optional parameters for the ServiceClient.NewListSharesSegmentPager // method. -func (client *ServiceClient) NewListSharesSegmentPager(options *ServiceClientListSharesSegmentOptions) *runtime.Pager[ServiceClientListSharesSegmentResponse] { - return runtime.NewPager(runtime.PagingHandler[ServiceClientListSharesSegmentResponse]{ - More: func(page ServiceClientListSharesSegmentResponse) bool { - return page.NextMarker != nil && len(*page.NextMarker) > 0 - }, - Fetcher: func(ctx context.Context, page *ServiceClientListSharesSegmentResponse) (ServiceClientListSharesSegmentResponse, error) { - var req *policy.Request - var err error - if page == nil { - req, err = client.listSharesSegmentCreateRequest(ctx, options) - } else { - req, err = runtime.NewRequest(ctx, http.MethodGet, *page.NextMarker) - } - if err != nil { - return ServiceClientListSharesSegmentResponse{}, err - } - resp, err := client.pl.Do(req) - if err != nil { - return ServiceClientListSharesSegmentResponse{}, err - } - if !runtime.HasStatusCode(resp, http.StatusOK) { - return ServiceClientListSharesSegmentResponse{}, runtime.NewResponseError(resp) - } - return client.listSharesSegmentHandleResponse(resp) - }, - }) -} - +// // listSharesSegmentCreateRequest creates the ListSharesSegment request. -func (client *ServiceClient) listSharesSegmentCreateRequest(ctx context.Context, options *ServiceClientListSharesSegmentOptions) (*policy.Request, error) { +func (client *ServiceClient) ListSharesSegmentCreateRequest(ctx context.Context, options *ServiceClientListSharesSegmentOptions) (*policy.Request, error) { req, err := runtime.NewRequest(ctx, http.MethodGet, client.endpoint) if err != nil { return nil, err @@ -155,7 +128,7 @@ func (client *ServiceClient) listSharesSegmentCreateRequest(ctx context.Context, } // listSharesSegmentHandleResponse handles the ListSharesSegment response. -func (client *ServiceClient) listSharesSegmentHandleResponse(resp *http.Response) (ServiceClientListSharesSegmentResponse, error) { +func (client *ServiceClient) ListSharesSegmentHandleResponse(resp *http.Response) (ServiceClientListSharesSegmentResponse, error) { result := ServiceClientListSharesSegmentResponse{} if val := resp.Header.Get("x-ms-request-id"); val != "" { result.RequestID = &val diff --git a/sdk/storage/azfile/internal/shared/shared.go b/sdk/storage/azfile/internal/shared/shared.go new file mode 100644 index 000000000000..e201782fc0b2 --- /dev/null +++ b/sdk/storage/azfile/internal/shared/shared.go @@ -0,0 +1,110 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package shared + +import ( + "errors" + "fmt" + "strings" +) + +const ( + TokenScope = "https://storage.azure.com/.default" + StorageAnalyticsVersion = "1.0" + + HeaderAuthorization = "Authorization" + HeaderXmsDate = "x-ms-date" + HeaderContentLength = "Content-Length" + HeaderContentEncoding = "Content-Encoding" + HeaderContentLanguage = "Content-Language" + HeaderContentType = "Content-Type" + HeaderContentMD5 = "Content-MD5" + HeaderIfModifiedSince = "If-Modified-Since" + HeaderIfMatch = "If-Match" + HeaderIfNoneMatch = "If-None-Match" + HeaderIfUnmodifiedSince = "If-Unmodified-Since" + HeaderRange = "Range" +) + +func GetClientOptions[T any](o *T) *T { + if o == nil { + return new(T) + } + return o +} + +var errConnectionString = errors.New("connection string is either blank or malformed. The expected connection string " + + "should contain key value pairs separated by semicolons. For example 'DefaultEndpointsProtocol=https;AccountName=;" + + "AccountKey=;EndpointSuffix=core.windows.net'") + +type ParsedConnectionString struct { + ServiceURL string + AccountName string + AccountKey string +} + +func ParseConnectionString(connectionString string) (ParsedConnectionString, error) { + const ( + defaultScheme = "https" + defaultSuffix = "core.windows.net" + ) + + connStrMap := make(map[string]string) + connectionString = strings.TrimRight(connectionString, ";") + + splitString := strings.Split(connectionString, ";") + if len(splitString) == 0 { + return ParsedConnectionString{}, errConnectionString + } + for _, stringPart := range splitString { + parts := strings.SplitN(stringPart, "=", 2) + if len(parts) != 2 { + return ParsedConnectionString{}, errConnectionString + } + connStrMap[parts[0]] = parts[1] + } + + accountName, ok := connStrMap["AccountName"] + if !ok { + return ParsedConnectionString{}, errors.New("connection string missing AccountName") + } + + accountKey, ok := connStrMap["AccountKey"] + if !ok { + sharedAccessSignature, ok := connStrMap["SharedAccessSignature"] + if !ok { + return ParsedConnectionString{}, errors.New("connection string missing AccountKey and SharedAccessSignature") + } + return ParsedConnectionString{ + ServiceURL: fmt.Sprintf("%v://%v.file.%v/?%v", defaultScheme, accountName, defaultSuffix, sharedAccessSignature), + }, nil + } + + protocol, ok := connStrMap["DefaultEndpointsProtocol"] + if !ok { + protocol = defaultScheme + } + + suffix, ok := connStrMap["EndpointSuffix"] + if !ok { + suffix = defaultSuffix + } + + if fileEndpoint, ok := connStrMap["FileEndpoint"]; ok { + return ParsedConnectionString{ + ServiceURL: fileEndpoint, + AccountName: accountName, + AccountKey: accountKey, + }, nil + } + + return ParsedConnectionString{ + ServiceURL: fmt.Sprintf("%v://%v.file.%v", protocol, accountName, suffix), + AccountName: accountName, + AccountKey: accountKey, + }, nil +} diff --git a/sdk/storage/azfile/internal/testcommon/clients_auth.go b/sdk/storage/azfile/internal/testcommon/clients_auth.go new file mode 100644 index 000000000000..6142b23ddb7c --- /dev/null +++ b/sdk/storage/azfile/internal/testcommon/clients_auth.go @@ -0,0 +1,124 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +// Contains common helpers for TESTS ONLY +package testcommon + +import ( + "errors" + "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/service" + "github.com/stretchr/testify/require" + "testing" +) + +type TestAccountType string + +const ( + TestAccountDefault TestAccountType = "" + TestAccountSecondary TestAccountType = "SECONDARY_" + TestAccountPremium TestAccountType = "PREMIUM_" + TestAccountSoftDelete TestAccountType = "SOFT_DELETE_" +) + +const ( + DefaultEndpointSuffix = "core.windows.net/" + DefaultFileEndpointSuffix = "file.core.windows.net/" + AccountNameEnvVar = "AZURE_STORAGE_ACCOUNT_NAME" + AccountKeyEnvVar = "AZURE_STORAGE_ACCOUNT_KEY" + DefaultEndpointSuffixEnvVar = "AZURE_STORAGE_ENDPOINT_SUFFIX" +) + +const ( + FakeStorageAccount = "fakestorage" + FakeStorageURL = "https://fakestorage.file.core.windows.net" +) + +func SetClientOptions(t *testing.T, opts *azcore.ClientOptions) { + opts.Logging.AllowedHeaders = append(opts.Logging.AllowedHeaders, "X-Request-Mismatch", "X-Request-Mismatch-Error") + + transport, err := recording.NewRecordingHTTPClient(t, nil) + require.NoError(t, err) + opts.Transport = transport +} + +func GetServiceClient(t *testing.T, accountType TestAccountType, options *service.ClientOptions) (*service.Client, error) { + if options == nil { + options = &service.ClientOptions{} + } + + SetClientOptions(t, &options.ClientOptions) + + cred, err := GetGenericSharedKeyCredential(accountType) + if err != nil { + return nil, err + } + + serviceClient, err := service.NewClientWithSharedKeyCredential("https://"+cred.AccountName()+".file.core.windows.net/", cred, options) + + return serviceClient, err +} + +func GetServiceClientNoCredential(t *testing.T, sasUrl string, options *service.ClientOptions) (*service.Client, error) { + if options == nil { + options = &service.ClientOptions{} + } + + SetClientOptions(t, &options.ClientOptions) + + serviceClient, err := service.NewClientWithNoCredential(sasUrl, options) + + return serviceClient, err +} + +func GetGenericAccountInfo(accountType TestAccountType) (string, string) { + if recording.GetRecordMode() == recording.PlaybackMode { + return FakeStorageAccount, "ZmFrZQ==" + } + accountNameEnvVar := string(accountType) + AccountNameEnvVar + accountKeyEnvVar := string(accountType) + AccountKeyEnvVar + accountName, _ := GetRequiredEnv(accountNameEnvVar) + accountKey, _ := GetRequiredEnv(accountKeyEnvVar) + return accountName, accountKey +} + +func GetGenericSharedKeyCredential(accountType TestAccountType) (*service.SharedKeyCredential, error) { + accountName, accountKey := GetGenericAccountInfo(accountType) + if accountName == "" || accountKey == "" { + return nil, errors.New(string(accountType) + AccountNameEnvVar + " and/or " + string(accountType) + AccountKeyEnvVar + " environment variables not specified.") + } + return service.NewSharedKeyCredential(accountName, accountKey) +} + +func GetGenericConnectionString(accountType TestAccountType) (*string, error) { + accountName, accountKey := GetGenericAccountInfo(accountType) + if accountName == "" || accountKey == "" { + return nil, errors.New(string(accountType) + AccountNameEnvVar + " and/or " + string(accountType) + AccountKeyEnvVar + " environment variables not specified.") + } + connectionString := fmt.Sprintf("DefaultEndpointsProtocol=https;AccountName=%s;AccountKey=%s;EndpointSuffix=core.windows.net/", + accountName, accountKey) + return &connectionString, nil +} + +func GetServiceClientFromConnectionString(t *testing.T, accountType TestAccountType, options *service.ClientOptions) (*service.Client, error) { + if options == nil { + options = &service.ClientOptions{} + } + SetClientOptions(t, &options.ClientOptions) + + transport, err := recording.NewRecordingHTTPClient(t, nil) + require.NoError(t, err) + options.Transport = transport + + cred, err := GetGenericConnectionString(accountType) + if err != nil { + return nil, err + } + svcClient, err := service.NewClientFromConnectionString(*cred, options) + return svcClient, err +} diff --git a/sdk/storage/azfile/internal/testcommon/common.go b/sdk/storage/azfile/internal/testcommon/common.go new file mode 100644 index 000000000000..d9db882d00c0 --- /dev/null +++ b/sdk/storage/azfile/internal/testcommon/common.go @@ -0,0 +1,58 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +// Contains common helpers for TESTS ONLY +package testcommon + +import ( + "errors" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" + "github.com/stretchr/testify/require" + "os" + "strings" + "testing" +) + +const ( + SharePrefix = "gos" +) + +func GenerateShareName(testName string) string { + return SharePrefix + GenerateEntityName(testName) +} + +func GenerateEntityName(testName string) string { + return strings.ReplaceAll(strings.ReplaceAll(strings.ToLower(testName), "/", ""), "test", "") +} + +// GetRequiredEnv gets an environment variable by name and returns an error if it is not found +func GetRequiredEnv(name string) (string, error) { + env, ok := os.LookupEnv(name) + if ok { + return env, nil + } else { + return "", errors.New("Required environment variable not set: " + name) + } +} + +func BeforeTest(t *testing.T, suite string, test string) { + const urlRegex = `https://\S+\.file\.core\.windows\.net` + require.NoError(t, recording.AddURISanitizer(FakeStorageURL, urlRegex, nil)) + require.NoError(t, recording.AddHeaderRegexSanitizer("x-ms-copy-source", FakeStorageURL, urlRegex, nil)) + // we freeze request IDs and timestamps to avoid creating noisy diffs + // NOTE: we can't freeze time stamps as that breaks some tests that use if-modified-since etc (maybe it can be fixed?) + //testframework.AddHeaderRegexSanitizer("X-Ms-Date", "Wed, 10 Aug 2022 23:34:14 GMT", "", nil) + require.NoError(t, recording.AddHeaderRegexSanitizer("x-ms-request-id", "00000000-0000-0000-0000-000000000000", "", nil)) + //testframework.AddHeaderRegexSanitizer("Date", "Wed, 10 Aug 2022 23:34:14 GMT", "", nil) + // TODO: more freezing + //testframework.AddBodyRegexSanitizer("RequestId:00000000-0000-0000-0000-000000000000", `RequestId:\w{8}-\w{4}-\w{4}-\w{4}-\w{12}`, nil) + //testframework.AddBodyRegexSanitizer("Time:2022-08-11T00:21:56.4562741Z", `Time:\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d*)?Z`, nil) + require.NoError(t, recording.Start(t, "sdk/storage/azfile/testdata", nil)) +} + +func AfterTest(t *testing.T, suite string, test string) { + require.NoError(t, recording.Stop(t, nil)) +} diff --git a/sdk/storage/azfile/service/client.go b/sdk/storage/azfile/service/client.go index 5c3c3564e40f..35455293cd67 100644 --- a/sdk/storage/azfile/service/client.go +++ b/sdk/storage/azfile/service/client.go @@ -9,10 +9,14 @@ package service import ( "context" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/base" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/share" + "net/http" ) // ClientOptions contains the optional parameters when creating a Client. @@ -23,20 +27,15 @@ type ClientOptions struct { // Client represents a URL to the Azure File Storage service allowing you to manipulate file shares. type Client base.Client[generated.ServiceClient] -// NewClient creates an instance of Client with the specified values. -// - serviceURL - the URL of the storage account e.g. https://.file.core.windows.net/ -// - cred - an Azure AD credential, typically obtained via the azidentity module -// - options - client options; pass nil to accept the default values -func NewClient(serviceURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { - return nil, nil -} - // NewClientWithNoCredential creates an instance of Client with the specified values. // This is used to anonymously access a storage account or with a shared access signature (SAS) token. // - serviceURL - the URL of the storage account e.g. https://.file.core.windows.net/? // - options - client options; pass nil to accept the default values func NewClientWithNoCredential(serviceURL string, options *ClientOptions) (*Client, error) { - return nil, nil + conOptions := shared.GetClientOptions(options) + pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions) + + return (*Client)(base.NewServiceClient(serviceURL, pl, nil)), nil } // NewClientWithSharedKeyCredential creates an instance of Client with the specified values. @@ -44,14 +43,32 @@ func NewClientWithNoCredential(serviceURL string, options *ClientOptions) (*Clie // - cred - a SharedKeyCredential created with the matching storage account and access key // - options - client options; pass nil to accept the default values func NewClientWithSharedKeyCredential(serviceURL string, cred *SharedKeyCredential, options *ClientOptions) (*Client, error) { - return nil, nil + authPolicy := exported.NewSharedKeyCredPolicy(cred) + conOptions := shared.GetClientOptions(options) + conOptions.PerRetryPolicies = append(conOptions.PerRetryPolicies, authPolicy) + pl := runtime.NewPipeline(exported.ModuleName, exported.ModuleVersion, runtime.PipelineOptions{}, &conOptions.ClientOptions) + + return (*Client)(base.NewServiceClient(serviceURL, pl, cred)), nil } // NewClientFromConnectionString creates an instance of Client with the specified values. // - connectionString - a connection string for the desired storage account // - options - client options; pass nil to accept the default values func NewClientFromConnectionString(connectionString string, options *ClientOptions) (*Client, error) { - return nil, nil + parsed, err := shared.ParseConnectionString(connectionString) + if err != nil { + return nil, err + } + + if parsed.AccountKey != "" && parsed.AccountName != "" { + credential, err := exported.NewSharedKeyCredential(parsed.AccountName, parsed.AccountKey) + if err != nil { + return nil, err + } + return NewClientWithSharedKeyCredential(parsed.ServiceURL, credential, options) + } + + return NewClientWithNoCredential(parsed.ServiceURL, options) } func (s *Client) generated() *generated.ServiceClient { @@ -64,13 +81,14 @@ func (s *Client) sharedKey() *SharedKeyCredential { // URL returns the URL endpoint used by the Client object. func (s *Client) URL() string { - return "s.generated().Endpoint()" + return s.generated().Endpoint() } // NewShareClient creates a new share.Client object by concatenating shareName to the end of this Client's URL. // The new share.Client uses the same request policy pipeline as the Client. func (s *Client) NewShareClient(shareName string) *share.Client { - return nil + shareURL := runtime.JoinPaths(s.generated().Endpoint(), shareName) + return (*share.Client)(base.NewShareClient(shareURL, s.generated().Pipeline(), s.sharedKey())) } // CreateShare is a lifecycle method to creates a new share under the specified account. @@ -113,11 +131,54 @@ func (s *Client) GetProperties(ctx context.Context, options *GetPropertiesOption // SetProperties operation sets properties for a storage account's File service endpoint. // For more information, see https://learn.microsoft.com/en-us/rest/api/storageservices/set-file-service-properties. func (s *Client) SetProperties(ctx context.Context, options *SetPropertiesOptions) (SetPropertiesResponse, error) { - return SetPropertiesResponse{}, nil + svcProperties, o := options.format() + resp, err := s.generated().SetProperties(ctx, svcProperties, o) + return resp, err } // NewListSharesPager operation returns a pager of the shares under the specified account. // For more information, see https://learn.microsoft.com/en-us/rest/api/storageservices/list-shares func (s *Client) NewListSharesPager(options *ListSharesOptions) *runtime.Pager[ListSharesSegmentResponse] { - return nil + listOptions := generated.ServiceClientListSharesSegmentOptions{} + if options != nil { + if options.Include.Deleted { + listOptions.Include = append(listOptions.Include, ListSharesIncludeTypeDeleted) + } + if options.Include.Metadata { + listOptions.Include = append(listOptions.Include, ListSharesIncludeTypeMetadata) + } + if options.Include.Snapshots { + listOptions.Include = append(listOptions.Include, ListSharesIncludeTypeSnapshots) + } + listOptions.Marker = options.Marker + listOptions.Maxresults = options.MaxResults + listOptions.Prefix = options.Prefix + } + + return runtime.NewPager(runtime.PagingHandler[ListSharesSegmentResponse]{ + More: func(page ListSharesSegmentResponse) bool { + return page.NextMarker != nil && len(*page.NextMarker) > 0 + }, + Fetcher: func(ctx context.Context, page *ListSharesSegmentResponse) (ListSharesSegmentResponse, error) { + var req *policy.Request + var err error + if page == nil { + req, err = s.generated().ListSharesSegmentCreateRequest(ctx, &listOptions) + } else { + listOptions.Marker = page.NextMarker + req, err = s.generated().ListSharesSegmentCreateRequest(ctx, &listOptions) + } + if err != nil { + return ListSharesSegmentResponse{}, err + } + resp, err := s.generated().Pipeline().Do(req) + if err != nil { + return ListSharesSegmentResponse{}, err + } + if !runtime.HasStatusCode(resp, http.StatusOK) { + return ListSharesSegmentResponse{}, runtime.NewResponseError(resp) + } + return s.generated().ListSharesSegmentHandleResponse(resp) + }, + }) } diff --git a/sdk/storage/azfile/service/client_test.go b/sdk/storage/azfile/service/client_test.go new file mode 100644 index 000000000000..9e69f16e30f2 --- /dev/null +++ b/sdk/storage/azfile/service/client_test.go @@ -0,0 +1,226 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package service_test + +import ( + "context" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/testcommon" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/service" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "os" + "testing" + "time" +) + +func Test(t *testing.T) { + recordMode := recording.GetRecordMode() + t.Logf("Running service Tests in %s mode\n", recordMode) + if recordMode == recording.LiveMode { + suite.Run(t, &ServiceRecordedTestsSuite{}) + suite.Run(t, &ServiceUnrecordedTestsSuite{}) + } else if recordMode == recording.PlaybackMode { + suite.Run(t, &ServiceRecordedTestsSuite{}) + } else if recordMode == recording.RecordingMode { + suite.Run(t, &ServiceRecordedTestsSuite{}) + } +} + +func (s *ServiceRecordedTestsSuite) BeforeTest(suite string, test string) { + testcommon.BeforeTest(s.T(), suite, test) +} + +func (s *ServiceRecordedTestsSuite) AfterTest(suite string, test string) { + testcommon.AfterTest(s.T(), suite, test) +} + +func (s *ServiceUnrecordedTestsSuite) BeforeTest(suite string, test string) { + +} + +func (s *ServiceUnrecordedTestsSuite) AfterTest(suite string, test string) { + +} + +type ServiceRecordedTestsSuite struct { + suite.Suite +} + +type ServiceUnrecordedTestsSuite struct { + suite.Suite +} + +func (s *ServiceUnrecordedTestsSuite) TestAccountNewServiceURLValidName() { + _require := require.New(s.T()) + + svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + correctURL := "https://" + os.Getenv("AZURE_STORAGE_ACCOUNT_NAME") + "." + testcommon.DefaultFileEndpointSuffix + _require.Equal(svcClient.URL(), correctURL) +} + +func (s *ServiceUnrecordedTestsSuite) TestAccountNewShareURLValidName() { + _require := require.New(s.T()) + testName := s.T().Name() + + svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + shareName := testcommon.GenerateShareName(testName) + shareClient := svcClient.NewShareClient(shareName) + _require.NoError(err) + + correctURL := "https://" + os.Getenv("AZURE_STORAGE_ACCOUNT_NAME") + "." + testcommon.DefaultFileEndpointSuffix + shareName + _require.Equal(shareClient.URL(), correctURL) +} + +func (s *ServiceUnrecordedTestsSuite) TestServiceClientFromConnectionString() { + _require := require.New(s.T()) + + svcClient, err := testcommon.GetServiceClientFromConnectionString(s.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + resp, err := svcClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.NotNil(resp.RequestID) +} + +func (s *ServiceUnrecordedTestsSuite) TestAccountProperties() { + _require := require.New(s.T()) + + svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + setPropertiesOptions := &service.SetPropertiesOptions{ + HourMetrics: &service.Metrics{ + Enabled: to.Ptr(true), + IncludeAPIs: to.Ptr(true), + RetentionPolicy: &service.RetentionPolicy{ + Enabled: to.Ptr(true), + Days: to.Ptr(int32(2)), + }, + }, + MinuteMetrics: &service.Metrics{ + Enabled: to.Ptr(true), + IncludeAPIs: to.Ptr(false), + RetentionPolicy: &service.RetentionPolicy{ + Enabled: to.Ptr(true), + Days: to.Ptr(int32(2)), + }, + }, + CORS: []*service.CORSRule{ + { + AllowedOrigins: to.Ptr("*"), + AllowedMethods: to.Ptr("PUT"), + AllowedHeaders: to.Ptr("x-ms-client-request-id"), + ExposedHeaders: to.Ptr("x-ms-*"), + MaxAgeInSeconds: to.Ptr(int32(2)), + }, + }, + } + + setPropsResp, err := svcClient.SetProperties(context.Background(), setPropertiesOptions) + _require.NoError(err) + _require.NotNil(setPropsResp.RequestID) + + time.Sleep(time.Second * 30) + + getPropsResp, err := svcClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.NotNil(getPropsResp.RequestID) + _require.EqualValues(getPropsResp.HourMetrics.RetentionPolicy.Enabled, setPropertiesOptions.HourMetrics.RetentionPolicy.Enabled) + _require.EqualValues(getPropsResp.HourMetrics.RetentionPolicy.Days, setPropertiesOptions.HourMetrics.RetentionPolicy.Days) + _require.EqualValues(getPropsResp.MinuteMetrics.RetentionPolicy.Enabled, setPropertiesOptions.MinuteMetrics.RetentionPolicy.Enabled) + _require.EqualValues(getPropsResp.MinuteMetrics.RetentionPolicy.Days, setPropertiesOptions.MinuteMetrics.RetentionPolicy.Days) + _require.EqualValues(len(getPropsResp.CORS), len(setPropertiesOptions.CORS)) +} + +func (s *ServiceRecordedTestsSuite) TestAccountHourMetrics() { + _require := require.New(s.T()) + + svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + setPropertiesOptions := &service.SetPropertiesOptions{ + HourMetrics: &service.Metrics{ + Enabled: to.Ptr(true), + IncludeAPIs: to.Ptr(true), + RetentionPolicy: &service.RetentionPolicy{ + Enabled: to.Ptr(true), + Days: to.Ptr(int32(5)), + }, + }, + } + _, err = svcClient.SetProperties(context.Background(), setPropertiesOptions) + _require.NoError(err) +} + +func (s *ServiceUnrecordedTestsSuite) TestAccountListSharesNonDefault() { + _require := require.New(s.T()) + testName := s.T().Name() + + svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil) + _require.NoError(err) + + mySharePrefix := testcommon.GenerateEntityName(testName) + pager := svcClient.NewListSharesPager(&service.ListSharesOptions{Prefix: to.Ptr(mySharePrefix)}) + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + _require.NotNil(resp.Prefix) + _require.Equal(*resp.Prefix, mySharePrefix) + _require.NotNil(resp.ServiceEndpoint) + _require.NotNil(resp.Version) + _require.Len(resp.Shares, 0) + } + + /*shareClients := map[string]*share.Client{} + for i := 0; i < 4; i++ { + shareName := mySharePrefix + "share" + strconv.Itoa(i) + shareClients[shareName] = createNewShare(_require, shareName, svcClient) + + _, err := shareClients[shareName].SetMetadata(context.Background(), basicMetadata, nil) + _require.NoError(err) + + _, err = shareClients[shareName].CreateSnapshot(context.Background(), nil) + _require.NoError(err) + + defer delShare(_require, shareClients[shareName], &ShareDeleteOptions{ + DeleteSnapshots: to.Ptr(DeleteSnapshotsOptionTypeInclude), + }) + } + + pager = svcClient.NewListSharesPager(&service.ListSharesOptions{ + Include: service.ListSharesInclude{Metadata: true, Snapshots: true}, + Prefix: to.Ptr(mySharePrefix), + MaxResults: to.Ptr(int32(2)), + }) + + for pager.More() { + resp, err := pager.NextPage(context.Background()) + _require.NoError(err) + if len(resp.Shares) > 0 { + _require.Len(resp.Shares, 2) + } + for _, shareItem := range resp.Shares { + _require.NotNil(shareItem.Properties) + _require.NotNil(shareItem.Properties.LastModified) + _require.NotNil(shareItem.Properties.ETag) + _require.Len(shareItem.Metadata, len(basicMetadata)) + for key, val1 := range basicMetadata { + if val2, ok := shareItem.Metadata[key]; !(ok && val1 == *val2) { + _require.Fail("metadata mismatch") + } + } + _require.NotNil(resp.Shares[0].Snapshot) + _require.Nil(resp.Shares[1].Snapshot) + } + }*/ +} diff --git a/sdk/storage/azfile/service/models.go b/sdk/storage/azfile/service/models.go index 5d3412dbaa3e..4ea97710f359 100644 --- a/sdk/storage/azfile/service/models.go +++ b/sdk/storage/azfile/service/models.go @@ -7,14 +7,22 @@ package service import ( + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/share" ) // SharedKeyCredential contains an account's name and its primary or secondary key. type SharedKeyCredential = exported.SharedKeyCredential +// NewSharedKeyCredential creates an immutable SharedKeyCredential containing the +// storage account's name and either its primary or secondary key. +func NewSharedKeyCredential(accountName, accountKey string) (*SharedKeyCredential, error) { + return exported.NewSharedKeyCredential(accountName, accountKey) +} + // CreateShareOptions contains the optional parameters for the share.Client.Create method. type CreateShareOptions = share.CreateOptions @@ -52,6 +60,31 @@ type SetPropertiesOptions struct { Protocol *ProtocolSettings } +func (o *SetPropertiesOptions) format() (generated.StorageServiceProperties, *generated.ServiceClientSetPropertiesOptions) { + if o == nil { + return generated.StorageServiceProperties{}, nil + } + + formatMetrics(o.HourMetrics) + formatMetrics(o.MinuteMetrics) + + return generated.StorageServiceProperties{ + CORS: o.CORS, + HourMetrics: o.HourMetrics, + MinuteMetrics: o.MinuteMetrics, + Protocol: o.Protocol, + }, nil +} + +// update version of Storage Analytics to configure. Use 1.0 for this value. +func formatMetrics(m *Metrics) { + if m == nil { + return + } + + m.Version = to.Ptr(shared.StorageAnalyticsVersion) +} + // StorageServiceProperties - Storage service properties. type StorageServiceProperties = generated.StorageServiceProperties @@ -81,19 +114,34 @@ type SMBMultichannel = generated.SMBMultichannel // ListSharesOptions contains the optional parameters for the Client.NewListSharesPager method. type ListSharesOptions struct { // Include this parameter to specify one or more datasets to include in the responseBody. - Include []ListSharesIncludeType + Include ListSharesInclude + // A string value that identifies the portion of the list to be returned with the next list operation. The operation returns // a marker value within the responseBody body if the list returned was not complete. // The marker value may then be used in a subsequent call to request the next set of list items. The marker value is opaque // to the client. Marker *string + // Specifies the maximum number of entries to return. If the request does not specify maxresults, or specifies a value greater // than 5,000, the server will return up to 5,000 items. MaxResults *int32 + // Filters the results to return only entries whose name begins with the specified prefix. Prefix *string } +// ListSharesInclude indicates what additional information the service should return with each share. +type ListSharesInclude struct { + // Tells the service whether to return metadata for each share. + Metadata bool + + // Tells the service whether to return soft-deleted shares. + Deleted bool + + // Tells the service whether to return share snapshots. + Snapshots bool +} + // Share - A listed Azure Storage share item. type Share = generated.Share diff --git a/sdk/storage/azfile/share/client.go b/sdk/storage/azfile/share/client.go index 9ccf095bdfd8..56ca68a7604b 100644 --- a/sdk/storage/azfile/share/client.go +++ b/sdk/storage/azfile/share/client.go @@ -22,14 +22,6 @@ type ClientOptions struct { // Client represents a URL to the Azure Storage share allowing you to manipulate its directories and files. type Client base.Client[generated.ShareClient] -// NewClient creates an instance of Client with the specified values. -// - shareURL - the URL of the share e.g. https://.file.core.windows.net/share -// - cred - an Azure AD credential, typically obtained via the azidentity module -// - options - client options; pass nil to accept the default values -func NewClient(shareURL string, cred azcore.TokenCredential, options *ClientOptions) (*Client, error) { - return nil, nil -} - // NewClientWithNoCredential creates an instance of Client with the specified values. // This is used to anonymously access a share or with a shared access signature (SAS) token. // - shareURL - the URL of the share e.g. https://.file.core.windows.net/share? @@ -64,7 +56,7 @@ func (s *Client) sharedKey() *SharedKeyCredential { // URL returns the URL endpoint used by the Client object. func (s *Client) URL() string { - return "s.generated().Endpoint()" + return s.generated().Endpoint() } // NewDirectoryClient creates a new directory.Client object by concatenating directoryName to the end of this Client's URL.