Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
)
import logging
try:
from urllib.parse import parse_qs
from urllib.parse import parse_qs, quote
except ImportError:
from urlparse import parse_qs # type: ignore
from urllib2 import quote # type: ignore

import six

Expand Down Expand Up @@ -156,7 +157,7 @@ def _create_pipeline(self, credential, **kwargs):
config = kwargs.get('_configuration') or create_configuration(**kwargs)
if kwargs.get('_pipeline'):
return config, kwargs['_pipeline']
config.transport = kwargs.get('transport') # type: HttpTransport
config.transport = kwargs.get('transport') # type: ignore
if not config.transport:
config.transport = RequestsTransport(config)
policies = [
Expand Down Expand Up @@ -276,7 +277,7 @@ def create_configuration(**kwargs):
def parse_query(query_str):
sas_values = QueryStringConstants.to_list()
parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
sas_params = ["{}={}".format(k, v) for k, v in parsed_query.items() if k in sas_values]
sas_params = ["{}={}".format(k, quote(v)) for k, v in parsed_query.items() if k in sas_values]
sas_token = None
if sas_params:
sas_token = '&'.join(sas_params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
QueueMessagePolicy)
from .policies_async import AsyncStorageResponseHook


if TYPE_CHECKING:
from azure.core.pipeline import Pipeline
from azure.core import Configuration
_LOGGER = logging.getLogger(__name__)


Expand Down Expand Up @@ -60,11 +62,11 @@ def _create_pipeline(self, credential, **kwargs):
raise TypeError("Unsupported credential: {}".format(credential))

if 'connection_timeout' not in kwargs:
kwargs['connection_timeout'] = DEFAULT_SOCKET_TIMEOUT[0]
kwargs['connection_timeout'] = DEFAULT_SOCKET_TIMEOUT[0] # type: ignore
config = kwargs.get('_configuration') or create_configuration(**kwargs)
if kwargs.get('_pipeline'):
return config, kwargs['_pipeline']
config.transport = kwargs.get('transport') # type: HttpTransport
config.transport = kwargs.get('transport') # type: ignore
if not config.transport:
config.transport = AsyncTransport(config)
policies = [
Expand All @@ -76,7 +78,7 @@ def _create_pipeline(self, credential, **kwargs):
credential_policy,
ContentDecodePolicy(),
AsyncRedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs),
StorageHosts(hosts=self._hosts, **kwargs), # type: ignore
config.retry_policy,
config.logging_policy,
AsyncStorageResponseHook(**kwargs),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AccountPermissions,
StorageErrorCode
)
from ._queue_utils import (
from ._message_encoding import (
TextBase64EncodePolicy,
TextBase64DecodePolicy,
BinaryBase64EncodePolicy,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# pylint: disable=unused-argument

from azure.core.exceptions import ResourceExistsError

from ._shared.models import StorageErrorCode
from .models import QueueProperties


def deserialize_metadata(response, obj, headers):
raw_metadata = {k: v for k, v in response.headers.items() if k.startswith("x-ms-meta-")}
return {k[10:]: v for k, v in raw_metadata.items()}


def deserialize_queue_properties(response, obj, headers):
metadata = deserialize_metadata(response, obj, headers)
queue_properties = QueueProperties(
metadata=metadata,
**headers
)
return queue_properties


def deserialize_queue_creation(response, obj, headers):
if response.status_code == 204:
error_code = StorageErrorCode.queue_already_exists
error = ResourceExistsError(
message="Queue already exists\nRequestId:{}\nTime:{}\nErrorCode:{}".format(
headers['x-ms-request-id'],
headers['Date'],
error_code
),
response=response)
error.error_code = error_code
error.additional_info = {}
raise error
return headers
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,9 @@
from xml.sax.saxutils import unescape as xml_unescape

import six
from azure.core.exceptions import ResourceExistsError, DecodeError
from azure.core.exceptions import DecodeError

from ._shared.models import StorageErrorCode
from ._shared.encryption import decrypt_queue_message, encrypt_queue_message
from .models import QueueProperties


def deserialize_metadata(response, obj, headers):
raw_metadata = {k: v for k, v in response.headers.items() if k.startswith("x-ms-meta-")}
return {k[10:]: v for k, v in raw_metadata.items()}


def deserialize_queue_properties(response, obj, headers):
metadata = deserialize_metadata(response, obj, headers)
queue_properties = QueueProperties(
metadata=metadata,
**headers
)
return queue_properties


def deserialize_queue_creation(response, obj, headers):
if response.status_code == 204:
error_code = StorageErrorCode.queue_already_exists
error = ResourceExistsError(
message="Queue already exists\nRequestId:{}\nTime:{}\nErrorCode:{}".format(
headers['x-ms-request-id'],
headers['Date'],
error_code
),
response=response)
error.error_code = error_code
error.additional_info = {}
raise error
return headers


class MessageEncodePolicy(object):
Expand Down Expand Up @@ -104,7 +72,7 @@ def decode(self, content, response):

class TextBase64EncodePolicy(MessageEncodePolicy):
"""Base 64 message encoding policy for text messages.

Encodes text (unicode) messages to base 64. If the input content
is not text, a TypeError will be raised. Input text must support UTF-8.
"""
Expand All @@ -117,7 +85,7 @@ def encode(self, content):

class TextBase64DecodePolicy(MessageDecodePolicy):
"""Message decoding policy for base 64-encoded messages into text.

Decodes base64-encoded messages to text (unicode). If the input content
is not valid base 64, a DecodeError will be raised. Message data must
support UTF-8.
Expand All @@ -136,7 +104,7 @@ def decode(self, content, response):

class BinaryBase64EncodePolicy(MessageEncodePolicy):
"""Base 64 message encoding policy for binary messages.

Encodes binary messages to base 64. If the input content
is not bytes, a TypeError will be raised.
"""
Expand All @@ -149,7 +117,7 @@ def encode(self, content):

class BinaryBase64DecodePolicy(MessageDecodePolicy):
"""Message decoding policy for base 64-encoded messages into bytes.

Decodes base64-encoded messages to bytes. If the input content
is not valid base 64, a DecodeError will be raised.
"""
Expand All @@ -167,7 +135,7 @@ def decode(self, content, response):

class TextXMLEncodePolicy(MessageEncodePolicy):
"""XML message encoding policy for text messages.

Encodes text (unicode) messages to XML. If the input content
is not text, a TypeError will be raised.
"""
Expand All @@ -180,7 +148,7 @@ def encode(self, content):

class TextXMLDecodePolicy(MessageDecodePolicy):
"""Message decoding policy for XML-encoded messages into text.

Decodes XML-encoded messages to text (unicode).
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def _add_authorization_header(self, request, string_to_sign):
raise _wrap_exception(ex, AzureSigningError)

def on_request(self, request, **kwargs):
if not 'content-type' in request.http_request.headers:
request.http_request.headers['content-type'] = 'application/xml; charset=utf-8'

string_to_sign = \
self._get_verb(request) + \
self._get_headers(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
)
import logging
try:
from urllib.parse import parse_qs
from urllib.parse import parse_qs, quote
except ImportError:
from urlparse import parse_qs # type: ignore
from urllib2 import quote # type: ignore

import six

Expand Down Expand Up @@ -87,9 +88,7 @@ def __init__(
self.require_encryption = kwargs.get('require_encryption', False)
self.key_encryption_key = kwargs.get('key_encryption_key')
self.key_resolver_function = kwargs.get('key_resolver_function')

self._config, self._pipeline = create_pipeline(
self.credential, storage_sdk=service, hosts=self._hosts, **kwargs)
self._config, self._pipeline = self._create_pipeline(self.credential, storage_sdk=service, **kwargs)

def __enter__(self):
self._client.__enter__()
Expand Down Expand Up @@ -145,6 +144,38 @@ def _format_query_string(self, sas_token, credential, snapshot=None, share_snaps
credential = None
return query_str.rstrip('?&'), credential

def _create_pipeline(self, credential, **kwargs):
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
credential_policy = None
if hasattr(credential, 'get_token'):
credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
elif isinstance(credential, SharedKeyCredentialPolicy):
credential_policy = credential
elif credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))

config = kwargs.get('_configuration') or create_configuration(**kwargs)
if kwargs.get('_pipeline'):
return config, kwargs['_pipeline']
config.transport = kwargs.get('transport') # type: ignore
if not config.transport:
config.transport = RequestsTransport(config)
policies = [
QueueMessagePolicy(),
config.headers_policy,
config.user_agent_policy,
StorageContentValidation(),
StorageRequestHook(**kwargs),
credential_policy,
ContentDecodePolicy(),
RedirectPolicy(**kwargs),
StorageHosts(hosts=self._hosts, **kwargs),
config.retry_policy,
config.logging_policy,
StorageResponseHook(**kwargs),
]
return config, Pipeline(config.transport, policies=policies)


def format_shared_key_credential(account, credential):
if isinstance(credential, six.string_types):
Expand Down Expand Up @@ -219,7 +250,6 @@ def create_configuration(**kwargs):
config.headers_policy = StorageHeadersPolicy(**kwargs)
config.user_agent_policy = StorageUserAgentPolicy(**kwargs)
config.retry_policy = kwargs.get('retry_policy') or ExponentialRetry(**kwargs)
config.redirect_policy = RedirectPolicy(**kwargs)
config.logging_policy = StorageLoggingPolicy(**kwargs)
config.proxy_policy = ProxyPolicy(**kwargs)

Expand All @@ -244,43 +274,10 @@ def create_configuration(**kwargs):
return config


def create_pipeline(credential, **kwargs):
# type: (Any, **Any) -> Tuple[Configuration, Pipeline]
credential_policy = None
if hasattr(credential, 'get_token'):
credential_policy = BearerTokenCredentialPolicy(credential, STORAGE_OAUTH_SCOPE)
elif isinstance(credential, SharedKeyCredentialPolicy):
credential_policy = credential
elif credential is not None:
raise TypeError("Unsupported credential: {}".format(credential))

config = kwargs.get('_configuration') or create_configuration(**kwargs)
if kwargs.get('_pipeline'):
return config, kwargs['_pipeline']
transport = kwargs.get('transport') # type: HttpTransport
if not transport:
transport = RequestsTransport(config)
policies = [
QueueMessagePolicy(),
config.headers_policy,
config.user_agent_policy,
StorageContentValidation(),
StorageRequestHook(**kwargs),
credential_policy,
ContentDecodePolicy(),
config.redirect_policy,
StorageHosts(**kwargs),
config.retry_policy,
config.logging_policy,
StorageResponseHook(**kwargs),
]
return config, Pipeline(transport, policies=policies)


def parse_query(query_str):
sas_values = QueryStringConstants.to_list()
parsed_query = {k: v[0] for k, v in parse_qs(query_str).items()}
sas_params = ["{}={}".format(k, v) for k, v in parsed_query.items() if k in sas_values]
sas_params = ["{}={}".format(k, quote(v)) for k, v in parsed_query.items() if k in sas_values]
sas_token = None
if sas_params:
sas_token = '&'.join(sas_params)
Expand Down
Loading