Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import os
import typing
from enum import Enum
Expand All @@ -19,10 +20,13 @@ class CacheType(Enum):
DEFAULT_HOST = "localhost"
DEFAULT_KEEP_ALIVE = 0
DEFAULT_OFFLINE_SOURCE_PATH: typing.Optional[str] = None
DEFAULT_OFFLINE_POLL_MS = 5000
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i added this to be compliant with the config gherkin test - usage of this will be fixed as soon as we touch the in-process provider

DEFAULT_PORT_IN_PROCESS = 8015
DEFAULT_PORT_RPC = 8013
DEFAULT_RESOLVER_TYPE = ResolverType.RPC
DEFAULT_RETRY_BACKOFF = 1000
DEFAULT_RETRY_BACKOFF_MAX = 120000
DEFAULT_RETRY_GRACE_ATTEMPTS = 5
DEFAULT_STREAM_DEADLINE = 600000
DEFAULT_TLS = False

Expand All @@ -32,9 +36,12 @@ class CacheType(Enum):
ENV_VAR_HOST = "FLAGD_HOST"
ENV_VAR_KEEP_ALIVE_TIME_MS = "FLAGD_KEEP_ALIVE_TIME_MS"
ENV_VAR_OFFLINE_FLAG_SOURCE_PATH = "FLAGD_OFFLINE_FLAG_SOURCE_PATH"
ENV_VAR_OFFLINE_POLL_MS = "FLAGD_OFFLINE_POLL_MS"
ENV_VAR_PORT = "FLAGD_PORT"
ENV_VAR_RESOLVER_TYPE = "FLAGD_RESOLVER"
ENV_VAR_RETRY_BACKOFF_MS = "FLAGD_RETRY_BACKOFF_MS"
ENV_VAR_RETRY_BACKOFF_MAX_MS = "FLAGD_RETRY_BACKOFF_MAX_MS"
ENV_VAR_RETRY_GRACE_ATTEMPTS = "FLAGD_RETRY_GRACE_ATTEMPTS"
ENV_VAR_STREAM_DEADLINE_MS = "FLAGD_STREAM_DEADLINE_MS"
ENV_VAR_TLS = "FLAGD_TLS"

Expand Down Expand Up @@ -62,19 +69,23 @@ def env_or_default(
return val if cast is None else cast(val)


@dataclasses.dataclass
class Config:
def __init__( # noqa: PLR0913
self,
host: typing.Optional[str] = None,
port: typing.Optional[int] = None,
tls: typing.Optional[bool] = None,
resolver_type: typing.Optional[ResolverType] = None,
resolver: typing.Optional[ResolverType] = None,
offline_flag_source_path: typing.Optional[str] = None,
offline_poll_interval_ms: typing.Optional[int] = None,
retry_backoff_ms: typing.Optional[int] = None,
deadline: typing.Optional[int] = None,
retry_backoff_max_ms: typing.Optional[int] = None,
retry_grace_attempts: typing.Optional[int] = None,
deadline_ms: typing.Optional[int] = None,
stream_deadline_ms: typing.Optional[int] = None,
keep_alive: typing.Optional[int] = None,
cache_type: typing.Optional[CacheType] = None,
keep_alive_time: typing.Optional[int] = None,
cache: typing.Optional[CacheType] = None,
max_cache_size: typing.Optional[int] = None,
):
self.host = env_or_default(ENV_VAR_HOST, DEFAULT_HOST) if host is None else host
Expand All @@ -94,18 +105,37 @@ def __init__( # noqa: PLR0913
if retry_backoff_ms is None
else retry_backoff_ms
)
self.retry_backoff_max_ms: int = (
int(
env_or_default(
ENV_VAR_RETRY_BACKOFF_MAX_MS, DEFAULT_RETRY_BACKOFF_MAX, cast=int
)
)
if retry_backoff_max_ms is None
else retry_backoff_max_ms
)

self.resolver_type = (
self.retry_grace_attempts: int = (
int(
env_or_default(
ENV_VAR_RETRY_GRACE_ATTEMPTS, DEFAULT_RETRY_GRACE_ATTEMPTS, cast=int
)
)
if retry_grace_attempts is None
else retry_grace_attempts
)

self.resolver = (
env_or_default(
ENV_VAR_RESOLVER_TYPE, DEFAULT_RESOLVER_TYPE, cast=convert_resolver_type
)
if resolver_type is None
else resolver_type
if resolver is None
else resolver
)

default_port = (
DEFAULT_PORT_RPC
if self.resolver_type is ResolverType.RPC
if self.resolver is ResolverType.RPC
else DEFAULT_PORT_IN_PROCESS
)

Expand All @@ -123,10 +153,20 @@ def __init__( # noqa: PLR0913
else offline_flag_source_path
)

self.deadline: int = (
self.offline_poll_interval_ms: int = (
int(
env_or_default(
ENV_VAR_OFFLINE_POLL_MS, DEFAULT_OFFLINE_POLL_MS, cast=int
)
)
if offline_poll_interval_ms is None
else offline_poll_interval_ms
)

self.deadline_ms: int = (
int(env_or_default(ENV_VAR_DEADLINE_MS, DEFAULT_DEADLINE, cast=int))
if deadline is None
else deadline
if deadline_ms is None
else deadline_ms
)

self.stream_deadline_ms: int = (
Expand All @@ -139,18 +179,18 @@ def __init__( # noqa: PLR0913
else stream_deadline_ms
)

self.keep_alive: int = (
self.keep_alive_time: int = (
int(
env_or_default(ENV_VAR_KEEP_ALIVE_TIME_MS, DEFAULT_KEEP_ALIVE, cast=int)
)
if keep_alive is None
else keep_alive
if keep_alive_time is None
else keep_alive_time
)

self.cache_type = (
self.cache = (
CacheType(env_or_default(ENV_VAR_CACHE_TYPE, DEFAULT_CACHE))
if cache_type is None
else cache_type
if cache is None
else cache
)

self.max_cache_size: int = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__( # noqa: PLR0913
keep_alive_time: typing.Optional[int] = None,
cache_type: typing.Optional[CacheType] = None,
max_cache_size: typing.Optional[int] = None,
retry_backoff_max_ms: typing.Optional[int] = None,
retry_grace_attempts: typing.Optional[int] = None,
):
"""
Create an instance of the FlagdProvider
Expand Down Expand Up @@ -79,31 +81,34 @@ def __init__( # noqa: PLR0913
host=host,
port=port,
tls=tls,
deadline=deadline,
deadline_ms=deadline,
retry_backoff_ms=retry_backoff_ms,
resolver_type=resolver_type,
retry_backoff_max_ms=retry_backoff_max_ms,
retry_grace_attempts=retry_grace_attempts,
resolver=resolver_type,
offline_flag_source_path=offline_flag_source_path,
stream_deadline_ms=stream_deadline_ms,
keep_alive=keep_alive_time,
cache_type=cache_type,
keep_alive_time=keep_alive_time,
cache=cache_type,
max_cache_size=max_cache_size,
)

self.resolver = self.setup_resolver()

def setup_resolver(self) -> AbstractResolver:
if self.config.resolver_type == ResolverType.RPC:
if self.config.resolver == ResolverType.RPC:
return GrpcResolver(
self.config,
self.emit_provider_ready,
self.emit_provider_error,
self.emit_provider_stale,
self.emit_provider_configuration_changed,
)
elif self.config.resolver_type == ResolverType.IN_PROCESS:
elif self.config.resolver == ResolverType.IN_PROCESS:
return InProcessResolver(self.config, self)
else:
raise ValueError(
f"`resolver_type` parameter invalid: {self.config.resolver_type}"
f"`resolver_type` parameter invalid: {self.config.resolver}"
)

def initialize(self, evaluation_context: EvaluationContext) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,32 @@


class GrpcResolver:
MAX_BACK_OFF = 120

def __init__(
self,
config: Config,
emit_provider_ready: typing.Callable[[ProviderEventDetails], None],
emit_provider_error: typing.Callable[[ProviderEventDetails], None],
emit_provider_stale: typing.Callable[[ProviderEventDetails], None],
emit_provider_configuration_changed: typing.Callable[
[ProviderEventDetails], None
],
):
self.config = config
self.emit_provider_ready = emit_provider_ready
self.emit_provider_error = emit_provider_error
self.emit_provider_stale = emit_provider_stale
self.emit_provider_configuration_changed = emit_provider_configuration_changed
self.cache: typing.Optional[BaseCacheImpl] = (
LRUCache(maxsize=self.config.max_cache_size)
if self.config.cache_type == CacheType.LRU
if self.config.cache == CacheType.LRU
else None
)
self.stub, self.channel = self._create_stub()
self.retry_backoff_seconds = config.retry_backoff_ms * 0.001
self.retry_backoff_max_seconds = config.retry_backoff_ms * 0.001
self.retry_grace_attempts = config.retry_grace_attempts
self.streamline_deadline_seconds = config.stream_deadline_ms * 0.001
self.deadline = config.deadline * 0.001
self.deadline = config.deadline_ms * 0.001
self.connected = False

def _create_stub(
Expand All @@ -70,13 +72,10 @@ def _create_stub(
channel_factory = grpc.secure_channel if config.tls else grpc.insecure_channel
channel = channel_factory(
f"{config.host}:{config.port}",
options=(("grpc.keepalive_time_ms", config.keep_alive),),
options=(("grpc.keepalive_time_ms", config.keep_alive_time),),
)
stub = evaluation_pb2_grpc.ServiceStub(channel)

if self.cache:
self.cache.clear()

return stub, channel

def initialize(self, evaluation_context: EvaluationContext) -> None:
Expand Down Expand Up @@ -113,8 +112,10 @@ def listen(self) -> None:
if self.streamline_deadline_seconds > 0
else {}
)
retry_counter = 0
while self.active:
request = evaluation_pb2.EventStreamRequest()

try:
logger.debug("Setting up gRPC sync flags connection")
for message in self.stub.EventStream(request, **call_args):
Expand All @@ -126,6 +127,7 @@ def listen(self) -> None:
)
)
self.connected = True
retry_counter = 0
# reset retry delay after successsful read
retry_delay = self.retry_backoff_seconds

Expand All @@ -146,15 +148,37 @@ def listen(self) -> None:
)

self.connected = False
self.handle_error(retry_counter, retry_delay)

retry_delay = self.handle_retry(retry_counter, retry_delay)

retry_counter = retry_counter + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my understanding. The retry_counter seems to be incremented (and the error and retry handling exectued) in every loop iteration, even if the connection succeeds, or am I missing something?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. handle_error is a bad name for what I am actually doing. Two things are hidden behind it.

  1. if we hit our grace_attempts with the error_counter we emit an error event
  2. if our immediate reconnect did not work, we're emitting a stale event on the first time.


def handle_retry(self, retry_counter: int, retry_delay: float) -> float:
if retry_counter == 0:
logger.info("gRPC sync disconnected, reconnecting immediately")
else:
logger.info(f"gRPC sync disconnected, reconnecting in {retry_delay}s")
time.sleep(retry_delay)
retry_delay = min(1.1 * retry_delay, self.retry_backoff_max_seconds)
return retry_delay

def handle_error(self, retry_counter: int, retry_delay: float) -> None:
if retry_counter == self.retry_grace_attempts:
if self.cache:
self.cache.clear()
self.emit_provider_error(
ProviderEventDetails(
message=f"gRPC sync disconnected, reconnecting in {retry_delay}s",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we increase retry_delay in handle_retry(), so we wouldn't print the correct value here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We increment the retry_delay after the sleep, so in the next iteration, it will be increased. hence that the order of execution ensures the right values. But for clarity, I'll reorder the methods.

error_code=ErrorCode.GENERAL,
)
)
logger.info(f"gRPC sync disconnected, reconnecting in {retry_delay}s")
time.sleep(retry_delay)
retry_delay = min(1.1 * retry_delay, self.MAX_BACK_OFF)
elif retry_counter == 1:
self.emit_provider_stale(
ProviderEventDetails(
message=f"gRPC sync disconnected, reconnecting in {retry_delay}s",
)
)

def handle_changed_flags(self, data: typing.Any) -> None:
changed_flags = list(data["flags"].keys())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

# running all gherkin tests, except the ones, not implemented
def pytest_collection_modifyitems(config):
marker = "not customCert and not unixsocket and not sync"
marker = "not customCert and not unixsocket and not sync and not targetURI"

# this seems to not work with python 3.8
if hasattr(config.option, "markexpr") and config.option.markexpr == "":
Expand Down
2 changes: 1 addition & 1 deletion providers/openfeature-provider-flagd/tests/e2e/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def assert_handlers(
)
)
def assert_handler_run(event_type: ProviderEvent, event_handles):
assert_handlers(event_handles, event_type, max_wait=6)
assert_handlers(event_handles, event_type, max_wait=30)


@then(
Expand Down
16 changes: 8 additions & 8 deletions providers/openfeature-provider-flagd/tests/e2e/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest
from asserts import assert_equal
from pytest_bdd import parsers, scenarios, then, when
from pytest_bdd import given, parsers, scenarios, then, when
from tests.e2e.conftest import TEST_HARNESS_PATH

from openfeature.contrib.provider.flagd.config import CacheType, Config, ResolverType
Expand Down Expand Up @@ -47,19 +47,19 @@ def option_values() -> dict:
return {}


@when(
@given(
parsers.cfparse(
'we have an option "{option}" of type "{type_info}" with value "{value}"',
'an option "{option}" of type "{type_info}" with value "{value}"',
),
)
def option_with_value(option: str, value: str, type_info: str, option_values: dict):
value = type_cast[type_info](value)
option_values[camel_to_snake(option)] = value


@when(
@given(
parsers.cfparse(
'we have an environment variable "{env}" with value "{value}"',
'an environment variable "{env}" with value "{value}"',
),
)
def env_with_value(monkeypatch, env: str, value: str):
Expand All @@ -68,7 +68,7 @@ def env_with_value(monkeypatch, env: str, value: str):

@when(
parsers.cfparse(
"we initialize a config",
"a config was initialized",
),
target_fixture="config",
)
Expand All @@ -78,12 +78,12 @@ def initialize_config(option_values):

@when(
parsers.cfparse(
'we initialize a config for "{resolver_type}"',
'a config was initialized for "{resolver_type}"',
),
target_fixture="config",
)
def initialize_config_for(resolver_type: str, option_values: dict):
return Config(resolver_type=ResolverType(resolver_type), **option_values)
return Config(resolver=ResolverType(resolver_type), **option_values)


@then(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ def image():

scenarios(
f"{TEST_HARNESS_PATH}/gherkin/flagd-reconnect.feature",
f"{TEST_HARNESS_PATH}/gherkin/events.feature",
)
Loading
Loading