diff --git a/pyproject.toml b/pyproject.toml index 68e568f..4afd895 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", ] dependencies = [ + "deprecated", "anyio >=3.6.2,<5", "aiosqlite >=0.18.0,<1", "y-py >=0.6.0,<0.7.0", @@ -42,6 +43,7 @@ test = [ "pytest-asyncio", "websockets >=10.0", "uvicorn", + "types-Deprecated" ] docs = [ "mkdocs", diff --git a/tests/test_file_store.py b/tests/test_file_store.py new file mode 100644 index 0000000..b65140b --- /dev/null +++ b/tests/test_file_store.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import struct +import time +from pathlib import Path + +import anyio +import pytest + +from ypy_websocket.stores import FileYStore, YDocExists, YDocNotFound +from ypy_websocket.yutils import Decoder, write_var_uint + + +@pytest.fixture +def create_store(): + async def _inner(path: str, version: int) -> None: + await anyio.Path(path).mkdir(parents=True, exist_ok=True) + version_path = Path(path, "__version__") + async with await anyio.open_file(version_path, "wb") as f: + version_bytes = str(version).encode() + await f.write(version_bytes) + + return _inner + + +@pytest.fixture +def add_document(): + async def _inner(path: str, doc_path: str, version: int, data: bytes | None = None) -> None: + file_path = Path(path, (doc_path + ".y")) + await anyio.Path(file_path.parent).mkdir(parents=True, exist_ok=True) + + async with await anyio.open_file(file_path, "ab") as f: + version_bytes = f"VERSION:{version}\n".encode() + await f.write(version_bytes) + + if data is not None: + data_len = write_var_uint(len(data)) + await f.write(data_len + data) + metadata = b"" + metadata_len = write_var_uint(len(metadata)) + await f.write(metadata_len + metadata) + timestamp = struct.pack(" None: + async with aiosqlite.connect(path) as db: + if tables: + await db.execute( + "CREATE TABLE IF NOT EXISTS documents (path TEXT PRIMARY KEY, version INTEGER NOT NULL)" + ) + await db.execute( + "CREATE TABLE IF NOT EXISTS yupdates (path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)" + ) + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_yupdates_path_timestamp ON yupdates (path, timestamp)" + ) + await db.execute(f"PRAGMA user_version = {version}") + await db.commit() + + return _inner + + +@pytest.fixture +def add_document(): + async def _inner(path: str, doc_path: str, version: int, data: bytes | None = None) -> None: + async with aiosqlite.connect(path) as db: + await db.execute( + "INSERT INTO documents VALUES (?, ?)", + (doc_path, version), + ) + if data is not None: + await db.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (doc_path, data, b"", time.time()), + ) + await db.commit() + + return _inner + + +@pytest.mark.anyio +async def test_initialization(tmp_path): + path = tmp_path / "tmp.db" + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + await _check_db(path, store) + + +@pytest.mark.anyio +async def test_initialization_with_old_database(tmp_path, create_database): + path = tmp_path / "tmp.db" + + # Create a database with an old version + await create_database(path, 1) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + await _check_db(path, store) + + +@pytest.mark.anyio +async def test_initialization_with_empty_database(tmp_path, create_database): + path = tmp_path / "tmp.db" + + # Create a database + await create_database(path, SQLiteYStore.version, False) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + await _check_db(path, store) + + +@pytest.mark.anyio +async def test_initialization_with_existing_database(tmp_path, create_database, add_document): + path = tmp_path / "tmp.db" + doc_path = "test.txt" + + # Create a database + await create_database(path, SQLiteYStore.version) + await add_document(path, doc_path, 0) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + await _check_db(path, store, doc_path) + + +@pytest.mark.anyio +async def test_exists(tmp_path, create_database, add_document): + path = tmp_path / "tmp.db" + doc_path = "test.txt" + + # Create a database with an old version + await create_database(path, SQLiteYStore.version) + await add_document(path, doc_path, 0) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + assert await store.exists(doc_path) + + assert not await store.exists("random.path") + + +@pytest.mark.anyio +async def test_list(tmp_path, create_database, add_document): + path = tmp_path / "tmp.db" + doc1 = "test_1.txt" + doc2 = "test_2.txt" + + # Create a database with an old version + await create_database(path, SQLiteYStore.version) + await add_document(path, doc1, 0) + await add_document(path, doc2, 0) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + count = 0 + async for doc in store.list(): + count += 1 + assert doc in [doc1, doc2] + + assert count == 2 + + +@pytest.mark.anyio +async def test_get(tmp_path, create_database, add_document): + path = tmp_path / "tmp.db" + doc_path = "test.txt" + + # Create a database + await create_database(path, SQLiteYStore.version) + await add_document(path, doc_path, 0) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + res = await store.get(doc_path) + assert res["path"] == doc_path + assert res["version"] == 0 + + res = await store.get("random.doc") + assert res is None + + +@pytest.mark.anyio +async def test_create(tmp_path, create_database, add_document): + path = tmp_path / "tmp.db" + doc_path = "test.txt" + + # Create a database + await create_database(path, SQLiteYStore.version) + await add_document(path, doc_path, 0) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + new_doc = "new_doc.path" + await store.create(new_doc, 0) + async with aiosqlite.connect(path) as db: + cursor = await db.execute( + "SELECT path, version FROM documents WHERE path = ?", + (new_doc,), + ) + res = await cursor.fetchone() + assert res[0] == new_doc + assert res[1] == 0 + + with pytest.raises(YDocExists) as e: + await store.create(doc_path, 0) + assert str(e.value) == f"The document {doc_path} already exists." + + +@pytest.mark.anyio +async def test_remove(tmp_path, create_database, add_document): + path = tmp_path / "tmp.db" + doc_path = "test.txt" + + # Create a database + await create_database(path, SQLiteYStore.version) + await add_document(path, doc_path, 0) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + assert await store.exists(doc_path) + await store.remove(doc_path) + assert not await store.exists(doc_path) + + new_doc = "new_doc.path" + assert not await store.exists(new_doc) + with pytest.raises(YDocNotFound) as e: + await store.remove(new_doc) + assert str(e.value) == f"The document {new_doc} doesn't exists." + assert not await store.exists(new_doc) + + +@pytest.mark.anyio +async def test_read(tmp_path, create_database, add_document): + path = tmp_path / "tmp.db" + doc_path = "test.txt" + update = b"foo" + + # Create a database + await create_database(path, SQLiteYStore.version) + await add_document(path, doc_path, 0, update) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + count = 0 + async for u, _, _ in store.read(doc_path): + count += 1 + assert update == u + + assert count == 1 + + +@pytest.mark.anyio +async def test_write(tmp_path, create_database, add_document): + path = tmp_path / "tmp.db" + doc_path = "test.txt" + + # Create a database + await create_database(path, SQLiteYStore.version) + await add_document(path, doc_path, 0) + + store = SQLiteYStore(str(path)) + await store.initialize() + + assert store.initialized + + update = b"foo" + await store.write(doc_path, update) + + async with aiosqlite.connect(path) as db: + async with db.execute("SELECT yupdate FROM yupdates WHERE path = ?", (doc_path,)) as cursor: + count = 0 + async for u, in cursor: + count += 1 + assert u == update + assert count == 1 + + +async def _check_db(path: str, store: SQLiteYStore, doc_path: str | None = None): + async with aiosqlite.connect(path) as db: + cursor = await db.execute("pragma user_version") + res = await cursor.fetchone() + assert res + assert store.version == res[0] + + cursor = await db.execute( + "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='documents'" + ) + res = await cursor.fetchone() + assert res + assert res[0] == 1 + + cursor = await db.execute( + "SELECT count(*) FROM sqlite_master WHERE type='table' AND name='yupdates'" + ) + res = await cursor.fetchone() + assert res + assert res and res[0] == 1 + + if doc_path is not None: + cursor = await db.execute( + "SELECT path, version FROM documents WHERE path = ?", + (doc_path,), + ) + res = await cursor.fetchone() + assert res + assert res[0] == doc_path + assert res[1] == 0 diff --git a/tests/test_ystore.py b/tests/test_ystore.py index 39208bd..d999e83 100644 --- a/tests/test_ystore.py +++ b/tests/test_ystore.py @@ -1,5 +1,3 @@ -import os -import tempfile import time from pathlib import Path from unittest.mock import patch @@ -7,7 +5,7 @@ import aiosqlite import pytest -from ypy_websocket.ystore import SQLiteYStore, TempFileYStore +from ypy_websocket.stores import SQLiteYStore, TempFileYStore class MetadataCallback: @@ -24,35 +22,31 @@ class MyTempFileYStore(TempFileYStore): prefix_dir = "test_temp_" -MY_SQLITE_YSTORE_DB_PATH = str(Path(tempfile.mkdtemp(prefix="test_sql_")) / "ystore.db") - - class MySQLiteYStore(SQLiteYStore): - db_path = MY_SQLITE_YSTORE_DB_PATH document_ttl = 1000 - def __init__(self, *args, delete_db=False, **kwargs): - if delete_db: - os.remove(self.db_path) - super().__init__(*args, **kwargs) - @pytest.mark.anyio @pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore)) -async def test_ystore(YStore): - store_name = "my_store" - ystore = YStore(store_name, metadata_callback=MetadataCallback()) - await ystore.start() +async def test_ystore(tmp_path, YStore): + store_path = tmp_path / "my_store" + doc_name = "my_doc.txt" + + ystore = YStore(str(store_path), metadata_callback=MetadataCallback()) + await ystore.initialize() + + await ystore.create(doc_name, 0) + data = [b"foo", b"bar", b"baz"] for d in data: - await ystore.write(d) + await ystore.write(doc_name, d) if YStore == MyTempFileYStore: - assert (Path(MyTempFileYStore.base_dir) / store_name).exists() + assert (Path(store_path) / (doc_name + ".y")).exists() elif YStore == MySQLiteYStore: - assert Path(MySQLiteYStore.db_path).exists() + assert Path(store_path).exists() i = 0 - async for d, m, t in ystore.read(): + async for d, m, t in ystore.read(doc_name): assert d == data[i] # data assert m == str(i).encode() # metadata i += 1 @@ -61,18 +55,23 @@ async def test_ystore(YStore): @pytest.mark.anyio -async def test_document_ttl_sqlite_ystore(test_ydoc): - store_name = "my_store" - ystore = MySQLiteYStore(store_name, delete_db=True) - await ystore.start() +async def test_document_ttl_sqlite_ystore(tmp_path, test_ydoc): + store_path = tmp_path / "my_store.db" + doc_name = "my_doc.txt" + + ystore = MySQLiteYStore(str(store_path)) + await ystore.initialize() + + await ystore.create(doc_name, 0) + now = time.time() for i in range(3): # assert that adding a record before document TTL doesn't delete document history with patch("time.time") as mock_time: mock_time.return_value = now - await ystore.write(test_ydoc.update()) - async with aiosqlite.connect(ystore.db_path) as db: + await ystore.write(doc_name, test_ydoc.update()) + async with aiosqlite.connect(store_path) as db: assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[ 0 ] == i + 1 @@ -80,20 +79,7 @@ async def test_document_ttl_sqlite_ystore(test_ydoc): # assert that adding a record after document TTL deletes previous document history with patch("time.time") as mock_time: mock_time.return_value = now + ystore.document_ttl + 1 - await ystore.write(test_ydoc.update()) - async with aiosqlite.connect(ystore.db_path) as db: + await ystore.write(doc_name, test_ydoc.update()) + async with aiosqlite.connect(store_path) as db: # two updates in DB: one squashed update and the new update assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2 - - -@pytest.mark.anyio -@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore)) -async def test_version(YStore, caplog): - store_name = "my_store" - prev_version = YStore.version - YStore.version = -1 - ystore = YStore(store_name) - await ystore.start() - await ystore.write(b"foo") - YStore.version = prev_version - assert "YStore version mismatch" in caplog.text diff --git a/ypy_websocket/stores/__init__.py b/ypy_websocket/stores/__init__.py new file mode 100644 index 0000000..4bbaf63 --- /dev/null +++ b/ypy_websocket/stores/__init__.py @@ -0,0 +1,4 @@ +from .base_store import BaseYStore # noqa +from .file_store import FileYStore, TempFileYStore # noqa +from .sqlite_store import SQLiteYStore # noqa +from .utils import YDocExists, YDocNotFound # noqa diff --git a/ypy_websocket/stores/base_store.py b/ypy_websocket/stores/base_store.py new file mode 100644 index 0000000..7943156 --- /dev/null +++ b/ypy_websocket/stores/base_store.py @@ -0,0 +1,154 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from inspect import isawaitable +from typing import AsyncIterator, Awaitable, Callable, cast + +import y_py as Y +from anyio import Event + + +class BaseYStore(ABC): + """ + Base class for the stores. + """ + + version = 3 + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None + + _store_path: str + _initialized: Event | None = None + + @abstractmethod + def __init__( + self, path: str, metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None + ): + """ + Initialize the object. + + Arguments: + path: The path where the store will be located. + metadata_callback: An optional callback to call to get the metadata. + log: An optional logger. + """ + ... + + @abstractmethod + async def initialize(self) -> None: + """ + Initializes the store. + """ + ... + + @abstractmethod + async def exists(self, path: str) -> bool: + """ + Returns True if the document exists, else returns False. + + Arguments: + path: The document name/path. + """ + ... + + @abstractmethod + async def list(self) -> AsyncIterator[str]: + """ + Returns a list with the name/path of the documents stored. + """ + ... + + @abstractmethod + async def get(self, path: str, updates: bool = False) -> dict | None: + """ + Returns the document's metadata or None if the document does't exist. + + Arguments: + path: The document name/path. + updates: Whether to return document's content or only the metadata. + """ + ... + + @abstractmethod + async def create(self, path: str, version: int) -> None: + """ + Creates a new document. + + Arguments: + path: The document name/path. + version: Document version. + """ + ... + + @abstractmethod + async def remove(self, path: str) -> dict | None: + """ + Removes a document. + + Arguments: + path: The document name/path. + """ + ... + + @abstractmethod + async def write(self, path: str, data: bytes) -> None: + """ + Store a document update. + + Arguments: + path: The document name/path. + data: The update to store. + """ + ... + + @abstractmethod + async def read(self, path: str) -> AsyncIterator[tuple[bytes, bytes]]: + """ + Async iterator for reading document's updates. + + Arguments: + path: The document name/path. + + Returns: + A tuple of (update, metadata, timestamp) for each update. + """ + ... + + @property + def initialized(self) -> bool: + if self._initialized is not None: + return self._initialized.is_set() + return False + + async def get_metadata(self) -> bytes: + """ + Returns: + The metadata. + """ + if self.metadata_callback is None: + return b"" + + metadata = self.metadata_callback() + if isawaitable(metadata): + metadata = await metadata + metadata = cast(bytes, metadata) + return metadata + + async def encode_state_as_update(self, path: str, ydoc: Y.YDoc) -> None: + """Store a YDoc state. + + Arguments: + path: The document name/path. + ydoc: The YDoc from which to store the state. + """ + update = Y.encode_state_as_update(ydoc) # type: ignore + await self.write(path, update) + + async def apply_updates(self, path: str, ydoc: Y.YDoc) -> None: + """Apply all stored updates to the YDoc. + + Arguments: + path: The document name/path. + ydoc: The YDoc on which to apply the updates. + """ + async for update, *rest in self.read(path): # type: ignore + Y.apply_update(ydoc, update) # type: ignore diff --git a/ypy_websocket/stores/file_store.py b/ypy_websocket/stores/file_store.py new file mode 100644 index 0000000..165a9f2 --- /dev/null +++ b/ypy_websocket/stores/file_store.py @@ -0,0 +1,305 @@ +from __future__ import annotations + +import struct +import tempfile +import time +from logging import Logger, getLogger +from pathlib import Path +from typing import AsyncIterator, Awaitable, Callable + +import anyio +from anyio import Event, Lock +from deprecated import deprecated + +from ..yutils import Decoder, get_new_path, write_var_uint +from .base_store import BaseYStore +from .utils import YDocExists, YDocNotFound + + +class FileYStore(BaseYStore): + """A YStore which uses one file per document.""" + + _lock: Lock + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None + + def __init__( + self, + path: str = "./ystore", + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None, + log: Logger | None = None, + ) -> None: + """Initialize the object. + + Arguments: + path: The file path used to store the updates. + metadata_callback: An optional callback to call to get the metadata. + log: An optional logger. + """ + self._lock = Lock() + self._store_path = path + self.metadata_callback = metadata_callback + self.log = log or getLogger(__name__) + + async def initialize(self) -> None: + """ + Initializes the store. + """ + if self.initialized or self._initialized is not None: + return + self._initialized = Event() + + version_path = Path(self._store_path, "__version__") + if not await anyio.Path(self._store_path).exists(): + await anyio.Path(self._store_path).mkdir(parents=True, exist_ok=True) + + version = -1 + create_version = False + if await anyio.Path(version_path).exists(): + async with await anyio.open_file(version_path, "rb") as f: + version = int(await f.readline()) + + # Store version mismatch. Move store and create a new one. + if self.version != version: + create_version = True + + if create_version: + new_path = await get_new_path(self._store_path) + self.log.warning( + f"YStore version mismatch, moving {self._store_path} to {new_path}" + ) + await anyio.Path(self._store_path).rename(new_path) + await anyio.Path(self._store_path).mkdir(parents=True, exist_ok=True) + + else: + create_version = True + + if create_version: + async with await anyio.open_file(version_path, "wb") as f: + version_bytes = str(self.version).encode() + await f.write(version_bytes) + + self._initialized.set() + + async def exists(self, path: str) -> bool: + """ + Returns True if the document exists, else returns False. + + Arguments: + path: The document name/path. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + return await anyio.Path(self._get_document_path(path)).exists() + + async def list(self) -> AsyncIterator[str]: # type: ignore[override] + """ + Returns a list with the name/path of the documents stored. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + async for child in anyio.Path(self._store_path).glob("**/*.y"): + yield child.relative_to(self._store_path).with_suffix("").as_posix() + + async def get(self, path: str, updates: bool = False) -> dict | None: + """ + Returns the document's metadata and updates or None if the document does't exist. + + Arguments: + path: The document name/path. + updates: Whether to return document's content or only the metadata. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + file_path = self._get_document_path(path) + if not await anyio.Path(file_path).exists(): + return None + else: + version = None + async with await anyio.open_file(file_path, "rb") as f: + header = await f.read(8) + if header == b"VERSION:": + version = int(await f.readline()) + + list_updates: list[tuple[bytes, bytes, float]] = [] + if updates: + data = await f.read() + async for update, metadata, timestamp in self._decode_data(data): + list_updates.append((update, metadata, timestamp)) + + return dict(path=path, version=version, updates=list_updates) + + async def create(self, path: str, version: int) -> None: + """ + Creates a new document. + + Arguments: + path: The document name/path. + version: Document version. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + file_path = self._get_document_path(path) + if await anyio.Path(file_path).exists(): + raise YDocExists(f"The document {path} already exists.") + + else: + await anyio.Path(file_path.parent).mkdir(parents=True, exist_ok=True) + async with await anyio.open_file(file_path, "wb") as f: + version_bytes = f"VERSION:{version}\n".encode() + await f.write(version_bytes) + + async def remove(self, path: str) -> None: + """ + Removes a document. + + Arguments: + path: The document name/path. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + file_path = self._get_document_path(path) + if await anyio.Path(file_path).exists(): + await anyio.Path(file_path).unlink(missing_ok=False) + else: + raise YDocNotFound(f"The document {path} doesn't exists.") + + async def read(self, path: str) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore + """Async iterator for reading the store content. + + Returns: + A tuple of (update, metadata, timestamp) for each update. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + async with self._lock: + file_path = self._get_document_path(path) + if not await anyio.Path(file_path).exists(): + raise YDocNotFound + + offset = await self._get_data_offset(file_path) + async with await anyio.open_file(file_path, "rb") as f: + await f.seek(offset) + data = await f.read() + if not data: + raise YDocNotFound + + async for res in self._decode_data(data): + yield res + + async def write(self, path: str, data: bytes) -> None: + """Store an update. + + Arguments: + data: The update to store. + """ + async with self._lock: + file_path = self._get_document_path(path) + if not await anyio.Path(file_path).exists(): + raise YDocNotFound + + async with await anyio.open_file(file_path, "ab") as f: + data_len = write_var_uint(len(data)) + await f.write(data_len + data) + metadata = await self.get_metadata() + metadata_len = write_var_uint(len(metadata)) + await f.write(metadata_len + metadata) + timestamp = struct.pack(" int: + try: + async with await anyio.open_file(path, "rb") as f: + header = await f.read(8) + if header == b"VERSION:": + await f.readline() + return await f.tell() + else: + raise Exception + + except Exception: + raise YDocNotFound(f"File {str(path)} not found.") + + async def _decode_data(self, data) -> AsyncIterator[tuple[bytes, bytes, float]]: + i = 0 + for d in Decoder(data).read_messages(): + if i == 0: + update = d + elif i == 1: + metadata = d + else: + timestamp = struct.unpack(" Path: + return Path(self._store_path, path + ".y") + + +@deprecated(reason="Use FileYStore instead") +class TempFileYStore(FileYStore): + """ + A YStore which uses the system's temporary directory. + Files are writen under a common directory. + To prefix the directory name (e.g. /tmp/my_prefix_b4whmm7y/): + + ```py + class PrefixTempFileYStore(TempFileYStore): + prefix_dir = "my_prefix_" + ``` + + ## Note: + This class is deprecated. Use FileYStore and pass the tmp folder + as path argument. For example: + + ```py + tmp_dir = tempfile.mkdtemp(prefix="prefix/directory/") + store = FileYStore(tmp_dir) + ``` + """ + + prefix_dir: str | None = None + base_dir: str | None = None + + def __init__( + self, + path: str, + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None, + log: Logger | None = None, + ): + """Initialize the object. + + Arguments: + path: The file path used to store the updates. + metadata_callback: An optional callback to call to get the metadata. + log: An optional logger. + """ + full_path = str(Path(self.get_base_dir()) / path) + super().__init__(full_path, metadata_callback=metadata_callback, log=log) + + def get_base_dir(self) -> str: + """Get the base directory where the update file is written. + + Returns: + The base directory path. + """ + if self.base_dir is None: + self.make_directory() + assert self.base_dir is not None + return self.base_dir + + def make_directory(self): + """Create the base directory where the update file is written.""" + type(self).base_dir = tempfile.mkdtemp(prefix=self.prefix_dir) diff --git a/ypy_websocket/stores/sqlite_store.py b/ypy_websocket/stores/sqlite_store.py new file mode 100644 index 0000000..f7d8375 --- /dev/null +++ b/ypy_websocket/stores/sqlite_store.py @@ -0,0 +1,288 @@ +from __future__ import annotations + +import time +from logging import Logger, getLogger +from typing import Any, AsyncIterator, Awaitable, Callable, Iterable + +import aiosqlite +import anyio +import y_py as Y +from anyio import Event, Lock + +from ..yutils import get_new_path +from .base_store import BaseYStore +from .utils import YDocExists, YDocNotFound + + +class SQLiteYStore(BaseYStore): + """A YStore which uses an SQLite database. + Unlike file-based YStores, the Y updates of all documents are stored in the same database. + + Subclass to point to your database file: + + ```py + class MySQLiteYStore(SQLiteYStore): + _store_path = "path/to/my_ystore.db" + ``` + """ + + _lock: Lock + # Determines the "time to live" for all documents, i.e. how recent the + # latest update of a document must be before purging document history. + # Defaults to never purging document history (None). + document_ttl: int | None = None + + def __init__( + self, + path: str = "./ystore.db", + metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None, + log: Logger | None = None, + ) -> None: + """Initialize the object. + + Arguments: + path: The database path used to store the updates. + metadata_callback: An optional callback to call to get the metadata. + log: An optional logger. + """ + self._lock = Lock() + self._store_path = path + self.metadata_callback = metadata_callback + self.log = log or getLogger(__name__) + + async def initialize(self) -> None: + """ + Initializes the store. + """ + if self.initialized or self._initialized is not None: + return + self._initialized = Event() + + async with self._lock: + if await anyio.Path(self._store_path).exists(): + version = -1 + async with aiosqlite.connect(self._store_path) as db: + cursor = await db.execute("pragma user_version") + row = await cursor.fetchone() + if row is not None: + version = row[0] + + # The DB has an old version. Move the database. + if self.version != version: + new_path = await get_new_path(self._store_path) + self.log.warning( + f"YStore version mismatch, moving {self._store_path} to {new_path}" + ) + await anyio.Path(self._store_path).rename(new_path) + + # Make sure every table exists. + async with aiosqlite.connect(self._store_path) as db: + await db.execute( + "CREATE TABLE IF NOT EXISTS documents (path TEXT PRIMARY KEY, version INTEGER NOT NULL)" + ) + await db.execute( + "CREATE TABLE IF NOT EXISTS yupdates (path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)" + ) + await db.execute( + "CREATE INDEX IF NOT EXISTS idx_yupdates_path_timestamp ON yupdates (path, timestamp)" + ) + await db.execute(f"PRAGMA user_version = {self.version}") + await db.commit() + + self._initialized.set() + + async def exists(self, path: str) -> bool: + """ + Returns True if the document exists, else returns False. + + Arguments: + path: The document name/path. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + async with self._lock: + async with aiosqlite.connect(self._store_path) as db: + cursor = await db.execute( + "SELECT path, version FROM documents WHERE path = ?", + (path,), + ) + return (await cursor.fetchone()) is not None + + async def list(self) -> AsyncIterator[str]: # type: ignore[override] + """ + Returns a list with the name/path of the documents stored. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + async with self._lock: + async with aiosqlite.connect(self._store_path) as db: + async with db.execute("SELECT path FROM documents") as cursor: + async for path in cursor: + yield path[0] + + async def get(self, path: str, updates: bool = False) -> dict | None: + """ + Returns the document's metadata and updates or None if the document does't exist. + + Arguments: + path: The document name/path. + updates: Whether to return document's content or only the metadata. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + async with self._lock: + async with aiosqlite.connect(self._store_path) as db: + cursor = await db.execute( + "SELECT path, version FROM documents WHERE path = ?", + (path,), + ) + doc = await cursor.fetchone() + + if doc is None: + return None + + list_updates: Iterable[Any] = [] + if updates: + cursor = await db.execute( + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", + (path,), + ) + list_updates = await cursor.fetchall() + + return dict(path=doc[0], version=doc[1], updates=list_updates) + + async def create(self, path: str, version: int) -> None: + """ + Creates a new document. + + Arguments: + path: The document name/path. + version: Document version. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + async with self._lock: + try: + async with aiosqlite.connect(self._store_path) as db: + await db.execute( + "INSERT INTO documents VALUES (?, ?)", + (path, version), + ) + await db.commit() + except aiosqlite.IntegrityError: + raise YDocExists(f"The document {path} already exists.") + + async def remove(self, path: str) -> None: + """ + Removes a document. + + Arguments: + path: The document name/path. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + async with self._lock: + async with aiosqlite.connect(self._store_path) as db: + cursor = await db.execute( + "SELECT path, version FROM documents WHERE path = ?", + (path,), + ) + if (await cursor.fetchone()) is None: + raise YDocNotFound(f"The document {path} doesn't exists.") + + await db.execute( + "DELETE FROM documents WHERE path = ?", + (path,), + ) + await db.execute( + "DELETE FROM yupdates WHERE path = ?", + (path,), + ) + await db.commit() + + async def read(self, path: str) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore + """Async iterator for reading the store content. + + Arguments: + path: The document name/path. + + Returns: + A tuple of (update, metadata, timestamp) for each update. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + try: + async with self._lock: + async with aiosqlite.connect(self._store_path) as db: + async with db.execute( + "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", + (path,), + ) as cursor: + found = False + async for update, metadata, timestamp in cursor: + found = True + yield update, metadata, timestamp + if not found: + raise YDocNotFound + except Exception: + raise YDocNotFound + + async def write(self, path: str, data: bytes) -> None: + """ + Store an update. + + Arguments: + path: The document name/path. + data: The update to store. + """ + if self._initialized is None: + raise Exception("The store was not initialized.") + await self._initialized.wait() + + async with self._lock: + async with aiosqlite.connect(self._store_path) as db: + # first, determine time elapsed since last update + cursor = await db.execute( + "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", + (path,), + ) + row = await cursor.fetchone() + diff = (time.time() - row[0]) if row else 0 + + if self.document_ttl is not None and diff > self.document_ttl: + # squash updates + ydoc = Y.YDoc() + async with db.execute( + "SELECT yupdate FROM yupdates WHERE path = ?", (path,) + ) as cursor: + async for update, in cursor: + Y.apply_update(ydoc, update) + # delete history + await db.execute("DELETE FROM yupdates WHERE path = ?", (path,)) + # insert squashed updates + squashed_update = Y.encode_state_as_update(ydoc) + metadata = await self.get_metadata() + await db.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (path, squashed_update, metadata, time.time()), + ) + + # finally, write this update to the DB + metadata = await self.get_metadata() + await db.execute( + "INSERT INTO yupdates VALUES (?, ?, ?, ?)", + (path, data, metadata, time.time()), + ) + await db.commit() diff --git a/ypy_websocket/stores/utils.py b/ypy_websocket/stores/utils.py new file mode 100644 index 0000000..6c5bf76 --- /dev/null +++ b/ypy_websocket/stores/utils.py @@ -0,0 +1,6 @@ +class YDocNotFound(Exception): + pass + + +class YDocExists(Exception): + pass diff --git a/ypy_websocket/yroom.py b/ypy_websocket/yroom.py index d602574..a406ce3 100644 --- a/ypy_websocket/yroom.py +++ b/ypy_websocket/yroom.py @@ -17,8 +17,8 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from .awareness import Awareness +from .stores import BaseYStore from .websocket import Websocket -from .ystore import BaseYStore from .yutils import ( YMessageType, create_update_message, @@ -120,9 +120,6 @@ def on_message(self, value: Callable[[bytes], Awaitable[bool] | bool] | None): self._on_message = value async def _broadcast_updates(self): - if self.ystore is not None and not self.ystore.started.is_set(): - self._task_group.start_soon(self.ystore.start) - async with self._update_receive_stream: async for update in self._update_receive_stream: if self._task_group.cancel_scope.cancel_called: @@ -135,7 +132,7 @@ async def _broadcast_updates(self): self._task_group.start_soon(client.send, message) if self.ystore: self.log.debug("Writing Y update to YStore") - self._task_group.start_soon(self.ystore.write, update) + self._task_group.start_soon(self.ystore.write, client.path, update) async def __aenter__(self) -> YRoom: if self._task_group is not None: diff --git a/ypy_websocket/ystore.py b/ypy_websocket/ystore.py deleted file mode 100644 index f4ad417..0000000 --- a/ypy_websocket/ystore.py +++ /dev/null @@ -1,446 +0,0 @@ -from __future__ import annotations - -import struct -import tempfile -import time -from abc import ABC, abstractmethod -from contextlib import AsyncExitStack -from inspect import isawaitable -from logging import Logger, getLogger -from pathlib import Path -from typing import AsyncIterator, Awaitable, Callable, cast - -import aiosqlite -import anyio -import y_py as Y -from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group -from anyio.abc import TaskGroup, TaskStatus - -from .yutils import Decoder, get_new_path, write_var_uint - - -class YDocNotFound(Exception): - pass - - -class BaseYStore(ABC): - - metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None - version = 2 - _started: Event | None = None - _starting: bool = False - _task_group: TaskGroup | None = None - - @abstractmethod - def __init__( - self, path: str, metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None - ): - ... - - @abstractmethod - async def write(self, data: bytes) -> None: - ... - - @abstractmethod - async def read(self) -> AsyncIterator[tuple[bytes, bytes]]: - ... - - @property - def started(self) -> Event: - if self._started is None: - self._started = Event() - return self._started - - async def __aenter__(self) -> BaseYStore: - if self._task_group is not None: - raise RuntimeError("YStore already running") - - async with AsyncExitStack() as exit_stack: - tg = create_task_group() - self._task_group = await exit_stack.enter_async_context(tg) - self._exit_stack = exit_stack.pop_all() - tg.start_soon(self.start) - - return self - - async def __aexit__(self, exc_type, exc_value, exc_tb): - if self._task_group is None: - raise RuntimeError("YStore not running") - - self._task_group.cancel_scope.cancel() - self._task_group = None - return await self._exit_stack.__aexit__(exc_type, exc_value, exc_tb) - - async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): - """Start the store. - - Arguments: - task_status: The status to set when the task has started. - """ - if self._starting: - return - else: - self._starting = True - - if self._task_group is not None: - raise RuntimeError("YStore already running") - - self.started.set() - self._starting = False - task_status.started() - - def stop(self) -> None: - """Stop the store.""" - if self._task_group is None: - raise RuntimeError("YStore not running") - - self._task_group.cancel_scope.cancel() - self._task_group = None - - async def get_metadata(self) -> bytes: - """ - Returns: - The metadata. - """ - if self.metadata_callback is None: - return b"" - - metadata = self.metadata_callback() - if isawaitable(metadata): - metadata = await metadata - metadata = cast(bytes, metadata) - return metadata - - async def encode_state_as_update(self, ydoc: Y.YDoc) -> None: - """Store a YDoc state. - - Arguments: - ydoc: The YDoc from which to store the state. - """ - update = Y.encode_state_as_update(ydoc) # type: ignore - await self.write(update) - - async def apply_updates(self, ydoc: Y.YDoc) -> None: - """Apply all stored updates to the YDoc. - - Arguments: - ydoc: The YDoc on which to apply the updates. - """ - async for update, *rest in self.read(): # type: ignore - Y.apply_update(ydoc, update) # type: ignore - - -class FileYStore(BaseYStore): - """A YStore which uses one file per document.""" - - path: str - metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None - lock: Lock - - def __init__( - self, - path: str, - metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None, - log: Logger | None = None, - ) -> None: - """Initialize the object. - - Arguments: - path: The file path used to store the updates. - metadata_callback: An optional callback to call to get the metadata. - log: An optional logger. - """ - self.path = path - self.metadata_callback = metadata_callback - self.log = log or getLogger(__name__) - self.lock = Lock() - - async def check_version(self) -> int: - """Check the version of the store format. - - Returns: - The offset where the data is located in the file. - """ - if not await anyio.Path(self.path).exists(): - version_mismatch = True - else: - version_mismatch = False - move_file = False - async with await anyio.open_file(self.path, "rb") as f: - header = await f.read(8) - if header == b"VERSION:": - version = int(await f.readline()) - if version == self.version: - offset = await f.tell() - else: - version_mismatch = True - else: - version_mismatch = True - if version_mismatch: - move_file = True - if move_file: - new_path = await get_new_path(self.path) - self.log.warning(f"YStore version mismatch, moving {self.path} to {new_path}") - await anyio.Path(self.path).rename(new_path) - if version_mismatch: - async with await anyio.open_file(self.path, "wb") as f: - version_bytes = f"VERSION:{self.version}\n".encode() - await f.write(version_bytes) - offset = len(version_bytes) - return offset - - async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore - """Async iterator for reading the store content. - - Returns: - A tuple of (update, metadata, timestamp) for each update. - """ - async with self.lock: - if not await anyio.Path(self.path).exists(): - raise YDocNotFound - offset = await self.check_version() - async with await anyio.open_file(self.path, "rb") as f: - await f.seek(offset) - data = await f.read() - if not data: - raise YDocNotFound - i = 0 - for d in Decoder(data).read_messages(): - if i == 0: - update = d - elif i == 1: - metadata = d - else: - timestamp = struct.unpack(" None: - """Store an update. - - Arguments: - data: The update to store. - """ - parent = Path(self.path).parent - async with self.lock: - await anyio.Path(parent).mkdir(parents=True, exist_ok=True) - await self.check_version() - async with await anyio.open_file(self.path, "ab") as f: - data_len = write_var_uint(len(data)) - await f.write(data_len + data) - metadata = await self.get_metadata() - metadata_len = write_var_uint(len(metadata)) - await f.write(metadata_len + metadata) - timestamp = struct.pack(" str: - """Get the base directory where the update file is written. - - Returns: - The base directory path. - """ - if self.base_dir is None: - self.make_directory() - assert self.base_dir is not None - return self.base_dir - - def make_directory(self): - """Create the base directory where the update file is written.""" - type(self).base_dir = tempfile.mkdtemp(prefix=self.prefix_dir) - - -class SQLiteYStore(BaseYStore): - """A YStore which uses an SQLite database. - Unlike file-based YStores, the Y updates of all documents are stored in the same database. - - Subclass to point to your database file: - - ```py - class MySQLiteYStore(SQLiteYStore): - db_path = "path/to/my_ystore.db" - ``` - """ - - db_path: str = "ystore.db" - # Determines the "time to live" for all documents, i.e. how recent the - # latest update of a document must be before purging document history. - # Defaults to never purging document history (None). - document_ttl: int | None = None - path: str - lock: Lock - db_initialized: Event - - def __init__( - self, - path: str, - metadata_callback: Callable[[], Awaitable[bytes] | bytes] | None = None, - log: Logger | None = None, - ) -> None: - """Initialize the object. - - Arguments: - path: The file path used to store the updates. - metadata_callback: An optional callback to call to get the metadata. - log: An optional logger. - """ - self.path = path - self.metadata_callback = metadata_callback - self.log = log or getLogger(__name__) - self.lock = Lock() - self.db_initialized = Event() - - async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): - """Start the SQLiteYStore. - - Arguments: - task_status: The status to set when the task has started. - """ - if self._starting: - return - else: - self._starting = True - - if self._task_group is not None: - raise RuntimeError("YStore already running") - - async with create_task_group() as self._task_group: - self._task_group.start_soon(self._init_db) - self.started.set() - self._starting = False - task_status.started() - - async def _init_db(self): - create_db = False - move_db = False - if not await anyio.Path(self.db_path).exists(): - create_db = True - else: - async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - cursor = await db.execute( - "SELECT count(name) FROM sqlite_master WHERE type='table' and name='yupdates'" - ) - table_exists = (await cursor.fetchone())[0] - if table_exists: - cursor = await db.execute("pragma user_version") - version = (await cursor.fetchone())[0] - if version != self.version: - move_db = True - create_db = True - else: - create_db = True - if move_db: - new_path = await get_new_path(self.db_path) - self.log.warning(f"YStore version mismatch, moving {self.db_path} to {new_path}") - await anyio.Path(self.db_path).rename(new_path) - if create_db: - async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - await db.execute( - "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, metadata BLOB, timestamp REAL NOT NULL)" - ) - await db.execute( - "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)" - ) - await db.execute(f"PRAGMA user_version = {self.version}") - await db.commit() - self.db_initialized.set() - - async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]: # type: ignore - """Async iterator for reading the store content. - - Returns: - A tuple of (update, metadata, timestamp) for each update. - """ - await self.db_initialized.wait() - try: - async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - async with db.execute( - "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?", - (self.path,), - ) as cursor: - found = False - async for update, metadata, timestamp in cursor: - found = True - yield update, metadata, timestamp - if not found: - raise YDocNotFound - except Exception: - raise YDocNotFound - - async def write(self, data: bytes) -> None: - """Store an update. - - Arguments: - data: The update to store. - """ - await self.db_initialized.wait() - async with self.lock: - async with aiosqlite.connect(self.db_path) as db: - # first, determine time elapsed since last update - cursor = await db.execute( - "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1", - (self.path,), - ) - row = await cursor.fetchone() - diff = (time.time() - row[0]) if row else 0 - - if self.document_ttl is not None and diff > self.document_ttl: - # squash updates - ydoc = Y.YDoc() - async with db.execute( - "SELECT yupdate FROM yupdates WHERE path = ?", (self.path,) - ) as cursor: - async for update, in cursor: - Y.apply_update(ydoc, update) - # delete history - await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,)) - # insert squashed updates - squashed_update = Y.encode_state_as_update(ydoc) - metadata = await self.get_metadata() - await db.execute( - "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, squashed_update, metadata, time.time()), - ) - - # finally, write this update to the DB - metadata = await self.get_metadata() - await db.execute( - "INSERT INTO yupdates VALUES (?, ?, ?, ?)", - (self.path, data, metadata, time.time()), - ) - await db.commit()