Skip to content

Commit

Permalink
fix: remove some buggy validation of BaseFlow values (#167)
Browse files Browse the repository at this point in the history
* fix: remove some buggy validation of BaseFlow values

* remove incorrect test asserts from upload tasks
  • Loading branch information
matt-codecov committed Nov 2, 2023
1 parent 164ee64 commit 3f5918e
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 56 deletions.
58 changes: 32 additions & 26 deletions helpers/checkpoint_logger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,27 @@
)


def _error(msg, flow, strict=False):
# When a new version of worker rolls out, it will pick up tasks that
# may have been enqueued by the old worker and be missing checkpoints
# data. At least for that reason, we want to allow failing softly.
metrics.incr("worker.checkpoint_logger.error")
CHECKPOINTS_ERRORS.labels(flow=flow.__name__).inc()
if strict:
raise ValueError(msg)
else:
logger.warning(msg)


class BaseFlow(str, Enum):
"""
Base class for a flow. Defines optional functions which are added by the
@success_events, @failure_events, @subflows, and @reliability_counters
decorators to (mostly) appease mypy.
Inherits from `str` so a dictionary of checkpoints data can be serialized
between worker tasks.
between worker tasks. It overrides sort order functions so that it follows
enum declaration order instead of lexicographic order.
"""

_subflows: Callable[[], TSubflows]
Expand All @@ -82,19 +95,9 @@ class BaseFlow(str, Enum):
is_failure: ClassVar[Callable[[T], bool]]
log_counters: ClassVar[Callable[[T], None]]

def __new__(cls: type[T], value: str) -> T:
"""
Hook into the creation of each enum member and inject the class name
into the enum's value (e.g. "MEMBER_NAME" -> "MyEnum.MEMBER_NAME")
"""
value = f"{cls.__name__}.{value}"
return super().__new__(cls, value)

def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]): # type: ignore[override]
"""
This powers `enum.auto()`. We want `MyEnum.MEMBER_NAME` as our value but
we don't have access to the name of `MyEnum` here so just return
`MEMBER_NAME` for now.
This powers `enum.auto()`. It sets the value of "MyEnum.A" to "A".
"""
return name

Expand Down Expand Up @@ -396,15 +399,7 @@ def __init__(
self.strict = strict

def _error(self: _Self, msg: str) -> None:
# When a new version of worker rolls out, it will pick up tasks that
# may have been enqueued by the old worker and be missing checkpoints
# data. At least for that reason, we want to allow failing softly.
metrics.incr("worker.checkpoint_logger.error")
CHECKPOINTS_ERRORS.labels(flow=self.cls.__name__).inc()
if self.strict:
raise ValueError(msg)
else:
logger.warning(msg)
_error(msg, self.cls, self.strict)

def _validate_checkpoint(self: _Self, checkpoint: T) -> None:
if checkpoint.__class__ != self.cls:
Expand Down Expand Up @@ -483,9 +478,20 @@ def from_kwargs(
) -> CheckpointLogger[T]:
data = kwargs.get(_kwargs_key(cls), {})

# Make sure these checkpoints were made with the same flow
for key in data.keys():
if key not in cls.__members__.values():
raise ValueError(f"Checkpoint {key} not part of flow `{cls.__name__}`")
# kwargs has been deserialized into a Python dictionary, but our enum values
# are deserialized as simple strings. We need to ensure the strings are all
# proper enum values as best we can, and then downcast to enum instances.
deserialized_data = {}
for checkpoint, timestamp in data.items():
try:
deserialized_data[cls(checkpoint)] = timestamp
except ValueError:
_error(
f"Checkpoint {checkpoint} not part of flow `{cls.__name__}`",
cls,
strict,
)
deserialized_data = {}
break

return CheckpointLogger(cls, data, strict)
return CheckpointLogger(cls, deserialized_data, strict)
31 changes: 18 additions & 13 deletions helpers/tests/unit/test_checkpoint_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ class TestEnum1(BaseFlow):


class TestEnum2(BaseFlow):
A = auto()
B = auto()
C = auto()
D = auto()
E = auto()
F = auto()


class SortOrderEnum(BaseFlow):
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_log_checkpoint_wrong_enum_throws(self) -> None:
checkpoints = CheckpointLogger(TestEnum1, strict=True)

with self.assertRaises(ValueError):
checkpoints.log(TestEnum2.A) # type: ignore[arg-type]
checkpoints.log(TestEnum2.D) # type: ignore[arg-type]

@patch("helpers.checkpoint_logger._get_milli_timestamp", side_effect=[1337, 9001])
def test_subflow_duration(self, mocker):
Expand Down Expand Up @@ -181,11 +181,11 @@ def test_subflow_duration_wrong_enum(self, mocker):

# Wrong enum for start checkpoint
with self.assertRaises(ValueError):
checkpoints._subflow_duration(TestEnum2.A, TestEnum1.A)
checkpoints._subflow_duration(TestEnum2.D, TestEnum1.A)

# Wrong enum for end checkpoint
with self.assertRaises(ValueError):
checkpoints._subflow_duration(TestEnum1.A, TestEnum2.A)
checkpoints._subflow_duration(TestEnum1.A, TestEnum2.D)

@pytest.mark.real_checkpoint_logger
@patch("helpers.checkpoint_logger._get_milli_timestamp", side_effect=[1337, 9001])
Expand Down Expand Up @@ -221,23 +221,28 @@ def test_create_from_kwargs(self):
TestEnum1.A: 1337,
TestEnum1.B: 9001,
}
deserialized_good_data = json.loads(json.dumps(good_data))
good_kwargs = {
"checkpoints_TestEnum1": good_data,
"checkpoints_TestEnum1": deserialized_good_data,
}
checkpoints = from_kwargs(TestEnum1, good_kwargs, strict=True)
assert checkpoints.data == good_data

# Data is from TestEnum2 but we expected TestEnum1
bad_data = {
TestEnum2.A: 1337,
TestEnum2.B: 9001,
TestEnum2.D: 1337,
TestEnum2.E: 9001,
}
deserialized_bad_data = json.loads(json.dumps(bad_data))
bad_kwargs = {
"checkpoints_TestEnum1": bad_data,
"checkpoints_TestEnum1": deserialized_bad_data,
}
with self.assertRaises(ValueError):
checkpoints = from_kwargs(TestEnum1, bad_kwargs, strict=True)

checkpoints = from_kwargs(TestEnum1, bad_kwargs, strict=False)
assert checkpoints.data == {}

@patch("helpers.checkpoint_logger._get_milli_timestamp", side_effect=[1337, 9001])
def test_log_to_kwargs(self, mock_timestamp):
kwargs = {}
Expand Down Expand Up @@ -514,10 +519,10 @@ def test_serialize_between_tasks(self):
serialized = json.dumps(original)
deserialized = json.loads(serialized)

assert serialized == '{"TestEnum1.A": 1337, "TestEnum1.B": 9001}'
assert serialized == '{"A": 1337, "B": 9001}'
assert deserialized == {
"TestEnum1.A": 1337,
"TestEnum1.B": 9001,
"A": 1337,
"B": 9001,
}

def test_sort_order(self):
Expand Down
12 changes: 3 additions & 9 deletions tasks/tests/unit/test_notify_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from unittest.mock import call

import pytest
Expand Down Expand Up @@ -455,14 +456,6 @@ async def test_simple_call_yes_notifications_no_base(
dbsession.refresh(commit)
assert commit.notified is True

assert checkpoints.data == {
UploadFlow.UPLOAD_TASK_BEGIN: 1337,
UploadFlow.PROCESSING_BEGIN: 9001,
UploadFlow.INITIAL_PROCESSING_COMPLETE: 10000,
UploadFlow.BATCH_PROCESSING_COMPLETE: 15000,
UploadFlow.PROCESSING_COMPLETE: 20000,
UploadFlow.NOTIFIED: 25000,
}
calls = [
call(
"notification_latency",
Expand Down Expand Up @@ -909,7 +902,8 @@ async def test_run_async_can_run_logic(self, dbsession, mock_redis, mocker):
task = NotifyTask()
mock_redis.get.return_value = False
checkpoints = _create_checkpoint_logger(mocker)
kwargs = {_kwargs_key(UploadFlow): checkpoints.data}
checkpoints_data = json.loads(json.dumps(checkpoints.data))
kwargs = {_kwargs_key(UploadFlow): checkpoints_data}
res = await task.run_async(
dbsession,
repoid=commit.repoid,
Expand Down
10 changes: 2 additions & 8 deletions tasks/tests/unit/test_upload_finisher_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ async def test_upload_finisher_task_call(
}

checkpoints = _create_checkpoint_logger(mocker)
kwargs = {_kwargs_key(UploadFlow): checkpoints.data}
checkpoints_data = json.loads(json.dumps(checkpoints.data))
kwargs = {_kwargs_key(UploadFlow): checkpoints_data}
result = await UploadFinisherTask().run_async(
dbsession,
previous_results,
Expand All @@ -94,13 +95,6 @@ async def test_upload_finisher_task_call(
timeout=300,
)

assert checkpoints.data == {
UploadFlow.UPLOAD_TASK_BEGIN: 1337,
UploadFlow.PROCESSING_BEGIN: 9001,
UploadFlow.INITIAL_PROCESSING_COMPLETE: 10000,
UploadFlow.BATCH_PROCESSING_COMPLETE: 15000,
UploadFlow.PROCESSING_COMPLETE: 20000,
}
calls = [
call(
"batch_processing_duration",
Expand Down

0 comments on commit 3f5918e

Please sign in to comment.