Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions azure/durable_functions/models/DurableHttpRequest.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from typing import Dict, Any
from typing import Dict, Union, Optional

from azure.durable_functions.models import TokenSource
from azure.durable_functions.models.TokenSource import TokenSource
from azure.durable_functions.models.utils.json_utils import add_attrib, add_json_attrib


class DurableHttpRequest:
"""Data structure representing a durable HTTP request."""

def __init__(self, method: str, uri: str, content: str = None, headers: Dict[str, str] = None,
token_source: TokenSource = None):
def __init__(self, method: str, uri: str, content: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
token_source: Optional[TokenSource] = None):
self._method: str = method
self._uri: str = uri
self._content: str = content
self._headers: Dict[str, str] = headers
self._token_source: TokenSource = token_source
self._content: Optional[str] = content
self._headers: Optional[Dict[str, str]] = headers
self._token_source: Optional[TokenSource] = token_source

@property
def method(self) -> str:
Expand All @@ -26,29 +27,29 @@ def uri(self) -> str:
return self._uri

@property
def content(self) -> str:
def content(self) -> Optional[str]:
"""Get the HTTP request content."""
return self._content

@property
def headers(self) -> Dict[str, str]:
def headers(self) -> Optional[Dict[str, str]]:
"""Get the HTTP request headers."""
return self._headers

@property
def token_source(self) -> TokenSource:
def token_source(self) -> Optional[TokenSource]:
"""Get the source of OAuth token to add to the request."""
return self._token_source

def to_json(self) -> Dict[str, Any]:
def to_json(self) -> Dict[str, Union[str, int]]:
"""Convert object into a json dictionary.

Returns
-------
Dict[str, Any]
Dict[str, Union[str, int]]
The instance of the class converted into a json dictionary
"""
json_dict = {}
json_dict: Dict[str, Union[str, int]] = {}
add_attrib(json_dict, self, 'method')
add_attrib(json_dict, self, 'uri')
add_attrib(json_dict, self, 'content')
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Dict
from typing import Dict, Optional

from azure.durable_functions.models.FunctionContext import FunctionContext

Expand All @@ -13,11 +13,12 @@ class DurableOrchestrationBindings:

# parameter names are as defined by JSON schema and do not conform to PEP8 naming conventions
def __init__(self, taskHubName: str, creationUrls: Dict[str, str],
managementUrls: Dict[str, str], rpcBaseUrl: str = None, **kwargs):
managementUrls: Dict[str, str], rpcBaseUrl: Optional[str] = None, **kwargs):
self._task_hub_name: str = taskHubName
self._creation_urls: Dict[str, str] = creationUrls
self._management_urls: Dict[str, str] = managementUrls
self._rpc_base_url: str = rpcBaseUrl
# TODO: we can remove this once we drop support for 1.x, this is always provided in 2.x
self._rpc_base_url: Optional[str] = rpcBaseUrl
self._client_data = FunctionContext(**kwargs)

@property
Expand All @@ -36,7 +37,7 @@ def management_urls(self) -> Dict[str, str]:
return self._management_urls

@property
def rpc_base_url(self) -> str:
def rpc_base_url(self) -> Optional[str]:
"""Get the base url communication between out of proc workers and the function host."""
return self._rpc_base_url

Expand Down
99 changes: 59 additions & 40 deletions azure/durable_functions/models/DurableOrchestrationClient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from datetime import datetime
from typing import List, Any, Awaitable, Optional, Dict
from typing import List, Any, Optional, Dict, Union
from time import time
from asyncio import sleep
from urllib.parse import urlparse, quote
Expand All @@ -11,7 +11,7 @@
from .DurableOrchestrationStatus import DurableOrchestrationStatus
from .RpcManagementOptions import RpcManagementOptions
from .OrchestrationRuntimeStatus import OrchestrationRuntimeStatus
from ..models import DurableOrchestrationBindings
from ..models.DurableOrchestrationBindings import DurableOrchestrationBindings
from .utils.http_utils import get_async_request, post_async_request, delete_async_request
from azure.functions._durable_functions import _serialize_custom_object

Expand Down Expand Up @@ -44,8 +44,8 @@ def __init__(self, context: str):

async def start_new(self,
orchestration_function_name: str,
instance_id: str = None,
client_input: object = None) -> Awaitable[str]:
instance_id: Optional[str] = None,
client_input: Optional[Any] = None) -> str:
"""Start a new instance of the specified orchestrator function.

If an orchestration instance with the specified ID already exists, the
Expand All @@ -55,10 +55,10 @@ async def start_new(self,
----------
orchestration_function_name : str
The name of the orchestrator function to start.
instance_id : str
instance_id : Optional[str]
The ID to use for the new orchestration instance. If no instance id is specified,
the Durable Functions extension will generate a random GUID (recommended).
client_input : object
client_input : Optional[Any]
JSON-serializable input value for the orchestrator function.

Returns
Expand All @@ -69,22 +69,25 @@ async def start_new(self,
request_url = self._get_start_new_url(
instance_id=instance_id, orchestration_function_name=orchestration_function_name)

response = await self._post_async_request(request_url, self._get_json_input(client_input))
response: List[Any] = await self._post_async_request(
request_url, self._get_json_input(client_input))

if response[0] <= 202 and response[1]:
status_code: int = response[0]
if status_code <= 202 and response[1]:
return response[1]["id"]
elif response[0] == 400:
elif status_code == 400:
# Orchestrator not found, report clean exception
exception_data = response[1]
exception_data: Dict[str, str] = response[1]
exception_message = exception_data["ExceptionMessage"]
raise Exception(exception_message)
else:
# Catch all: simply surfacing the durable-extension exception
# we surface the stack trace too, since this may be a more involed exception
exception_message = response[1]
raise Exception(exception_message)
ex_message: Any = response[1]
raise Exception(ex_message)

def create_check_status_response(self, request, instance_id):
def create_check_status_response(
self, request: func.HttpRequest, instance_id: str) -> func.HttpResponse:
"""Create a HttpResponse that contains useful information for \
checking the status of the specified instance.

Expand Down Expand Up @@ -148,16 +151,16 @@ def get_client_response_links(
payload = self._orchestration_bindings.management_urls.copy()

for key, _ in payload.items():
request_is_not_none = not (request is None)
if request_is_not_none and request.url:
if not(request is None) and request.url:
payload[key] = self._replace_url_origin(request.url, payload[key])
payload[key] = payload[key].replace(
self._orchestration_bindings.management_urls["id"], instance_id)

return payload

async def raise_event(self, instance_id, event_name, event_data=None,
task_hub_name=None, connection_name=None):
async def raise_event(
self, instance_id: str, event_name: str, event_data: Any = None,
task_hub_name: str = None, connection_name: str = None) -> None:
"""Send an event notification message to a waiting orchestration instance.

In order to handle the event, the target orchestration instance must be
Expand All @@ -169,7 +172,7 @@ async def raise_event(self, instance_id, event_name, event_data=None,
The ID of the orchestration instance that will handle the event.
event_name : str
The name of the event.
event_data : any, optional
event_data : Any, optional
The JSON-serializable data associated with the event.
task_hub_name : str, optional
The TaskHubName of the orchestration that will handle the event.
Expand All @@ -183,8 +186,8 @@ async def raise_event(self, instance_id, event_name, event_data=None,
Exception
Raises an exception if the status code is 404 or 400 when raising the event.
"""
if not event_name:
raise ValueError("event_name must be a valid string.")
if event_name == "":
raise ValueError("event_name must be a non-empty string.")

request_url = self._get_raise_event_url(
instance_id, event_name, task_hub_name, connection_name)
Expand All @@ -203,9 +206,9 @@ async def raise_event(self, instance_id, event_name, event_data=None,
if error_message:
raise Exception(error_message)

async def get_status(self, instance_id: str, show_history: bool = None,
show_history_output: bool = None,
show_input: bool = None) -> DurableOrchestrationStatus:
async def get_status(self, instance_id: str, show_history: bool = False,
show_history_output: bool = False,
show_input: bool = False) -> DurableOrchestrationStatus:
"""Get the status of the specified orchestration instance.

Parameters
Expand Down Expand Up @@ -268,7 +271,8 @@ async def get_status_all(self) -> List[DurableOrchestrationStatus]:
if error_message:
raise Exception(error_message)
else:
return [DurableOrchestrationStatus.from_json(o) for o in response[1]]
statuses: List[Any] = response[1]
return [DurableOrchestrationStatus.from_json(o) for o in statuses]

async def get_status_by(self, created_time_from: datetime = None,
created_time_to: datetime = None,
Expand All @@ -291,6 +295,7 @@ async def get_status_by(self, created_time_from: datetime = None,
DurableOrchestrationStatus
The status of the requested orchestration instances
"""
# TODO: do we really want folks to us this without specifying all the args?
options = RpcManagementOptions(created_time_from=created_time_from,
created_time_to=created_time_to,
runtime_status=runtime_status)
Expand Down Expand Up @@ -326,19 +331,20 @@ async def purge_instance_history(self, instance_id: str) -> PurgeHistoryResult:
response = await self._delete_async_request(request_url)
return self._parse_purge_instance_history_response(response)

async def purge_instance_history_by(self, created_time_from: datetime = None,
created_time_to: datetime = None,
runtime_status: List[OrchestrationRuntimeStatus] = None) \
async def purge_instance_history_by(
self, created_time_from: Optional[datetime] = None,
created_time_to: Optional[datetime] = None,
runtime_status: Optional[List[OrchestrationRuntimeStatus]] = None) \
-> PurgeHistoryResult:
"""Delete the history of all orchestration instances that match the specified conditions.

Parameters
----------
created_time_from : datetime
created_time_from : Optional[datetime]
Delete orchestration history which were created after this Date.
created_time_to: datetime
created_time_to: Optional[datetime]
Delete orchestration history which were created before this Date.
runtime_status: List[OrchestrationRuntimeStatus]
runtime_status: Optional[List[OrchestrationRuntimeStatus]]
Delete orchestration instances which match any of the runtimeStatus values
in this list.

Expand All @@ -347,14 +353,15 @@ async def purge_instance_history_by(self, created_time_from: datetime = None,
PurgeHistoryResult
The results of the request to purge history
"""
# TODO: do we really want folks to us this without specifying all the args?
options = RpcManagementOptions(created_time_from=created_time_from,
created_time_to=created_time_to,
runtime_status=runtime_status)
request_url = options.to_url(self._orchestration_bindings.rpc_base_url)
response = await self._delete_async_request(request_url)
return self._parse_purge_instance_history_response(response)

async def terminate(self, instance_id: str, reason: str):
async def terminate(self, instance_id: str, reason: str) -> None:
"""Terminate the specified orchestration instance.

Parameters
Expand All @@ -364,6 +371,11 @@ async def terminate(self, instance_id: str, reason: str):
reason: str
The reason for terminating the instance.

Raises
------
Exception:
When the terminate call failed with an unexpected status code

Returns
-------
None
Expand Down Expand Up @@ -446,7 +458,8 @@ async def wait_for_completion_or_create_check_status_response(
return self.create_check_status_response(request, instance_id)

@staticmethod
def _create_http_response(status_code: int, body: Any) -> func.HttpResponse:
def _create_http_response(
status_code: int, body: Union[str, Any]) -> func.HttpResponse:
body_as_json = body if isinstance(body, str) else json.dumps(body)
response_args = {
"status_code": status_code,
Expand All @@ -459,7 +472,7 @@ def _create_http_response(status_code: int, body: Any) -> func.HttpResponse:
return func.HttpResponse(**response_args)

@staticmethod
def _get_json_input(client_input: object) -> str:
def _get_json_input(client_input: object) -> Optional[str]:
"""Serialize the orchestrator input.

Parameters
Expand All @@ -469,8 +482,10 @@ def _get_json_input(client_input: object) -> str:

Returns
-------
str
A string representing the JSON-serialization of `client_input`
Optional[str]
If `client_input` is not None, return a string representing
the JSON-serialization of `client_input`. Otherwise, returns
None

Exceptions
----------
Expand All @@ -482,7 +497,7 @@ def _get_json_input(client_input: object) -> str:
return None

@staticmethod
def _replace_url_origin(request_url, value_url):
def _replace_url_origin(request_url: str, value_url: str) -> str:
request_parsed_url = urlparse(request_url)
value_parsed_url = urlparse(value_url)
request_url_origin = '{url.scheme}://{url.netloc}/'.format(url=request_parsed_url)
Expand All @@ -491,7 +506,8 @@ def _replace_url_origin(request_url, value_url):
return value_url

@staticmethod
def _parse_purge_instance_history_response(response: [int, Any]):
def _parse_purge_instance_history_response(
response: List[Any]) -> PurgeHistoryResult:
switch_statement = {
200: lambda: PurgeHistoryResult.from_json(response[1]), # instance completed
404: lambda: PurgeHistoryResult(instancesDeleted=0), # instance not found
Expand All @@ -506,17 +522,20 @@ def _parse_purge_instance_history_response(response: [int, Any]):
else:
raise Exception(result)

def _get_start_new_url(self, instance_id, orchestration_function_name):
def _get_start_new_url(
self, instance_id: Optional[str], orchestration_function_name: str) -> str:
instance_path = f'/{instance_id}' if instance_id is not None else ''
request_url = f'{self._orchestration_bindings.rpc_base_url}orchestrators/' \
f'{orchestration_function_name}{instance_path}'
return request_url

def _get_raise_event_url(self, instance_id, event_name, task_hub_name, connection_name):
def _get_raise_event_url(
self, instance_id: str, event_name: str,
task_hub_name: Optional[str], connection_name: Optional[str]) -> str:
request_url = f'{self._orchestration_bindings.rpc_base_url}' \
f'instances/{instance_id}/raiseEvent/{event_name}'

query = []
query: List[str] = []
if task_hub_name:
query.append(f'taskHub={task_hub_name}')

Expand Down
Loading