Skip to content

Commit

Permalink
Add retries to opensearch requests (#623)
Browse files Browse the repository at this point in the history
Retry logic added to all opensearch requests in search, bulk search, add docs. tests added.
  • Loading branch information
VitusAcabado authored Oct 24, 2023
1 parent 53fd0cf commit 1360c5c
Show file tree
Hide file tree
Showing 9 changed files with 858 additions and 67 deletions.
122 changes: 95 additions & 27 deletions src/marqo/_httprequests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import json
import time
import pprint
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Union
Expand All @@ -20,6 +21,11 @@
)
from urllib3.exceptions import InsecureRequestWarning
import warnings
from marqo.tensor_search.tensor_search_logging import get_logger
from marqo.tensor_search.utils import read_env_vars_and_defaults_ints
from marqo.tensor_search.enums import EnvVars

logger = get_logger(__name__)

ALLOWED_OPERATIONS = {requests.delete, requests.get, requests.post, requests.put}

Expand All @@ -38,7 +44,14 @@ def send_request(
path: str,
body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str], str]] = None,
content_type: Optional[str] = None,
max_retry_attempts: Optional[int] = None,
max_retry_backoff_seconds: Optional[int] = None
) -> Any:
if max_retry_attempts is None:
max_retry_attempts = read_env_vars_and_defaults_ints(EnvVars.DEFAULT_MARQO_MAX_BACKEND_RETRY_ATTEMPTS)
if max_retry_backoff_seconds is None:
max_retry_backoff_seconds = read_env_vars_and_defaults_ints(EnvVars.DEFAULT_MARQO_MAX_BACKEND_RETRY_ATTEMPTS)

to_verify = False # self.config.cluster_is_remote

if http_method not in ALLOWED_OPERATIONS:
Expand All @@ -52,64 +65,106 @@ def send_request(
with warnings.catch_warnings():
if not self.config.cluster_is_remote:
warnings.simplefilter('ignore', InsecureRequestWarning)
try:
request_path = self.config.url + '/' + path
if isinstance(body, (bytes, str)):
response = http_method(
request_path,
timeout=self.config.timeout,
headers=req_headers,
data=body,
verify=to_verify
)
else:
response = http_method(
request_path,
timeout=self.config.timeout,
headers=req_headers,
data=json.dumps(body) if body else None,
verify=to_verify
)
return self.__validate(response)
except requests.exceptions.Timeout as err:
raise BackendTimeoutError(str(err)) from err
except requests.exceptions.ConnectionError as err:
raise BackendCommunicationError(str(err)) from err

for attempt in range(max_retry_attempts + 1):
try:
request_path = self.config.url + '/' + path
if isinstance(body, (bytes, str)):
response = http_method(
request_path,
timeout=self.config.timeout,
headers=req_headers,
data=body,
verify=to_verify
)
else:
response = http_method(
request_path,
timeout=self.config.timeout,
headers=req_headers,
data=json.dumps(body) if body else None,
verify=to_verify
)
return self.__validate(response)
except requests.exceptions.Timeout as err:
raise BackendTimeoutError(str(err)) from err
except requests.exceptions.ConnectionError as err:
if (attempt == max_retry_attempts):
raise BackendCommunicationError(str(err)) from err
else:
logger.info(f"BackendCommunicationError encountered... Retrying request to {request_path}. Attempt {attempt + 1} of {max_retry_attempts}")
backoff_sleep = self.calculate_backoff_sleep(attempt, max_retry_backoff_seconds)
time.sleep(backoff_sleep)

def get(
self, path: str,
body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str], str]] = None,
max_retry_attempts: Optional[int] = None,
max_retry_backoff_seconds: Optional[int] = None
) -> Any:
content_type = None
if body is not None:
content_type = 'application/json'
res = self.send_request(requests.get, path=path, body=body, content_type=content_type)
res = self.send_request(
http_method=requests.get,
path=path,
body=body,
content_type=content_type,
max_retry_attempts=max_retry_attempts,
max_retry_backoff_seconds=max_retry_backoff_seconds
)
return res

def post(
self,
path: str,
body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str], str]] = None,
content_type: Optional[str] = 'application/json',
max_retry_attempts: Optional[int] = None,
max_retry_backoff_seconds: Optional[int] = None
) -> Any:
return self.send_request(requests.post, path, body, content_type)
return self.send_request(
http_method=requests.post,
path=path,
body=body,
content_type=content_type,
max_retry_attempts=max_retry_attempts,
max_retry_backoff_seconds=max_retry_backoff_seconds
)

def put(
self,
path: str,
body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str], str]] = None,
content_type: Optional[str] = None,
max_retry_attempts: Optional[int] = None,
max_retry_backoff_seconds: Optional[int] = None
) -> Any:
if body is not None:
content_type = 'application/json'
return self.send_request(requests.put, path, body, content_type)
return self.send_request(
http_method=requests.put,
path=path,
body=body,
content_type=content_type,
max_retry_attempts=max_retry_attempts,
max_retry_backoff_seconds=max_retry_backoff_seconds
)

def delete(
self,
path: str,
body: Optional[Union[Dict[str, Any], List[Dict[str, Any]], List[str]]] = None,
max_retry_attempts: Optional[int] = None,
max_retry_backoff_seconds: Optional[int] = None
) -> Any:
return self.send_request(requests.delete, path, body)
return self.send_request(
http_method=requests.delete,
path=path,
body=body,
max_retry_attempts=max_retry_attempts,
max_retry_backoff_seconds=max_retry_backoff_seconds
)

@staticmethod
def __to_json(
Expand All @@ -129,6 +184,19 @@ def __validate(
except requests.exceptions.HTTPError as err:
convert_to_marqo_web_error_and_raise(response=request, err=err)

def calculate_backoff_sleep(self, attempt: int, cap: int) -> float:
"""Calculates the backoff sleep time for a given attempt
Args:
attempt (int): the attempt number
Returns:
float: the backoff sleep time
"""
return min(
cap * 1000, # convert to milliseconds
(2 ** attempt) * 10 # start at 10ms for first attempt
) / 1000 # convert to seconds

def convert_to_marqo_web_error_and_raise(response: requests.Response, err: requests.exceptions.HTTPError):
"""Translates OpenSearch errors into Marqo errors, which are then raised
Expand Down
24 changes: 20 additions & 4 deletions src/marqo/tensor_search/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
import pprint


def get_index_info(config: Config, index_name: str) -> IndexInfo:
def get_index_info(
config: Config,
index_name: str,
max_retry_attempts: int = None,
max_retry_backoff_seconds: int = None
) -> IndexInfo:
"""Gets useful information about the index. Also updates the IndexInfo cache
Args:
Expand All @@ -31,7 +36,11 @@ def get_index_info(config: Config, index_name: str) -> IndexInfo:
NonTensorIndexError: If the index's mapping doesn't conform to a Tensor Search index.
IndexNotFoundError: If index does not exist.
"""
res = HttpRequests(config).get(path=F"{index_name}/_mapping")
res = HttpRequests(config).get(
path=F"{index_name}/_mapping",
max_retry_attempts=max_retry_attempts,
max_retry_backoff_seconds=max_retry_backoff_seconds
)

if not (index_name in res and "mappings" in res[index_name]
and "_meta" in res[index_name]["mappings"]):
Expand Down Expand Up @@ -62,7 +71,9 @@ def get_index_info(config: Config, index_name: str) -> IndexInfo:

def add_customer_field_properties(config: Config, index_name: str,
customer_field_names: Iterable[Tuple[str, enums.OpenSearchDataType]],
multimodal_combination_fields: Dict[str, Iterable[Tuple[str, enums.OpenSearchDataType]]]):
multimodal_combination_fields: Dict[str, Iterable[Tuple[str, enums.OpenSearchDataType]]],
max_retry_attempts: int = None,
max_retry_backoff_seconds: int = None) -> None:
"""Adds new customer fields to index mapping.
Pushes the updated mapping to OpenSearch, and updates the local cache.
Expand Down Expand Up @@ -110,7 +121,12 @@ def add_customer_field_properties(config: Config, index_name: str,
},
}

mapping_res = HttpRequests(config).put(path=F"{index_name}/_mapping", body=json.dumps(body))
mapping_res = HttpRequests(config).put(
path=F"{index_name}/_mapping",
body=json.dumps(body),
max_retry_attempts=max_retry_attempts,
max_retry_backoff_seconds=max_retry_backoff_seconds
)

merged_chunk_properties = {
**existing_info.properties[enums.TensorField.chunks]["properties"],
Expand Down
6 changes: 6 additions & 0 deletions src/marqo/tensor_search/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,11 @@ def default_env_vars() -> dict:
EnvVars.MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES: None,
EnvVars.MARQO_MAX_NUMBER_OF_REPLICAS: 1,
EnvVars.MARQO_MAX_ADD_DOCS_COUNT: 64,
EnvVars.DEFAULT_MARQO_MAX_BACKEND_RETRY_ATTEMPTS: 0,
EnvVars.DEFAULT_MARQO_MAX_BACKEND_RETRY_BACKOFF: 1,
EnvVars.MARQO_MAX_BACKEND_SEARCH_RETRY_ATTEMPTS: 0,
EnvVars.MARQO_MAX_BACKEND_SEARCH_RETRY_BACKOFF: 1,
EnvVars.MARQO_MAX_BACKEND_ADD_DOCS_RETRY_ATTEMPTS: 0,
EnvVars.MARQO_MAX_BACKEND_ADD_DOCS_RETRY_BACKOFF: 1
}

2 changes: 1 addition & 1 deletion src/marqo/tensor_search/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@
'eb': 6,
'zb': 7,
'yb': 8
}
}
6 changes: 6 additions & 0 deletions src/marqo/tensor_search/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ class EnvVars:
MARQO_MAX_NUMBER_OF_REPLICAS = "MARQO_MAX_NUMBER_OF_REPLICAS"
MARQO_BEST_AVAILABLE_DEVICE = "MARQO_BEST_AVAILABLE_DEVICE"
MARQO_MAX_ADD_DOCS_COUNT = "MARQO_MAX_ADD_DOCS_COUNT"
MARQO_MAX_BACKEND_SEARCH_RETRY_ATTEMPTS = "MARQO_MAX_BACKEND_SEARCH_RETRY_ATTEMPTS"
MARQO_MAX_BACKEND_SEARCH_RETRY_BACKOFF = "MARQO_MAX_BACKEND_SEARCH_RETRY_BACKOFF"
MARQO_MAX_BACKEND_ADD_DOCS_RETRY_ATTEMPTS = "MARQO_MAX_BACKEND_ADD_DOCS_RETRY_ATTEMPTS"
MARQO_MAX_BACKEND_ADD_DOCS_RETRY_BACKOFF = "MARQO_MAX_BACKEND_ADD_DOCS_RETRY_BACKOFF"
DEFAULT_MARQO_MAX_BACKEND_RETRY_ATTEMPTS = "DEFAULT_MARQO_MAX_BACKEND_RETRY_ATTEMPTS"
DEFAULT_MARQO_MAX_BACKEND_RETRY_BACKOFF = "DEFAULT_MARQO_MAX_BACKEND_RETRY_BACKOFF"


class RequestType:
Expand Down
14 changes: 12 additions & 2 deletions src/marqo/tensor_search/index_meta_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def empty_cache():
index_info_cache = dict()


def get_index_info(config: Config, index_name: str) -> IndexInfo:
def get_index_info(
config: Config,
index_name: str,
max_retry_attempts: int = None,
max_retry_backoff_seconds: int = None
) -> IndexInfo:
"""Looks for the index name in the cache.
If it isn't found there, it will try searching the cluster
Expand All @@ -51,7 +56,12 @@ def get_index_info(config: Config, index_name: str) -> IndexInfo:
if index_name in index_info_cache:
return index_info_cache[index_name]
else:
found_index_info = backend.get_index_info(config=config, index_name=index_name)
found_index_info = backend.get_index_info(
config=config,
index_name=index_name,
max_retry_attempts=max_retry_attempts,
max_retry_backoff_seconds=max_retry_backoff_seconds
)
index_info_cache[index_name] = found_index_info
return index_info_cache[index_name]

Expand Down
Loading

0 comments on commit 1360c5c

Please sign in to comment.