diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9dafeec0e7..b8de9b0c74 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,8 +22,10 @@ jobs: # Run pre-commit to lint and format check files that were changed (but not deleted) compared to master. # XXX: there is a very small chance that it'll expand to exceed Linux's limits # `getconf ARG_MAX` - max # bytes of args + environ for exec() + # we skip the `no-commit-to-branch` because in CI we are in fact on master already + # and we have merged to it run: | - pre-commit run --files $(git diff --diff-filter=d --name-only master) + SKIP=no-commit-to-branch pre-commit run --files $(git diff --diff-filter=d --name-only master) typing: name: "mypy typing" diff --git a/snuba/admin/clickhouse/common.py b/snuba/admin/clickhouse/common.py new file mode 100644 index 0000000000..0a67cbf839 --- /dev/null +++ b/snuba/admin/clickhouse/common.py @@ -0,0 +1,93 @@ +from typing import MutableMapping + +from snuba import settings +from snuba.clickhouse.native import ClickhousePool +from snuba.clusters.cluster import ClickhouseClientSettings, ClickhouseCluster +from snuba.datasets.storages import StorageKey +from snuba.datasets.storages.factory import get_storage +from snuba.utils.serializable_exception import SerializableException + + +class InvalidNodeError(SerializableException): + pass + + +class InvalidCustomQuery(SerializableException): + pass + + +class InvalidStorageError(SerializableException): + pass + + +def is_valid_node(host: str, port: int, cluster: ClickhouseCluster) -> bool: + nodes = cluster.get_local_nodes() + return any(node.host_name == host and node.port == port for node in nodes) + + +NODE_CONNECTIONS: MutableMapping[str, ClickhousePool] = {} + + +def get_ro_node_connection( + clickhouse_host: str, clickhouse_port: int, storage_name: str +) -> ClickhousePool: + storage_key = None + try: + storage_key = StorageKey(storage_name) + except ValueError: + raise InvalidStorageError( + f"storage {storage_name} is not a valid storage name", + extra_data={"storage_name": storage_name}, + ) + + key = f"{storage_key}-{clickhouse_host}" + if key in NODE_CONNECTIONS: + return NODE_CONNECTIONS[key] + + storage = get_storage(storage_key) + cluster = storage.get_cluster() + + if not is_valid_node(clickhouse_host, clickhouse_port, cluster): + raise InvalidNodeError( + f"host {clickhouse_host} and port {clickhouse_port} are not valid", + extra_data={"host": clickhouse_host, "port": clickhouse_port}, + ) + + database = cluster.get_database() + connection = ClickhousePool( + clickhouse_host, + clickhouse_port, + settings.CLICKHOUSE_READONLY_USER, + settings.CLICKHOUSE_READONLY_PASSWORD, + database, + max_pool_size=2, + # force read-only + client_settings=ClickhouseClientSettings.QUERY.value.settings, + ) + NODE_CONNECTIONS[key] = connection + return connection + + +CLUSTER_CONNECTIONS: MutableMapping[StorageKey, ClickhousePool] = {} + + +def get_ro_cluster_connection(storage_name: str) -> ClickhousePool: + + storage_key = None + try: + storage_key = StorageKey(storage_name) + except ValueError: + raise InvalidStorageError( + f"storage {storage_name} is not a valid storage name", + extra_data={"storage_name": storage_name}, + ) + + if storage_key in CLUSTER_CONNECTIONS: + return CLUSTER_CONNECTIONS[storage_key] + + storage = get_storage(storage_key) + cluster = storage.get_cluster() + connection = cluster.get_query_connection(ClickhouseClientSettings.QUERY) + + CLUSTER_CONNECTIONS[storage_key] = connection + return connection diff --git a/snuba/admin/clickhouse/system_queries.py b/snuba/admin/clickhouse/system_queries.py index a6afd03303..0d2ef31b4c 100644 --- a/snuba/admin/clickhouse/system_queries.py +++ b/snuba/admin/clickhouse/system_queries.py @@ -4,12 +4,9 @@ from clickhouse_driver.errors import ErrorCodes -from snuba import settings +from snuba.admin.clickhouse.common import InvalidCustomQuery, get_ro_node_connection from snuba.clickhouse.errors import ClickhouseError -from snuba.clickhouse.native import ClickhousePool, ClickhouseResult -from snuba.clusters.cluster import ClickhouseClientSettings, ClickhouseCluster -from snuba.datasets.storages import StorageKey -from snuba.datasets.storages.factory import get_storage +from snuba.clickhouse.native import ClickhouseResult from snuba.utils.serializable_exception import SerializableException @@ -17,22 +14,10 @@ class NonExistentSystemQuery(SerializableException): pass -class InvalidNodeError(SerializableException): - pass - - -class InvalidStorageError(SerializableException): - pass - - class InvalidResultError(SerializableException): pass -class InvalidCustomQuery(SerializableException): - pass - - class _QueryRegistry: """Keep a mapping of SystemQueries to their names""" @@ -101,44 +86,14 @@ class ActivePartitions(SystemQuery): """ -def _is_valid_node(host: str, port: int, cluster: ClickhouseCluster) -> bool: - nodes = cluster.get_local_nodes() - return any(node.host_name == host and node.port == port for node in nodes) - - def _run_sql_query_on_host( clickhouse_host: str, clickhouse_port: int, storage_name: str, sql: str ) -> ClickhouseResult: """ Run the SQL query. It should be validated before getting to this point """ - storage_key = None - try: - storage_key = StorageKey(storage_name) - except ValueError: - raise InvalidStorageError(extra_data={"storage_name": storage_name}) - - storage = get_storage(storage_key) - cluster = storage.get_cluster() - - if not _is_valid_node(clickhouse_host, clickhouse_port, cluster): - raise InvalidNodeError( - extra_data={"host": clickhouse_host, "port": clickhouse_port} - ) - - database = cluster.get_database() - - connection = ClickhousePool( - clickhouse_host, - clickhouse_port, - settings.CLICKHOUSE_READONLY_USER, - settings.CLICKHOUSE_READONLY_PASSWORD, - database, - # force read-only - client_settings=ClickhouseClientSettings.QUERY.value.settings, - ) + connection = get_ro_node_connection(clickhouse_host, clickhouse_port, storage_name) query_result = connection.execute(query=sql, with_column_types=True) - connection.close() return query_result diff --git a/snuba/admin/clickhouse/tracing.py b/snuba/admin/clickhouse/tracing.py new file mode 100644 index 0000000000..6c6baad044 --- /dev/null +++ b/snuba/admin/clickhouse/tracing.py @@ -0,0 +1,27 @@ +from snuba.admin.clickhouse.common import InvalidCustomQuery, get_ro_cluster_connection +from snuba.clickhouse.native import ClickhouseResult + + +def validate_trace_query(sql_query: str) -> None: + """ + Simple validation to ensure query only attempts read queries. + + Raises InvalidCustomQuery if query is invalid or not allowed. + """ + sql_query = " ".join(sql_query.split()) + lowered = sql_query.lower().strip() + + if not lowered.startswith("select"): + raise InvalidCustomQuery("Only SELECT queries are allowed") + + disallowed_keywords = ["insert", ";"] + for kw in disallowed_keywords: + if kw in lowered: + raise InvalidCustomQuery(f"{kw} is not allowed in the query") + + +def run_query_and_get_trace(storage_name: str, query: str) -> ClickhouseResult: + validate_trace_query(query) + connection = get_ro_cluster_connection(storage_name) + query_result = connection.execute(query=query, capture_trace=True) + return query_result diff --git a/snuba/admin/views.py b/snuba/admin/views.py index 2c65ef57a8..9f2bc75b44 100644 --- a/snuba/admin/views.py +++ b/snuba/admin/views.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from typing import Any, List, MutableMapping, Optional, cast @@ -6,17 +8,20 @@ from snuba import state from snuba.admin.auth import UnauthorizedException, authorize_request -from snuba.admin.clickhouse.nodes import get_storage_info -from snuba.admin.clickhouse.system_queries import ( +from snuba.admin.clickhouse.common import ( InvalidCustomQuery, InvalidNodeError, - InvalidResultError, InvalidStorageError, +) +from snuba.admin.clickhouse.nodes import get_storage_info +from snuba.admin.clickhouse.system_queries import ( + InvalidResultError, NonExistentSystemQuery, SystemQuery, run_system_query_on_host_by_name, run_system_query_on_host_with_sql, ) +from snuba.admin.clickhouse.tracing import run_query_and_get_trace from snuba.admin.notifications.base import RuntimeConfigAction, RuntimeConfigAutoClient from snuba.admin.runtime_config import ( ConfigChange, @@ -74,7 +79,6 @@ def clickhouse_queries() -> Response: @application.route("/run_clickhouse_system_query", methods=["POST"]) def clickhouse_system_query() -> Response: req = request.get_json() - try: host = req["host"] port = req["port"] @@ -142,6 +146,59 @@ def clickhouse_system_query() -> Response: return make_response(jsonify({"error": "Something went wrong"}), 400) +# Sample cURL command: +# +# curl -X POST \ +# -H 'Content-Type: application/json' \ +# http://localhost:1219/clickhouse_trace_query?query=SELECT+count()+FROM+errors_local +@application.route("/clickhouse_trace_query", methods=["POST"]) +def clickhouse_trace_query() -> Response: + req = json.loads(request.data) + try: + storage = req["storage"] + raw_sql = req["sql"] + except KeyError as e: + return make_response( + jsonify( + { + "error": { + "type": "request", + "message": f"Invalid request, missing key {e.args[0]}", + } + } + ), + 400, + ) + + try: + result = run_query_and_get_trace(storage, raw_sql) + trace_output = result.trace_output + return make_response(jsonify({"trace_output": trace_output}), 200) + except InvalidCustomQuery as err: + return make_response( + jsonify( + { + "error": { + "type": "validation", + "message": err.message or "Invalid query", + } + } + ), + 400, + ) + except ClickhouseError as err: + details = { + "type": "clickhouse", + "message": str(err), + "code": err.code, + } + return make_response(jsonify({"error": details}), 400) + except Exception as err: + return make_response( + jsonify({"error": {"type": "unknown", "message": str(err)}}), 500, + ) + + @application.route("/configs", methods=["GET", "POST"]) def configs() -> Response: if request.method == "POST": diff --git a/snuba/clickhouse/native.py b/snuba/clickhouse/native.py index 5328db7766..fec214ed9f 100644 --- a/snuba/clickhouse/native.py +++ b/snuba/clickhouse/native.py @@ -120,6 +120,13 @@ def execute( conn = self._create_conn() try: + if capture_trace: + settings = ( + {**settings, "send_logs_level": "trace"} + if settings + else {"send_logs_level": "trace"} + ) + query_execute = partial( conn.execute, query, @@ -134,12 +141,7 @@ def execute( trace_output = "" if capture_trace: with capture_logging() as buffer: - if settings: - settings = {**settings, "send_logs_level": "trace"} - else: - settings = {"send_logs_level": "trace"} - result_data = query_execute() - # In order to avoid exposing PII the results are discarded + query_execute() # In order to avoid exposing PII the results are discarded result_data = [[], []] if with_column_types else [] trace_output = buffer.getvalue() else: diff --git a/snuba/query/snql/parser.py b/snuba/query/snql/parser.py index e904169c78..5268f3979b 100644 --- a/snuba/query/snql/parser.py +++ b/snuba/query/snql/parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging from dataclasses import replace from datetime import datetime, timedelta @@ -159,10 +161,11 @@ low_pri_op = "+" / "-" high_pri_op = "/" / "*" - param_expression = low_pri_arithmetic / quoted_literal + param_expression = low_pri_arithmetic / quoted_literal / identifier parameters_list = parameter* (param_expression) - parameter = param_expression space* comma space* + parameter = (lambda / param_expression) space* comma space* function_call = function_name open_paren parameters_list? close_paren (open_paren parameters_list? close_paren)? (space+ "AS" space+ alias_literal)? + lambda = open_paren space* identifier (comma space* identifier)* space* close_paren space* arrow space* function_call aliased_tag_column = tag_column space+ "AS" space+ alias_literal aliased_subscriptable = subscriptable space+ "AS" space+ alias_literal @@ -181,11 +184,13 @@ subscriptable = column_name open_square column_name close_square column_name = ~r"[a-zA-Z_][a-zA-Z0-9_\.:]*" tag_column = "tags" open_square tag_name close_square - tag_name = ~r"[^\[\]]*" + tag_name = ~r"[^\[\]]*" + identifier = backtick ~r"[a-zA-Z_][a-zA-Z0-9_]*" backtick function_name = ~r"[a-zA-Z_][a-zA-Z0-9_]*" entity_alias = ~r"[a-zA-Z_][a-zA-Z0-9_]*" entity_name = ~r"[a-zA-Z_]+" relationship_name = ~r"[a-zA-Z_][a-zA-Z0-9_]*" + arrow = "->" open_brace = "{" close_brace = "}" open_paren = "(" @@ -195,6 +200,7 @@ space = ~r"\s" comma = "," colon = ":" + backtick = "`" """ ) @@ -822,6 +828,42 @@ def visit_aliased_column_name( column, _, _, _, alias = visited_children return SelectedExpression(alias.text, column) + def visit_identifier( + self, node: Node, visited_children: Tuple[Any, Node, Any] + ) -> Argument: + return Argument(None, visited_children[1].text) + + def visit_lambda( + self, + node: Node, + visited_children: Tuple[ + Any, + Any, + Argument, + Union[Node, List[Node | Argument]], + Any, + Any, + Any, + Any, + Any, + Expression, + ], + ) -> Lambda: + first_identifier = visited_children[2] + other_identifiers = visited_children[3] + functionCall = visited_children[-1] + parameters = [first_identifier.name] + if isinstance(other_identifiers, list): + for other in other_identifiers: + if isinstance(other, Argument): + parameters.append(other.name) + elif isinstance(other, list): + parameters.extend( + [o.name for o in other if isinstance(o, Argument)] + ) + + return Lambda(None, tuple(parameters), functionCall) + def generic_visit(self, node: Node, visited_children: Any) -> Any: return generic_visit(node, visited_children) @@ -1119,6 +1161,39 @@ def mangle_column_value(exp: Expression) -> Expression: query.transform_expressions(mangle_column_value) +def validate_identifiers_in_lambda( + query: Union[CompositeQuery[QueryEntity], LogicalQuery] +) -> None: + """ + Check to make sure that any identifiers referenced in a lambda were defined in that lambda + or in an outer lambda. + """ + identifiers: Set[str] = set() + unseen_identifiers: Set[str] = set() + + def validate_lambda(exp: Lambda) -> None: + for p in exp.parameters: + identifiers.add(p) + unseen_identifiers.discard(p) + + for inner_exp in exp.transformation: + if isinstance(inner_exp, Argument) and inner_exp.name not in identifiers: + unseen_identifiers.add(inner_exp.name) + elif isinstance(inner_exp, Lambda): + validate_lambda(inner_exp) + + for p in exp.parameters: + identifiers.discard(p) + + for exp in query.get_all_expressions(): + if isinstance(exp, Lambda): + validate_lambda(exp) + + if len(unseen_identifiers) > 0: + ident_str = ",".join(f"`{u}`" for u in unseen_identifiers) + raise InvalidExpressionException(f"identifier(s) {ident_str} not defined") + + def _replace_time_condition( query: Union[CompositeQuery[QueryEntity], LogicalQuery] ) -> None: @@ -1340,7 +1415,11 @@ def _post_process( _array_column_conditions, ] -VALIDATORS = [validate_query, validate_entities_with_query] +VALIDATORS = [ + validate_identifiers_in_lambda, + validate_query, + validate_entities_with_query, +] CustomProcessors = Sequence[ diff --git a/tests/admin/test_api.py b/tests/admin/test_api.py index a9a41cf78b..8bc6ec28d0 100644 --- a/tests/admin/test_api.py +++ b/tests/admin/test_api.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Any import pytest @@ -71,3 +73,81 @@ def test_post_configs(admin_api: Any) -> None: "/configs", data=json.dumps({"key": "test_string", "value": "bar"}) ) assert response.status_code == 400 + + +def get_node_for_table(admin_api: Any, storage_name: str) -> tuple[str, str, int]: + response = admin_api.get("/clickhouse_nodes") + assert response.status_code == 200, response + nodes = json.loads(response.data) + for node in nodes: + if node["storage_name"] == storage_name: + table = node["local_table_name"] + host = node["local_nodes"][0]["host"] + port = node["local_nodes"][0]["port"] + return str(table), str(host), int(port) + + raise Exception(f"{storage_name} does not have a local node") + + +def test_system_query(admin_api: Any) -> None: + _, host, port = get_node_for_table(admin_api, "errors") + response = admin_api.post( + "/run_clickhouse_system_query", + headers={"Content-Type": "application/json"}, + data=json.dumps( + { + "host": host, + "port": port, + "storage": "errors_ro", + "query_name": "CurrentMerges", + } + ), + ) + assert response.status_code == 200 + data = json.loads(response.data) + assert data["column_names"] == ["count()", "is_currently_executing"] + assert data["rows"] == [] + + +def test_query_trace(admin_api: Any) -> None: + table, _, _ = get_node_for_table(admin_api, "errors_ro") + response = admin_api.post( + "/clickhouse_trace_query", + headers={"Content-Type": "application/json"}, + data=json.dumps( + {"storage": "errors_ro", "sql": f"SELECT count() FROM {table}"} + ), + ) + assert response.status_code == 200 + data = json.loads(response.data) + assert " executeQuery" in data["trace_output"] + + +def test_query_trace_bad_query(admin_api: Any) -> None: + table, _, _ = get_node_for_table(admin_api, "errors_ro") + response = admin_api.post( + "/clickhouse_trace_query", + headers={"Content-Type": "application/json"}, + data=json.dumps( + {"storage": "errors_ro", "sql": f"SELECT count(asdasds) FROM {table}"} + ), + ) + assert response.status_code == 400 + data = json.loads(response.data) + assert "Exception: Missing columns" in data["error"]["message"] + assert "clickhouse" == data["error"]["type"] + + +def test_query_trace_invalid_query(admin_api: Any) -> None: + table, _, _ = get_node_for_table(admin_api, "errors_ro") + response = admin_api.post( + "/clickhouse_trace_query", + headers={"Content-Type": "application/json"}, + data=json.dumps( + {"storage": "errors_ro", "sql": f"SELECT count() FROM {table};"} + ), + ) + assert response.status_code == 400 + data = json.loads(response.data) + assert "; is not allowed in the query" in data["error"]["message"] + assert "validation" == data["error"]["type"] diff --git a/tests/admin/test_system_queries.py b/tests/admin/test_system_queries.py index dd89670ef6..d61a58cd88 100644 --- a/tests/admin/test_system_queries.py +++ b/tests/admin/test_system_queries.py @@ -1,9 +1,7 @@ import pytest -from snuba.admin.clickhouse.system_queries import ( - InvalidCustomQuery, - validate_system_query, -) +from snuba.admin.clickhouse.common import InvalidCustomQuery +from snuba.admin.clickhouse.system_queries import validate_system_query @pytest.mark.parametrize( diff --git a/tests/clickhouse/test_query_format.py b/tests/clickhouse/test_query_format.py index 7c1e873f8f..eea7261df1 100644 --- a/tests/clickhouse/test_query_format.py +++ b/tests/clickhouse/test_query_format.py @@ -553,12 +553,6 @@ def test_format_expressions( clickhouse_query_anonymized = format_query_anonymized(query) assert clickhouse_query.get_sql() == formatted_str assert clickhouse_query.structured() == formatted_seq - if clickhouse_query_anonymized.get_sql() != formatted_anonymized_str: - import pdb - - pdb.set_trace() - print(clickhouse_query_anonymized.get_sql()) - assert clickhouse_query_anonymized.get_sql() == formatted_anonymized_str diff --git a/tests/fixtures.py b/tests/fixtures.py index f5c2f40809..e5fdd38c38 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -253,6 +253,8 @@ def get_raw_transaction(span_id: str | None = None) -> InsertEvent: "description": "SELECT * FROM users", "data": {}, "timestamp": calendar.timegm(end_timestamp.timetuple()), + "hash": "5029609156d8133", + "exclusive_time": 1.2, } ], }, diff --git a/tests/query/snql/test_invalid_queries.py b/tests/query/snql/test_invalid_queries.py index 59da4bfdc2..5be044c02d 100644 --- a/tests/query/snql/test_invalid_queries.py +++ b/tests/query/snql/test_invalid_queries.py @@ -80,6 +80,16 @@ "Parsing error on line 1 at ' ORDER BY f '", id="aliases are only in the select", ), + pytest.param( + "MATCH (discover_events) SELECT arrayMap((x) -> identity(`x`), sdk_integrations) AS sdks WHERE project_id = 1 AND timestamp >= toDateTime('2021-01-01') AND timestamp < toDateTime('2021-01-02')", + "Parsing error on line 1 at 'ap((x) -> ide'", + id="identifiers have backticks", + ), + pytest.param( + "MATCH (discover_events) SELECT arrayMap((`x`) -> `x`, sdk_integrations) AS sdks WHERE project_id = 1 AND timestamp >= toDateTime('2021-01-01') AND timestamp < toDateTime('2021-01-02')", + "Parsing error on line 1 at 'ap((`x`) -> `'", + id="ensure function after arrow", + ), ] diff --git a/tests/query/snql/test_query.py b/tests/query/snql/test_query.py index 0657dc9ea4..501f7709af 100644 --- a/tests/query/snql/test_query.py +++ b/tests/query/snql/test_query.py @@ -1647,6 +1647,250 @@ def build_cond(tn: str) -> str: ), id="aliased columns in select and group by", ), + pytest.param( + f"""MATCH (discover_events) + SELECT arrayMap((`x`) -> identity(`x`), sdk_integrations) AS sdks + WHERE {added_condition} + """, + LogicalQuery( + QueryEntity( + EntityKey.DISCOVER_EVENTS, + get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(), + ), + selected_columns=[ + SelectedExpression( + "sdks", + FunctionCall( + "_snuba_sdks", + "arrayMap", + ( + Lambda( + None, + ("x",), + FunctionCall(None, "identity", (Argument(None, "x"),)), + ), + Column("_snuba_sdk_integrations", None, "sdk_integrations"), + ), + ), + ), + ], + limit=1000, + condition=required_condition, + offset=0, + ), + id="higher order functions single identifier", + ), + pytest.param( + f"""MATCH (discover_events) + SELECT arrayMap((`x`, `y`) -> identity(tuple(`x`, `y`)), sdk_integrations) AS sdks + WHERE {added_condition} + """, + LogicalQuery( + QueryEntity( + EntityKey.DISCOVER_EVENTS, + get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(), + ), + selected_columns=[ + SelectedExpression( + "sdks", + FunctionCall( + "_snuba_sdks", + "arrayMap", + ( + Lambda( + None, + ("x", "y"), + FunctionCall( + None, + "identity", + ( + FunctionCall( + None, + "tuple", + (Argument(None, "x"), Argument(None, "y")), + ), + ), + ), + ), + Column("_snuba_sdk_integrations", None, "sdk_integrations"), + ), + ), + ), + ], + limit=1000, + condition=required_condition, + offset=0, + ), + id="higher order function multiple identifier", + ), + pytest.param( + f"""MATCH (discover_events) + SELECT arrayMap((`x`, `y`, `z`) -> tuple(`x`, `y`, `z`), sdk_integrations) AS sdks + WHERE {added_condition} + """, + LogicalQuery( + QueryEntity( + EntityKey.DISCOVER_EVENTS, + get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(), + ), + selected_columns=[ + SelectedExpression( + "sdks", + FunctionCall( + "_snuba_sdks", + "arrayMap", + ( + Lambda( + None, + ("x", "y", "z"), + FunctionCall( + None, + "tuple", + ( + Argument(None, "x"), + Argument(None, "y"), + Argument(None, "z"), + ), + ), + ), + Column("_snuba_sdk_integrations", None, "sdk_integrations"), + ), + ), + ), + ], + limit=1000, + condition=required_condition, + offset=0, + ), + id="higher order function lots of identifier", + ), + pytest.param( + f"""MATCH (discover_events) + SELECT arrayMap((`x`) -> identity(arrayMap((`y`) -> tuple(`x`, `y`), sdk_integrations)), sdk_integrations) AS sdks + WHERE {added_condition} + """, + LogicalQuery( + QueryEntity( + EntityKey.DISCOVER_EVENTS, + get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(), + ), + selected_columns=[ + SelectedExpression( + "sdks", + FunctionCall( + "_snuba_sdks", + "arrayMap", + ( + Lambda( + None, + ("x",), + FunctionCall( + None, + "identity", + ( + FunctionCall( + None, + "arrayMap", + ( + Lambda( + None, + ("y",), + FunctionCall( + None, + "tuple", + ( + Argument(None, "x"), + Argument(None, "y"), + ), + ), + ), + Column( + "_snuba_sdk_integrations", + None, + "sdk_integrations", + ), + ), + ), + ), + ), + ), + Column("_snuba_sdk_integrations", None, "sdk_integrations"), + ), + ), + ), + ], + limit=1000, + condition=required_condition, + offset=0, + ), + id="higher order function with nested higher order function", + ), + pytest.param( + f"""MATCH (discover_events) + SELECT arrayReduce('sumIf', spans.op, arrayMap((`x`, `y`) -> if(equals(and(equals(`x`, 'db'), equals(`y`, 'ops')), 1), 1, 0), spans.op, spans.group)) AS spans + WHERE {added_condition} + """, + LogicalQuery( + QueryEntity( + EntityKey.DISCOVER_EVENTS, + get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(), + ), + selected_columns=[ + SelectedExpression( + "spans", + FunctionCall( + "_snuba_spans", + "arrayReduce", + ( + Literal(None, "sumIf"), + Column("_snuba_spans.op", None, "spans.op"), + FunctionCall( + None, + "arrayMap", + ( + Lambda( + None, + ("x", "y"), + FunctionCall( + None, + "if", + ( + binary_condition( + "equals", + binary_condition( + "and", + binary_condition( + "equals", + Argument(None, "x"), + Literal(None, "db"), + ), + binary_condition( + "equals", + Argument(None, "y"), + Literal(None, "ops"), + ), + ), + Literal(None, 1), + ), + Literal(None, 1), + Literal(None, 0), + ), + ), + ), + Column("_snuba_spans.op", None, "spans.op"), + Column("_snuba_spans.group", None, "spans.group"), + ), + ), + ), + ), + ), + ], + limit=1000, + condition=required_condition, + offset=0, + ), + id="higher order function complex case", + ), ] diff --git a/tests/query/test_query_validation.py b/tests/query/test_query_validation.py index 2462a3fa40..019f4fcb86 100644 --- a/tests/query/test_query_validation.py +++ b/tests/query/test_query_validation.py @@ -1,6 +1,9 @@ +import re + import pytest from snuba.datasets.factory import get_dataset +from snuba.query.exceptions import InvalidExpressionException from snuba.query.parser.exceptions import ParsingException from snuba.query.snql.parser import parse_snql_query @@ -11,13 +14,24 @@ SELECT event_id WHERE timestamp LIKE 'carbonara' """, + ParsingException("missing >= condition on column timestamp for entity events"), id="Invalid LIKE param", ), + pytest.param( + "MATCH (discover_events) SELECT arrayMap((`x`) -> identity(`y`), sdk_integrations) AS sdks WHERE project_id = 1 AND timestamp >= toDateTime('2021-01-01') AND timestamp < toDateTime('2021-01-02')", + InvalidExpressionException("identifier(s) `y` not defined"), + id="invalid lambda identifier", + ), + pytest.param( + "MATCH (discover_events) SELECT arrayMap((`x`) -> arrayMap((`y`) -> identity(`z`), sdk_integrations), sdk_integrations) AS sdks WHERE project_id = 1 AND timestamp >= toDateTime('2021-01-01') AND timestamp < toDateTime('2021-01-02')", + InvalidExpressionException("identifier(s) `z` not defined"), + id="invalid nested lambda identifier", + ), ] -@pytest.mark.parametrize("query_body", test_cases) -def test_validation(query_body: str) -> None: +@pytest.mark.parametrize("query_body, exception", test_cases) +def test_validation(query_body: str, exception: Exception) -> None: events = get_dataset("events") - with pytest.raises(ParsingException): + with pytest.raises(type(exception), match=re.escape(str(exception))): parse_snql_query(query_body, events) diff --git a/tests/test_clickhouse.py b/tests/test_clickhouse.py index dfeb03f694..b1b44a8c56 100644 --- a/tests/test_clickhouse.py +++ b/tests/test_clickhouse.py @@ -75,7 +75,7 @@ def test_capture_trace() -> None: ) assert data.results == [] assert data.meta == [] - assert data.trace_output == "" + assert data.trace_output != "" assert data.profile is not None assert data.profile["elapsed"] > 0 assert data.profile["bytes"] > 0 diff --git a/tests/test_snql_api.py b/tests/test_snql_api.py index d3b2d54a25..0b4569e06a 100644 --- a/tests/test_snql_api.py +++ b/tests/test_snql_api.py @@ -21,6 +21,7 @@ def post(self, url: str, data: str) -> Any: return self.app.post(url, data=data, headers={"referer": "test"}) def setup_method(self, test_method: Callable[..., Any]) -> None: + state.set_config("write_span_columns_rollout_percentage", 100) super().setup_method(test_method) self.trace_id = uuid.UUID("7400045b-25c4-43b8-8591-4600aa83ad04") self.event = get_raw_event() @@ -551,6 +552,34 @@ def test_nullable_query(self) -> None: ) assert response.status_code == 200 + def test_suspect_spans_data(self) -> None: + response = self.post( + "/discover/snql", + data=json.dumps( + { + "query": f""" + MATCH (discover_transactions) + SELECT arrayReduce('sumIf', spans.exclusive_time_32, arrayMap((`x`, `y`) -> if(equals(and(equals(`x`, 'db'), equals(`y`, '05029609156d8133')), 1), 1, 0), spans.op, spans.group)) AS array_spans_exclusive_time + WHERE + transaction_name = '/api/do_things' AND + has(spans.op, 'db') = 1 AND + has(spans.group, '5029609156d8133') = 1 AND + duration < 900000.0 AND + finish_ts >= toDateTime('{self.base_time.isoformat()}') AND + finish_ts < toDateTime('{self.next_time.isoformat()}') AND + project_id IN tuple({self.project_id}) + ORDER BY array_spans_exclusive_time DESC + LIMIT 10 + """ + } + ), + ) + + assert response.status_code == 200 + data = json.loads(response.data)["data"] + assert len(data) == 1 + assert data[0]["array_spans_exclusive_time"] > 0 + def test_invalid_column(self) -> None: response = self.post( "/outcomes/snql",