diff --git a/devservices/config.yml b/devservices/config.yml index 4d12d0777d..52c549c22f 100644 --- a/devservices/config.yml +++ b/devservices/config.yml @@ -65,6 +65,7 @@ services: ENABLE_ISSUE_OCCURRENCE_CONSUMER: ${ENABLE_ISSUE_OCCURRENCE_CONSUMER:-} ENABLE_AUTORUN_MIGRATION_SEARCH_ISSUES: 1 ENABLE_GROUP_ATTRIBUTES_CONSUMER: ${ENABLE_GROUP_ATTRIBUTES_CONSUMER:-} + ENABLE_LW_DELETIONS_CONSUMER: ${ENABLE_LW_DELETIONS_CONSUMER:-} platform: linux/amd64 extra_hosts: host.docker.internal: host-gateway diff --git a/snuba/cli/devserver.py b/snuba/cli/devserver.py index b3246fc3b4..0b77f8a974 100644 --- a/snuba/cli/devserver.py +++ b/snuba/cli/devserver.py @@ -481,6 +481,24 @@ def devserver(*, bootstrap: bool, workers: bool) -> None: ), ] + if settings.ENABLE_LW_DELETIONS_CONSUMER: + daemons += [ + ( + "lw-deletions-consumer", + [ + "snuba", + "lw-deletions-consumer", + "--storage-name=search_issues", + "--consumer-group=search_issues_deletes_group", + "--max-rows-batch-size=10", + "--max-batch-time-ms=1000", + "--auto-offset-reset=latest", + "--no-strict-offset-reset", + "--log-level=debug", + ], + ), + ] + manager = Manager() for name, cmd in daemons: manager.add_process( diff --git a/snuba/cli/lw_deletions_consumer.py b/snuba/cli/lw_deletions_consumer.py new file mode 100644 index 0000000000..ba214394f6 --- /dev/null +++ b/snuba/cli/lw_deletions_consumer.py @@ -0,0 +1,174 @@ +import logging +import signal +from typing import Any, Optional, Sequence + +import click +import sentry_sdk +from arroyo import configure_metrics +from arroyo.backends.kafka import KafkaPayload +from arroyo.processing import StreamProcessor + +from snuba import environment, settings +from snuba.consumers.consumer_builder import ( + ConsumerBuilder, + KafkaParameters, + ProcessingParameters, +) +from snuba.consumers.consumer_config import resolve_consumer_config +from snuba.datasets.deletion_settings import MAX_ROWS_TO_DELETE_DEFAULT +from snuba.datasets.storages.factory import get_writable_storage +from snuba.datasets.storages.storage_key import StorageKey +from snuba.environment import setup_logging, setup_sentry +from snuba.lw_deletions.formatters import STORAGE_FORMATTER +from snuba.lw_deletions.strategy import LWDeletionsConsumerStrategyFactory +from snuba.utils.metrics.wrapper import MetricsWrapper +from snuba.utils.streams.metrics_adapter import StreamMetricsAdapter +from snuba.web.bulk_delete_query import STORAGE_TOPIC + +# A longer batch time for deletes is reasonable +# since we want fewer mutations +DEFAULT_DELETIONS_MAX_BATCH_TIME_MS = 60000 * 2 + +logger = logging.getLogger(__name__) + + +@click.command() +@click.option( + "--consumer-group", + help="Consumer group use for consuming the deletion topic.", + required=True, +) +@click.option( + "--bootstrap-server", + multiple=True, + help="Kafka bootstrap server to use for consuming.", +) +@click.option("--storage", help="Storage name to consume from", required=True) +@click.option( + "--max-rows-batch-size", + default=MAX_ROWS_TO_DELETE_DEFAULT, + type=int, + help="Max amount of rows to delete at one time.", +) +@click.option( + "--max-batch-time-ms", + default=DEFAULT_DELETIONS_MAX_BATCH_TIME_MS, + type=int, + help="Max duration to buffer messages in memory for.", +) +@click.option( + "--auto-offset-reset", + default="earliest", + type=click.Choice(["error", "earliest", "latest"]), + help="Kafka consumer auto offset reset.", +) +@click.option( + "--no-strict-offset-reset", + is_flag=True, + help="Forces the kafka consumer auto offset reset.", +) +@click.option( + "--queued-max-messages-kbytes", + default=settings.DEFAULT_QUEUED_MAX_MESSAGE_KBYTES, + type=int, + help="Maximum number of kilobytes per topic+partition in the local consumer queue.", +) +@click.option( + "--queued-min-messages", + default=settings.DEFAULT_QUEUED_MIN_MESSAGES, + type=int, + help="Minimum number of messages per topic+partition the local consumer queue should contain before messages are sent to kafka.", +) +@click.option("--log-level", help="Logging level to use.") +def lw_deletions_consumer( + *, + consumer_group: str, + bootstrap_server: Sequence[str], + storage: str, + max_rows_batch_size: int, + max_batch_time_ms: int, + auto_offset_reset: str, + no_strict_offset_reset: bool, + queued_max_messages_kbytes: int, + queued_min_messages: int, + log_level: str, +) -> None: + setup_logging(log_level) + setup_sentry() + + logger.info("Consumer Starting") + + sentry_sdk.set_tag("storage", storage) + shutdown_requested = False + consumer: Optional[StreamProcessor[KafkaPayload]] = None + + def handler(signum: int, frame: Any) -> None: + nonlocal shutdown_requested + shutdown_requested = True + + if consumer is not None: + consumer.signal_shutdown() + + signal.signal(signal.SIGINT, handler) + signal.signal(signal.SIGTERM, handler) + + topic = STORAGE_TOPIC[storage] + + while not shutdown_requested: + metrics_tags = { + "consumer_group": consumer_group, + "storage": storage, + } + metrics = MetricsWrapper( + environment.metrics, "lw_deletions_consumer", tags=metrics_tags + ) + configure_metrics(StreamMetricsAdapter(metrics), force=True) + consumer_config = resolve_consumer_config( + storage_names=[storage], + raw_topic=topic.value, + commit_log_topic=None, + replacements_topic=None, + bootstrap_servers=bootstrap_server, + commit_log_bootstrap_servers=[], + replacement_bootstrap_servers=[], + slice_id=None, + max_batch_size=max_rows_batch_size, + max_batch_time_ms=max_batch_time_ms, + group_instance_id=consumer_group, + ) + + consumer_builder = ConsumerBuilder( + consumer_config=consumer_config, + kafka_params=KafkaParameters( + group_id=consumer_group, + auto_offset_reset=auto_offset_reset, + strict_offset_reset=not no_strict_offset_reset, + queued_max_messages_kbytes=queued_max_messages_kbytes, + queued_min_messages=queued_min_messages, + ), + processing_params=ProcessingParameters(None, None, None), + max_batch_size=max_rows_batch_size, + max_batch_time_ms=max_batch_time_ms, + max_insert_batch_size=0, + max_insert_batch_time_ms=0, + metrics=metrics, + slice_id=None, + join_timeout=None, + enforce_schema=False, + metrics_tags=metrics_tags, + ) + + writable_storage = get_writable_storage(StorageKey(storage)) + formatter = STORAGE_FORMATTER[storage]() + strategy_factory = LWDeletionsConsumerStrategyFactory( + max_batch_size=max_rows_batch_size, + max_batch_time_ms=max_batch_time_ms, + storage=writable_storage, + formatter=formatter, + metrics=metrics, + ) + + consumer = consumer_builder.build_lw_deletions_consumer(strategy_factory) + + consumer.run() + consumer_builder.flush() diff --git a/snuba/consumers/consumer_builder.py b/snuba/consumers/consumer_builder.py index 41a5f8a6b9..1c39c97bea 100644 --- a/snuba/consumers/consumer_builder.py +++ b/snuba/consumers/consumer_builder.py @@ -375,6 +375,15 @@ def build_dlq_consumer( dlq_policy, ) + def build_lw_deletions_consumer( + self, strategy_factory: ProcessingStrategyFactory[KafkaPayload] + ) -> StreamProcessor[KafkaPayload]: + return self.__build_consumer( + strategy_factory, + self.raw_topic, + self.__build_default_dlq_policy(), + ) + def __build_default_dlq_policy(self) -> Optional[DlqPolicy[KafkaPayload]]: """ Default DLQ policy applies to the base consumer or the DLQ consumer when diff --git a/snuba/lw_deletions/__init__.py b/snuba/lw_deletions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/snuba/lw_deletions/batching.py b/snuba/lw_deletions/batching.py new file mode 100644 index 0000000000..7729e93856 --- /dev/null +++ b/snuba/lw_deletions/batching.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import time +from typing import Callable, Generic, MutableSequence, Optional, TypeVar, Union + +from arroyo.processing.strategies.abstract import ProcessingStrategy +from arroyo.processing.strategies.buffer import Buffer +from arroyo.types import BaseValue, FilteredPayload, Message, TStrategyPayload + +ValuesBatch = MutableSequence[BaseValue[TStrategyPayload]] + + +TPayload = TypeVar("TPayload") +TResult = TypeVar("TResult") + + +Accumulator = Callable[[TResult, BaseValue[TPayload]], TResult] + + +class ReduceRowsBuffer(Generic[TPayload, TResult]): + def __init__( + self, + accumulator: Accumulator[TResult, TPayload], + initial_value: Callable[[], TResult], + max_batch_size: int, + max_batch_time: float, + increment_by: Optional[Callable[[BaseValue[TPayload]], int]] = None, + ): + self.accumulator = accumulator + self.initial_value = initial_value + self.max_batch_size = max_batch_size + self.max_batch_time = max_batch_time + self.increment_by = increment_by + + self._buffer = initial_value() + self._buffer_size = 0 + self._buffer_until = time.time() + max_batch_time + + @property + def buffer(self) -> TResult: + return self._buffer + + @property + def is_empty(self) -> bool: + return self._buffer_size == 0 + + @property + def is_ready(self) -> bool: + return ( + self._buffer_size >= self.max_batch_size + or time.time() >= self._buffer_until + ) + + def append(self, message: BaseValue[TPayload]) -> None: + """ + Instead of increasing the buffer size based on the number + of messages, we use the `rows_to_delete` attribute in the + message payload so we can batch by the number of rows we + want to delete. + """ + self._buffer = self.accumulator(self._buffer, message) + if self.increment_by: + buffer_increment = self.increment_by(message) + else: + buffer_increment = 1 + self._buffer_size += buffer_increment + + def new(self) -> "ReduceRowsBuffer[TPayload, TResult]": + return ReduceRowsBuffer( + accumulator=self.accumulator, + initial_value=self.initial_value, + max_batch_size=self.max_batch_size, + max_batch_time=self.max_batch_time, + increment_by=self.increment_by, + ) + + +class ReduceCustom( + ProcessingStrategy[Union[FilteredPayload, TPayload]], Generic[TPayload, TResult] +): + def __init__( + self, + max_batch_size: int, + max_batch_time: float, + accumulator: Accumulator[TResult, TPayload], + initial_value: Callable[[], TResult], + next_step: ProcessingStrategy[TResult], + increment_by: Optional[Callable[[BaseValue[TPayload]], int]] = None, + ) -> None: + self.__buffer_step = Buffer( + buffer=ReduceRowsBuffer( + max_batch_size=max_batch_size, + max_batch_time=max_batch_time, + accumulator=accumulator, + initial_value=initial_value, + increment_by=increment_by, + ), + next_step=next_step, + ) + + def submit(self, message: Message[Union[FilteredPayload, TPayload]]) -> None: + self.__buffer_step.submit(message) + + def poll(self) -> None: + self.__buffer_step.poll() + + def close(self) -> None: + self.__buffer_step.close() + + def terminate(self) -> None: + self.__buffer_step.terminate() + + def join(self, timeout: Optional[float] = None) -> None: + self.__buffer_step.join(timeout) + + +class BatchStepCustom(ProcessingStrategy[Union[FilteredPayload, TStrategyPayload]]): + def __init__( + self, + max_batch_size: int, + max_batch_time: float, + next_step: ProcessingStrategy[ValuesBatch[TStrategyPayload]], + increment_by: Optional[Callable[[BaseValue[TStrategyPayload]], int]] = None, + ) -> None: + def accumulator( + result: ValuesBatch[TStrategyPayload], value: BaseValue[TStrategyPayload] + ) -> ValuesBatch[TStrategyPayload]: + result.append(value) + return result + + self.__reduce_step: ReduceCustom[ + TStrategyPayload, ValuesBatch[TStrategyPayload] + ] = ReduceCustom( + max_batch_size, + max_batch_time, + accumulator, + lambda: [], + next_step, + increment_by, + ) + + def submit( + self, message: Message[Union[FilteredPayload, TStrategyPayload]] + ) -> None: + self.__reduce_step.submit(message) + + def poll(self) -> None: + self.__reduce_step.poll() + + def close(self) -> None: + self.__reduce_step.close() + + def terminate(self) -> None: + self.__reduce_step.terminate() + + def join(self, timeout: Optional[float] = None) -> None: + self.__reduce_step.join(timeout) diff --git a/snuba/lw_deletions/formatters.py b/snuba/lw_deletions/formatters.py new file mode 100644 index 0000000000..c65095e13b --- /dev/null +++ b/snuba/lw_deletions/formatters.py @@ -0,0 +1,63 @@ +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Mapping, MutableMapping, Sequence, Type + +from snuba.datasets.storages.storage_key import StorageKey +from snuba.web.bulk_delete_query import DeleteQueryMessage +from snuba.web.delete_query import ConditionsType + + +class Formatter(ABC): + """ + Simple class with just a format method, which should + be implemented for each storage type used for deletes. + + The `format` method takes a list of batched up messages + and formats the conditions for the storage, if needed. + """ + + @abstractmethod + def format( + self, messages: Sequence[DeleteQueryMessage] + ) -> Sequence[ConditionsType]: + raise NotImplementedError + + +class SearchIssuesFormatter(Formatter): + def format( + self, messages: Sequence[DeleteQueryMessage] + ) -> Sequence[ConditionsType]: + """ + For the search issues storage we want the additional + formatting step of combining group ids for messages + that have the same project id. + + ex. + project_id [1] and group_id [1, 2] + project_id [1] and group_id [3, 4] + + would be grouped into one condition: + + project_id [1] and group_id [1, 2, 3, 4] + + """ + mapping: MutableMapping[int, set[int]] = defaultdict(set) + conditions = [message["conditions"] for message in messages] + for condition in conditions: + project_id = condition["project_id"][0] + # appease mypy + assert isinstance(project_id, int) + mapping[project_id] = mapping[project_id].union( + # using int() to make mypy happy + set([int(g_id) for g_id in condition["group_id"]]) + ) + + return [ + {"project_id": [project_id], "group_id": list(group_ids)} + for project_id, group_ids in mapping.items() + ] + + +STORAGE_FORMATTER: Mapping[str, Type[Formatter]] = { + StorageKey.SEARCH_ISSUES.value: SearchIssuesFormatter +} diff --git a/snuba/lw_deletions/strategy.py b/snuba/lw_deletions/strategy.py new file mode 100644 index 0000000000..077b4d96ba --- /dev/null +++ b/snuba/lw_deletions/strategy.py @@ -0,0 +1,178 @@ +import time +import typing +from typing import Mapping, Optional, Sequence, TypeVar + +import rapidjson +from arroyo.backends.kafka import KafkaPayload +from arroyo.processing.strategies import ( + CommitOffsets, + ProcessingStrategy, + ProcessingStrategyFactory, +) +from arroyo.processing.strategies.abstract import MessageRejected +from arroyo.types import BaseValue, Commit, Message, Partition + +from snuba import settings +from snuba.attribution import AppID +from snuba.attribution.attribution_info import AttributionInfo +from snuba.datasets.storage import WritableTableStorage +from snuba.lw_deletions.batching import BatchStepCustom, ValuesBatch +from snuba.lw_deletions.formatters import Formatter +from snuba.query.query_settings import HTTPQuerySettings +from snuba.state import get_int_config +from snuba.utils.metrics import MetricsBackend +from snuba.web.bulk_delete_query import construct_or_conditions, construct_query +from snuba.web.delete_query import ( + ConditionsType, + TooManyOngoingMutationsError, + _execute_query, + _num_ongoing_mutations, +) + +TPayload = TypeVar("TPayload") + +import logging + +logger = logging.Logger(__name__) + + +class FormatQuery(ProcessingStrategy[ValuesBatch[KafkaPayload]]): + def __init__( + self, + next_step: ProcessingStrategy[ValuesBatch[KafkaPayload]], + storage: WritableTableStorage, + formatter: Formatter, + metrics: MetricsBackend, + ) -> None: + self.__next_step = next_step + self.__storage = storage + self.__cluster_name = self.__storage.get_cluster().get_clickhouse_cluster_name() + self.__tables = storage.get_deletion_settings().tables + self.__formatter: Formatter = formatter + self.__metrics = metrics + + def poll(self) -> None: + self.__next_step.poll() + + def submit(self, message: Message[ValuesBatch[KafkaPayload]]) -> None: + decode_messages = [ + rapidjson.loads(m.payload.value) for m in message.value.payload + ] + conditions = self.__formatter.format(decode_messages) + + try: + self._execute_delete(conditions) + except TooManyOngoingMutationsError: + # backpressure is applied while we wait for the + # currently ongoing mutations to finish + self.__metrics.increment("too_many_ongoing_mutations") + raise MessageRejected + + self.__next_step.submit(message) + + def _get_attribute_info(self) -> AttributionInfo: + return AttributionInfo( + app_id=AppID("lw-deletes"), + tenant_ids={}, + referrer="lw-deletes", + team=None, + feature=None, + parent_api=None, + ) + + def _execute_delete(self, conditions: Sequence[ConditionsType]) -> None: + self._check_ongoing_mutations() + query_settings = HTTPQuerySettings() + for table in self.__tables: + query = construct_query( + self.__storage, table, construct_or_conditions(conditions) + ) + start = time.time() + _execute_query( + query=query, + storage=self.__storage, + cluster_name=self.__cluster_name, + table=table, + attribution_info=self._get_attribute_info(), + query_settings=query_settings, + ) + self.__metrics.timing( + "execute_delete_query_ms", + (time.time() - start) * 1000, + tags={"table": table}, + ) + + def _check_ongoing_mutations(self) -> None: + start = time.time() + ongoing_mutations = _num_ongoing_mutations( + self.__storage.get_cluster(), self.__tables + ) + max_ongoing_mutations = typing.cast( + int, + get_int_config( + "max_ongoing_mutatations_for_delete", + default=settings.MAX_ONGOING_MUTATIONS_FOR_DELETE, + ), + ) + self.__metrics.timing( + "ongoing_mutations_query_ms", (time.time() - start) * 1000 + ) + max_ongoing_mutations = int(settings.MAX_ONGOING_MUTATIONS_FOR_DELETE) + if ongoing_mutations > max_ongoing_mutations: + + raise TooManyOngoingMutationsError( + f"{ongoing_mutations} mutations for {self.__tables} table(s) is above max ongoing mutations: {max_ongoing_mutations} " + ) + + def close(self) -> None: + self.__next_step.close() + + def terminate(self) -> None: + self.__next_step.terminate() + + def join(self, timeout: Optional[float] = None) -> None: + self.__next_step.join(timeout) + + +def increment_by(message: BaseValue[KafkaPayload]) -> int: + rows_to_delete = rapidjson.loads(message.payload.value)["rows_to_delete"] + assert isinstance(rows_to_delete, int) + return rows_to_delete + + +class LWDeletionsConsumerStrategyFactory(ProcessingStrategyFactory[KafkaPayload]): + """ + The factory manages the lifecycle of the `ProcessingStrategy`. + A strategy is created every time new partitions are assigned to the + consumer, while it is destroyed when partitions are revoked or the + consumer is closed + """ + + def __init__( + self, + max_batch_size: int, + max_batch_time_ms: int, + storage: WritableTableStorage, + formatter: Formatter, + metrics: MetricsBackend, + ) -> None: + self.max_batch_size = max_batch_size + self.max_batch_time_ms = max_batch_time_ms + self.storage = storage + self.formatter = formatter + self.metrics = metrics + + def create_with_partitions( + self, + commit: Commit, + partitions: Mapping[Partition, int], + ) -> ProcessingStrategy[KafkaPayload]: + batch_step = BatchStepCustom( + max_batch_size=self.max_batch_size, + max_batch_time=self.max_batch_time_ms, + next_step=FormatQuery( + CommitOffsets(commit), self.storage, self.formatter, self.metrics + ), + increment_by=increment_by, + ) + return batch_step diff --git a/snuba/settings/__init__.py b/snuba/settings/__init__.py index 0505276885..20ff6f3100 100644 --- a/snuba/settings/__init__.py +++ b/snuba/settings/__init__.py @@ -354,6 +354,9 @@ class RedisClusters(TypedDict): "ENABLE_GROUP_ATTRIBUTES_CONSUMER", False ) +# Enable lw deletions consumer (search issues only for now) +ENABLE_LW_DELETIONS_CONSUMER = os.environ.get("ENABLE_LW_DELETIONS_CONSUMER", True) + # Cutoff time from UTC 00:00:00 to stop running optimize jobs to # avoid spilling over to the next day. OPTIMIZE_JOB_CUTOFF_TIME = 23 diff --git a/tests/lw_deletions/__init__.py b/tests/lw_deletions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/lw_deletions/test_formatters.py b/tests/lw_deletions/test_formatters.py new file mode 100644 index 0000000000..c24811f814 --- /dev/null +++ b/tests/lw_deletions/test_formatters.py @@ -0,0 +1,67 @@ +from typing import Sequence, Type + +import pytest + +from snuba.lw_deletions.formatters import Formatter, SearchIssuesFormatter +from snuba.web.bulk_delete_query import DeleteQueryMessage +from snuba.web.delete_query import ConditionsType + + +def create_delete_query_message(conditions: ConditionsType) -> DeleteQueryMessage: + return DeleteQueryMessage( + rows_to_delete=1, + tenant_ids={}, + conditions=conditions, + storage_name="search_issues", + ) + + +SEARCH_ISSUES_FORMATTER = SearchIssuesFormatter + + +@pytest.mark.parametrize( + "messages, expected_formatted, formatter", + [ + pytest.param( + [ + create_delete_query_message({"project_id": [1], "group_id": [1, 2, 3]}), + create_delete_query_message({"project_id": [1], "group_id": [4, 5, 6]}), + ], + [ + {"project_id": [1], "group_id": [1, 2, 3, 4, 5, 6]}, + ], + SEARCH_ISSUES_FORMATTER, + id="search_issues_combine_group_ids_same_project", + ), + pytest.param( + [ + create_delete_query_message({"project_id": [1], "group_id": [1, 2, 3]}), + create_delete_query_message({"project_id": [2], "group_id": [3]}), + ], + [ + {"project_id": [1], "group_id": [1, 2, 3]}, + {"project_id": [2], "group_id": [3]}, + ], + SEARCH_ISSUES_FORMATTER, + id="search_issues_diff_projects_dont_combine", + ), + pytest.param( + [ + create_delete_query_message({"project_id": [1], "group_id": [1, 2, 3]}), + create_delete_query_message({"project_id": [1], "group_id": [2, 3, 4]}), + ], + [ + {"project_id": [1], "group_id": [1, 2, 3, 4]}, + ], + SEARCH_ISSUES_FORMATTER, + id="search_issues_dedupe_group_ids_in_same_project", + ), + ], +) +def test_search_issues_formatter( + messages: Sequence[DeleteQueryMessage], + expected_formatted: Sequence[ConditionsType], + formatter: Type[Formatter], +) -> None: + formatted = formatter().format(messages) + assert formatted == expected_formatted diff --git a/tests/lw_deletions/test_lw_deletions.py b/tests/lw_deletions/test_lw_deletions.py new file mode 100644 index 0000000000..721a9b0f7a --- /dev/null +++ b/tests/lw_deletions/test_lw_deletions.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Iterator +from unittest.mock import Mock, patch + +import pytest +import rapidjson +from arroyo.backends.kafka import KafkaPayload +from arroyo.types import BrokerValue, Message, Partition, Topic + +from snuba.datasets.storages.factory import get_writable_storage +from snuba.datasets.storages.storage_key import StorageKey +from snuba.lw_deletions.batching import BatchStepCustom +from snuba.lw_deletions.formatters import SearchIssuesFormatter +from snuba.lw_deletions.strategy import FormatQuery, increment_by +from snuba.utils.streams.topics import Topic as SnubaTopic +from snuba.web.bulk_delete_query import DeleteQueryMessage +from snuba.web.delete_query import ConditionsType + +ROWS_CONDITIONS = { + 5: {"project_id": [1], "group_id": [1, 2, 3, 4]}, + 6: {"project_id": [2], "group_id": [1, 2, 3, 4]}, + 1: {"project_id": [2], "group_id": [4, 5, 6, 7]}, + 8: {"project_id": [2], "group_id": [8, 9]}, +} + + +def _get_message(rows: int, conditions: ConditionsType) -> DeleteQueryMessage: + return { + "rows_to_delete": rows, + "storage_name": "search_issues", + "conditions": conditions, + "tenant_ids": {"project_id": 1, "organization_id": 1}, + } + + +def generate_message() -> Iterator[Message[KafkaPayload]]: + epoch = datetime(1970, 1, 1) + i = 0 + messages = [ + rapidjson.dumps(_get_message(rows_to_delete, conditions)).encode("utf-8") + for rows_to_delete, conditions in ROWS_CONDITIONS.items() + ] + + while True: + yield Message( + BrokerValue( + KafkaPayload(None, messages[i], []), + Partition(Topic(SnubaTopic.LW_DELETIONS_GENERIC_EVENTS.value), 0), + i, + epoch, + ) + ) + i += 1 + + +@patch("snuba.lw_deletions.strategy._num_ongoing_mutations", return_value=1) +@patch("snuba.lw_deletions.strategy._execute_query") +@pytest.mark.redis_db +def test_multiple_batches_strategies( + mock_execute: Mock, mock_num_mutations: Mock +) -> None: + commit_step = Mock() + metrics = Mock() + storage = get_writable_storage(StorageKey("search_issues")) + + strategy = BatchStepCustom( + max_batch_size=8, + max_batch_time=1000, + next_step=FormatQuery(commit_step, storage, SearchIssuesFormatter(), metrics), + increment_by=increment_by, + ) + make_message = generate_message() + strategy.submit(next(make_message)) + strategy.submit(next(make_message)) + strategy.submit(next(make_message)) + strategy.submit(next(make_message)) + strategy.close() + strategy.join() + + assert mock_execute.call_count == 2 + assert commit_step.submit.call_count == 2 + + +@patch("snuba.lw_deletions.strategy._num_ongoing_mutations", return_value=1) +@patch("snuba.lw_deletions.strategy._execute_query") +@pytest.mark.redis_db +def test_single_batch(mock_execute: Mock, mock_num_mutations: Mock) -> None: + commit_step = Mock() + metrics = Mock() + storage = get_writable_storage(StorageKey("search_issues")) + + strategy = BatchStepCustom( + max_batch_size=8, + max_batch_time=1000, + next_step=FormatQuery(commit_step, storage, SearchIssuesFormatter(), metrics), + increment_by=increment_by, + ) + message = Message( + BrokerValue( + KafkaPayload( + None, + rapidjson.dumps( + _get_message(10, {"project_id": [1], "group_id": [1]}) + ).encode("utf-8"), + [], + ), + Partition(Topic(SnubaTopic.LW_DELETIONS_GENERIC_EVENTS.value), 0), + 0, + datetime(1970, 1, 1), + ) + ) + strategy.submit(message) + strategy.join(2.0) + + assert mock_execute.call_count == 1 + assert commit_step.submit.call_count == 1 + + +@patch("snuba.lw_deletions.strategy._num_ongoing_mutations", return_value=10) +@patch("snuba.lw_deletions.strategy._execute_query") +@pytest.mark.redis_db +def test_too_many_mutations(mock_execute: Mock, mock_num_mutations: Mock) -> None: + """ + Before we execute the DELETE FROM query, we check to see how many + ongoing mutations there are.If there are more ongoing mutations than + the max allows, we raise MessageRejected and back pressure is applied. + + The max is 5 (the default) and our mocked ongoing mutations is 10. + """ + commit_step = Mock() + metrics = Mock() + storage = get_writable_storage(StorageKey("search_issues")) + + strategy = BatchStepCustom( + max_batch_size=8, + max_batch_time=1000, + next_step=FormatQuery(commit_step, storage, SearchIssuesFormatter(), metrics), + increment_by=increment_by, + ) + message = Message( + BrokerValue( + KafkaPayload( + None, + rapidjson.dumps( + _get_message(10, {"project_id": [2], "group_id": [2]}) + ).encode("utf-8"), + [], + ), + Partition(Topic(SnubaTopic.LW_DELETIONS_GENERIC_EVENTS.value), 0), + 1, + datetime(1970, 1, 1), + ) + ) + strategy.submit(message) + strategy.join(2.0) + + assert mock_execute.call_count == 0 + assert commit_step.submit.call_count == 0