Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ class AzureSigningError(ClientAuthenticationError):


class _HttpChallenge(object): # pylint:disable=too-few-public-methods
"""Represents a parsed HTTP WWW-Authentication Bearer challenge from a server."""
"""Represents a parsed HTTP WWW-Authentication Bearer challenge from a server.

:param challenge: The WWW-Authenticate header of the challenge response.
:type challenge: str
"""

def __init__(self, challenge):
if not challenge:
Expand Down Expand Up @@ -78,7 +82,6 @@ def __init__(self, challenge):
self.resource = self._parameters.get("resource") or self._parameters.get("resource_id") or ""


# pylint: disable=no-self-use
class SharedKeyCredentialPolicy(SansIOHTTPPolicy):
def __init__(self, credential, is_emulated=False):
self._credential = credential
Expand Down Expand Up @@ -139,7 +142,7 @@ def _add_authorization_header(self, request, string_to_sign):
except Exception as ex:
# Wrap any error that occurred as signing error
# Doing so will clarify/locate the source of problem
raise _wrap_exception(ex, AzureSigningError)
raise _wrap_exception(ex, AzureSigningError) from ex

def on_request(self, request: PipelineRequest) -> None:
self.sign_request(request)
Expand Down Expand Up @@ -198,6 +201,7 @@ def on_challenge(self, request: PipelineRequest, response: PipelineResponse) ->
:param ~azure.core.pipeline.PipelineRequest request: the request which elicited an authentication challenge
:param ~azure.core.pipeline.PipelineResponse response: the resource provider's response
:returns: a bool indicating whether the policy should send the request
:rtype: bool
"""
if not self._discover_tenant and not self._discover_scopes:
# We can't discover the tenant or use a different scope; the request will fail because it hasn't changed
Expand Down Expand Up @@ -268,5 +272,5 @@ def _configure_credential(
if isinstance(credential, AzureNamedKeyCredential):
return SharedKeyCredentialPolicy(credential)
if credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))
raise TypeError(f"Unsupported credential: {credential}")
return None
112 changes: 63 additions & 49 deletions sdk/tables/azure-data-tables/azure/data/tables/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def get_api_version(kwargs: Dict[str, Any], default: str) -> str:
if api_version and api_version not in _SUPPORTED_API_VERSIONS:
versions = "\n".join(_SUPPORTED_API_VERSIONS)
raise ValueError(
"Unsupported API version '{}'. Please select from:\n{}".format(
api_version, versions
)
f"Unsupported API version '{api_version}'. Please select from:\n{versions}"
)
return api_version or default

Expand All @@ -81,11 +79,11 @@ def __init__(
try:
if not account_url.lower().startswith("http"):
account_url = "https://" + account_url
except AttributeError:
raise ValueError("Account URL must be a string.")
except AttributeError as exc:
raise ValueError("Account URL must be a string.") from exc
parsed_url = urlparse(account_url.rstrip("/"))
if not parsed_url.netloc:
raise ValueError("Invalid URL: {}".format(account_url))
raise ValueError(f"Invalid URL: {account_url}")

_, sas_token = parse_query(parsed_url.query)
if not sas_token and not credential:
Expand Down Expand Up @@ -118,10 +116,8 @@ def __init__(
raise ValueError("Token credential is only supported with HTTPS.")
if hasattr(self.credential, "named_key"):
self.account_name = self.credential.named_key.name # type: ignore
secondary_hostname = "{}-secondary.table.{}".format(
self.credential.named_key.name, # type: ignore
os.getenv("TABLES_STORAGE_ENDPOINT_SUFFIX", DEFAULT_STORAGE_ENDPOINT_SUFFIX)
)
endpoint_suffix = os.getenv("TABLES_STORAGE_ENDPOINT_SUFFIX", DEFAULT_STORAGE_ENDPOINT_SUFFIX)
secondary_hostname = f"{self.account_name}-secondary.table.{endpoint_suffix}"

if not self._hosts:
if len(account) > 1:
Expand All @@ -143,26 +139,31 @@ def __init__(
self._policies.insert(0, CosmosPatchTransformPolicy())

@property
def url(self):
def url(self) -> str:
"""The full endpoint URL to this entity, including SAS token if used.

This could be either the primary endpoint,
or the secondary endpoint depending on the current :func:`location_mode`.

:return: The full endpoint URL including SAS token if used.
:rtype: str
"""
return self._format_url(self._hosts[self._location_mode])
return self._format_url(self._hosts[self._location_mode]) # type: ignore

@property
def _primary_endpoint(self):
"""The full primary endpoint URL.

:type: str
:return: The full primary endpoint URL.
:rtype: str
"""
return self._format_url(self._hosts[LocationMode.PRIMARY])

@property
def _primary_hostname(self):
"""The hostname of the primary endpoint.

:return: The hostname of the primary endpoint.
:type: str
"""
return self._hosts[LocationMode.PRIMARY]
Expand All @@ -174,8 +175,9 @@ def _secondary_endpoint(self):
If not available a ValueError will be raised. To explicitly specify a secondary hostname, use the optional
`secondary_hostname` keyword argument on instantiation.

:return: The full secondary endpoint URL.
:type: str
:raise ValueError:
:raise ValueError: If the secondary endpoint URL is not configured.
"""
if not self._hosts[LocationMode.SECONDARY]:
raise ValueError("No secondary host configured.")
Expand All @@ -188,34 +190,28 @@ def _secondary_hostname(self):
If not available this will be None. To explicitly specify a secondary hostname, use the optional
`secondary_hostname` keyword argument on instantiation.

:return: The hostname of the secondary endpoint.
:type: str or None
"""
return self._hosts[LocationMode.SECONDARY]

@property
def api_version(self):
def api_version(self) -> str:
"""The version of the Storage API used for requests.

:return: The Storage API version.
:type: str
"""
return self._client._config.version # pylint: disable=protected-access
return self._client._config.version # type: ignore # pylint: disable=protected-access


class TablesBaseClient(AccountHostsMixin):
"""Base class for TableClient

:param str endpoint: A URL to an Azure Tables account.
:keyword credential:
The credentials with which to authenticate. This is optional if the
account URL already has a SAS token. The value can be one of AzureNamedKeyCredential (azure-core),
AzureSasCredential (azure-core), or TokenCredentials from azure-identity.
:paramtype credential:
:class:`~azure.core.credentials.AzureNamedKeyCredential` or
:class:`~azure.core.credentials.AzureSasCredential` or
:class:`~azure.core.credentials.TokenCredential`
:keyword api_version: Specifies the version of the operation to use for this request. Default value
is "2019-02-02". Note that overriding this default value may result in unsupported behavior.
:paramtype api_version: str
:ivar str account_name: The name of the Tables account.
:ivar str scheme: The scheme component in the full URL to the Tables account.
:ivar str url: The storage endpoint.
:ivar str api_version: The service API version.
"""

def __init__( # pylint: disable=missing-client-constructor-parameter-credential
Expand All @@ -225,6 +221,21 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential
credential: Optional[Union[AzureSasCredential, AzureNamedKeyCredential, TokenCredential]] = None,
**kwargs
) -> None:
"""Create TablesBaseClient from a Credential.

:param str endpoint: A URL to an Azure Tables account.
:keyword credential:
The credentials with which to authenticate. This is optional if the
account URL already has a SAS token. The value can be one of AzureNamedKeyCredential (azure-core),
AzureSasCredential (azure-core), or a TokenCredential implementation from azure-identity.
:paramtype credential:
~azure.core.credentials.AzureNamedKeyCredential or
~azure.core.credentials.AzureSasCredential or
~azure.core.credentials.TokenCredential or None
:keyword api_version: Specifies the version of the operation to use for this request. Default value
is "2019-02-02".
:paramtype api_version: str
"""
super(TablesBaseClient, self).__init__(endpoint, credential=credential, **kwargs) # type: ignore
self._client = AzureTable(
self.url,
Expand Down Expand Up @@ -259,16 +270,25 @@ def _configure_policies(self, **kwargs):
]

def _batch_send(self, table_name: str, *reqs: HttpRequest, **kwargs) -> List[Mapping[str, Any]]:
"""Given a series of request, do a Storage batch call."""
# pylint:disable=docstring-should-be-keyword
"""Given a series of request, do a Storage batch call.

:param table_name: The table name.
:type table_name: str
:param reqs: The HTTP request.
:type reqs: ~azure.core.pipeline.transport.HttpRequest
:return: A list of batch part metadata in response.
:rtype: list[Mapping[str, Any]]
"""
# Pop it here, so requests doesn't feel bad about additional kwarg
policies = [StorageHeadersPolicy()]

changeset = HttpRequest("POST", None) # type: ignore
changeset.set_multipart_mixed(
*reqs, policies=policies, boundary="changeset_{}".format(uuid4()) # type: ignore
*reqs, policies=policies, boundary=f"changeset_{uuid4()}" # type: ignore
)
request = self._client._client.post( # pylint: disable=protected-access
url="{}://{}/$batch".format(self.scheme, self._primary_hostname),
url=f"{self.scheme}://{self._primary_hostname}/$batch",
headers={
"x-ms-version": self.api_version,
"DataServiceVersion": "3.0",
Expand All @@ -281,7 +301,7 @@ def _batch_send(self, table_name: str, *reqs: HttpRequest, **kwargs) -> List[Map
changeset,
policies=policies,
enforce_https=False,
boundary="batch_{}".format(uuid4()),
boundary=f"batch_{uuid4()}",
)
pipeline_response = self._client._client._pipeline.run(request, **kwargs) # pylint: disable=protected-access
response = pipeline_response.http_response
Expand Down Expand Up @@ -322,6 +342,9 @@ class TransportWrapper(HttpTransport):
"""Wrapper class that ensures that an inner client created
by a `get_client` method does not close the outer transport for the parent
when used in a context manager.

:param transport: The Http Transport instance
:type transport: ~azure.core.pipeline.transport.HttpTransport
"""
def __init__(self, transport):
self._transport = transport
Expand Down Expand Up @@ -352,25 +375,19 @@ def parse_connection_str(conn_str, credential, keyword_args):
if not credential:
try:
credential = AzureNamedKeyCredential(name=conn_settings["accountname"], key=conn_settings["accountkey"])
except KeyError:
except KeyError as exc:
credential = conn_settings.get("sharedaccesssignature", None)
if not credential:
raise ValueError("Connection string missing required connection details.")
raise ValueError("Connection string missing required connection details.") from exc
credential = AzureSasCredential(credential)
primary = conn_settings.get("tableendpoint")
secondary = conn_settings.get("tablesecondaryendpoint")
if not primary:
if secondary:
raise ValueError("Connection string specifies only secondary endpoint.")
try:
primary = "{}://{}.table.{}".format(
conn_settings["defaultendpointsprotocol"],
conn_settings["accountname"],
conn_settings["endpointsuffix"],
)
secondary = "{}-secondary.table.{}".format(
conn_settings["accountname"], conn_settings["endpointsuffix"]
)
primary = f"{conn_settings['defaultendpointsprotocol']}://{conn_settings['accountname']}.table.{conn_settings['endpointsuffix']}" # pylint: disable=line-too-long
secondary = f"{conn_settings['accountname']}-secondary.table.{conn_settings['endpointsuffix']}"
except KeyError:
pass

Expand All @@ -380,12 +397,9 @@ def parse_connection_str(conn_str, credential, keyword_args):
else:
endpoint_suffix = os.getenv("TABLES_STORAGE_ENDPOINT_SUFFIX", DEFAULT_STORAGE_ENDPOINT_SUFFIX)
try:
primary = "https://{}.table.{}".format(
conn_settings["accountname"],
conn_settings.get("endpointsuffix", endpoint_suffix),
)
except KeyError:
raise ValueError("Connection string missing required connection details.")
primary = f"https://{conn_settings['accountname']}.table.{conn_settings.get('endpointsuffix', endpoint_suffix)}" # pylint: disable=line-too-long
except KeyError as exc:
raise ValueError("Connection string missing required connection details.") from exc

if "secondary_hostname" not in keyword_args:
keyword_args["secondary_hostname"] = secondary
Expand Down Expand Up @@ -416,7 +430,7 @@ def parse_query(query_str):
sas_values = QueryStringConstants.to_list()
parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
sas_params = [
"{}={}".format(k, quote(v, safe=""))
f"{k}={quote(v, safe='')}"
for k, v in parsed_query.items()
if k in sas_values
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,8 @@
# --------------------------------------------------------------------------
import base64
import hashlib
import datetime
import hmac


class UTC(datetime.tzinfo):
"""Time Zone info for handling UTC"""

def utcoffset(self, dt):
"""UTF offset for UTC is 0."""
return datetime.timedelta(0)

def tzname(self, dt):
"""Timestamp representation."""
return "Z"

def dst(self, dt):
"""No daylight saving for UTC."""
return datetime.timedelta(hours=1)


try:
from datetime import timezone
TZ_UTC = timezone.utc # type: ignore
except ImportError:
TZ_UTC = UTC() # type: ignore
from datetime import timezone


def _to_str(value):
Expand All @@ -38,7 +15,7 @@ def _to_str(value):

def _to_utc_datetime(value):
try:
value = value.astimezone(TZ_UTC)
value = value.astimezone(timezone.utc)
except ValueError:
# Before Python 3.8, this raised for a naive datetime.
pass
Expand Down
Loading