diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index 0bcd9952fcf3a..44d6c569a1bbe 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -121,31 +121,33 @@ def get_conn(self) -> BlobServiceClient: # Here we use anonymous public read # more info # https://docs.microsoft.com/en-us/azure/storage/blobs/storage-manage-access-to-resources - return BlobServiceClient(account_url=conn.host) + return BlobServiceClient(account_url=conn.host, **extra) - if extra.get('connection_string') or extra.get('extra__wasb__connection_string'): + connection_string = extra.pop('connection_string', extra.pop('extra__wasb__connection_string', None)) + if connection_string: # connection_string auth takes priority - connection_string = extra.get('connection_string') or extra.get('extra__wasb__connection_string') - return BlobServiceClient.from_connection_string(connection_string) - if extra.get('shared_access_key') or extra.get('extra__wasb__shared_access_key'): - shared_access_key = extra.get('shared_access_key') or extra.get('extra__wasb__shared_access_key') + return BlobServiceClient.from_connection_string(connection_string, **extra) + + shared_access_key = extra.pop('shared_access_key', extra.pop('extra__wasb__shared_access_key', None)) + if shared_access_key: # using shared access key - return BlobServiceClient(account_url=conn.host, credential=shared_access_key) - if extra.get('tenant_id') or extra.get('extra__wasb__tenant_id'): + return BlobServiceClient(account_url=conn.host, credential=shared_access_key, **extra) + + tenant = extra.pop('tenant_id', extra.pop('extra__wasb__tenant_id', None)) + if tenant: # use Active Directory auth app_id = conn.login app_secret = conn.password - tenant = extra.get('tenant_id', extra.get('extra__wasb__tenant_id')) token_credential = ClientSecretCredential(tenant, app_id, app_secret) - return BlobServiceClient(account_url=conn.host, credential=token_credential) + return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra) - sas_token = extra.get('sas_token') or extra.get('extra__wasb__sas_token') + sas_token = extra.pop('sas_token', extra.pop('extra__wasb__sas_token', None)) if sas_token: if sas_token.startswith('https'): - return BlobServiceClient(account_url=sas_token) + return BlobServiceClient(account_url=sas_token, **extra) else: return BlobServiceClient( - account_url=f'https://{conn.login}.blob.core.windows.net/{sas_token}' + account_url=f'https://{conn.login}.blob.core.windows.net/{sas_token}', **extra ) # Fall back to old auth (password) or use managed identity if not provided. diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index b258c50ced4a4..e31a560ff0821 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -52,11 +52,14 @@ def setup(self): self.public_read_conn_id = 'pub_read_id' self.managed_identity_conn_id = 'managed_identity' + self.proxies = {'http': 'http_proxy_uri', 'https': 'https_proxy_uri'} + db.merge_conn( Connection( conn_id=self.public_read_conn_id, conn_type=self.connection_type, host='https://accountname.blob.core.windows.net', + extra=json.dumps({'proxies': self.proxies}), ) ) @@ -64,7 +67,7 @@ def setup(self): Connection( conn_id=self.connection_string_id, conn_type=self.connection_type, - extra=json.dumps({'connection_string': CONN_STRING}), + extra=json.dumps({'connection_string': CONN_STRING, 'proxies': self.proxies}), ) ) db.merge_conn( @@ -72,50 +75,59 @@ def setup(self): conn_id=self.shared_key_conn_id, conn_type=self.connection_type, host='https://accountname.blob.core.windows.net', - extra=json.dumps({'shared_access_key': 'token'}), + extra=json.dumps({'shared_access_key': 'token', 'proxies': self.proxies}), ) ) db.merge_conn( Connection( conn_id=self.ad_conn_id, conn_type=self.connection_type, - extra=json.dumps( - {'tenant_id': 'token', 'application_id': 'appID', 'application_secret': "appsecret"} - ), + host='conn_host', + login='appID', + password='appsecret', + extra=json.dumps({'tenant_id': 'token', 'proxies': self.proxies}), ) ) db.merge_conn( Connection( conn_id=self.managed_identity_conn_id, conn_type=self.connection_type, + extra=json.dumps({'proxies': self.proxies}), ) ) db.merge_conn( Connection( conn_id=self.sas_conn_id, conn_type=self.connection_type, - extra=json.dumps({'sas_token': 'token'}), + extra=json.dumps({'sas_token': 'token', 'proxies': self.proxies}), ) ) db.merge_conn( Connection( conn_id=self.extra__wasb__sas_conn_id, conn_type=self.connection_type, - extra=json.dumps({'extra__wasb__sas_token': 'token'}), + extra=json.dumps({'extra__wasb__sas_token': 'token', 'proxies': self.proxies}), ) ) db.merge_conn( Connection( conn_id=self.http_sas_conn_id, conn_type=self.connection_type, - extra=json.dumps({'sas_token': 'https://login.blob.core.windows.net/token'}), + extra=json.dumps( + {'sas_token': 'https://login.blob.core.windows.net/token', 'proxies': self.proxies} + ), ) ) db.merge_conn( Connection( conn_id=self.extra__wasb__http_sas_conn_id, conn_type=self.connection_type, - extra=json.dumps({'extra__wasb__sas_token': 'https://login.blob.core.windows.net/token'}), + extra=json.dumps( + { + 'extra__wasb__sas_token': 'https://login.blob.core.windows.net/token', + 'proxies': self.proxies, + } + ), ) ) @@ -160,6 +172,31 @@ def test_sas_token_connection(self, conn_id_str, extra_key): assert isinstance(conn, BlobServiceClient) assert conn.url.endswith(sas_token + '/') + @pytest.mark.parametrize( + argnames="conn_id_str", + argvalues=[ + 'connection_string_id', + 'shared_key_conn_id', + 'ad_conn_id', + 'managed_identity_conn_id', + 'sas_conn_id', + 'extra__wasb__sas_conn_id', + 'http_sas_conn_id', + 'extra__wasb__http_sas_conn_id', + ], + ) + def test_connection_extra_arguments(self, conn_id_str): + conn_id = self.__getattribute__(conn_id_str) + hook = WasbHook(wasb_conn_id=conn_id) + conn = hook.get_conn() + assert conn._config.proxy_policy.proxies == self.proxies + + def test_connection_extra_arguments_public_read(self): + conn_id = self.public_read_conn_id + hook = WasbHook(wasb_conn_id=conn_id, public_read=True) + conn = hook.get_conn() + assert conn._config.proxy_policy.proxies == self.proxies + @mock.patch("airflow.providers.microsoft.azure.hooks.wasb.BlobServiceClient") def test_check_for_blob(self, mock_service): hook = WasbHook(wasb_conn_id=self.shared_key_conn_id)