Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ class UtilsTest(unittest.TestCase):
def test_convert_datetime_to_utc_int(self):
# UTC
utc_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=dateutil.tz.tzutc()))
assert utc_time_in_sec == 0
assert utc_time_in_sec == 0
# UTC naive (without a timezone specified)
utc_naive_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0))
assert utc_naive_time_in_sec == 0
# PST is UTC-8
pst_time_in_sec = _convert_datetime_to_utc_int(datetime(1970, 1, 1, 0, 0, 0, 0, tzinfo=dateutil.tz.gettz('America/Vancouver')))
assert pst_time_in_sec == 8 * 3600
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# license information.
# --------------------------------------------------------------------------
from threading import Lock, Condition
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
Expand All @@ -24,9 +24,9 @@ class CommunicationTokenCredential(object):
_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
token, # type: str
**kwargs
):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,8 +35,8 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""
Expand Down Expand Up @@ -79,12 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
return get_current_utc_as_int() < self._token.expires_on
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,27 @@
# license information.
# --------------------------------------------------------------------------
from asyncio import Condition, Lock
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
Any
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
:keyword token_refresher: The token refresher to provide capacity to fetch fresh token
:keyword token_refresher: The async token refresher to provide capacity to fetch fresh token
:raises: TypeError
"""

_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
def __init__(self, token: str, **kwargs: Any):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,25 +33,24 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
async def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""

if not self._token_refresher or not self._token_expiring():
return self._token

should_this_thread_refresh = False

with self._lock:
async with self._lock:

while self._token_expiring():
if self._some_thread_refreshing:
if self._is_currenttoken_valid():
return self._token

self._wait_till_inprogress_thread_finish_refreshing()
await self._wait_till_inprogress_thread_finish_refreshing()
else:
should_this_thread_refresh = True
self._some_thread_refreshing = True
Expand All @@ -62,32 +59,37 @@ def get_token(self):

if should_this_thread_refresh:
try:
newtoken = self._token_refresher() # pylint:disable=not-callable
newtoken = await self._token_refresher() # pylint:disable=not-callable

with self._lock:
async with self._lock:
self._token = newtoken
self._some_thread_refreshing = False
self._lock.notify_all()
except:
with self._lock:
async with self._lock:
self._some_thread_refreshing = False
self._lock.notify_all()

raise

return self._token

def _wait_till_inprogress_thread_finish_refreshing(self):
async def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.release()
self._lock.acquire()
await self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on
return get_current_utc_as_int() < self._token.expires_on

async def close(self) -> None:
pass

async def __aenter__(self):
return self

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
async def __aexit__(self, *args):
await self.close()
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import base64
import json
import time
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
Expand All @@ -17,6 +16,14 @@
from azure.core.credentials import AccessToken

def _convert_datetime_to_utc_int(expires_on):
"""
Converts DateTime in local time to the Epoch in UTC in second.

:param input_datetime: Input datetime
Comment thread
petrsvihlik marked this conversation as resolved.
:type input_datetime: datetime
:return: Integer
:rtype: int
"""
return int(calendar.timegm(expires_on.utctimetuple()))

def parse_connection_str(conn_str):
Expand Down Expand Up @@ -50,9 +57,10 @@ def get_current_utc_time():
# type: () -> str
return str(datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S ")) + "GMT"


def get_current_utc_as_int():
# type: () -> int
current_utc_datetime = datetime.utcnow().replace(tzinfo=TZ_UTC)
current_utc_datetime = datetime.utcnow()
return _convert_datetime_to_utc_int(current_utc_datetime)

def create_access_token(token):
Expand Down Expand Up @@ -83,10 +91,6 @@ def create_access_token(token):
except ValueError:
raise ValueError(token_parse_err_msg)

def _convert_expires_on_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())

def get_authentication_policy(
endpoint, # type: str
credential, # type: TokenCredential or str
Expand Down Expand Up @@ -122,7 +126,3 @@ def get_authentication_policy(

raise TypeError("Unsupported credential: {}. Use an access token string to use HMACCredentialsPolicy"
"or a token credential from azure.identity".format(type(credential)))

def _convert_expires_on_datetime_to_utc_int(expires_on):
epoch = time.mktime(datetime(1970, 1, 1).timetuple())
return epoch-time.mktime(expires_on.timetuple())
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(self,
access_key, # type: str
decode_url=False # type: bool
):
# pylint: disable=bad-option-value,useless-object-inheritance,disable=super-with-arguments
# type: (...) -> None
super(HMACCredentialsPolicy, self).__init__()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
# license information.
# --------------------------------------------------------------------------
from threading import Lock, Condition
from datetime import datetime, timedelta
from datetime import timedelta
from typing import ( # pylint: disable=unused-import
cast,
Tuple,
)

from msrest.serialization import TZ_UTC

from .utils import get_current_utc_as_int
from .user_token_refresh_options import CommunicationTokenRefreshOptions


class CommunicationTokenCredential(object):
"""Credential type used for authenticating to an Azure Communication service.
:param str token: The token used to authenticate to an Azure Communication service
Expand All @@ -24,9 +24,9 @@ class CommunicationTokenCredential(object):
_ON_DEMAND_REFRESHING_INTERVAL_MINUTES = 2

def __init__(self,
token, # type: str
**kwargs
):
token, # type: str
**kwargs
):
token_refresher = kwargs.pop('token_refresher', None)
communication_token_refresh_options = CommunicationTokenRefreshOptions(token=token,
token_refresher=token_refresher)
Expand All @@ -35,8 +35,8 @@ def __init__(self,
self._lock = Condition(Lock())
self._some_thread_refreshing = False

def get_token(self):
# type () -> ~azure.core.credentials.AccessToken
def get_token(self, *scopes, **kwargs): # pylint: disable=unused-argument
# type (*str, **Any) -> AccessToken
"""The value of the configured token.
:rtype: ~azure.core.credentials.AccessToken
"""
Expand Down Expand Up @@ -79,12 +79,8 @@ def _wait_till_inprogress_thread_finish_refreshing(self):
self._lock.acquire()

def _token_expiring(self):
return self._token.expires_on - self._get_utc_now() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES)
return self._token.expires_on - get_current_utc_as_int() <\
timedelta(minutes=self._ON_DEMAND_REFRESHING_INTERVAL_MINUTES).total_seconds()
Comment thread
petrsvihlik marked this conversation as resolved.

def _is_currenttoken_valid(self):
return self._get_utc_now() < self._token.expires_on

@classmethod
def _get_utc_now(cls):
return datetime.now().replace(tzinfo=TZ_UTC)
return get_current_utc_as_int() < self._token.expires_on
Loading