Skip to content

Commit

Permalink
Adding support for Managed Identities for Azure storage account (#85)
Browse files Browse the repository at this point in the history
Co-authored-by: zongsi.zhang <[email protected]>
Co-authored-by: Oscar Cassetti <[email protected]>
  • Loading branch information
3 people authored Dec 13, 2021
1 parent b5523c5 commit 829bfed
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 18 deletions.
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ require (
gopkg.in/yaml.v2 v2.2.8
)

require github.com/Azure/go-autorest/autorest/adal v0.9.11

require (
github.com/Azure/azure-pipeline-go v0.2.3 // indirect
github.com/Azure/go-autorest v14.2.0+incompatible // indirect
github.com/Azure/go-autorest/autorest v0.11.17 // indirect
github.com/Azure/go-autorest/autorest/adal v0.9.11 // indirect
github.com/Azure/go-autorest/autorest/azure/cli v0.4.2 // indirect
github.com/Azure/go-autorest/autorest/date v0.3.0 // indirect
github.com/Azure/go-autorest/logger v0.2.0 // indirect
Expand Down
5 changes: 3 additions & 2 deletions internal/ingress/s3sqs/s3sqs.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (
"time"

awssqs "github.com/aws/aws-sdk-go/service/sqs"
"github.com/kelindar/loader"
"github.com/kelindar/talaria/internal/config"
"github.com/kelindar/talaria/internal/ingress/s3sqs/sqs"
"github.com/kelindar/talaria/internal/monitor"
"github.com/kelindar/talaria/internal/monitor/errors"
"github.com/kelindar/loader"
"golang.org/x/sync/semaphore"
)

Expand Down Expand Up @@ -51,7 +51,7 @@ type Reader interface {
// New creates a new ingestion with SQS/S3 files.
func New(conf *config.S3SQS, region string, monitor monitor.Monitor) (*Ingress, error) {
loader := loader.New()
reader, err := sqs.NewReader(conf, region)
reader, err := sqs.NewReader(conf, region, monitor)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -146,6 +146,7 @@ func (s *Ingress) ingest(bucket, key string, handler func(v []byte) bool) {
data, err := s.loader.Load(context.Background(), fmt.Sprintf("s3://%s/%s", bucket, key))
defer s.limit.Release(1)
if err != nil {
s.monitor.Count1(ctxTag, "s3readerror")
s.monitor.Error(err)
return
}
Expand Down
11 changes: 9 additions & 2 deletions internal/ingress/s3sqs/sqs/reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sqs"
"github.com/kelindar/talaria/internal/config"
"github.com/kelindar/talaria/internal/monitor"
"github.com/kelindar/talaria/internal/monitor/errors"
)

const (
Expand All @@ -21,7 +23,7 @@ const (
)

// NewReader returns a reader
func NewReader(c *config.S3SQS, region string) (*Reader, error) {
func NewReader(c *config.S3SQS, region string, monitor monitor.Monitor) (*Reader, error) {
const defaultVisibilityTimeout = time.Second * 30

conf := aws.NewConfig().
Expand All @@ -46,6 +48,7 @@ func NewReader(c *config.S3SQS, region string) (*Reader, error) {
visibilityTimeout: visibilityTimeout,
queueURL: c.Queue,
waitTimeSeconds: c.WaitTimeout,
monitor: monitor,
}, nil
}

Expand All @@ -56,6 +59,7 @@ type Reader struct {
visibilityTimeout time.Duration
queueURL string
waitTimeSeconds int64 // The duration (in seconds) for which the call will wait for a message to arrive
monitor monitor.Monitor
}

// StartPolling messages from SQS. User defines
Expand Down Expand Up @@ -100,7 +104,10 @@ func (r *Reader) receive(attributeNames, messageAttributeNames []*string, maxPer
WaitTimeSeconds: &r.waitTimeSeconds,
}

resp, _ := r.sqs.ReceiveMessage(input)
resp, err := r.sqs.ReceiveMessage(input)
if err != nil {
r.monitor.Error(errors.Internal("sqs: failed to receive messages", err))
}
return resp.Messages
}

Expand Down
44 changes: 31 additions & 13 deletions internal/storage/writer/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (

"github.com/Azure/azure-sdk-for-go/storage"
"github.com/Azure/azure-storage-blob-go/azblob"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
"github.com/Azure/go-autorest/autorest/azure/auth"
"github.com/kelindar/talaria/internal/encoding/key"
"github.com/kelindar/talaria/internal/monitor"
Expand Down Expand Up @@ -100,12 +102,10 @@ func NewMultiAccountWriter(monitor monitor.Monitor, blobServiceURL, container, p
if blobServiceURL == "" {
blobServiceURL = defaultBlobServiceURL
}

credential, err := GetAzureStorageCredentials(monitor)
if err != nil {
return nil, errors.Internal("azure: unable to get azure storage credential", err)
}

containerURLs := make([]azblob.ContainerURL, len(storageAccount))
for i, sa := range storageAccount {
azureStoragePipeline := azblob.NewPipeline(credential, azblob.PipelineOptions{
Expand All @@ -115,13 +115,15 @@ func NewMultiAccountWriter(monitor monitor.Monitor, blobServiceURL, container, p
})
u, _ := url.Parse(fmt.Sprintf(blobServiceURL, sa))
containerURLs[i] = azblob.NewServiceURL(*u, azureStoragePipeline).NewContainerURL(container)
monitor.Info(fmt.Sprintf("azure: new azure storage pipeline created for %s", u))

}

var chooser *weightedrand.Chooser
if weights != nil {

if len(storageAccount) != len(weights) {
return nil, fmt.Errorf("azure: Invalid configuration number of storage account %v != number of weights %v", len(storageAccount), len(weights))
return nil, fmt.Errorf("Invalid configuration number of storage account %v != number of weights %v", len(storageAccount), len(weights))
}

choices := make([]weightedrand.Choice, len(storageAccount))
Expand All @@ -130,6 +132,7 @@ func NewMultiAccountWriter(monitor monitor.Monitor, blobServiceURL, container, p
Item: &containerURLs[i],
Weight: w,
}
monitor.Info(fmt.Sprintf("azure: writer weights for %v set to %d", containerURLs[i], w))
}
chooser, err = weightedrand.NewChooser(choices...)
if err != nil {
Expand All @@ -150,17 +153,9 @@ func NewMultiAccountWriter(monitor monitor.Monitor, blobServiceURL, container, p
}

func GetAzureStorageCredentials(monitor monitor.Monitor) (azblob.Credential, error) {
settings, err := auth.GetSettingsFromEnvironment()
if err != nil {
return nil, err
}

cc, err := settings.GetClientCredentials()
if err != nil {
return nil, err
}
spt, err := getServicePrincipalToken(monitor)

spt, err := cc.ServicePrincipalToken()
if err != nil {
return nil, err
}
Expand All @@ -169,7 +164,6 @@ func GetAzureStorageCredentials(monitor monitor.Monitor) (azblob.Credential, err
if err := spt.Refresh(); err != nil {
return nil, err
}

// Token refresher function
var tokenRefresher azblob.TokenRefresher
tokenRefresher = func(credential azblob.TokenCredential) time.Duration {
Expand Down Expand Up @@ -230,3 +224,27 @@ func (m *MultiAccountWriter) getContainerURL() (*azblob.ContainerURL, error) {
i := rand.Intn(len(m.containerURLs))
return &m.containerURLs[i], nil
}

func getServicePrincipalToken(monitor monitor.Monitor) (*adal.ServicePrincipalToken, error) {

spt, err := adal.NewServicePrincipalTokenFromManagedIdentity(azure.PublicCloud.ResourceIdentifiers.Storage, &adal.ManagedIdentityOptions{})

if err == nil {
monitor.Info("azure: acquired Manange Identity Credentials")
return spt, err
}
monitor.Warning(errors.Internal("azure: unable to retrieve Manange Identity Credentials", err))

settings, err := auth.GetSettingsFromEnvironment()
if err != nil {
return nil, err
}
cc, err := settings.GetClientCredentials()
if err != nil {
return nil, err
}

spt, err = cc.ServicePrincipalToken()

return spt, err
}

0 comments on commit 829bfed

Please sign in to comment.