diff --git a/sdk/storage/azfile/directory/client.go b/sdk/storage/azfile/directory/client.go index dc7716ad1337..e01e7d0c960e 100644 --- a/sdk/storage/azfile/directory/client.go +++ b/sdk/storage/azfile/directory/client.go @@ -42,9 +42,9 @@ func NewClientWithSharedKeyCredential(directoryURL string, cred *SharedKeyCreden // NewClientFromConnectionString creates an instance of Client with the specified values. // - connectionString - a connection string for the desired storage account // - shareName - the name of the share within the storage account -// - directoryName - the name of the directory within the storage account +// - directoryPath - the path of the directory within the share // - options - client options; pass nil to accept the default values -func NewClientFromConnectionString(connectionString string, shareName string, directoryName string, options *ClientOptions) (*Client, error) { +func NewClientFromConnectionString(connectionString string, shareName string, directoryPath string, options *ClientOptions) (*Client, error) { return nil, nil } diff --git a/sdk/storage/azfile/file/client.go b/sdk/storage/azfile/file/client.go index dcc22ae55ab6..27587ad4f455 100644 --- a/sdk/storage/azfile/file/client.go +++ b/sdk/storage/azfile/file/client.go @@ -42,10 +42,9 @@ func NewClientWithSharedKeyCredential(fileURL string, cred *SharedKeyCredential, // NewClientFromConnectionString creates an instance of Client with the specified values. // - connectionString - a connection string for the desired storage account // - shareName - the name of the share within the storage account -// - directoryName - the name of the directory within the storage account -// - fileName - the name of the file within the storage account +// - filePath - the path of the file within the share // - options - client options; pass nil to accept the default values -func NewClientFromConnectionString(connectionString string, shareName string, directoryName string, fileName string, options *ClientOptions) (*Client, error) { +func NewClientFromConnectionString(connectionString string, shareName string, filePath string, options *ClientOptions) (*Client, error) { return nil, nil } diff --git a/sdk/storage/azfile/fileerror/error_codes.go b/sdk/storage/azfile/fileerror/error_codes.go index 3f91c984c711..c897c0953828 100644 --- a/sdk/storage/azfile/fileerror/error_codes.go +++ b/sdk/storage/azfile/fileerror/error_codes.go @@ -100,3 +100,8 @@ const ( UnsupportedQueryParameter Code = "UnsupportedQueryParameter" UnsupportedXMLNode Code = "UnsupportedXmlNode" ) + +var ( + // MissingSharedKeyCredential - Error is returned when SAS URL is being created without SharedKeyCredential. + MissingSharedKeyCredential = errors.New("SAS can only be signed with a SharedKeyCredential") +) diff --git a/sdk/storage/azfile/internal/shared/shared.go b/sdk/storage/azfile/internal/shared/shared.go index e201782fc0b2..b2e04301a09f 100644 --- a/sdk/storage/azfile/internal/shared/shared.go +++ b/sdk/storage/azfile/internal/shared/shared.go @@ -9,6 +9,7 @@ package shared import ( "errors" "fmt" + "net" "strings" ) @@ -108,3 +109,22 @@ func ParseConnectionString(connectionString string) (ParsedConnectionString, err AccountKey: accountKey, }, nil } + +// IsIPEndpointStyle checks if URL's host is IP, in this case the storage account endpoint will be composed as: +// http(s)://IP(:port)/storageaccount/share(||container||etc)/... +// As url's Host property, host could be both host or host:port +func IsIPEndpointStyle(host string) bool { + if host == "" { + return false + } + if h, _, err := net.SplitHostPort(host); err == nil { + host = h + } + // For IPv6, there could be case where SplitHostPort fails for cannot finding port. + // In this case, eliminate the '[' and ']' in the URL. + // For details about IPv6 URL, please refer to https://tools.ietf.org/html/rfc2732 + if host[0] == '[' && host[len(host)-1] == ']' { + host = host[1 : len(host)-1] + } + return net.ParseIP(host) != nil +} diff --git a/sdk/storage/azfile/sas/account.go b/sdk/storage/azfile/sas/account.go new file mode 100644 index 000000000000..8e3ee410aef0 --- /dev/null +++ b/sdk/storage/azfile/sas/account.go @@ -0,0 +1,183 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "bytes" + "errors" + "fmt" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/exported" +) + +// SharedKeyCredential contains an account's name and its primary or secondary key. +type SharedKeyCredential = exported.SharedKeyCredential + +// AccountSignatureValues is used to generate a Shared Access Signature (SAS) for an Azure Storage account. +// For more information, see https://docs.microsoft.com/rest/api/storageservices/constructing-an-account-sas +type AccountSignatureValues struct { + Version string `param:"sv"` // If not specified, this format to SASVersion + Protocol Protocol `param:"spr"` // See the SASProtocol* constants + StartTime time.Time `param:"st"` // Not specified if IsZero + ExpiryTime time.Time `param:"se"` // Not specified if IsZero + Permissions string `param:"sp"` // Create by initializing a AccountSASPermissions and then call String() + IPRange IPRange `param:"sip"` + ResourceTypes string `param:"srt"` // Create by initializing AccountSASResourceTypes and then call String() +} + +// SignWithSharedKey uses an account's shared key credential to sign this signature values to produce +// the proper SAS query parameters. +func (v AccountSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredential) (QueryParameters, error) { + // https://docs.microsoft.com/en-us/rest/api/storageservices/Constructing-an-Account-SAS + if v.ExpiryTime.IsZero() || v.Permissions == "" || v.ResourceTypes == "" { + return QueryParameters{}, errors.New("account SAS is missing at least one of these: ExpiryTime, Permissions, Service, or ResourceType") + } + if v.Version == "" { + v.Version = Version + } + perms, err := parseAccountPermissions(v.Permissions) + if err != nil { + return QueryParameters{}, err + } + v.Permissions = perms.String() + + resources, err := parseAccountResourceTypes(v.ResourceTypes) + if err != nil { + return QueryParameters{}, err + } + v.ResourceTypes = resources.String() + + startTime, expiryTime, _ := formatTimesForSigning(v.StartTime, v.ExpiryTime, time.Time{}) + + stringToSign := strings.Join([]string{ + sharedKeyCredential.AccountName(), + v.Permissions, + "f", // file service + v.ResourceTypes, + startTime, + expiryTime, + v.IPRange.String(), + string(v.Protocol), + v.Version, + ""}, // That is right, the account SAS requires a terminating extra newline + "\n") + + signature, err := exported.ComputeHMACSHA256(sharedKeyCredential, stringToSign) + if err != nil { + return QueryParameters{}, err + } + p := QueryParameters{ + // Common SAS parameters + version: v.Version, + protocol: v.Protocol, + startTime: v.StartTime, + expiryTime: v.ExpiryTime, + permissions: v.Permissions, + ipRange: v.IPRange, + + // Account-specific SAS parameters + services: "f", // will always be "f" for Azure File + resourceTypes: v.ResourceTypes, + + // Calculated SAS signature + signature: signature, + } + + return p, nil +} + +// AccountPermissions type simplifies creating the permissions string for an Azure Storage Account SAS. +// Initialize an instance of this type and then call its String method to set AccountSASSignature value's Permissions field. +type AccountPermissions struct { + Read, Write, Delete, List, Create bool +} + +// String produces the SAS permissions string for an Azure Storage account. +// Call this method to set AccountSignatureValues' Permissions field. +func (p *AccountPermissions) String() string { + var buffer bytes.Buffer + if p.Read { + buffer.WriteRune('r') + } + if p.Write { + buffer.WriteRune('w') + } + if p.Delete { + buffer.WriteRune('d') + } + if p.List { + buffer.WriteRune('l') + } + if p.Create { + buffer.WriteRune('c') + } + return buffer.String() +} + +// parseAccountPermissions initializes the AccountSASPermissions' fields from a string. +func parseAccountPermissions(s string) (AccountPermissions, error) { + p := AccountPermissions{} // Clear out the flags + for _, r := range s { + switch r { + case 'r': + p.Read = true + case 'w': + p.Write = true + case 'd': + p.Delete = true + case 'l': + p.List = true + case 'c': + p.Create = true + default: + return AccountPermissions{}, fmt.Errorf("invalid permission character: '%v'", r) + } + } + return p, nil +} + +// AccountResourceTypes type simplifies creating the resource types string for an Azure Storage Account SAS. +// Initialize an instance of this type and then call its String method to set AccountSignatureValues' ResourceTypes field. +type AccountResourceTypes struct { + Service, Container, Object bool +} + +// String produces the SAS resource types string for an Azure Storage account. +// Call this method to set AccountSignatureValues' ResourceTypes field. +func (rt *AccountResourceTypes) String() string { + var buffer bytes.Buffer + if rt.Service { + buffer.WriteRune('s') + } + if rt.Container { + buffer.WriteRune('c') + } + if rt.Object { + buffer.WriteRune('o') + } + return buffer.String() +} + +// parseAccountResourceTypes initializes the AccountResourceTypes' fields from a string. +func parseAccountResourceTypes(s string) (AccountResourceTypes, error) { + rt := AccountResourceTypes{} + for _, r := range s { + switch r { + case 's': + rt.Service = true + case 'c': + rt.Container = true + case 'o': + rt.Object = true + default: + return AccountResourceTypes{}, fmt.Errorf("invalid resource type character: '%v'", r) + } + } + return rt, nil +} diff --git a/sdk/storage/azfile/sas/account_test.go b/sdk/storage/azfile/sas/account_test.go new file mode 100644 index 000000000000..d22d645185ed --- /dev/null +++ b/sdk/storage/azfile/sas/account_test.go @@ -0,0 +1,124 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestAccountPermissions_String(t *testing.T) { + testdata := []struct { + input AccountPermissions + expected string + }{ + {input: AccountPermissions{Read: true}, expected: "r"}, + {input: AccountPermissions{Write: true}, expected: "w"}, + {input: AccountPermissions{Delete: true}, expected: "d"}, + {input: AccountPermissions{List: true}, expected: "l"}, + {input: AccountPermissions{Create: true}, expected: "c"}, + {input: AccountPermissions{ + Read: true, + Write: true, + Delete: true, + List: true, + Create: true, + }, expected: "rwdlc"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, c.input.String()) + } +} + +func TestAccountPermissions_Parse(t *testing.T) { + testdata := []struct { + input string + expected AccountPermissions + }{ + {expected: AccountPermissions{Read: true}, input: "r"}, + {expected: AccountPermissions{Write: true}, input: "w"}, + {expected: AccountPermissions{Delete: true}, input: "d"}, + {expected: AccountPermissions{List: true}, input: "l"}, + {expected: AccountPermissions{Create: true}, input: "c"}, + {expected: AccountPermissions{ + Read: true, + Write: true, + Delete: true, + List: true, + Create: true, + }, input: "rwdlc"}, + {expected: AccountPermissions{ + Read: true, + Write: true, + Delete: true, + List: true, + Create: true, + }, input: "rcdlw"}, + } + for _, c := range testdata { + permissions, err := parseAccountPermissions(c.input) + require.Nil(t, err) + require.Equal(t, c.expected, permissions) + } +} + +func TestAccountPermissions_ParseNegative(t *testing.T) { + _, err := parseAccountPermissions("rwldcz") // Here 'z' is invalid + require.NotNil(t, err) + require.Contains(t, err.Error(), "122") +} + +func TestAccountResourceTypes_String(t *testing.T) { + testdata := []struct { + input AccountResourceTypes + expected string + }{ + {input: AccountResourceTypes{Service: true}, expected: "s"}, + {input: AccountResourceTypes{Container: true}, expected: "c"}, + {input: AccountResourceTypes{Object: true}, expected: "o"}, + {input: AccountResourceTypes{ + Service: true, + Container: true, + Object: true, + }, expected: "sco"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, c.input.String()) + } +} + +func TestAccountResourceTypes_Parse(t *testing.T) { + testdata := []struct { + input string + expected AccountResourceTypes + }{ + {expected: AccountResourceTypes{Service: true}, input: "s"}, + {expected: AccountResourceTypes{Container: true}, input: "c"}, + {expected: AccountResourceTypes{Object: true}, input: "o"}, + {expected: AccountResourceTypes{ + Service: true, + Container: true, + Object: true, + }, input: "sco"}, + {expected: AccountResourceTypes{ + Service: true, + Container: true, + Object: true, + }, input: "osc"}, + } + for _, c := range testdata { + permissions, err := parseAccountResourceTypes(c.input) + require.Nil(t, err) + require.Equal(t, c.expected, permissions) + } +} + +func TestAccountResourceTypes_ParseNegative(t *testing.T) { + _, err := parseAccountResourceTypes("scoz") // Here 'z' is invalid + require.NotNil(t, err) + require.Contains(t, err.Error(), "122") +} diff --git a/sdk/storage/azfile/sas/query_params.go b/sdk/storage/azfile/sas/query_params.go new file mode 100644 index 000000000000..a7d53c41aade --- /dev/null +++ b/sdk/storage/azfile/sas/query_params.go @@ -0,0 +1,339 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "errors" + "net" + "net/url" + "strings" + "time" +) + +// timeFormat represents the format of a SAS start or expiry time. Use it when formatting/parsing a time.Time. +const ( + timeFormat = "2006-01-02T15:04:05Z" // "2017-07-27T00:00:00Z" // ISO 8601 + snapshotTimeFormat = "2006-01-02T15:04:05.0000000Z07:00" +) + +var ( + // Version is the default version encoded in the SAS token. + Version = "2020-02-10" +) + +// TimeFormats ISO 8601 format. +// Please refer to https://docs.microsoft.com/en-us/rest/api/storageservices/constructing-a-service-sas for more details. +var timeFormats = []string{"2006-01-02T15:04:05.0000000Z", timeFormat, "2006-01-02T15:04Z", "2006-01-02"} + +// Protocol indicates the http/https. +type Protocol string + +const ( + // ProtocolHTTPS can be specified for a SAS protocol. + ProtocolHTTPS Protocol = "https" + + // ProtocolHTTPSandHTTP can be specified for a SAS protocol. + ProtocolHTTPSandHTTP Protocol = "https,http" +) + +// FormatTimesForSigning converts a time.Time to a snapshotTimeFormat string suitable for a +// Field's StartTime or ExpiryTime fields. Returns "" if value.IsZero(). +func formatTimesForSigning(startTime, expiryTime, snapshotTime time.Time) (string, string, string) { + ss := "" + if !startTime.IsZero() { + ss = formatTimeWithDefaultFormat(&startTime) + } + se := "" + if !expiryTime.IsZero() { + se = formatTimeWithDefaultFormat(&expiryTime) + } + sh := "" + if !snapshotTime.IsZero() { + sh = snapshotTime.Format(snapshotTimeFormat) + } + return ss, se, sh +} + +// formatTimeWithDefaultFormat format time with ISO 8601 in "yyyy-MM-ddTHH:mm:ssZ". +func formatTimeWithDefaultFormat(t *time.Time) string { + return formatTime(t, timeFormat) // By default, "yyyy-MM-ddTHH:mm:ssZ" is used +} + +// formatTime format time with given format, use ISO 8601 in "yyyy-MM-ddTHH:mm:ssZ" by default. +func formatTime(t *time.Time, format string) string { + if format != "" { + return t.Format(format) + } + return t.Format(timeFormat) // By default, "yyyy-MM-ddTHH:mm:ssZ" is used +} + +// ParseTime try to parse a SAS time string. +func parseTime(val string) (t time.Time, timeFormat string, err error) { + for _, sasTimeFormat := range timeFormats { + t, err = time.Parse(sasTimeFormat, val) + if err == nil { + timeFormat = sasTimeFormat + break + } + } + + if err != nil { + err = errors.New("fail to parse time with IOS 8601 formats, please refer to https://docs.microsoft.com/en-us/rest/api/storageservices/constructing-a-service-sas for more details") + } + + return +} + +// IPRange represents a SAS IP range's start IP and (optionally) end IP. +type IPRange struct { + Start net.IP // Not specified if length = 0 + End net.IP // Not specified if length = 0 +} + +// String returns a string representation of an IPRange. +func (ipr *IPRange) String() string { + if len(ipr.Start) == 0 { + return "" + } + start := ipr.Start.String() + if len(ipr.End) == 0 { + return start + } + return start + "-" + ipr.End.String() +} + +// https://docs.microsoft.com/en-us/rest/api/storageservices/constructing-a-service-sas + +// QueryParameters object represents the components that make up an Azure Storage SAS' query parameters. +// You parse a map of query parameters into its fields by calling NewQueryParameters(). You add the components +// to a query parameter map by calling AddToValues(). +// NOTE: Changing any field requires computing a new SAS signature using a XxxSASSignatureValues type. +// This type defines the components used by all Azure Storage resources (Containers, Blobs, Files, & Queues). +type QueryParameters struct { + // All members are immutable or values so copies of this struct are goroutine-safe. + version string `param:"sv"` + services string `param:"ss"` + resourceTypes string `param:"srt"` + protocol Protocol `param:"spr"` + startTime time.Time `param:"st"` + expiryTime time.Time `param:"se"` + snapshotTime time.Time `param:"snapshot"` + ipRange IPRange `param:"sip"` + identifier string `param:"si"` + resource string `param:"sr"` + permissions string `param:"sp"` + signature string `param:"sig"` + cacheControl string `param:"rscc"` + contentDisposition string `param:"rscd"` + contentEncoding string `param:"rsce"` + contentLanguage string `param:"rscl"` + contentType string `param:"rsct"` + // private member used for startTime and expiryTime formatting. + stTimeFormat string + seTimeFormat string +} + +// SnapshotTime returns snapshotTime. +func (p *QueryParameters) SnapshotTime() time.Time { + return p.snapshotTime +} + +// Version returns version. +func (p *QueryParameters) Version() string { + return p.version +} + +// Services returns services. +func (p *QueryParameters) Services() string { + return p.services +} + +// ResourceTypes returns resourceTypes. +func (p *QueryParameters) ResourceTypes() string { + return p.resourceTypes +} + +// Protocol returns protocol. +func (p *QueryParameters) Protocol() Protocol { + return p.protocol +} + +// StartTime returns startTime. +func (p *QueryParameters) StartTime() time.Time { + return p.startTime +} + +// ExpiryTime returns expiryTime. +func (p *QueryParameters) ExpiryTime() time.Time { + return p.expiryTime +} + +// IPRange returns ipRange. +func (p *QueryParameters) IPRange() IPRange { + return p.ipRange +} + +// Identifier returns identifier. +func (p *QueryParameters) Identifier() string { + return p.identifier +} + +// Resource returns resource. +func (p *QueryParameters) Resource() string { + return p.resource +} + +// Permissions returns permissions. +func (p *QueryParameters) Permissions() string { + return p.permissions +} + +// Signature returns signature. +func (p *QueryParameters) Signature() string { + return p.signature +} + +// CacheControl returns cacheControl. +func (p *QueryParameters) CacheControl() string { + return p.cacheControl +} + +// ContentDisposition returns contentDisposition. +func (p *QueryParameters) ContentDisposition() string { + return p.contentDisposition +} + +// ContentEncoding returns contentEncoding. +func (p *QueryParameters) ContentEncoding() string { + return p.contentEncoding +} + +// ContentLanguage returns contentLanguage. +func (p *QueryParameters) ContentLanguage() string { + return p.contentLanguage +} + +// ContentType returns contentType. +func (p *QueryParameters) ContentType() string { + return p.contentType +} + +// Encode encodes the SAS query parameters into URL encoded form sorted by key. +func (p *QueryParameters) Encode() string { + v := url.Values{} + + if p.version != "" { + v.Add("sv", p.version) + } + if p.services != "" { + v.Add("ss", p.services) + } + if p.resourceTypes != "" { + v.Add("srt", p.resourceTypes) + } + if p.protocol != "" { + v.Add("spr", string(p.protocol)) + } + if !p.startTime.IsZero() { + v.Add("st", formatTime(&(p.startTime), p.stTimeFormat)) + } + if !p.expiryTime.IsZero() { + v.Add("se", formatTime(&(p.expiryTime), p.seTimeFormat)) + } + if len(p.ipRange.Start) > 0 { + v.Add("sip", p.ipRange.String()) + } + if p.identifier != "" { + v.Add("si", p.identifier) + } + if p.resource != "" { + v.Add("sr", p.resource) + } + if p.permissions != "" { + v.Add("sp", p.permissions) + } + if p.signature != "" { + v.Add("sig", p.signature) + } + if p.cacheControl != "" { + v.Add("rscc", p.cacheControl) + } + if p.contentDisposition != "" { + v.Add("rscd", p.contentDisposition) + } + if p.contentEncoding != "" { + v.Add("rsce", p.contentEncoding) + } + if p.contentLanguage != "" { + v.Add("rscl", p.contentLanguage) + } + if p.contentType != "" { + v.Add("rsct", p.contentType) + } + + return v.Encode() +} + +// NewQueryParameters creates and initializes a QueryParameters object based on the +// query parameter map's passed-in values. If deleteSASParametersFromValues is true, +// all SAS-related query parameters are removed from the passed-in map. If +// deleteSASParametersFromValues is false, the map passed-in map is unaltered. +func NewQueryParameters(values url.Values, deleteSASParametersFromValues bool) QueryParameters { + p := QueryParameters{} + for k, v := range values { + val := v[0] + isSASKey := true + switch strings.ToLower(k) { + case "sv": + p.version = val + case "ss": + p.services = val + case "srt": + p.resourceTypes = val + case "spr": + p.protocol = Protocol(val) + case "snapshot": + p.snapshotTime, _ = time.Parse(snapshotTimeFormat, val) + case "st": + p.startTime, p.stTimeFormat, _ = parseTime(val) + case "se": + p.expiryTime, p.seTimeFormat, _ = parseTime(val) + case "sip": + dashIndex := strings.Index(val, "-") + if dashIndex == -1 { + p.ipRange.Start = net.ParseIP(val) + } else { + p.ipRange.Start = net.ParseIP(val[:dashIndex]) + p.ipRange.End = net.ParseIP(val[dashIndex+1:]) + } + case "si": + p.identifier = val + case "sr": + p.resource = val + case "sp": + p.permissions = val + case "sig": + p.signature = val + case "rscc": + p.cacheControl = val + case "rscd": + p.contentDisposition = val + case "rsce": + p.contentEncoding = val + case "rscl": + p.contentLanguage = val + case "rsct": + p.contentType = val + default: + isSASKey = false // We didn't recognize the query parameter + } + if isSASKey && deleteSASParametersFromValues { + delete(values, k) + } + } + return p +} diff --git a/sdk/storage/azfile/sas/query_params_test.go b/sdk/storage/azfile/sas/query_params_test.go new file mode 100644 index 000000000000..b7d00f63981c --- /dev/null +++ b/sdk/storage/azfile/sas/query_params_test.go @@ -0,0 +1,211 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "fmt" + "net" + "net/url" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestFormatTimesForSigning(t *testing.T) { + testdata := []struct { + inputStart time.Time + inputEnd time.Time + inputSnapshot time.Time + expectedStart string + expectedEnd string + expectedSnapshot string + }{ + {expectedStart: "", expectedEnd: "", expectedSnapshot: ""}, + {inputStart: time.Date(1955, 6, 25, 22, 15, 56, 345456, time.UTC), expectedStart: "1955-06-25T22:15:56Z", expectedEnd: "", expectedSnapshot: ""}, + {inputEnd: time.Date(2023, 4, 5, 8, 50, 27, 4500, time.UTC), expectedStart: "", expectedEnd: "2023-04-05T08:50:27Z", expectedSnapshot: ""}, + {inputSnapshot: time.Date(2021, 1, 5, 22, 15, 33, 1234879, time.UTC), expectedStart: "", expectedEnd: "", expectedSnapshot: "2021-01-05T22:15:33.0012348Z"}, + { + inputStart: time.Date(1955, 6, 25, 22, 15, 56, 345456, time.UTC), + inputEnd: time.Date(2023, 4, 5, 8, 50, 27, 4500, time.UTC), + inputSnapshot: time.Date(2021, 1, 5, 22, 15, 33, 1234879, time.UTC), + expectedStart: "1955-06-25T22:15:56Z", + expectedEnd: "2023-04-05T08:50:27Z", + expectedSnapshot: "2021-01-05T22:15:33.0012348Z", + }, + } + for _, c := range testdata { + start, end, ss := formatTimesForSigning(c.inputStart, c.inputEnd, c.inputSnapshot) + require.Equal(t, c.expectedStart, start) + require.Equal(t, c.expectedEnd, end) + require.Equal(t, c.expectedSnapshot, ss) + } +} + +func TestFormatTimeWithDefaultFormat(t *testing.T) { + testdata := []struct { + input time.Time + expectedTime string + }{ + {input: time.Date(1955, 4, 5, 8, 50, 27, 4500, time.UTC), expectedTime: "1955-04-05T08:50:27Z"}, + {input: time.Date(1917, 3, 9, 16, 22, 56, 0, time.UTC), expectedTime: "1917-03-09T16:22:56Z"}, + {input: time.Date(2021, 1, 5, 22, 15, 0, 0, time.UTC), expectedTime: "2021-01-05T22:15:00Z"}, + {input: time.Date(2023, 6, 25, 0, 0, 0, 0, time.UTC), expectedTime: "2023-06-25T00:00:00Z"}, + } + for _, c := range testdata { + formattedTime := formatTimeWithDefaultFormat(&c.input) + require.Equal(t, c.expectedTime, formattedTime) + } +} + +func TestFormatTime(t *testing.T) { + testdata := []struct { + input time.Time + format string + expectedTime string + }{ + {input: time.Date(1955, 4, 5, 8, 50, 27, 4500, time.UTC), format: "2006-01-02T15:04:05.0000000Z", expectedTime: "1955-04-05T08:50:27.0000045Z"}, + {input: time.Date(1955, 4, 5, 8, 50, 27, 4500, time.UTC), format: "", expectedTime: "1955-04-05T08:50:27Z"}, + {input: time.Date(1917, 3, 9, 16, 22, 56, 0, time.UTC), format: "2006-01-02T15:04:05Z", expectedTime: "1917-03-09T16:22:56Z"}, + {input: time.Date(1917, 3, 9, 16, 22, 56, 0, time.UTC), format: "", expectedTime: "1917-03-09T16:22:56Z"}, + {input: time.Date(2021, 1, 5, 22, 15, 0, 0, time.UTC), format: "2006-01-02T15:04Z", expectedTime: "2021-01-05T22:15Z"}, + {input: time.Date(2021, 1, 5, 22, 15, 0, 0, time.UTC), format: "", expectedTime: "2021-01-05T22:15:00Z"}, + {input: time.Date(2023, 6, 25, 0, 0, 0, 0, time.UTC), format: "2006-01-02", expectedTime: "2023-06-25"}, + {input: time.Date(2023, 6, 25, 0, 0, 0, 0, time.UTC), format: "", expectedTime: "2023-06-25T00:00:00Z"}, + } + for _, c := range testdata { + formattedTime := formatTime(&c.input, c.format) + require.Equal(t, c.expectedTime, formattedTime) + } +} + +func TestParseTime(t *testing.T) { + testdata := []struct { + input string + expectedTime time.Time + expectedFormat string + }{ + {input: "1955-04-05T08:50:27.0000045Z", expectedTime: time.Date(1955, 4, 5, 8, 50, 27, 4500, time.UTC), expectedFormat: "2006-01-02T15:04:05.0000000Z"}, + {input: "1917-03-09T16:22:56Z", expectedTime: time.Date(1917, 3, 9, 16, 22, 56, 0, time.UTC), expectedFormat: "2006-01-02T15:04:05Z"}, + {input: "2021-01-05T22:15Z", expectedTime: time.Date(2021, 1, 5, 22, 15, 0, 0, time.UTC), expectedFormat: "2006-01-02T15:04Z"}, + {input: "2023-06-25", expectedTime: time.Date(2023, 6, 25, 0, 0, 0, 0, time.UTC), expectedFormat: "2006-01-02"}, + } + for _, c := range testdata { + parsedTime, format, err := parseTime(c.input) + require.Nil(t, err) + require.Equal(t, c.expectedTime, parsedTime) + require.Equal(t, c.expectedFormat, format) + } +} + +func TestParseTimeNegative(t *testing.T) { + _, _, err := parseTime("notatime") + require.Error(t, err, "fail to parse time with IOS 8601 formats, please refer to https://docs.microsoft.com/en-us/rest/api/storageservices/constructing-a-service-sas for more details") +} + +func TestIPRange_String(t *testing.T) { + testdata := []struct { + inputStart net.IP + inputEnd net.IP + expected string + }{ + {expected: ""}, + {inputStart: net.IPv4(10, 255, 0, 0), expected: "10.255.0.0"}, + {inputStart: net.IPv4(10, 255, 0, 0), inputEnd: net.IPv4(10, 255, 0, 50), expected: "10.255.0.0-10.255.0.50"}, + } + for _, c := range testdata { + var ipRange IPRange + if c.inputStart != nil { + ipRange.Start = c.inputStart + } + if c.inputEnd != nil { + ipRange.End = c.inputEnd + } + require.Equal(t, c.expected, ipRange.String()) + } +} + +func TestSAS(t *testing.T) { + // Note: This is a totally invalid fake SAS, this is just testing our ability to parse different query parameters on a SAS + const sas = "sv=2019-12-12&sr=b&st=2111-01-09T01:42:34.936Z&se=2222-03-09T01:42:34.936Z&sp=rw&sip=168.1.5.60-168.1.5.70&spr=https,http&si=myIdentifier&ss=bf&srt=s&rscc=cc&rscd=cd&rsce=ce&rscl=cl&rsct=ct&sig=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D" + _url := fmt.Sprintf("https://teststorageaccount.file.core.windows.net/testshare/testpath?%s", sas) + _uri, err := url.Parse(_url) + require.NoError(t, err) + sasQueryParams := NewQueryParameters(_uri.Query(), true) + validateSAS(t, sas, sasQueryParams) +} + +func validateSAS(t *testing.T, sas string, parameters QueryParameters) { + sasCompMap := make(map[string]string) + for _, sasComp := range strings.Split(sas, "&") { + comp := strings.Split(sasComp, "=") + sasCompMap[comp[0]] = comp[1] + } + + require.Equal(t, parameters.Version(), sasCompMap["sv"]) + require.Equal(t, parameters.Services(), sasCompMap["ss"]) + require.Equal(t, parameters.ResourceTypes(), sasCompMap["srt"]) + require.Equal(t, string(parameters.Protocol()), sasCompMap["spr"]) + if _, ok := sasCompMap["st"]; ok { + startTime, _, err := parseTime(sasCompMap["st"]) + require.NoError(t, err) + require.Equal(t, parameters.StartTime(), startTime) + } + if _, ok := sasCompMap["se"]; ok { + endTime, _, err := parseTime(sasCompMap["se"]) + require.NoError(t, err) + require.Equal(t, parameters.ExpiryTime(), endTime) + } + + if _, ok := sasCompMap["snapshot"]; ok { + snapshotTime, _, err := parseTime(sasCompMap["snapshot"]) + require.NoError(t, err) + require.Equal(t, parameters.SnapshotTime(), snapshotTime) + } + ipRange := parameters.IPRange() + require.Equal(t, ipRange.String(), sasCompMap["sip"]) + require.Equal(t, parameters.Identifier(), sasCompMap["si"]) + require.Equal(t, parameters.Resource(), sasCompMap["sr"]) + require.Equal(t, parameters.Permissions(), sasCompMap["sp"]) + + sign, err := url.QueryUnescape(sasCompMap["sig"]) + require.NoError(t, err) + + require.Equal(t, parameters.Signature(), sign) + require.Equal(t, parameters.CacheControl(), sasCompMap["rscc"]) + require.Equal(t, parameters.ContentDisposition(), sasCompMap["rscd"]) + require.Equal(t, parameters.ContentEncoding(), sasCompMap["rsce"]) + require.Equal(t, parameters.ContentLanguage(), sasCompMap["rscl"]) + require.Equal(t, parameters.ContentType(), sasCompMap["rsct"]) +} + +func TestSASInvalidQueryParameter(t *testing.T) { + // Signature is invalid below + const sas = "sv=2019-12-12&signature=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D&sr=b" + _url := fmt.Sprintf("https://teststorageaccount.file.core.windows.net/testshare/testpath?%s", sas) + _uri, err := url.Parse(_url) + require.NoError(t, err) + NewQueryParameters(_uri.Query(), true) + // NewQueryParameters should not delete signature + require.Contains(t, _uri.Query(), "signature") +} + +func TestEncode(t *testing.T) { + // Note: This is a totally invalid fake SAS, this is just testing our ability to parse different query parameters on a SAS + expected := "rscc=cc&rscd=cd&rsce=ce&rscl=cl&rsct=ct&se=2222-03-09T01%3A42%3A34Z&si=myIdentifier&sig=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D&sip=168.1.5.60-168.1.5.70&sp=rw&spr=https%2Chttp&sr=b&srt=sco&ss=bf&st=2111-01-09T01%3A42%3A34Z&sv=2019-12-12" + randomOrder := "se=2222-03-09T01:42:34.936Z&rsce=ce&ss=bf&si=myIdentifier&sip=168.1.5.60-168.1.5.70&rscc=cc&srt=sco&sig=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D&rsct=ct&rscl=cl&sv=2019-12-12&sr=b&st=2111-01-09T01:42:34.936Z&rscd=cd&sp=rw&spr=https,http" + testdata := []string{expected, randomOrder} + + for _, sas := range testdata { + _url := fmt.Sprintf("https://teststorageaccount.file.core.windows.net/testshare/testpath?%s", sas) + _uri, err := url.Parse(_url) + require.NoError(t, err) + queryParams := NewQueryParameters(_uri.Query(), true) + require.Equal(t, expected, queryParams.Encode()) + } +} diff --git a/sdk/storage/azfile/sas/service.go b/sdk/storage/azfile/sas/service.go new file mode 100644 index 000000000000..05f1851bce58 --- /dev/null +++ b/sdk/storage/azfile/sas/service.go @@ -0,0 +1,228 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package sas + +import ( + "bytes" + "fmt" + "strings" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/exported" +) + +// FileSignatureValues is used to generate a Shared Access Signature (SAS) for an Azure Storage file or share. +// For more information on creating service sas, see https://docs.microsoft.com/rest/api/storageservices/constructing-a-service-sas +// User Delegation SAS not supported for files service +type FileSignatureValues struct { + Version string `param:"sv"` // If not specified, this defaults to Version + Protocol Protocol `param:"spr"` // See the Protocol* constants + StartTime time.Time `param:"st"` // Not specified if IsZero + ExpiryTime time.Time `param:"se"` // Not specified if IsZero + SnapshotTime time.Time + Permissions string `param:"sp"` // Create by initializing a SharePermissions or FilePermissions and then call String() + IPRange IPRange `param:"sip"` + Identifier string `param:"si"` + ShareName string + DirectoryOrFilePath string // Ex: "directory/FileName". Use "" to create a Share SAS, directory path for Directory SAS and file path for File SAS. + CacheControl string // rscc + ContentDisposition string // rscd + ContentEncoding string // rsce + ContentLanguage string // rscl + ContentType string // rsct +} + +// SignWithSharedKey uses an account's SharedKeyCredential to sign this signature values to produce the proper SAS query parameters. +func (v FileSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredential) (QueryParameters, error) { + if sharedKeyCredential == nil { + return QueryParameters{}, fmt.Errorf("cannot sign SAS query without Shared Key Credential") + } + + resource := "s" + if v.DirectoryOrFilePath == "" { + //Make sure the permission characters are in the correct order + perms, err := parseSharePermissions(v.Permissions) + if err != nil { + return QueryParameters{}, err + } + v.Permissions = perms.String() + } else { + resource = "f" + // Make sure the permission characters are in the correct order + perms, err := parseFilePermissions(v.Permissions) + if err != nil { + return QueryParameters{}, err + } + v.Permissions = perms.String() + } + + if v.Version == "" { + v.Version = Version + } + startTime, expiryTime, snapshotTime := formatTimesForSigning(v.StartTime, v.ExpiryTime, v.SnapshotTime) + + // String to sign: http://msdn.microsoft.com/en-us/library/azure/dn140255.aspx + stringToSign := strings.Join([]string{ + v.Permissions, + startTime, + expiryTime, + getCanonicalName(sharedKeyCredential.AccountName(), v.ShareName, v.DirectoryOrFilePath), + v.Identifier, + v.IPRange.String(), + string(v.Protocol), + v.Version, + resource, + snapshotTime, + v.CacheControl, // rscc + v.ContentDisposition, // rscd + v.ContentEncoding, // rsce + v.ContentLanguage, // rscl + v.ContentType}, // rsct + "\n") + + signature, err := exported.ComputeHMACSHA256(sharedKeyCredential, stringToSign) + if err != nil { + return QueryParameters{}, err + } + + p := QueryParameters{ + // Common SAS parameters + version: v.Version, + protocol: v.Protocol, + startTime: v.StartTime, + expiryTime: v.ExpiryTime, + permissions: v.Permissions, + ipRange: v.IPRange, + + // Share/File-specific SAS parameters + resource: resource, + identifier: v.Identifier, + cacheControl: v.CacheControl, + contentDisposition: v.ContentDisposition, + contentEncoding: v.ContentEncoding, + contentLanguage: v.ContentLanguage, + contentType: v.ContentType, + snapshotTime: v.SnapshotTime, + // Calculated SAS signature + signature: signature, + } + + return p, nil +} + +// getCanonicalName computes the canonical name for a share or file resource for SAS signing. +func getCanonicalName(account string, shareName string, filePath string) string { + // Share: "/file/account/sharename" + // File: "/file/account/sharename/filename" + // File: "/file/account/sharename/directoryname/filename" + elements := []string{"/file/", account, "/", shareName} + if filePath != "" { + dfp := strings.Replace(filePath, "\\", "/", -1) + if dfp[0] == '/' { + dfp = dfp[1:] + } + elements = append(elements, "/", dfp) + } + return strings.Join(elements, "") +} + +// SharePermissions type simplifies creating the permissions string for an Azure Storage share SAS. +// Initialize an instance of this type and then call its String method to set FileSignatureValues' Permissions field. +// All permissions descriptions can be found here: https://docs.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-share +type SharePermissions struct { + Read, Create, Write, Delete, List bool +} + +// String produces the SAS permissions string for an Azure Storage share. +// Call this method to set FileSignatureValues' FilePermissions field. +func (p *SharePermissions) String() string { + var b bytes.Buffer + if p.Read { + b.WriteRune('r') + } + if p.Create { + b.WriteRune('c') + } + if p.Write { + b.WriteRune('w') + } + if p.Delete { + b.WriteRune('d') + } + if p.List { + b.WriteRune('l') + } + return b.String() +} + +// parseSharePermissions initializes SharePermissions' fields from a string. +func parseSharePermissions(s string) (SharePermissions, error) { + p := SharePermissions{} // Clear the flags + for _, r := range s { + switch r { + case 'r': + p.Read = true + case 'c': + p.Create = true + case 'w': + p.Write = true + case 'd': + p.Delete = true + case 'l': + p.List = true + default: + return SharePermissions{}, fmt.Errorf("invalid permission: '%v'", r) + } + } + return p, nil +} + +// FilePermissions type simplifies creating the permissions string for an Azure Storage file SAS. +// Initialize an instance of this type and then call its String method to set FileSignatureValues' Permissions field. +// All permissions descriptions can be found here: https://docs.microsoft.com/en-us/rest/api/storageservices/create-service-sas#permissions-for-a-file +type FilePermissions struct { + Read, Create, Write, Delete bool +} + +// String produces the SAS permissions string for an Azure Storage file. +// Call this method to set FileSASSignatureValues' FilePermissions field. +func (p *FilePermissions) String() string { + var b bytes.Buffer + if p.Read { + b.WriteRune('r') + } + if p.Create { + b.WriteRune('c') + } + if p.Write { + b.WriteRune('w') + } + if p.Delete { + b.WriteRune('d') + } + return b.String() +} + +// parseFilePermissions initializes the FilePermissions' fields from a string. +func parseFilePermissions(s string) (FilePermissions, error) { + p := FilePermissions{} // Clear the flags + for _, r := range s { + switch r { + case 'r': + p.Read = true + case 'c': + p.Create = true + case 'w': + p.Write = true + case 'd': + p.Delete = true + default: + return FilePermissions{}, fmt.Errorf("invalid permission: '%v'", r) + } + } + return p, nil +} diff --git a/sdk/storage/azfile/sas/service_test.go b/sdk/storage/azfile/sas/service_test.go new file mode 100644 index 000000000000..dd640be0e4fc --- /dev/null +++ b/sdk/storage/azfile/sas/service_test.go @@ -0,0 +1,147 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestSharePermissions_String(t *testing.T) { + testdata := []struct { + input SharePermissions + expected string + }{ + {input: SharePermissions{Read: true}, expected: "r"}, + {input: SharePermissions{Create: true}, expected: "c"}, + {input: SharePermissions{Write: true}, expected: "w"}, + {input: SharePermissions{Delete: true}, expected: "d"}, + {input: SharePermissions{List: true}, expected: "l"}, + {input: SharePermissions{ + Read: true, + Create: true, + Write: true, + Delete: true, + List: true, + }, expected: "rcwdl"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, c.input.String()) + } +} + +func TestSharePermissions_Parse(t *testing.T) { + testdata := []struct { + input string + expected SharePermissions + }{ + {expected: SharePermissions{Read: true}, input: "r"}, + {expected: SharePermissions{Create: true}, input: "c"}, + {expected: SharePermissions{Write: true}, input: "w"}, + {expected: SharePermissions{Delete: true}, input: "d"}, + {expected: SharePermissions{List: true}, input: "l"}, + {expected: SharePermissions{ + Read: true, + Create: true, + Write: true, + Delete: true, + List: true, + }, input: "rcwdl"}, + {expected: SharePermissions{ + Read: true, + Create: true, + Write: true, + Delete: true, + List: true, + }, input: "cwrdl"}, // Wrong order parses correctly + } + for _, c := range testdata { + permissions, err := parseSharePermissions(c.input) + require.Nil(t, err) + require.Equal(t, c.expected, permissions) + } +} + +func TestSharePermissions_ParseNegative(t *testing.T) { + _, err := parseSharePermissions("cwtrdl") // Here 't' is invalid + require.NotNil(t, err) + require.Contains(t, err.Error(), "116") +} + +func TestFilePermissions_String(t *testing.T) { + testdata := []struct { + input FilePermissions + expected string + }{ + {input: FilePermissions{Read: true}, expected: "r"}, + {input: FilePermissions{Create: true}, expected: "c"}, + {input: FilePermissions{Write: true}, expected: "w"}, + {input: FilePermissions{Delete: true}, expected: "d"}, + {input: FilePermissions{ + Read: true, + Create: true, + Write: true, + Delete: true, + }, expected: "rcwd"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, c.input.String()) + } +} + +func TestFilePermissions_Parse(t *testing.T) { + testdata := []struct { + expected FilePermissions + input string + }{ + {expected: FilePermissions{Read: true}, input: "r"}, + {expected: FilePermissions{Create: true}, input: "c"}, + {expected: FilePermissions{Write: true}, input: "w"}, + {expected: FilePermissions{Delete: true}, input: "d"}, + {expected: FilePermissions{ + Read: true, + Create: true, + Write: true, + Delete: true, + }, input: "rcwd"}, + {expected: FilePermissions{ + Read: true, + Create: true, + Write: true, + Delete: true, + }, input: "wcrd"}, // Wrong order parses correctly + } + for _, c := range testdata { + permissions, err := parseFilePermissions(c.input) + require.Nil(t, err) + require.Equal(t, c.expected, permissions) + } +} + +func TestFilePermissions_ParseNegative(t *testing.T) { + _, err := parseFilePermissions("wcrdf") // Here 'f' is invalid + require.NotNil(t, err) + require.Contains(t, err.Error(), "102") +} + +func TestGetCanonicalName(t *testing.T) { + testdata := []struct { + inputAccount string + inputShare string + inputFilePath string + expected string + }{ + {inputAccount: "fakestorageaccount", inputShare: "fakestorageshare", expected: "/file/fakestorageaccount/fakestorageshare"}, + {inputAccount: "fakestorageaccount", inputShare: "fakestorageshare", inputFilePath: "fakestoragefile", expected: "/file/fakestorageaccount/fakestorageshare/fakestoragefile"}, + {inputAccount: "fakestorageaccount", inputShare: "fakestorageshare", inputFilePath: "fakestoragedirectory/fakestoragefile", expected: "/file/fakestorageaccount/fakestorageshare/fakestoragedirectory/fakestoragefile"}, + {inputAccount: "fakestorageaccount", inputShare: "fakestorageshare", inputFilePath: "fakestoragedirectory\\fakestoragefile", expected: "/file/fakestorageaccount/fakestorageshare/fakestoragedirectory/fakestoragefile"}, + {inputAccount: "fakestorageaccount", inputShare: "fakestorageshare", inputFilePath: "fakestoragedirectory", expected: "/file/fakestorageaccount/fakestorageshare/fakestoragedirectory"}, + } + for _, c := range testdata { + require.Equal(t, c.expected, getCanonicalName(c.inputAccount, c.inputShare, c.inputFilePath)) + } +} diff --git a/sdk/storage/azfile/sas/url_parts.go b/sdk/storage/azfile/sas/url_parts.go new file mode 100644 index 000000000000..3aa950dae942 --- /dev/null +++ b/sdk/storage/azfile/sas/url_parts.go @@ -0,0 +1,147 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/shared" + "net/url" + "strings" +) + +const ( + shareSnapshot = "sharesnapshot" +) + +// IPEndpointStyleInfo is used for IP endpoint style URL when working with Azure storage emulator. +// Ex: "https://10.132.141.33/accountname/sharename" +type IPEndpointStyleInfo struct { + AccountName string // "" if not using IP endpoint style +} + +// URLParts object represents the components that make up an Azure Storage Share/Directory/File URL. You parse an +// existing URL into its parts by calling NewFileURLParts(). You construct a URL from parts by calling URL(). +// NOTE: Changing any SAS-related field requires computing a new SAS signature. +type URLParts struct { + Scheme string // Ex: "https://" + Host string // Ex: "account.share.core.windows.net", "10.132.141.33", "10.132.141.33:80" + IPEndpointStyleInfo IPEndpointStyleInfo // Useful Parts for IP endpoint style URL. + ShareName string // Share name, Ex: "myshare" + DirectoryOrFilePath string // Path of directory or file, Ex: "mydirectory/myfile" + ShareSnapshot string // IsZero is true if not a snapshot + SAS QueryParameters + UnparsedParams string +} + +// ParseURL parses a URL initializing URLParts' fields including any SAS-related & sharesnapshot query parameters. +// Any other query parameters remain in the UnparsedParams field. +func ParseURL(u string) (URLParts, error) { + uri, err := url.Parse(u) + if err != nil { + return URLParts{}, err + } + + up := URLParts{ + Scheme: uri.Scheme, + Host: uri.Host, + } + + if uri.Path != "" { + path := uri.Path + if path[0] == '/' { + path = path[1:] + } + if shared.IsIPEndpointStyle(up.Host) { + if accountEndIndex := strings.Index(path, "/"); accountEndIndex == -1 { // Slash not found; path has account name & no share, path of directory or file + up.IPEndpointStyleInfo.AccountName = path + path = "" // no ShareName present in the URL so path should be empty + } else { + up.IPEndpointStyleInfo.AccountName = path[:accountEndIndex] // The account name is the part between the slashes + path = path[accountEndIndex+1:] + } + } + + shareEndIndex := strings.Index(path, "/") // Find the next slash (if it exists) + if shareEndIndex == -1 { // Slash not found; path has share name & no path of directory or file + up.ShareName = path + } else { // Slash found; path has share name & path of directory or file + up.ShareName = path[:shareEndIndex] + up.DirectoryOrFilePath = path[shareEndIndex+1:] + } + } + + // Convert the query parameters to a case-sensitive map & trim whitespace + paramsMap := uri.Query() + + up.ShareSnapshot = "" // Assume no snapshot + if snapshotStr, ok := caseInsensitiveValues(paramsMap).Get(shareSnapshot); ok { + up.ShareSnapshot = snapshotStr[0] + // If we recognized the query parameter, remove it from the map + delete(paramsMap, shareSnapshot) + } + + up.SAS = NewQueryParameters(paramsMap, true) + up.UnparsedParams = paramsMap.Encode() + return up, nil +} + +// String returns a URL object whose fields are initialized from the URLParts fields. The URL's RawQuery +// field contains the SAS, snapshot, and unparsed query parameters. +func (up URLParts) String() string { + path := "" + // Concatenate account name for IP endpoint style URL + if shared.IsIPEndpointStyle(up.Host) && up.IPEndpointStyleInfo.AccountName != "" { + path += "/" + up.IPEndpointStyleInfo.AccountName + } + // Concatenate share & path of directory or file (if they exist) + if up.ShareName != "" { + path += "/" + up.ShareName + if up.DirectoryOrFilePath != "" { + path += "/" + up.DirectoryOrFilePath + } + } + + rawQuery := up.UnparsedParams + + //If no snapshot is initially provided, fill it in from the SAS query properties to help the user + if up.ShareSnapshot == "" && !up.SAS.SnapshotTime().IsZero() { + up.ShareSnapshot = up.SAS.SnapshotTime().Format(snapshotTimeFormat) + } + + // Concatenate share snapshot query parameter (if it exists) + if up.ShareSnapshot != "" { + if len(rawQuery) > 0 { + rawQuery += "&" + } + rawQuery += shareSnapshot + "=" + up.ShareSnapshot + } + sas := up.SAS.Encode() + if sas != "" { + if len(rawQuery) > 0 { + rawQuery += "&" + } + rawQuery += sas + } + u := url.URL{ + Scheme: up.Scheme, + Host: up.Host, + Path: path, + RawQuery: rawQuery, + } + return u.String() +} + +type caseInsensitiveValues url.Values // map[string][]string + +func (values caseInsensitiveValues) Get(key string) ([]string, bool) { + key = strings.ToLower(key) + for k, v := range values { + if strings.ToLower(k) == key { + return v, true + } + } + return []string{}, false +} diff --git a/sdk/storage/azfile/sas/url_parts_test.go b/sdk/storage/azfile/sas/url_parts_test.go new file mode 100644 index 000000000000..21691e0a7ae7 --- /dev/null +++ b/sdk/storage/azfile/sas/url_parts_test.go @@ -0,0 +1,75 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +package sas + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseURLIPStyle(t *testing.T) { + urlWithIP := "https://127.0.0.1:5000/fakestorageaccount" + fileURLParts, err := ParseURL(urlWithIP) + require.NoError(t, err) + require.Equal(t, fileURLParts.Scheme, "https") + require.Equal(t, fileURLParts.Host, "127.0.0.1:5000") + require.Equal(t, fileURLParts.IPEndpointStyleInfo.AccountName, "fakestorageaccount") + + urlWithIP = "https://127.0.0.1:5000/fakestorageaccount/fakeshare" + fileURLParts, err = ParseURL(urlWithIP) + require.NoError(t, err) + require.Equal(t, fileURLParts.Scheme, "https") + require.Equal(t, fileURLParts.Host, "127.0.0.1:5000") + require.Equal(t, fileURLParts.IPEndpointStyleInfo.AccountName, "fakestorageaccount") + require.Equal(t, fileURLParts.ShareName, "fakeshare") + + urlWithIP = "https://127.0.0.1:5000/fakestorageaccount/fakeshare/fakefile" + fileURLParts, err = ParseURL(urlWithIP) + require.NoError(t, err) + require.Equal(t, fileURLParts.Scheme, "https") + require.Equal(t, fileURLParts.Host, "127.0.0.1:5000") + require.Equal(t, fileURLParts.IPEndpointStyleInfo.AccountName, "fakestorageaccount") + require.Equal(t, fileURLParts.ShareName, "fakeshare") + require.Equal(t, fileURLParts.DirectoryOrFilePath, "fakefile") +} + +func TestParseURL(t *testing.T) { + testStorageAccount := "fakestorageaccount" + host := fmt.Sprintf("%s.file.core.windows.net", testStorageAccount) + testShare := "fakeshare" + fileNames := []string{"/._.TESTT.txt", "/.gitignore/dummyfile1"} + + const sasStr = "sv=2019-12-12&sr=b&st=2111-01-09T01:42:34.936Z&se=2222-03-09T01:42:34.936Z&sp=rw&sip=168.1.5.60-168.1.5.70&spr=https,http&si=myIdentifier&ss=bf&srt=s&sig=clNxbtnkKSHw7f3KMEVVc4agaszoRFdbZr%2FWBmPNsrw%3D" + + for _, fileName := range fileNames { + sasURL := fmt.Sprintf("https://%s.file.core.windows.net/%s%s?%s", testStorageAccount, testShare, fileName, sasStr) + fileURLParts, err := ParseURL(sasURL) + require.NoError(t, err) + + require.Equal(t, fileURLParts.Scheme, "https") + require.Equal(t, fileURLParts.Host, host) + require.Equal(t, fileURLParts.ShareName, testShare) + + validateSAS(t, sasStr, fileURLParts.SAS) + } + + for _, fileName := range fileNames { + shareSnapshotID := "2011-03-09T01:42:34Z" + sasWithShareSnapshotID := "?sharesnapshot=" + shareSnapshotID + "&" + sasStr + urlWithShareSnapshot := fmt.Sprintf("https://%s.file.core.windows.net/%s%s%s", testStorageAccount, testShare, fileName, sasWithShareSnapshotID) + fileURLParts, err := ParseURL(urlWithShareSnapshot) + require.NoError(t, err) + + require.Equal(t, fileURLParts.Scheme, "https") + require.Equal(t, fileURLParts.Host, host) + require.Equal(t, fileURLParts.ShareName, testShare) + + validateSAS(t, sasStr, fileURLParts.SAS) + } +} diff --git a/sdk/storage/azfile/service/client.go b/sdk/storage/azfile/service/client.go index 35455293cd67..7b758ab0ab7a 100644 --- a/sdk/storage/azfile/service/client.go +++ b/sdk/storage/azfile/service/client.go @@ -11,12 +11,16 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/fileerror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/base" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/exported" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/shared" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/sas" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/share" "net/http" + "strings" + "time" ) // ClientOptions contains the optional parameters when creating a Client. @@ -182,3 +186,32 @@ func (s *Client) NewListSharesPager(options *ListSharesOptions) *runtime.Pager[L }, }) } + +// GetSASURL is a convenience method for generating a SAS token for the currently pointed at account. +// It can only be used if the credential supplied during creation was a SharedKeyCredential. +func (s *Client) GetSASURL(resources sas.AccountResourceTypes, permissions sas.AccountPermissions, expiry time.Time, o *GetSASURLOptions) (string, error) { + if s.sharedKey() == nil { + return "", fileerror.MissingSharedKeyCredential + } + st := o.format() + qps, err := sas.AccountSignatureValues{ + Version: sas.Version, + Protocol: sas.ProtocolHTTPS, + Permissions: permissions.String(), + ResourceTypes: resources.String(), + StartTime: st, + ExpiryTime: expiry.UTC(), + }.SignWithSharedKey(s.sharedKey()) + if err != nil { + return "", err + } + + endpoint := s.URL() + if !strings.HasSuffix(endpoint, "/") { + // add a trailing slash to be consistent with the portal + endpoint += "/" + } + endpoint += "?" + qps.Encode() + + return endpoint, nil +} diff --git a/sdk/storage/azfile/service/client_test.go b/sdk/storage/azfile/service/client_test.go index 9e69f16e30f2..e97d82b3f1ae 100644 --- a/sdk/storage/azfile/service/client_test.go +++ b/sdk/storage/azfile/service/client_test.go @@ -8,9 +8,12 @@ package service_test import ( "context" + "fmt" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Azure/azure-sdk-for-go/sdk/internal/recording" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/fileerror" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/testcommon" + "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/sas" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/service" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -224,3 +227,95 @@ func (s *ServiceUnrecordedTestsSuite) TestAccountListSharesNonDefault() { } }*/ } + +func (s *ServiceUnrecordedTestsSuite) TestSASServiceClient() { + _require := require.New(s.T()) + // testName := s.T().Name() + cred, _ := testcommon.GetGenericSharedKeyCredential(testcommon.TestAccountDefault) + + serviceClient, err := service.NewClientWithSharedKeyCredential(fmt.Sprintf("https://%s.file.core.windows.net/", cred.AccountName()), cred, nil) + _require.Nil(err) + + // shareName := testcommon.GenerateShareName(testName) + + // Note: Always set all permissions, services, types to true to ensure order of string formed is correct. + resources := sas.AccountResourceTypes{ + Object: true, + Service: true, + Container: true, + } + permissions := sas.AccountPermissions{ + Read: true, + Write: true, + Delete: true, + List: true, + Create: true, + } + expiry := time.Now().Add(time.Hour) + sasUrl, err := serviceClient.GetSASURL(resources, permissions, expiry, nil) + _require.Nil(err) + + svcClient, err := testcommon.GetServiceClientNoCredential(s.T(), sasUrl, nil) + _require.Nil(err) + + // create share using SAS + //_, err = svcClient.CreateShare(context.Background(), shareName, nil) + //_require.Nil(err) + // + //_, err = svcClient.DeleteShare(context.Background(), shareName, nil) + //_require.Nil(err) + + resp, err := svcClient.GetProperties(context.Background(), nil) + _require.NoError(err) + _require.NotNil(resp.RequestID) +} + +func (s *ServiceUnrecordedTestsSuite) TestSASServiceClientNoKey() { + _require := require.New(s.T()) + accountName := os.Getenv("AZURE_STORAGE_ACCOUNT_NAME") + + serviceClient, err := service.NewClientWithNoCredential(fmt.Sprintf("https://%s.file.core.windows.net/", accountName), nil) + _require.Nil(err) + resources := sas.AccountResourceTypes{ + Object: true, + Service: true, + Container: true, + } + permissions := sas.AccountPermissions{ + Read: true, + Write: true, + Delete: true, + List: true, + Create: true, + } + + expiry := time.Now().Add(time.Hour) + _, err = serviceClient.GetSASURL(resources, permissions, expiry, nil) + _require.Equal(err, fileerror.MissingSharedKeyCredential) +} + +func (s *ServiceUnrecordedTestsSuite) TestSASServiceClientSignNegative() { + _require := require.New(s.T()) + accountName := os.Getenv("AZURE_STORAGE_ACCOUNT_NAME") + accountKey := os.Getenv("AZURE_STORAGE_ACCOUNT_KEY") + cred, err := service.NewSharedKeyCredential(accountName, accountKey) + _require.Nil(err) + + serviceClient, err := service.NewClientWithSharedKeyCredential(fmt.Sprintf("https://%s.file.core.windows.net/", accountName), cred, nil) + _require.Nil(err) + resources := sas.AccountResourceTypes{ + Object: true, + Service: true, + Container: true, + } + permissions := sas.AccountPermissions{ + Read: true, + Write: true, + Delete: true, + List: true, + Create: true, + } + expiry := time.Time{} + _, err = serviceClient.GetSASURL(resources, permissions, expiry, nil) + _require.Equal(err.Error(), "account SAS is missing at least one of these: ExpiryTime, Permissions, Service, or ResourceType") +} diff --git a/sdk/storage/azfile/service/models.go b/sdk/storage/azfile/service/models.go index 4ea97710f359..0a529af87248 100644 --- a/sdk/storage/azfile/service/models.go +++ b/sdk/storage/azfile/service/models.go @@ -12,6 +12,7 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/generated" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/internal/shared" "github.com/Azure/azure-sdk-for-go/sdk/storage/azfile/share" + "time" ) // SharedKeyCredential contains an account's name and its primary or secondary key. @@ -147,3 +148,24 @@ type Share = generated.Share // ShareProperties - Properties of a share. type ShareProperties = generated.ShareProperties + +// --------------------------------------------------------------------------------------------------------------------- + +// GetSASURLOptions contains the optional parameters for the Client.GetSASURL method. +type GetSASURLOptions struct { + StartTime *time.Time +} + +func (o *GetSASURLOptions) format() time.Time { + if o == nil { + return time.Time{} + } + + var st time.Time + if o.StartTime != nil { + st = o.StartTime.UTC() + } else { + st = time.Time{} + } + return st +}