Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/storage/azqueue/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#### Bugs Fixed

* Fixed service SAS creation where expiry time or permissions can be omitted when stored access policy is used.

#### Other Changes

### 1.0.0 (2023-05-09)
Expand Down
71 changes: 71 additions & 0 deletions sdk/storage/azqueue/queue_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1456,3 +1456,74 @@ func (s *UnrecordedTestSuite) TestServiceSASDequeueMessage() {
_require.Equal(0, len(resp.Messages))
_require.Nil(err)
}

func (s *UnrecordedTestSuite) TestQueueSASUsingAccessPolicy() {
_require := require.New(s.T())

cred, err := testcommon.GetGenericCredential(testcommon.TestAccountDefault)
_require.NoError(err)

svcClient, err := testcommon.GetServiceClient(s.T(), testcommon.TestAccountDefault, nil)
_require.NoError(err)

testName := s.T().Name()
queueName := testcommon.GenerateQueueName(testName)
queueClient := testcommon.GetQueueClient(queueName, svcClient)
defer testcommon.DeleteQueue(context.Background(), _require, queueClient)

_, err = queueClient.Create(context.Background(), nil)
_require.NoError(err)

id := "testAccessPolicy"
ps := azqueue.AccessPolicyPermission{Read: true, Add: true, Update: true, Process: true}
signedIdentifiers := make([]*azqueue.SignedIdentifier, 0)
signedIdentifiers = append(signedIdentifiers, &azqueue.SignedIdentifier{
AccessPolicy: &azqueue.AccessPolicy{
Expiry: to.Ptr(time.Now().Add(1 * time.Hour)),
Start: to.Ptr(time.Now()),
Permission: to.Ptr(ps.String()),
},
ID: &id,
})

_, err = queueClient.SetAccessPolicy(context.Background(), &azqueue.SetAccessPolicyOptions{
QueueACL: signedIdentifiers,
})
_require.NoError(err)

gResp, err := queueClient.GetAccessPolicy(context.Background(), nil)
_require.NoError(err)
_require.Len(gResp.SignedIdentifiers, 1)

time.Sleep(30 * time.Second)

sasQueryParams, err := sas.QueueSignatureValues{
Protocol: sas.ProtocolHTTPS,
Identifier: id,
QueueName: queueName,
}.SignWithSharedKey(cred)
_require.NoError(err)

queueSAS := queueClient.URL() + "?" + sasQueryParams.Encode()
queueClientSAS, err := azqueue.NewQueueClientWithNoCredential(queueSAS, nil)
_require.NoError(err)

_, err = queueClientSAS.GetProperties(context.Background(), nil)
_require.NoError(err)

// enqueue 4 messages
for i := 0; i < 4; i++ {
_, err = queueClientSAS.EnqueueMessage(context.Background(), fmt.Sprintf("%v : %v", testcommon.QueueDefaultData, i), nil)
_require.NoError(err)
}

// dequeue 4 messages
for i := 0; i < 4; i++ {
resp, err := queueClientSAS.DequeueMessage(context.Background(), nil)
_require.NoError(err)
_require.Equal(1, len(resp.Messages))
_require.NotNil(resp.Messages[0].MessageText)
_require.Equal(fmt.Sprintf("%v : %v", testcommon.QueueDefaultData, i), *resp.Messages[0].MessageText)
_require.NotNil(resp.Messages[0].MessageID)
}
}
5 changes: 3 additions & 2 deletions sdk/storage/azqueue/sas/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ type QueueSignatureValues struct {

// SignWithSharedKey uses an account's SharedKeyCredential to sign this signature values to produce the proper SAS query parameters.
func (v QueueSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCredential) (QueryParameters, error) {
if v.ExpiryTime.IsZero() || v.Permissions == "" {
if v.Identifier == "" && (v.ExpiryTime.IsZero() || v.Permissions == "") {
return QueryParameters{}, errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions")
}

Expand Down Expand Up @@ -75,7 +75,8 @@ func (v QueueSignatureValues) SignWithSharedKey(sharedKeyCredential *SharedKeyCr
permissions: v.Permissions,
ipRange: v.IPRange,
// Calculated SAS signature
signature: signature,
signature: signature,
identifier: signedIdentifier,
}

return p, nil
Expand Down
45 changes: 45 additions & 0 deletions sdk/storage/azqueue/sas/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
package sas

import (
"errors"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azqueue/internal/exported"
"github.com/stretchr/testify/require"
"testing"
"time"
)

func TestQueuePermissions_String(t *testing.T) {
Expand Down Expand Up @@ -79,3 +82,45 @@ func TestGetCanonicalName(t *testing.T) {
require.Equal(t, c.expected, getCanonicalName(c.inputAccount, c.inputQueue))
}
}

func TestQueueSignatureValues_SignWithSharedKey(t *testing.T) {
cred, err := exported.NewSharedKeyCredential("fakeaccountname", "AKIAIOSFODNN7EXAMPLE")
require.Nil(t, err, "error creating valid shared key credentials.")

expiryDate, err := time.Parse("2006-01-02", "2023-07-20")
require.Nil(t, err, "error creating valid expiry date.")

testdata := []struct {
object QueueSignatureValues
expected QueryParameters
expectedError error
}{
{
object: QueueSignatureValues{QueueName: "fakestoragequeue", Permissions: "r", ExpiryTime: expiryDate},
expected: QueryParameters{version: Version, permissions: "r", expiryTime: expiryDate},
expectedError: nil,
},
{
object: QueueSignatureValues{QueueName: "fakestoragequeue", Permissions: "", ExpiryTime: expiryDate},
expected: QueryParameters{},
expectedError: errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions"),
},
{
object: QueueSignatureValues{QueueName: "fakestoragequeue", Permissions: "r", ExpiryTime: *new(time.Time)},
expected: QueryParameters{},
expectedError: errors.New("service SAS is missing at least one of these: ExpiryTime or Permissions"),
},
{
object: QueueSignatureValues{QueueName: "fakestoragequeue", Permissions: "", ExpiryTime: *new(time.Time), Identifier: "fakepolicyname"},
expected: QueryParameters{version: Version, identifier: "fakepolicyname"},
expectedError: nil,
},
}
for _, c := range testdata {
act, err := c.object.SignWithSharedKey(cred)
// ignore signature value
act.signature = ""
require.Equal(t, c.expected, act)
require.Equal(t, c.expectedError, err)
}
}