diff --git a/pycrdt_websocket/ystore.py b/pycrdt_websocket/ystore.py index 76e3a50..7b59f00 100644 --- a/pycrdt_websocket/ystore.py +++ b/pycrdt_websocket/ystore.py @@ -9,7 +9,7 @@ from inspect import isawaitable from logging import Logger, getLogger from pathlib import Path -from typing import AsyncIterator, Awaitable, Callable, cast +from typing import AsyncIterator, Awaitable, Callable, Literal, cast import anyio from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group @@ -323,6 +323,10 @@ class MySQLiteYStore(SQLiteYStore): # latest update of a document must be before purging document history. # Defaults to never purging document history (None). document_ttl: int | None = None + # The maximum length of the history of the documents in seconds that is kept. + history_length: int | None = None + # The minimum interval in seconds between history cleanup operations. + min_cleanup_interval: int = 60 path: str lock: Lock db_initialized: Event | None @@ -478,25 +482,31 @@ async def write(self, data: bytes) -> None: async with self._db: # first, determine time elapsed since last update cursor = await self._db.cursor() - await cursor.execute( - "SELECT timestamp FROM yupdates WHERE path = ? " - "ORDER BY timestamp DESC LIMIT 1", - (self.path,), - ) - row = await cursor.fetchone() - diff = (time.time() - row[0]) if row else 0 - if self.document_ttl is not None and diff > self.document_ttl: + newest_diff = await self._get_time_differential_to_entry(cursor, direction="DESC") + oldest_diff = await self._get_time_differential_to_entry(cursor, direction="ASC") + + squashed = False + if (self.document_ttl is not None and newest_diff > self.document_ttl) or ( + self.history_length is not None + and oldest_diff > self.min_cleanup_interval + self.history_length + ): # squash updates ydoc = Doc() + older_than = time.time() - ( + self.history_length if self.history_length is not None else 0 + ) await cursor.execute( - "SELECT yupdate FROM yupdates WHERE path = ?", - (self.path,), + "SELECT yupdate FROM yupdates WHERE path = ? AND timestamp < ?", + (self.path, older_than), ) for (update,) in await cursor.fetchall(): ydoc.apply_update(update) - # delete history - await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) + # delete older history + await cursor.execute( + "DELETE FROM yupdates WHERE path = ? AND timestamp < ?", + (self.path, older_than), + ) # insert squashed updates squashed_update = ydoc.get_update() metadata = await self.get_metadata() @@ -504,6 +514,7 @@ async def write(self, data: bytes) -> None: "INSERT INTO yupdates VALUES (?, ?, ?, ?)", (self.path, squashed_update, metadata, time.time()), ) + squashed = True # finally, write this update to the DB metadata = await self.get_metadata() @@ -511,3 +522,20 @@ async def write(self, data: bytes) -> None: "INSERT INTO yupdates VALUES (?, ?, ?, ?)", (self.path, data, metadata, time.time()), ) + + if squashed: + # Vacuuming database + await self._db.commit() + await cursor.execute("VACUUM") + + async def _get_time_differential_to_entry( + self, cursor, direction: Literal["ASC", "DESC"] = "DESC" + ) -> float: + """Get the time differential to the newest (DESC) or oldest (ASC) entry in the database.""" + await cursor.execute( + "SELECT timestamp FROM yupdates WHERE path = ? " + f"ORDER BY timestamp {direction} LIMIT 1", + (self.path,), + ) + row = await cursor.fetchone() + return (time.time() - row[0]) if row else 0