Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(snql) Support higher order functions in SnQL #2333

Merged
merged 6 commits into from
Jan 12, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 77 additions & 4 deletions snuba/query/snql/parser.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
from dataclasses import replace
from datetime import datetime, timedelta
Expand Down Expand Up @@ -159,10 +161,11 @@

low_pri_op = "+" / "-"
high_pri_op = "/" / "*"
param_expression = low_pri_arithmetic / quoted_literal
param_expression = low_pri_arithmetic / quoted_literal / identifier
parameters_list = parameter* (param_expression)
parameter = param_expression space* comma space*
parameter = (lambda / param_expression) space* comma space*
function_call = function_name open_paren parameters_list? close_paren (open_paren parameters_list? close_paren)? (space+ "AS" space+ alias_literal)?
lambda = open_paren space* identifier (comma space* identifier)* space* close_paren space* arrow space* function_call

aliased_tag_column = tag_column space+ "AS" space+ alias_literal
aliased_subscriptable = subscriptable space+ "AS" space+ alias_literal
Expand All @@ -181,11 +184,13 @@
subscriptable = column_name open_square column_name close_square
column_name = ~r"[a-zA-Z_][a-zA-Z0-9_\.:]*"
tag_column = "tags" open_square tag_name close_square
tag_name = ~r"[^\[\]]*"
tag_name = ~r"[^\[\]]*"
identifier = backtick ~r"[a-zA-Z_][a-zA-Z0-9_]*" backtick
function_name = ~r"[a-zA-Z_][a-zA-Z0-9_]*"
entity_alias = ~r"[a-zA-Z_][a-zA-Z0-9_]*"
entity_name = ~r"[a-zA-Z_]+"
relationship_name = ~r"[a-zA-Z_][a-zA-Z0-9_]*"
arrow = "->"
open_brace = "{"
close_brace = "}"
open_paren = "("
Expand All @@ -195,6 +200,7 @@
space = ~r"\s"
comma = ","
colon = ":"
backtick = "`"

"""
)
Expand Down Expand Up @@ -822,6 +828,42 @@ def visit_aliased_column_name(
column, _, _, _, alias = visited_children
return SelectedExpression(alias.text, column)

def visit_identifier(
self, node: Node, visited_children: Tuple[Any, Node, Any]
) -> Argument:
return Argument(None, visited_children[1].text)

def visit_lambda(
self,
node: Node,
visited_children: Tuple[
Any,
Any,
Argument,
Union[Node, List[Node | Argument]],
Any,
Any,
Any,
Any,
Any,
Expression,
],
) -> Lambda:
first_identifier = visited_children[2]
other_identifiers = visited_children[3]
functionCall = visited_children[-1]
parameters = [first_identifier.name]
if isinstance(other_identifiers, list):
for other in other_identifiers:
if isinstance(other, Argument):
parameters.append(other.name)
elif isinstance(other, list):
parameters.extend(
[o.name for o in other if isinstance(o, Argument)]
)

return Lambda(None, tuple(parameters), functionCall)

def generic_visit(self, node: Node, visited_children: Any) -> Any:
return generic_visit(node, visited_children)

Expand Down Expand Up @@ -1119,6 +1161,33 @@ def mangle_column_value(exp: Expression) -> Expression:
query.transform_expressions(mangle_column_value)


def validate_identifiers_in_lambda(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing I remembered before you merge. You don' test the recursive case of this function, you just have the one test. You may want to do that

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, I added a new test and tweaked the code as a result.

query: Union[CompositeQuery[QueryEntity], LogicalQuery]
) -> None:
"""
Check to make sure that any identifiers referenced in a lambda were defined in that lambda
or in an outer lambda.
"""
identifiers: Set[str] = set()

def validate_lambda(exp: Lambda) -> None:
for p in exp.parameters:
identifiers.add(p)

for inner_exp in exp.transformation:
if isinstance(inner_exp, Argument) and inner_exp.name not in identifiers:
raise InvalidExpressionException(f"identifier {inner_exp} not defined")
elif isinstance(inner_exp, Lambda):
validate_lambda(inner_exp)

for p in exp.parameters:
identifiers.discard(p)

for exp in query.get_all_expressions():
if isinstance(exp, Lambda):
validate_lambda(exp)


def _replace_time_condition(
query: Union[CompositeQuery[QueryEntity], LogicalQuery]
) -> None:
Expand Down Expand Up @@ -1340,7 +1409,11 @@ def _post_process(
_array_column_conditions,
]

VALIDATORS = [validate_query, validate_entities_with_query]
VALIDATORS = [
validate_identifiers_in_lambda,
validate_query,
validate_entities_with_query,
]


CustomProcessors = Sequence[
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,8 @@ def get_raw_transaction(span_id: str | None = None) -> InsertEvent:
"description": "SELECT * FROM users",
"data": {},
"timestamp": calendar.timegm(end_timestamp.timetuple()),
"hash": "5029609156d8133",
"exclusive_time": 1.2,
evanh marked this conversation as resolved.
Show resolved Hide resolved
}
],
},
Expand Down
10 changes: 10 additions & 0 deletions tests/query/snql/test_invalid_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,16 @@
"Parsing error on line 1 at ' ORDER BY f '",
id="aliases are only in the select",
),
pytest.param(
"MATCH (discover_events) SELECT arrayMap((x) -> identity(`x`), sdk_integrations) AS sdks WHERE project_id = 1 AND timestamp >= toDateTime('2021-01-01') AND timestamp < toDateTime('2021-01-02')",
"Parsing error on line 1 at 'ap((x) -> ide'",
id="identifiers have backticks",
),
pytest.param(
"MATCH (discover_events) SELECT arrayMap((`x`) -> `x`, sdk_integrations) AS sdks WHERE project_id = 1 AND timestamp >= toDateTime('2021-01-01') AND timestamp < toDateTime('2021-01-02')",
"Parsing error on line 1 at 'ap((`x`) -> `'",
evanh marked this conversation as resolved.
Show resolved Hide resolved
id="ensure function after arrow",
),
]


Expand Down
183 changes: 183 additions & 0 deletions tests/query/snql/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,189 @@ def build_cond(tn: str) -> str:
),
id="aliased columns in select and group by",
),
pytest.param(
f"""MATCH (discover_events)
SELECT arrayMap((`x`) -> identity(`x`), sdk_integrations) AS sdks
WHERE {added_condition}
""",
LogicalQuery(
QueryEntity(
EntityKey.DISCOVER_EVENTS,
get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(),
),
selected_columns=[
SelectedExpression(
"sdks",
FunctionCall(
"_snuba_sdks",
"arrayMap",
(
Lambda(
None,
("x",),
FunctionCall(None, "identity", (Argument(None, "x"),)),
),
Column("_snuba_sdk_integrations", None, "sdk_integrations"),
),
),
),
],
limit=1000,
condition=required_condition,
offset=0,
),
id="higher order functions single identifier",
),
pytest.param(
f"""MATCH (discover_events)
SELECT arrayMap((`x`, `y`) -> identity(tuple(`x`, `y`)), sdk_integrations) AS sdks
WHERE {added_condition}
""",
LogicalQuery(
QueryEntity(
EntityKey.DISCOVER_EVENTS,
get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(),
),
selected_columns=[
SelectedExpression(
"sdks",
FunctionCall(
"_snuba_sdks",
"arrayMap",
(
Lambda(
None,
("x", "y"),
FunctionCall(
None,
"identity",
(
FunctionCall(
None,
"tuple",
(Argument(None, "x"), Argument(None, "y")),
),
),
),
),
Column("_snuba_sdk_integrations", None, "sdk_integrations"),
),
),
),
],
limit=1000,
condition=required_condition,
offset=0,
),
id="higher order function multiple identifier",
),
pytest.param(
f"""MATCH (discover_events)
SELECT arrayMap((`x`, `y`, `z`) -> tuple(`x`, `y`, `z`), sdk_integrations) AS sdks
WHERE {added_condition}
""",
LogicalQuery(
QueryEntity(
EntityKey.DISCOVER_EVENTS,
get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(),
),
selected_columns=[
SelectedExpression(
"sdks",
FunctionCall(
"_snuba_sdks",
"arrayMap",
(
Lambda(
None,
("x", "y", "z"),
FunctionCall(
None,
"tuple",
(
Argument(None, "x"),
Argument(None, "y"),
Argument(None, "z"),
),
),
),
Column("_snuba_sdk_integrations", None, "sdk_integrations"),
),
),
),
],
limit=1000,
condition=required_condition,
offset=0,
),
id="higher order function lots of identifier",
),
pytest.param(
f"""MATCH (discover_events)
SELECT arrayReduce('sumIf', spans.op, arrayMap((`x`, `y`) -> if(equals(and(equals(`x`, 'db'), equals(`y`, 'ops')), 1), 1, 0), spans.op, spans.group)) AS spans
WHERE {added_condition}
""",
LogicalQuery(
QueryEntity(
EntityKey.DISCOVER_EVENTS,
get_entity(EntityKey.DISCOVER_EVENTS).get_data_model(),
),
selected_columns=[
SelectedExpression(
"spans",
FunctionCall(
"_snuba_spans",
"arrayReduce",
(
Literal(None, "sumIf"),
Column("_snuba_spans.op", None, "spans.op"),
FunctionCall(
None,
"arrayMap",
(
Lambda(
None,
("x", "y"),
FunctionCall(
None,
"if",
(
binary_condition(
"equals",
binary_condition(
"and",
binary_condition(
"equals",
Argument(None, "x"),
Literal(None, "db"),
),
binary_condition(
"equals",
Argument(None, "y"),
Literal(None, "ops"),
),
),
Literal(None, 1),
),
Literal(None, 1),
Literal(None, 0),
),
),
),
Column("_snuba_spans.op", None, "spans.op"),
Column("_snuba_spans.group", None, "spans.group"),
),
),
),
),
),
],
limit=1000,
condition=required_condition,
offset=0,
),
id="higher order function complex case",
),
]


Expand Down
15 changes: 12 additions & 3 deletions tests/query/test_query_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import re

import pytest

from snuba.datasets.factory import get_dataset
from snuba.query.exceptions import InvalidExpressionException
from snuba.query.parser.exceptions import ParsingException
from snuba.query.snql.parser import parse_snql_query

Expand All @@ -11,13 +14,19 @@
SELECT event_id
WHERE timestamp LIKE 'carbonara'
""",
ParsingException("missing >= condition on column timestamp for entity events"),
id="Invalid LIKE param",
),
pytest.param(
"MATCH (discover_events) SELECT arrayMap((`x`) -> identity(`y`), sdk_integrations) AS sdks WHERE project_id = 1 AND timestamp >= toDateTime('2021-01-01') AND timestamp < toDateTime('2021-01-02')",
InvalidExpressionException("identifier y not defined"),
id="invalid lambda identifier",
),
]


@pytest.mark.parametrize("query_body", test_cases)
def test_validation(query_body: str) -> None:
@pytest.mark.parametrize("query_body, exception", test_cases)
def test_validation(query_body: str, exception: Exception) -> None:
events = get_dataset("events")
with pytest.raises(ParsingException):
with pytest.raises(type(exception), match=re.escape(str(exception))):
parse_snql_query(query_body, events)
Loading