Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Switch search SQL to triple-quote strings. #14311

Merged
merged 5 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/14311.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow use of postgres and sqllite full-text search operators in search queries.
188 changes: 99 additions & 89 deletions synapse/storage/databases/main/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def store_search_entries_txn(
if not self.hs.config.server.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
sql = """
INSERT INTO event_search
(event_id, room_id, key, vector, stream_ordering, origin_server_ts)
VALUES (?,?,?,to_tsvector('english', ?),?,?)
"""

args1 = (
(
Expand All @@ -101,20 +101,20 @@ def store_search_entries_txn(
txn.execute_batch(sql, args1)

elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args2 = (
(
entry.event_id,
entry.room_id,
entry.key,
_clean_value_for_search(entry.value),
)
for entry in entries
self.db_pool.simple_insert_many_txn(
Copy link
Member Author

@clokep clokep Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is clearer switching this to simple_insert_many_txn, but the code is identical... (I think).

(Note that we can't do this for psql since it has the to_tsvector(...) call around a parameter)

txn,
table="event_search",
keys=("event_id", "room_id", "key", "value"),
values=(
(
entry.event_id,
entry.room_id,
entry.key,
_clean_value_for_search(entry.value),
)
for entry in entries
),
)
txn.execute_batch(sql, args2)

else:
# This should be unreachable.
Expand Down Expand Up @@ -162,15 +162,17 @@ async def _background_reindex_search(
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]

def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
" JOIN event_json USING (room_id, event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
sql = """
SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts
FROM events
JOIN event_json USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND (%s)
ORDER BY stream_ordering DESC
LIMIT ?
""" % (
" OR ".join("type = '%s'" % (t,) for t in TYPES),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the sort of thing we'd normally use make_in_list_sql_clause for? Maybe one for the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, yes. More than happy to change it here if you'd like.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice-to-have, but won't insist e.g. if you've got a big TODO stack!

)

txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))

Expand Down Expand Up @@ -284,8 +286,10 @@ def create_index(conn: LoggingDatabaseConnection) -> None:

try:
c.execute(
"CREATE INDEX CONCURRENTLY event_search_fts_idx"
" ON event_search USING GIN (vector)"
"""
CREATE INDEX CONCURRENTLY event_search_fts_idx
ON event_search USING GIN (vector)
"""
)
except psycopg2.ProgrammingError as e:
logger.warning(
Expand Down Expand Up @@ -323,12 +327,16 @@ def create_index(conn: LoggingDatabaseConnection) -> None:
# We create with NULLS FIRST so that when we search *backwards*
# we get the ones with non null origin_server_ts *first*
c.execute(
"CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
"room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
"""
CREATE INDEX CONCURRENTLY event_search_room_order
ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
"""
)
c.execute(
"CREATE INDEX CONCURRENTLY event_search_order ON event_search("
"origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
"""
CREATE INDEX CONCURRENTLY event_search_order
ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
"""
)
conn.set_session(autocommit=False)

Expand All @@ -345,14 +353,14 @@ def create_index(conn: LoggingDatabaseConnection) -> None:
)

def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts"
" FROM events AS e"
" WHERE e.event_id = es.event_id"
" AND ? <= e.stream_ordering AND e.stream_ordering < ?"
" RETURNING es.stream_ordering"
)
sql = """
UPDATE event_search AS es
SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts
FROM events AS e
WHERE e.event_id = es.event_id
AND ? <= e.stream_ordering AND e.stream_ordering < ?
RETURNING es.stream_ordering
"""

min_stream_id = max_stream_id - batch_size
txn.execute(sql, (min_stream_id, max_stream_id))
Expand Down Expand Up @@ -456,33 +464,33 @@ async def search_msgs(
if isinstance(self.database_engine, PostgresEngine):
search_query = search_term
tsquery_func = self.database_engine.tsquery_func
sql = (
f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,"
" room_id, event_id"
" FROM event_search"
f" WHERE vector @@ {tsquery_func}('english', ?)"
)
sql = f"""
SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,
room_id, event_id
FROM event_search
WHERE vector @@ {tsquery_func}('english', ?)
"""
args = [search_query, search_query] + args

count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
f" WHERE vector @@ {tsquery_func}('english', ?)"
)
count_sql = f"""
SELECT room_id, count(*) as count FROM event_search
WHERE vector @@ {tsquery_func}('english', ?)
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_for_sqlite(search_term)

sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
" FROM event_search"
" WHERE value MATCH ?"
)
sql = """
SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
FROM event_search
WHERE value MATCH ?
"""
args = [search_query] + args

count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ?"
)
count_sql = """
SELECT room_id, count(*) as count FROM event_search
WHERE value MATCH ?
"""
count_args = [search_query] + count_args
else:
# This should be unreachable.
Expand Down Expand Up @@ -588,26 +596,27 @@ async def search_rooms(
raise SynapseError(400, "Invalid pagination token")

clauses.append(
"(origin_server_ts < ?"
" OR (origin_server_ts = ? AND stream_ordering < ?))"
"""
(origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?))
"""
)
args.extend([origin_server_ts, origin_server_ts, stream])

if isinstance(self.database_engine, PostgresEngine):
search_query = search_term
tsquery_func = self.database_engine.tsquery_func
sql = (
f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
" FROM event_search"
f" WHERE vector @@ {tsquery_func}('english', ?) AND "
)
sql = f"""
SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,
origin_server_ts, stream_ordering, room_id, event_id
FROM event_search
WHERE vector @@ {tsquery_func}('english', ?) AND
"""
args = [search_query, search_query] + args

count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
f" WHERE vector @@ {tsquery_func}('english', ?) AND "
)
count_sql = f"""
SELECT room_id, count(*) as count FROM event_search
WHERE vector @@ {tsquery_func}('english', ?) AND
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):

Expand All @@ -619,23 +628,24 @@ async def search_rooms(
# in the events table to get the topological ordering. We need
# to use the indexes in this order because sqlite refuses to
# MATCH unless it uses the full text search index
sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id,"
" origin_server_ts, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search"
" WHERE value MATCH ?"
" )"
" CROSS JOIN events USING (event_id)"
" WHERE "
sql = """
SELECT
rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering
FROM (
SELECT key, event_id, matchinfo(event_search) as matchinfo
FROM event_search
WHERE value MATCH ?
)
CROSS JOIN events USING (event_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: I think this could be a RIGHT JOIN because events.event_id is not NULL?

WHERE
"""
search_query = _parse_query_for_sqlite(search_term)
args = [search_query] + args

count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ? AND "
)
count_sql = """
SELECT room_id, count(*) as count FROM event_search
WHERE value MATCH ? AND
"""
count_args = [search_query] + count_args
else:
# This should be unreachable.
Expand All @@ -647,10 +657,10 @@ async def search_rooms(
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
if isinstance(self.database_engine, PostgresEngine):
sql += (
" ORDER BY origin_server_ts DESC NULLS LAST,"
" stream_ordering DESC NULLS LAST LIMIT ?"
)
sql += """
ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST
LIMIT ?
"""
elif isinstance(self.database_engine, Sqlite3Engine):
sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
else:
Expand Down