Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Switch search SQL to triple-quote strings. #14311

Merged
merged 5 commits into from
Oct 28, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 98 additions & 89 deletions synapse/storage/databases/main/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,11 @@ def store_search_entries_txn(
if not self.hs.config.server.enable_search:
return
if isinstance(self.database_engine, PostgresEngine):
sql = (
"INSERT INTO event_search"
" (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
sql = """
INSERT INTO event_search
(event_id, room_id, key, vector, stream_ordering, origin_server_ts)
VALUES (?,?,?,to_tsvector('english', ?),?,?)
"""

args1 = (
(
Expand All @@ -101,20 +101,20 @@ def store_search_entries_txn(
txn.execute_batch(sql, args1)

elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args2 = (
(
entry.event_id,
entry.room_id,
entry.key,
_clean_value_for_search(entry.value),
)
for entry in entries
self.db_pool.simple_insert_many_txn(
Copy link
Member Author

@clokep clokep Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it is clearer switching this to simple_insert_many_txn, but the code is identical... (I think).

(Note that we can't do this for psql since it has the to_tsvector(...) call around a parameter)

txn,
table="event_search",
keys=("event_id", "room_id", "key", "value"),
values=(
(
entry.event_id,
entry.room_id,
entry.key,
_clean_value_for_search(entry.value),
)
for entry in entries
),
)
txn.execute_batch(sql, args2)

else:
# This should be unreachable.
Expand Down Expand Up @@ -162,15 +162,17 @@ async def _background_reindex_search(
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]

def reindex_search_txn(txn: LoggingTransaction) -> int:
sql = (
"SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
" JOIN event_json USING (room_id, event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
sql = """
SELECT stream_ordering, event_id, room_id, type, json,
origin_server_ts FROM events
JOIN event_json USING (room_id, event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND (%s)
ORDER BY stream_ordering DESC
LIMIT ?
""" % (
" OR ".join("type = '%s'" % (t,) for t in TYPES),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the sort of thing we'd normally use make_in_list_sql_clause for? Maybe one for the future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe so, yes. More than happy to change it here if you'd like.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice-to-have, but won't insist e.g. if you've got a big TODO stack!

)

txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))

Expand Down Expand Up @@ -284,8 +286,10 @@ def create_index(conn: LoggingDatabaseConnection) -> None:

try:
c.execute(
"CREATE INDEX CONCURRENTLY event_search_fts_idx"
" ON event_search USING GIN (vector)"
"""
CREATE INDEX CONCURRENTLY event_search_fts_idx
ON event_search USING GIN (vector)
"""
)
except psycopg2.ProgrammingError as e:
logger.warning(
Expand Down Expand Up @@ -323,12 +327,16 @@ def create_index(conn: LoggingDatabaseConnection) -> None:
# We create with NULLS FIRST so that when we search *backwards*
# we get the ones with non null origin_server_ts *first*
c.execute(
"CREATE INDEX CONCURRENTLY event_search_room_order ON event_search("
"room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
"""
CREATE INDEX CONCURRENTLY event_search_room_order ON event_search(
room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
"""
)
c.execute(
"CREATE INDEX CONCURRENTLY event_search_order ON event_search("
"origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)"
"""
CREATE INDEX CONCURRENTLY event_search_order ON event_search(
origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
"""
)
conn.set_session(autocommit=False)

Expand All @@ -345,14 +353,14 @@ def create_index(conn: LoggingDatabaseConnection) -> None:
)

def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
sql = (
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
" origin_server_ts = e.origin_server_ts"
" FROM events AS e"
" WHERE e.event_id = es.event_id"
" AND ? <= e.stream_ordering AND e.stream_ordering < ?"
" RETURNING es.stream_ordering"
)
sql = """
UPDATE event_search AS es SET stream_ordering = e.stream_ordering,
origin_server_ts = e.origin_server_ts
FROM events AS e
WHERE e.event_id = es.event_id
AND ? <= e.stream_ordering AND e.stream_ordering < ?
RETURNING es.stream_ordering
"""

min_stream_id = max_stream_id - batch_size
txn.execute(sql, (min_stream_id, max_stream_id))
Expand Down Expand Up @@ -456,33 +464,33 @@ async def search_msgs(
if isinstance(self.database_engine, PostgresEngine):
search_query = search_term
tsquery_func = self.database_engine.tsquery_func
sql = (
f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,"
" room_id, event_id"
" FROM event_search"
f" WHERE vector @@ {tsquery_func}('english', ?)"
)
sql = f"""
SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) AS rank,
room_id, event_id
FROM event_search
WHERE vector @@ {tsquery_func}('english', ?)
"""
args = [search_query, search_query] + args

count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
f" WHERE vector @@ {tsquery_func}('english', ?)"
)
count_sql = f"""
SELECT room_id, count(*) as count FROM event_search
WHERE vector @@ {tsquery_func}('english', ?)
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_for_sqlite(search_term)

sql = (
"SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
" FROM event_search"
" WHERE value MATCH ?"
)
sql = """
SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
FROM event_search
WHERE value MATCH ?
"""
args = [search_query] + args

count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ?"
)
count_sql = """
SELECT room_id, count(*) as count FROM event_search
WHERE value MATCH ?
"""
count_args = [search_query] + count_args
else:
# This should be unreachable.
Expand Down Expand Up @@ -588,26 +596,27 @@ async def search_rooms(
raise SynapseError(400, "Invalid pagination token")

clauses.append(
"(origin_server_ts < ?"
" OR (origin_server_ts = ? AND stream_ordering < ?))"
"""(origin_server_ts < ?
OR (origin_server_ts = ? AND stream_ordering < ?))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: I wonder if this can be expressed as ( (origin_server_ts, stream_ordering) < (?, ?) ) using a lexicographic ordering?

"""
)
args.extend([origin_server_ts, origin_server_ts, stream])

if isinstance(self.database_engine, PostgresEngine):
search_query = search_term
tsquery_func = self.database_engine.tsquery_func
sql = (
f"SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,"
" origin_server_ts, stream_ordering, room_id, event_id"
" FROM event_search"
f" WHERE vector @@ {tsquery_func}('english', ?) AND "
)
sql = f"""
SELECT ts_rank_cd(vector, {tsquery_func}('english', ?)) as rank,
origin_server_ts, stream_ordering, room_id, event_id
FROM event_search
WHERE vector @@ {tsquery_func}('english', ?) AND
"""
args = [search_query, search_query] + args

count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
f" WHERE vector @@ {tsquery_func}('english', ?) AND "
)
count_sql = f"""
SELECT room_id, count(*) as count FROM event_search
WHERE vector @@ {tsquery_func}('english', ?) AND
"""
count_args = [search_query] + count_args
elif isinstance(self.database_engine, Sqlite3Engine):

Expand All @@ -619,23 +628,23 @@ async def search_rooms(
# in the events table to get the topological ordering. We need
# to use the indexes in this order because sqlite refuses to
# MATCH unless it uses the full text search index
sql = (
"SELECT rank(matchinfo) as rank, room_id, event_id,"
" origin_server_ts, stream_ordering"
" FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
" FROM event_search"
" WHERE value MATCH ?"
" )"
" CROSS JOIN events USING (event_id)"
" WHERE "
sql = """
SELECT rank(matchinfo) as rank, room_id, event_id,
origin_server_ts, stream_ordering
FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo
FROM event_search
WHERE value MATCH ?
)
CROSS JOIN events USING (event_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside: I think this could be a RIGHT JOIN because events.event_id is not NULL?

WHERE
"""
search_query = _parse_query_for_sqlite(search_term)
args = [search_query] + args

count_sql = (
"SELECT room_id, count(*) as count FROM event_search"
" WHERE value MATCH ? AND "
)
count_sql = """
SELECT room_id, count(*) as count FROM event_search
WHERE value MATCH ? AND
"""
count_args = [search_query] + count_args
else:
# This should be unreachable.
Expand All @@ -647,10 +656,10 @@ async def search_rooms(
# We add an arbitrary limit here to ensure we don't try to pull the
# entire table from the database.
if isinstance(self.database_engine, PostgresEngine):
sql += (
" ORDER BY origin_server_ts DESC NULLS LAST,"
" stream_ordering DESC NULLS LAST LIMIT ?"
)
sql += """
ORDER BY origin_server_ts DESC NULLS LAST,
stream_ordering DESC NULLS LAST LIMIT ?
"""
elif isinstance(self.database_engine, Sqlite3Engine):
sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
else:
Expand Down