From ba1f58d8f487e61e38a663d495a3635aadba771c Mon Sep 17 00:00:00 2001 From: Heng Lu <79895375+ms-henglu@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:27:09 +0800 Subject: [PATCH] bugfix: resource manager account is not initialized correctly (#675) * bugfix: resource manager account is not initialized correctly * fix context cancelled --- CHANGELOG.md | 1 + internal/clients/account.go | 77 ++++++++++++------- internal/clients/client.go | 2 +- .../azapi_client_config_data_source.go | 2 +- 4 files changed, 51 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a0fe6188b..b44e2e36f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## Unreleased FEATURES: - `azapi_resource` resource: Support resource move operation, it allows moving resources from `azurerm` provider. +- `azapi_client_config` data source: Support `object_id` field. BUG FIXES: - Fix a bug when `body` contains an unknown float number, the provider will crash. diff --git a/internal/clients/account.go b/internal/clients/account.go index 4f5c6c448..b754a2e27 100644 --- a/internal/clients/account.go +++ b/internal/clients/account.go @@ -2,6 +2,7 @@ package clients import ( "bytes" + "context" "encoding/base64" "encoding/json" "errors" @@ -11,30 +12,33 @@ import ( "strings" "sync" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" ) +type ObjectIDProvider func(ctx context.Context) (string, error) + type ResourceManagerAccount struct { - tenantId *string - subscriptionId *string - objectId *string - mutex *sync.Mutex - client *Client + tenantId *string + subscriptionId *string + objectId *string + mutex *sync.Mutex + objectIDProvider ObjectIDProvider } -func NewResourceManagerAccount(client *Client) ResourceManagerAccount { +func NewResourceManagerAccount(tenantId, subscriptionId string, provider ObjectIDProvider) ResourceManagerAccount { out := ResourceManagerAccount{ mutex: &sync.Mutex{}, } - if client != nil && client.Account.tenantId != nil && *client.Account.tenantId != "" { - out.tenantId = client.Account.tenantId + if tenantId != "" { + out.tenantId = &tenantId } - if client != nil && client.Account.subscriptionId != nil && *client.Account.subscriptionId != "" { - out.subscriptionId = client.Account.subscriptionId + if subscriptionId != "" { + out.subscriptionId = &subscriptionId } // We lazy load object ID because it's not always needed and could cause a performance hit - out.client = client + out.objectIDProvider = provider return out } @@ -80,7 +84,7 @@ func (account *ResourceManagerAccount) GetSubscriptionId() string { return *account.subscriptionId } -func (account *ResourceManagerAccount) GetObjectId() string { +func (account *ResourceManagerAccount) GetObjectId(ctx context.Context) string { account.mutex.Lock() defer account.mutex.Unlock() @@ -88,28 +92,21 @@ func (account *ResourceManagerAccount) GetObjectId() string { return *account.objectId } - tok, err := account.client.Option.Cred.GetToken(account.client.StopContext, policy.TokenRequestOptions{ - TenantID: account.client.Option.TenantId, - Scopes: []string{account.client.Option.CloudCfg.Services[cloud.ResourceManager].Endpoint + "/.default"}}) - if err != nil { - log.Printf("[DEBUG] Error getting requesting token from credentials: %s", err) - } - - if tok.Token == "" { - err = account.loadSignedInUserFromAzCmd() - if err != nil { - log.Printf("[DEBUG] Error getting user object ID from az cli: %s", err) - } - } else { - cl, err := parseTokenClaims(tok.Token) + if account.objectIDProvider != nil { + objectId, err := account.objectIDProvider(ctx) if err != nil { - log.Printf("[DEBUG] Error getting object id from token: %s", err) + log.Printf("[DEBUG] Error getting object ID: %s", err) } - if cl != nil && cl.ObjectId != "" { - account.objectId = &cl.ObjectId + if objectId != "" { + account.objectId = &objectId + return *account.objectId } } + err := account.loadSignedInUserFromAzCmd() + if err != nil { + log.Printf("[DEBUG] Error getting user object ID from az cli: %s", err) + } if account.objectId == nil { log.Printf("[DEBUG] No object ID found") return "" @@ -215,3 +212,25 @@ type tokenClaims struct { AppId string `json:"appid,omitempty"` IdType string `json:"idtyp,omitempty"` } + +func ParsedTokenClaimsObjectIDProvider(cred azcore.TokenCredential, cloudCfg cloud.Configuration) ObjectIDProvider { + return func(ctx context.Context) (string, error) { + tok, err := cred.GetToken(context.Background(), policy.TokenRequestOptions{ + EnableCAE: true, + Scopes: []string{cloudCfg.Services[cloud.ResourceManager].Audience + "/.default"}}) + if err != nil { + return "", fmt.Errorf("getting requesting token from credentials: %w", err) + } + if tok.Token == "" { + return "", errors.New("token is empty") + } + cl, err := parseTokenClaims(tok.Token) + if err != nil { + return "", fmt.Errorf("getting object id from token: %w", err) + } + if cl == nil || cl.ObjectId == "" { + return "", errors.New("object id is empty") + } + return cl.ObjectId, nil + } +} diff --git a/internal/clients/client.go b/internal/clients/client.go index 08c0d0d8a..f5f3354a4 100644 --- a/internal/clients/client.go +++ b/internal/clients/client.go @@ -133,7 +133,7 @@ func (client *Client) Build(ctx context.Context, o *Option) error { } client.DataPlaneClient = dataPlaneClient - client.Account = NewResourceManagerAccount(client) + client.Account = NewResourceManagerAccount(o.TenantId, o.SubscriptionId, ParsedTokenClaimsObjectIDProvider(o.Cred, o.CloudCfg)) return nil } diff --git a/internal/services/azapi_client_config_data_source.go b/internal/services/azapi_client_config_data_source.go index 4a6351b33..ca575dffe 100644 --- a/internal/services/azapi_client_config_data_source.go +++ b/internal/services/azapi_client_config_data_source.go @@ -90,7 +90,7 @@ func (r *ClientConfigDataSource) Read(ctx context.Context, request datasource.Re subscriptionId := r.ProviderData.Account.GetSubscriptionId() tenantId := r.ProviderData.Account.GetTenantId() - objectId := r.ProviderData.Account.GetObjectId() + objectId := r.ProviderData.Account.GetObjectId(ctx) model.ID = types.StringValue(fmt.Sprintf("clientConfigs/subscriptionId=%s;tenantId=%s", subscriptionId, tenantId)) model.SubscriptionID = types.StringValue(subscriptionId)