Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(features) Use dataclasses for flagpole instead of pydantic #75859

Merged
merged 9 commits into from
Aug 15, 2024
Merged
84 changes: 54 additions & 30 deletions src/flagpole/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@
"""
from __future__ import annotations

from datetime import datetime
import dataclasses
import functools
import os
from typing import Any

import jsonschema
import orjson
import yaml
from pydantic import BaseModel, Field, ValidationError, constr

from flagpole.conditions import ConditionBase, Segment
from flagpole.evaluation_context import ContextBuilder, EvaluationContext
Expand All @@ -76,26 +78,29 @@ class InvalidFeatureFlagConfiguration(Exception):
pass


class Feature(BaseModel):
name: constr(min_length=1, to_lower=True) = Field( # type:ignore[valid-type]
description="The feature name."
)
@functools.cache
def load_json_schema() -> dict[str, Any]:
path = os.path.join(os.path.dirname(__file__), "flagpole-schema.json")
with open(path, "rb") as json_file:
data = orjson.loads(json_file.read())
return data


@dataclasses.dataclass(frozen=True)
class Feature:
name: str
"The feature name."

owner: constr(min_length=1) = Field( # type:ignore[valid-type]
description="The owner of this feature. Either an email address or team name, preferably."
)
owner: str
"The owner of this feature. Either an email address or team name, preferably."

segments: list[Segment] = Field(
description="The list of segments to evaluate for the feature. An empty list will always evaluate to False."
)
"The list of segments to evaluate for the feature. An empty list will always evaluate to False."

enabled: bool = Field(default=True, description="Whether or not the feature is enabled.")
enabled: bool = dataclasses.field(default=True)
"Whether or not the feature is enabled."

created_at: datetime = Field(description="The datetime when this feature was created.")
segments: list[Segment] = dataclasses.field(default_factory=list)
"The list of segments to evaluate for the feature. An empty list will always evaluate to False."

created_at: str | None = None
"The datetime when this feature was created."

def match(self, context: EvaluationContext) -> bool:
Expand All @@ -109,24 +114,40 @@ def match(self, context: EvaluationContext) -> bool:

return False

@classmethod
def dump_schema_to_file(cls, file_path: str) -> None:
with open(file_path, "w") as file:
file.write(cls.schema_json(indent=2))
def validate(self) -> bool:
"""
Validate a feature against the JSON schema.
Will raise if the the current dict form a feature does not match the schema.
"""
dict_data = dataclasses.asdict(self)
spec = load_json_schema()
jsonschema.validate(dict_data, spec)

return True

@classmethod
def from_feature_dictionary(cls, name: str, config_dict: dict[str, Any]) -> Feature:
segment_data = config_dict.get("segments")
if not isinstance(segment_data, list):
raise InvalidFeatureFlagConfiguration("Feature has no segments defined")
try:
feature = cls(name=name, **config_dict)
except ValidationError as exc:
raise InvalidFeatureFlagConfiguration("Provided JSON is not a valid feature") from exc
segments = [Segment.from_dict(segment) for segment in segment_data]
feature = cls(
name=name,
owner=str(config_dict.get("owner", "")),
enabled=bool(config_dict.get("enabled", True)),
created_at=str(config_dict.get("created_at")),
segments=segments,
)
except Exception as exc:
raise InvalidFeatureFlagConfiguration(
"Provided config_dict is not a valid feature"
) from exc

return feature

@classmethod
def from_feature_config_json(
cls, name: str, config_json: str, context_builder: ContextBuilder | None = None
) -> Feature:
def from_feature_config_json(cls, name: str, config_json: str) -> Feature:
try:
config_data_dict = orjson.loads(config_json)
except orjson.JSONDecodeError as decode_error:
Expand All @@ -135,6 +156,9 @@ def from_feature_config_json(
if not isinstance(config_data_dict, dict):
raise InvalidFeatureFlagConfiguration("Feature JSON is not a valid feature")

if not name:
raise InvalidFeatureFlagConfiguration("Feature name is required")

return cls.from_feature_dictionary(name=name, config_dict=config_data_dict)

@classmethod
Expand All @@ -148,7 +172,7 @@ def from_bulk_json(cls, json: str) -> list[Feature]:
return features

@classmethod
def from_bulk_yaml(cls, yaml_str) -> list[Feature]:
def from_bulk_yaml(cls, yaml_str: str) -> list[Feature]:
features: list[Feature] = []
parsed_yaml = yaml.safe_load(yaml_str)
for feature, yaml_dict in parsed_yaml.items():
Expand All @@ -157,9 +181,9 @@ def from_bulk_yaml(cls, yaml_str) -> list[Feature]:
return features

def to_dict(self) -> dict[str, Any]:
json_dict = dict(orjson.loads(self.json()))
json_dict.pop("name")
return {self.name: json_dict}
dict_data = dataclasses.asdict(self)
dict_data.pop("name")
return {self.name: dict_data}

def to_yaml_str(self) -> str:
return yaml.dump(self.to_dict())
Expand Down
132 changes: 57 additions & 75 deletions src/flagpole/conditions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import dataclasses
from abc import abstractmethod
from collections.abc import Mapping
from enum import Enum
from typing import Annotated, Any, Literal, TypeVar

from pydantic import BaseModel, Field, StrictBool, StrictFloat, StrictInt, StrictStr, constr
from typing import Any, Self, TypeVar

from flagpole.evaluation_context import EvaluationContext

Expand Down Expand Up @@ -48,20 +48,20 @@ def create_case_insensitive_set_from_list(values: list[T]) -> set[T]:
return case_insensitive_set


class ConditionBase(BaseModel):
property: str = Field(description="The evaluation context property to match against.")
@dataclasses.dataclass(frozen=True)
class ConditionBase:
property: str
"""The evaluation context property to match against."""

operator: ConditionOperatorKind = Field(
description="The operator to use when comparing the evaluation context property to the condition's value."
)
"""The operator to use when comparing the evaluation context property to the condition's value."""

value: Any = Field(
description="The value to compare against the condition's evaluation context property."
)
value: Any
"""The value to compare against the condition's evaluation context property."""

operator: str = dataclasses.field(default="")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this, but it is less gross than having to always provide operator to every condition constructor throughout tests.

"""
The name of the operator to use when comparing the evaluation context property to the condition's value.
Values must be a valid ConditionOperatorKind.
"""

def match(self, context: EvaluationContext, segment_name: str) -> bool:
return self._operator_match(
condition_property=context.get(self.property), segment_name=segment_name
Expand Down Expand Up @@ -99,12 +99,8 @@ def _evaluate_contains(self, condition_property: Any, segment_name: str) -> bool

return value in create_case_insensitive_set_from_list(condition_property)

def _evaluate_equals(
self, condition_property: Any, segment_name: str, strict_validation: bool = False
) -> bool:
# Strict validation enforces that a property exists when used in an
# equals condition
if condition_property is None and not strict_validation:
def _evaluate_equals(self, condition_property: Any, segment_name: str) -> bool:
if condition_property is None:
return False

if not isinstance(condition_property, type(self.value)):
Expand All @@ -121,33 +117,33 @@ def _evaluate_equals(
return condition_property == self.value


InOperatorValueTypes = list[StrictInt] | list[StrictFloat] | list[StrictStr]
InOperatorValueTypes = list[int] | list[float] | list[str]


class InCondition(ConditionBase):
operator: Literal[ConditionOperatorKind.IN] = ConditionOperatorKind.IN
value: InOperatorValueTypes
operator: str = dataclasses.field(default="in")

def _operator_match(self, condition_property: Any, segment_name: str):
return self._evaluate_in(condition_property=condition_property, segment_name=segment_name)


class NotInCondition(ConditionBase):
operator: Literal[ConditionOperatorKind.NOT_IN] = ConditionOperatorKind.NOT_IN
value: InOperatorValueTypes
operator: str = dataclasses.field(default="not_in")

def _operator_match(self, condition_property: Any, segment_name: str):
return not self._evaluate_in(
condition_property=condition_property, segment_name=segment_name
)


ContainsOperatorValueTypes = StrictInt | StrictStr | StrictFloat
ContainsOperatorValueTypes = int | str | float


class ContainsCondition(ConditionBase):
operator: Literal[ConditionOperatorKind.CONTAINS] = ConditionOperatorKind.CONTAINS
value: ContainsOperatorValueTypes
operator: str = dataclasses.field(default="contains")

def _operator_match(self, condition_property: Any, segment_name: str):
return self._evaluate_contains(
Expand All @@ -156,101 +152,87 @@ def _operator_match(self, condition_property: Any, segment_name: str):


class NotContainsCondition(ConditionBase):
operator: Literal[ConditionOperatorKind.NOT_CONTAINS] = ConditionOperatorKind.NOT_CONTAINS
value: ContainsOperatorValueTypes
operator: str = dataclasses.field(default="not_contains")

def _operator_match(self, condition_property: Any, segment_name: str):
return not self._evaluate_contains(
condition_property=condition_property, segment_name=segment_name
)


EqualsOperatorValueTypes = (
StrictInt
| StrictFloat
| StrictStr
| StrictBool
| list[StrictInt]
| list[StrictFloat]
| list[StrictStr]
)
EqualsOperatorValueTypes = int | float | str | bool | list[int] | list[float] | list[str]


class EqualsCondition(ConditionBase):
operator: Literal[ConditionOperatorKind.EQUALS] = ConditionOperatorKind.EQUALS
value: EqualsOperatorValueTypes
strict_validation: bool = Field(
description="Whether the condition should enable strict validation, raising an exception if the evaluation context property is missing",
default=False,
)
"""Whether the condition should enable strict validation, raising an exception if the evaluation context property is missing"""
operator: str = dataclasses.field(default="equals")

def _operator_match(self, condition_property: Any, segment_name: str):
return self._evaluate_equals(
condition_property=condition_property,
segment_name=segment_name,
strict_validation=self.strict_validation,
)


class NotEqualsCondition(ConditionBase):
operator: Literal[ConditionOperatorKind.NOT_EQUALS] = ConditionOperatorKind.NOT_EQUALS
value: EqualsOperatorValueTypes
strict_validation: bool = Field(
description="Whether the condition should enable strict validation, raising an exception if the evaluation context property is missing",
default=False,
)
"""Whether the condition should enable strict validation, raising an exception if the evaluation context property is missing"""
operator: str = dataclasses.field(default="not_equals")

def _operator_match(self, condition_property: Any, segment_name: str):
return not self._evaluate_equals(
condition_property=condition_property,
segment_name=segment_name,
strict_validation=self.strict_validation,
)


# We have to group and annotate all the different subclasses of Operator
# in order for Pydantic to be able to discern between the different types
# when parsing a dict or JSON.
AvailableConditions = Annotated[
InCondition
| NotInCondition
| ContainsCondition
| NotContainsCondition
| EqualsCondition
| NotEqualsCondition,
Field(discriminator="operator"),
]
OPERATOR_LOOKUP: Mapping[ConditionOperatorKind, type[ConditionBase]] = {
ConditionOperatorKind.IN: InCondition,
ConditionOperatorKind.NOT_IN: NotInCondition,
ConditionOperatorKind.CONTAINS: ContainsCondition,
ConditionOperatorKind.NOT_CONTAINS: NotContainsCondition,
ConditionOperatorKind.EQUALS: EqualsCondition,
ConditionOperatorKind.NOT_EQUALS: NotEqualsCondition,
}


class Segment(BaseModel):
name: constr(min_length=1) = Field( # type:ignore[valid-type]
description="A brief description or identifier for the segment"
def condition_from_dict(data: Mapping[str, Any]) -> ConditionBase:
operator_kind = ConditionOperatorKind(data.get("operator", "invalid"))
if operator_kind not in OPERATOR_LOOKUP:
valid = ", ".join(OPERATOR_LOOKUP.keys())
raise ValueError(f"The {operator_kind} is not a known operator. Choose from {valid}")

condition_cls = OPERATOR_LOOKUP[operator_kind]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to raise explicitly here if we get an invalid operator type? Otherwise, won't this be an attribute error on NoneType?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make the raise explicit. Currently it would raise KeyError which isn't as informative as it could be.

return condition_cls(
property=str(data.get("property")), operator=operator_kind.value, value=data.get("value")
)


@dataclasses.dataclass
class Segment:
name: str
"A brief description or identifier for the segment"

conditions: list[AvailableConditions] = Field(
description="The list of conditions that the segment must be matched in order for this segment to be active"
)
conditions: list[ConditionBase] = dataclasses.field(default_factory=list)
"The list of conditions that the segment must be matched in order for this segment to be active"

rollout: int | None = Field(
default=0,
description="""
Rollout rate controls how many buckets will be granted a feature when this segment matches.

Rollout rates range from 0 (off) to 100 (all users). Rollout rates use `context.id`
to determine bucket membership consistently over time.
""",
)
rollout: int | None = dataclasses.field(default=0)
"""
Rollout rate controls how many buckets will be granted a feature when this segment matches.

Rollout rates range from 0 (off) to 100 (all users). Rollout rates use `context.id`
to determine bucket membership consistently over time.
"""

@classmethod
def from_dict(cls, data: Mapping[str, Any]) -> Self:
conditions = [condition_from_dict(condition) for condition in data.get("conditions", [])]
return cls(
name=str(data.get("name", "")),
rollout=int(data.get("rollout", 0)),
conditions=conditions,
)

def match(self, context: EvaluationContext) -> bool:
for condition in self.conditions:
match_condition = condition.match(context, segment_name=self.name)
Expand Down
Loading
Loading