Skip to content

Commit

Permalink
[FR] Add support for sample base queries in elasticsearch (#74)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikaayenson authored Sep 1, 2023
1 parent c5ffa2c commit 5b57dab
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 9 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Event Query Language - Changelog
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

# Version 0.9.18

_Released 2023-09-01_

### Added

* Support for [sample](https://www.elastic.co/guide/en/elasticsearch/reference/current/eql-syntax.html#eql-samples) base query type for Elasticsearch queries

# Version 0.9.17

_Released 2023-08-02_
Expand Down
2 changes: 1 addition & 1 deletion eql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
Walker,
)

__version__ = '0.9.17'
__version__ = '0.9.18'
__all__ = (
"__version__",
"AnalyticOutput",
Expand Down
22 changes: 22 additions & 0 deletions eql/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"SubqueryBy",
"Join",
"Sequence",
"Sample",

# pipes
"PipeCommand",
Expand Down Expand Up @@ -944,6 +945,27 @@ def _render(self):
return text


class Sample(EqlNode):
"""Sample finds events matching the defined filters, regardless of their temporal order.
Sample supports defining one or more join keys.
"""

__slots__ = 'queries',

def __init__(self, queries):
"""Create a Sample of multiple events.
:param list[SubqueryBy] queries: List of queries to be sampled
"""
self.queries = queries

def _render(self):
text = 'sample'
text += self.indent('\n'.join(query.render() for query in self.queries))
return text


class Sequence(EqlNode):
"""Sequence is very similar to join, but enforces an ordering.
Expand Down
46 changes: 46 additions & 0 deletions eql/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,49 @@ def close_join_callback(event): # type: (Event) -> None
for pos, query in enumerate(node.queries):
convert_join_term(query, pos)

def _convert_sample_term(self, subquery, size, samples, next_pipe=None):
check_event = self.convert(subquery.query)

# Determine if there's a join_value present
has_join_value = True if subquery.join_values else False

# If there's a join value, get it.
get_join_value = self._convert_key(subquery.join_values, scoped=True) if has_join_value else None

@self.event_callback(subquery.query.event_type)
def sample_callback(event): # type: (Event) -> None
if check_event(event):
if has_join_value: # The regular case where join values exist
join_value = get_join_value(event)
if join_value not in samples:
samples[join_value] = []
samples[join_value].append(event)

if len(samples[join_value]) == size:
next_pipe(samples[join_value])
samples.pop(join_value)

else: # The case where no join values exist
samples.append(event)
if len(samples) == size:
# Pass a copy to the next_pipe to avoid mutation issues
next_pipe(samples[:])
samples.clear()

def _convert_sample(self, node, next_pipe):
# type: (Sample, callable) -> callable

# Check if there's a join value for any subquery
has_join_value_for_any_subquery = any(subquery.join_values for subquery in node.queries)

# Initialize samples based on the presence of join values
samples = {} if has_join_value_for_any_subquery else []
size = len(node.queries)

for _, query in reversed(list(enumerate(node.queries))):
# Create these in reverse order, so one event can't hit multiple callbacks to be propagated
self._convert_sample_term(query, size, samples, next_pipe)

def _convert_sequence_term(self, subquery, position, size, lookups, next_pipe=None):
# type: (SubqueryBy, int, int, list[dict[object, list[Event]]], callable) -> callable
check_event = self.convert(subquery.query)
Expand Down Expand Up @@ -1080,6 +1123,9 @@ def callback(event): # type: (Event) -> None
elif isinstance(base_query, Sequence):
self._convert_sequence(base_query, output_pipe)

elif isinstance(base_query, Sample):
self._convert_sample(base_query, output_pipe)

else:
raise EqlCompileError("Unsupported {}".format(type(base_query).__name__))

Expand Down
2 changes: 2 additions & 0 deletions eql/etc/eql.g
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ query_with_definitions: definitions piped_query
piped_query: base_query [pipes]
| pipes
base_query: sequence
| sample
| join
| event_query
event_query: [name "where"] expr
sequence: "sequence" [join_values with_params? | with_params join_values?] subquery_by+ [until_subquery_by]
sample: "sample" join_values? subquery_by+
join: "join" join_values? subquery_by subquery_by+ until_subquery_by?
until_subquery_by.2: "until" subquery_by
pipes: pipe+
Expand Down
2 changes: 1 addition & 1 deletion eql/highlighters.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class EqlLexer(RegexLexer):
include('whitespace'),
include('comments'),
(r'(and|in|not|or)\b', token.Operator.Word), # Keyword.Pseudo can also work
(r'(join|sequence|until|where)\b', token.Keyword),
(r'(join|sequence|until|where|sample)\b', token.Keyword),
(r'(%s)(=\s+)(where)\b' % _name, bygroups(token.Name, token.Whitespace, token.Keyword)),
(r'(const)(\s+)(%s)\b' % _name, bygroups(token.Keyword.Declaration, token.Whitespace, token.Name.Constant)),
(r'(macro)(\s+)(%s)\b' % _name, bygroups(token.Keyword.Declaration, token.Whitespace, token.Name.Constant)),
Expand Down
23 changes: 20 additions & 3 deletions eql/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@
nullable_fields = ParserConfig(strict_fields=False)
non_nullable_fields = ParserConfig(strict_fields=True)
allow_enum_fields = ParserConfig(enable_enum=True)
allow_sample = ParserConfig(allow_sample=True)
allow_runs = ParserConfig(allow_runs=True)
elasticsearch_syntax = ParserConfig(elasticsearch_syntax=True)
elasticsearch_validate_optional_fields = ParserConfig(elasticsearch_syntax=True, validate_optional_fields=True)
elastic_endpoint_syntax = ParserConfig(elasticsearch_syntax=True, dollar_var=True, allow_alias=True)

keywords = ("and", "by", "const", "false", "in", "join", "macro",
"not", "null", "of", "or", "sequence", "true", "until", "with", "where"
"not", "null", "of", "or", "sample", "sequence", "true", "until",
"with", "where"
)

RESERVED = {n.render(): n for n in [ast.Boolean(True), ast.Boolean(False), ast.Null()]}
Expand Down Expand Up @@ -145,7 +148,9 @@ def __init__(self, text):
self._stacks = defaultdict(list)
self._alias_enabled = ParserConfig.read_stack("allow_alias", False)
self._alias_mapping = {}
self._allow_runs = ParserConfig.read_stack("allow_runs", False)
self._in_variable = False
self._allow_sample = ParserConfig.read_stack("allow_sample", False)

@property
def lines(self):
Expand Down Expand Up @@ -983,7 +988,7 @@ def pipe(self, node):
return pipe_cls([arg.node for arg in args])

def base_query(self, node):
"""Visit a sequence, join or event query."""
"""Visit a sample, sequence, join or event query."""
return self.visit(node.children[0])

def piped_query(self, node):
Expand Down Expand Up @@ -1214,6 +1219,17 @@ def get_sequence_term_parameter(self, param_node, position, close):

return key, ast.Boolean(bool(value.node.value))

def sample(self, node):
"""Callback function to walk the AST for a sample."""
if not self._allow_sample or not self._elasticsearch_syntax:
raise self._error(node, "Sample not supported")

queries, _ = self._get_subqueries_and_close(node, allow_fork=True)
if len(queries) <= 1:
raise self._error(node, "Only one item in the sample",
cls=EqlSemanticError)
return ast.Sample(queries)

def sequence(self, node):
"""Callback function to walk the AST."""
if not self._subqueries_enabled:
Expand All @@ -1224,7 +1240,7 @@ def sequence(self, node):
if node['with_params']:
params = self.time_range(node['with_params']['time_range'])

allow_runs = self._elasticsearch_syntax
allow_runs = self._elasticsearch_syntax and self._allow_runs

queries, close = self._get_subqueries_and_close(node, allow_fork=True, allow_runs=allow_runs)
if len(queries) <= 1 and not self._elasticsearch_syntax:
Expand Down Expand Up @@ -1507,6 +1523,7 @@ def sequence(self, tree):

# these have similar enough ASTs that this is fine for extracting terms
join = sequence
sample = sequence


def extract_query_terms(text, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions eql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,8 @@ def get_query_type(query):

if isinstance(query.first, ast.Sequence):
return "sequence"
elif isinstance(query.first, ast.Sample):
return "sample"
elif isinstance(query.first, ast.Join):
return "join"
elif isinstance(query.first, ast.EventQuery):
Expand Down
15 changes: 12 additions & 3 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from eql.ast import * # noqa: F403
from eql.errors import EqlSchemaError, EqlSyntaxError, EqlSemanticError, EqlTypeMismatchError, EqlParseError
from eql.parser import (
parse_query, parse_expression, parse_definitions, ignore_missing_functions, parse_field, parse_literal,
extract_query_terms, keywords, elasticsearch_syntax, elastic_endpoint_syntax, elasticsearch_validate_optional_fields
allow_sample, allow_runs, parse_query, parse_expression, parse_definitions, ignore_missing_functions, parse_field,
parse_literal, extract_query_terms, keywords, elasticsearch_syntax, elastic_endpoint_syntax,
elasticsearch_validate_optional_fields
)
from eql.walkers import DepthFirstWalker
from eql.pipes import * # noqa: F403
Expand Down Expand Up @@ -556,7 +557,7 @@ def test_elasticsearch_flag(self):
}
})

with elasticsearch_syntax:
with elasticsearch_syntax, allow_runs:
subquery1 = '[process where opcode == 1] by unique_pid'
runs = [2, 10, 30]
for run in runs:
Expand Down Expand Up @@ -629,6 +630,14 @@ def test_elasticsearch_flag(self):
parse_query('process where ?process.name : "cmd.exe"')
parse_query('process where ?process_name : "cmd.exe"')

# sample base query usage
with allow_sample:
parse_query('sample by user [process where opcode == 1] [process where opcode == 1]')

# invalid sample base query usage
self.assertRaises(EqlSemanticError, parse_query,
'sample by user [process where opcode == 1] [process where opcode == 1]')

with schema:
parse_query("process where process_name == 'cmd.exe'")
parse_query("process where process_name == ?'cmd.exe'")
Expand Down
18 changes: 17 additions & 1 deletion tests/test_python_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from eql import * # noqa: F403
from eql.ast import * # noqa: F403
from eql.engine import Scope
from eql.parser import ignore_missing_functions
from eql.parser import ignore_missing_functions, allow_sample, elasticsearch_syntax
from eql.schema import EVENT_TYPE_GENERIC
from eql.tests.base import TestEngine

Expand Down Expand Up @@ -113,6 +113,22 @@ def test_engine_load(self):
parsed_analytic = parse_analytic({'metadata': {'id': uuid.uuid4()}, 'query': query})
engine.add_analytic(parsed_analytic)

sample_queries = [
'sample [event where x == y] [event where a == b]',
'sample by x,y,z [event where a == 1] [event where b == 2]',
'sample [event where name == "test"] [event where name == "test"]'
]

with ignore_missing_functions, allow_sample, elasticsearch_syntax:
for query in sample_queries:
# Make sure every query can be converted without raising any exceptions
parsed_query = parse_query(query)
engine.add_query(parsed_query)

# Also try to load it as an analytic
parsed_analytic = parse_analytic({'metadata': {'id': uuid.uuid4()}, 'query': query})
engine.add_analytic(parsed_analytic)

def test_raises_errors(self):
"""Confirm that exceptions are raised when expected."""
queries = [
Expand Down

0 comments on commit 5b57dab

Please sign in to comment.