From 1bc1b244ab09897d0658961b820010074b8b3ecb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Thu, 11 May 2023 14:38:58 +0200 Subject: [PATCH 1/6] feat: added `update_bucket` methods to datastores, finished support for bucket 'data' attribute --- aw_datastore/datastore.py | 4 ++ aw_datastore/storages/abstract.py | 13 +++++ aw_datastore/storages/memory.py | 37 +++++++++++-- aw_datastore/storages/peewee.py | 89 +++++++++++++++++++++++++------ aw_datastore/storages/sqlite.py | 68 ++++++++++++++++++++--- tests/test_datastore.py | 22 +++++++- 6 files changed, 205 insertions(+), 28 deletions(-) diff --git a/aw_datastore/datastore.py b/aw_datastore/datastore.py index b2e3add..8638823 100644 --- a/aw_datastore/datastore.py +++ b/aw_datastore/datastore.py @@ -64,6 +64,10 @@ def create_bucket( ) return self[bucket_id] + def update_bucket(self, bucket_id: str, **kwargs): + self.logger.info(f"Updating bucket '{bucket_id}'") + return self.storage_strategy.update_bucket(bucket_id, **kwargs) + def delete_bucket(self, bucket_id: str): self.logger.info(f"Deleting bucket '{bucket_id}'") if bucket_id in self.bucket_instances: diff --git a/aw_datastore/storages/abstract.py b/aw_datastore/storages/abstract.py index 7b10254..b6bb49c 100644 --- a/aw_datastore/storages/abstract.py +++ b/aw_datastore/storages/abstract.py @@ -30,6 +30,19 @@ def create_bucket( hostname: str, created: str, name: Optional[str] = None, + data: Optional[dict] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def update_bucket( + self, + bucket_id: str, + type_id: Optional[str] = None, + client: Optional[str] = None, + hostname: Optional[str] = None, + name: Optional[str] = None, + data: Optional[dict] = None, ) -> None: raise NotImplementedError diff --git a/aw_datastore/storages/memory.py b/aw_datastore/storages/memory.py index 2f625eb..98aee91 100644 --- a/aw_datastore/storages/memory.py +++ b/aw_datastore/storages/memory.py @@ -1,7 +1,7 @@ -import sys import copy +import sys from datetime import datetime -from typing import List, Dict, Optional +from typing import Dict, List, Optional from aw_core.models import Event @@ -21,7 +21,14 @@ def __init__(self, testing: bool) -> None: self._metadata: Dict[str, dict] = dict() def create_bucket( - self, bucket_id, type_id, client, hostname, created, name=None + self, + bucket_id, + type_id, + client, + hostname, + created, + name=None, + data=None, ) -> None: if not name: name = bucket_id @@ -32,9 +39,33 @@ def create_bucket( "client": client, "hostname": hostname, "created": created, + "data": data, } self.db[bucket_id] = [] + def update_bucket( + self, + bucket_id: str, + type_id: Optional[str] = None, + client: Optional[str] = None, + hostname: Optional[str] = None, + name: Optional[str] = None, + data: Optional[dict] = None, + ) -> None: + if bucket_id in self._metadata: + if type_id: + self._metadata[bucket_id]["type"] = type_id + if client: + self._metadata[bucket_id]["client"] = client + if hostname: + self._metadata[bucket_id]["hostname"] = hostname + if name: + self._metadata[bucket_id]["name"] = name + if data: + self._metadata[bucket_id]["data"] = data + else: + raise Exception("Bucket did not exist, could not update") + def delete_bucket(self, bucket_id: str) -> None: if bucket_id in self.db: del self.db[bucket_id] diff --git a/aw_datastore/storages/peewee.py b/aw_datastore/storages/peewee.py index 7569ebe..d22fa3c 100644 --- a/aw_datastore/storages/peewee.py +++ b/aw_datastore/storages/peewee.py @@ -1,24 +1,29 @@ -from typing import Optional, List, Dict, Any -from datetime import datetime, timezone, timedelta import json -import os import logging +import os +from datetime import datetime, timedelta, timezone +from typing import ( + Any, + Dict, + List, + Optional, +) + import iso8601 +from aw_core.dirs import get_data_dir +from aw_core.models import Event +from playhouse.sqlite_ext import SqliteExtDatabase import peewee from peewee import ( - Model, + AutoField, CharField, - IntegerField, - DecimalField, DateTimeField, + DecimalField, ForeignKeyField, - AutoField, + IntegerField, + Model, ) -from playhouse.sqlite_ext import SqliteExtDatabase - -from aw_core.models import Event -from aw_core.dirs import get_data_dir from .abstract import AbstractStorage @@ -38,6 +43,25 @@ LATEST_VERSION = 2 +def auto_migrate(path: str) -> None: + from playhouse.migrate import SqliteMigrator, migrate + + db = SqliteExtDatabase(path) + migrator = SqliteMigrator(db) + + # check db version (NOTE: this is not the same as the file-based version) + db_version = next(db.execute_sql("PRAGMA user_version"))[0] + + if db_version == 0: + datastr_field = CharField(default="{}") + with db.atomic(): + migrate(migrator.add_column("bucketmodel", "datastr", datastr_field)) + # bump version + db.execute_sql("PRAGMA user_version = 1") + + db.close() + + def chunks(ls, n): """Yield successive n-sized chunks from ls. From: https://stackoverflow.com/a/312464/965332""" @@ -67,6 +91,7 @@ class BucketModel(BaseModel): type = CharField() client = CharField() hostname = CharField() + datastr = CharField(null=True) # JSON-encoded object def json(self): return { @@ -78,6 +103,7 @@ def json(self): "type": self.type, "client": self.client, "hostname": self.hostname, + "data": json.loads(self.datastr) if self.datastr else {}, } @@ -124,7 +150,6 @@ def __init__(self, testing: bool = True, filepath: Optional[str] = None) -> None self.db = _db self.db.init(filepath) logger.info(f"Using database file: {filepath}") - self.db.connect() self.bucket_keys: Dict[str, int] = {} @@ -132,13 +157,17 @@ def __init__(self, testing: bool = True, filepath: Optional[str] = None) -> None EventModel.create_table(safe=True) self.update_bucket_keys() + # Migrate database if needed, requires closing the connection first + self.db.close() + auto_migrate(filepath) + self.db.connect() + def update_bucket_keys(self) -> None: buckets = BucketModel.select() self.bucket_keys = {bucket.id: bucket.key for bucket in buckets} def buckets(self) -> Dict[str, Dict[str, Any]]: - buckets = {bucket.id: bucket.json() for bucket in BucketModel.select()} - return buckets + return {bucket.id: bucket.json() for bucket in BucketModel.select()} def create_bucket( self, @@ -148,6 +177,7 @@ def create_bucket( hostname: str, created: str, name: Optional[str] = None, + data: Optional[Dict[str, Any]] = None, ): BucketModel.create( id=bucket_id, @@ -156,9 +186,37 @@ def create_bucket( hostname=hostname, created=created, name=name, + datastr=json.dumps(data or "{}"), ) self.update_bucket_keys() + def update_bucket( + self, + bucket_id: str, + type_id: Optional[str] = None, + client: Optional[str] = None, + hostname: Optional[str] = None, + name: Optional[str] = None, + data: Optional[dict] = None, + ) -> None: + if bucket_id in self.bucket_keys: + bucket = BucketModel.get(BucketModel.key == self.bucket_keys[bucket_id]) + + if type_id is not None: + bucket.type = type_id + if client is not None: + bucket.client = client + if hostname is not None: + bucket.hostname = hostname + if name is not None: + bucket.name = name + if data is not None: + bucket.datastr = json.dumps(data) # Encoding data dictionary to JSON + + bucket.save() + else: + raise Exception("Bucket did not exist, could not update") + def delete_bucket(self, bucket_id: str) -> None: if bucket_id in self.bucket_keys: EventModel.delete().where( @@ -173,9 +231,10 @@ def delete_bucket(self, bucket_id: str) -> None: def get_metadata(self, bucket_id: str): if bucket_id in self.bucket_keys: - return BucketModel.get( + bucket = BucketModel.get( BucketModel.key == self.bucket_keys[bucket_id] ).json() + return bucket else: raise Exception("Bucket did not exist, could not get metadata") diff --git a/aw_datastore/storages/sqlite.py b/aw_datastore/storages/sqlite.py index debee99..ec57e88 100644 --- a/aw_datastore/storages/sqlite.py +++ b/aw_datastore/storages/sqlite.py @@ -1,13 +1,12 @@ -from typing import Optional, List, Iterable -from datetime import datetime, timezone, timedelta import json -import os import logging - +import os import sqlite3 +from datetime import datetime, timedelta, timezone +from typing import Iterable, List, Optional -from aw_core.models import Event from aw_core.dirs import get_data_dir +from aw_core.models import Event from .abstract import AbstractStorage @@ -136,7 +135,7 @@ def buckets(self): buckets = {} c = self.conn.cursor() for row in c.execute( - "SELECT id, name, type, client, hostname, created FROM buckets" + "SELECT id, name, type, client, hostname, created, datastr FROM buckets" ): buckets[row[0]] = { "id": row[0], @@ -145,6 +144,7 @@ def buckets(self): "client": row[3], "hostname": row[4], "created": row[5], + "data": json.loads(row[6] or "{}"), } return buckets @@ -156,15 +156,66 @@ def create_bucket( hostname: str, created: str, name: Optional[str] = None, + data: Optional[dict] = None, ): self.conn.execute( "INSERT INTO buckets(id, name, type, client, hostname, created, datastr) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", - [bucket_id, name, type_id, client, hostname, created, str({})], + [ + bucket_id, + name, + type_id, + client, + hostname, + created, + json.dumps(data or {}), + ], ) self.commit() return self.get_metadata(bucket_id) + def update_bucket( + self, + bucket_id: str, + type_id: Optional[str] = None, + client: Optional[str] = None, + hostname: Optional[str] = None, + name: Optional[str] = None, + data: Optional[dict] = None, + ): + updates = [] + values = [] + + if type_id is not None: + updates.append("type = ?") + values.append(type_id) + + if client is not None: + updates.append("client = ?") + values.append(client) + + if hostname is not None: + updates.append("hostname = ?") + values.append(hostname) + + if name is not None: + updates.append("name = ?") + values.append(name) + + if data is not None: + updates.append("datastr = ?") + values.append(json.dumps(data)) + + values.append(bucket_id) + + if not updates: + raise ValueError("At least one field must be updated.") + + sql = "UPDATE buckets SET " + ", ".join(updates) + " WHERE id = ?" + self.conn.execute(sql, values) + self.commit() + return self.get_metadata(bucket_id) + def delete_bucket(self, bucket_id: str): self.conn.execute( "DELETE FROM events WHERE bucketrow IN (SELECT rowid FROM buckets WHERE id = ?)", @@ -178,7 +229,7 @@ def delete_bucket(self, bucket_id: str): def get_metadata(self, bucket_id: str): c = self.conn.cursor() res = c.execute( - "SELECT id, name, type, client, hostname, created FROM buckets WHERE id = ?", + "SELECT id, name, type, client, hostname, created, datastr FROM buckets WHERE id = ?", [bucket_id], ) row = res.fetchone() @@ -190,6 +241,7 @@ def get_metadata(self, bucket_id: str): "client": row[3], "hostname": row[4], "created": row[5], + "data": json.loads(row[6] or "{}"), } else: raise Exception("Bucket did not exist, could not get metadata") diff --git a/tests/test_datastore.py b/tests/test_datastore.py index 0c26043..615469d 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -27,7 +27,10 @@ def test_get_buckets(datastore): """ Tests fetching buckets """ - datastore.buckets() + buckets = datastore.buckets() + for bucket in buckets.values(): + assert bucket["id"] in buckets + assert bucket["data"] == {} @pytest.mark.parametrize("datastore", param_datastore_objects()) @@ -55,9 +58,24 @@ def test_create_bucket(datastore): assert bid not in datastore.buckets() +@pytest.mark.parametrize("datastore", param_datastore_objects()) +def test_update_bucket(datastore): + try: + bid = "test-" + str(random.randint(0, 1000000)) + datastore.create_bucket( + bucket_id=bid, type="test", client="test", hostname="test", name="test" + ) + datastore.update_bucket(bid, name="new name") + assert datastore[bid].metadata()["name"] == "new name" + datastore.update_bucket(bid, data={"key": "value"}) + assert datastore[bid].metadata()["data"] == {"key": "value"} + finally: + datastore.delete_bucket(bid) + + @pytest.mark.parametrize("datastore", param_datastore_objects()) def test_delete_bucket(datastore): - bid = "test" + bid = "test-" + str(random.randint(0, 1000000)) datastore.create_bucket( bucket_id=bid, type="test", client="test", hostname="test", name="test" ) From 7eb23629d3b213baedb76c5ca05783124ca0f05d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Thu, 11 May 2023 15:29:44 +0200 Subject: [PATCH 2/6] fix: fixed bug when migrating/adding datastr column in peewee --- aw_datastore/storages/peewee.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/aw_datastore/storages/peewee.py b/aw_datastore/storages/peewee.py index d22fa3c..d5bbfe7 100644 --- a/aw_datastore/storages/peewee.py +++ b/aw_datastore/storages/peewee.py @@ -49,15 +49,14 @@ def auto_migrate(path: str) -> None: db = SqliteExtDatabase(path) migrator = SqliteMigrator(db) - # check db version (NOTE: this is not the same as the file-based version) - db_version = next(db.execute_sql("PRAGMA user_version"))[0] + # check if bucketmodel has datastr field + info = db.execute_sql("PRAGMA table_info(bucketmodel)") + has_datastr = any(row[1] == "datastr" for row in info) - if db_version == 0: + if not has_datastr: datastr_field = CharField(default="{}") with db.atomic(): migrate(migrator.add_column("bucketmodel", "datastr", datastr_field)) - # bump version - db.execute_sql("PRAGMA user_version = 1") db.close() From 4aae9420661a42d3023393374b6d65a9062d5701 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Thu, 11 May 2023 15:34:24 +0200 Subject: [PATCH 3/6] fix: added missing param --- aw_datastore/datastore.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aw_datastore/datastore.py b/aw_datastore/datastore.py index 8638823..15cae8c 100644 --- a/aw_datastore/datastore.py +++ b/aw_datastore/datastore.py @@ -57,10 +57,11 @@ def create_bucket( hostname: str, created: datetime = datetime.now(timezone.utc), name: Optional[str] = None, + data: Optional[dict] = None, ) -> "Bucket": self.logger.info(f"Creating bucket '{bucket_id}'") self.storage_strategy.create_bucket( - bucket_id, type, client, hostname, created.isoformat(), name=name + bucket_id, type, client, hostname, created.isoformat(), name=name, data=data ) return self[bucket_id] From bc8fac66837038a0608752e0749ec766f22e6e89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Thu, 11 May 2023 15:45:22 +0200 Subject: [PATCH 4/6] fix: fixed serialization issue --- aw_datastore/storages/memory.py | 2 +- aw_datastore/storages/peewee.py | 2 +- tests/test_datastore.py | 15 +++++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/aw_datastore/storages/memory.py b/aw_datastore/storages/memory.py index 98aee91..9019eaa 100644 --- a/aw_datastore/storages/memory.py +++ b/aw_datastore/storages/memory.py @@ -39,7 +39,7 @@ def create_bucket( "client": client, "hostname": hostname, "created": created, - "data": data, + "data": data or {}, } self.db[bucket_id] = [] diff --git a/aw_datastore/storages/peewee.py b/aw_datastore/storages/peewee.py index d5bbfe7..bbf79c2 100644 --- a/aw_datastore/storages/peewee.py +++ b/aw_datastore/storages/peewee.py @@ -185,7 +185,7 @@ def create_bucket( hostname=hostname, created=created, name=name, - datastr=json.dumps(data or "{}"), + datastr=json.dumps(data) if data else "{}", ) self.update_bucket_keys() diff --git a/tests/test_datastore.py b/tests/test_datastore.py index 615469d..3ec3a71 100644 --- a/tests/test_datastore.py +++ b/tests/test_datastore.py @@ -45,14 +45,17 @@ def test_create_bucket(datastore): hostname="testhost", name=name, created=now, + data={"key": "value"}, ) - assert bid == bucket.metadata()["id"] - assert name == bucket.metadata()["name"] - assert "testtype" == bucket.metadata()["type"] - assert "testclient" == bucket.metadata()["client"] - assert "testhost" == bucket.metadata()["hostname"] - assert now == iso8601.parse_date(bucket.metadata()["created"]) + metadata = bucket.metadata() + assert bid == metadata["id"] + assert name == metadata["name"] + assert "testtype" == metadata["type"] + assert "testclient" == metadata["client"] + assert "testhost" == metadata["hostname"] + assert now == iso8601.parse_date(metadata["created"]) assert bid in datastore.buckets() + assert {"key": "value"} == metadata["data"] finally: datastore.delete_bucket(bid) assert bid not in datastore.buckets() From 54222bdf9bf27303c3e7aa67b6053e55f5f9faa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Thu, 11 May 2023 15:58:55 +0200 Subject: [PATCH 5/6] refactor: refactored update_bucket for SqliteStorage --- aw_datastore/storages/sqlite.py | 41 +++++++++++---------------------- 1 file changed, 14 insertions(+), 27 deletions(-) diff --git a/aw_datastore/storages/sqlite.py b/aw_datastore/storages/sqlite.py index ec57e88..67ecbce 100644 --- a/aw_datastore/storages/sqlite.py +++ b/aw_datastore/storages/sqlite.py @@ -183,36 +183,23 @@ def update_bucket( name: Optional[str] = None, data: Optional[dict] = None, ): - updates = [] - values = [] - - if type_id is not None: - updates.append("type = ?") - values.append(type_id) - - if client is not None: - updates.append("client = ?") - values.append(client) - - if hostname is not None: - updates.append("hostname = ?") - values.append(hostname) - - if name is not None: - updates.append("name = ?") - values.append(name) - - if data is not None: - updates.append("datastr = ?") - values.append(json.dumps(data)) - - values.append(bucket_id) - + update_values = [ + ("type", type_id), + ("client", client), + ("hostname", hostname), + ("name", name), + ("datastr", json.dumps(data) if data is not None else None), + ] + updates, values = zip(*[(k, v) for k, v in update_values if v is not None]) if not updates: raise ValueError("At least one field must be updated.") - sql = "UPDATE buckets SET " + ", ".join(updates) + " WHERE id = ?" - self.conn.execute(sql, values) + sql = ( + "UPDATE buckets SET " + + ", ".join(f"{u} = ?" for u in updates) + + " WHERE id = ?" + ) + self.conn.execute(sql, (*values, bucket_id)) self.commit() return self.get_metadata(bucket_id) From 889ff267c59c38ab3cb682495c2f102fcd5a7d7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erik=20Bj=C3=A4reholt?= Date: Thu, 11 May 2023 16:40:48 +0200 Subject: [PATCH 6/6] fix: simplified statement --- aw_datastore/storages/peewee.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aw_datastore/storages/peewee.py b/aw_datastore/storages/peewee.py index bbf79c2..a4afe22 100644 --- a/aw_datastore/storages/peewee.py +++ b/aw_datastore/storages/peewee.py @@ -185,7 +185,7 @@ def create_bucket( hostname=hostname, created=created, name=name, - datastr=json.dumps(data) if data else "{}", + datastr=json.dumps(data or {}), ) self.update_bucket_keys()