Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added update_bucket methods to datastores, finished support for bucket 'data' attribute #121

Merged
merged 6 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion aw_datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,18 @@ def create_bucket(
hostname: str,
created: datetime = datetime.now(timezone.utc),
name: Optional[str] = None,
data: Optional[dict] = None,
) -> "Bucket":
self.logger.info(f"Creating bucket '{bucket_id}'")
self.storage_strategy.create_bucket(
bucket_id, type, client, hostname, created.isoformat(), name=name
bucket_id, type, client, hostname, created.isoformat(), name=name, data=data
)
return self[bucket_id]

def update_bucket(self, bucket_id: str, **kwargs):
self.logger.info(f"Updating bucket '{bucket_id}'")
return self.storage_strategy.update_bucket(bucket_id, **kwargs)

def delete_bucket(self, bucket_id: str):
self.logger.info(f"Deleting bucket '{bucket_id}'")
if bucket_id in self.bucket_instances:
Expand Down
13 changes: 13 additions & 0 deletions aw_datastore/storages/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ def create_bucket(
hostname: str,
created: str,
name: Optional[str] = None,
data: Optional[dict] = None,
) -> None:
raise NotImplementedError

@abstractmethod
def update_bucket(
self,
bucket_id: str,
type_id: Optional[str] = None,
client: Optional[str] = None,
hostname: Optional[str] = None,
name: Optional[str] = None,
data: Optional[dict] = None,
) -> None:
raise NotImplementedError

Expand Down
37 changes: 34 additions & 3 deletions aw_datastore/storages/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import copy
import sys
from datetime import datetime
from typing import List, Dict, Optional
from typing import Dict, List, Optional

from aw_core.models import Event

Expand All @@ -21,7 +21,14 @@ def __init__(self, testing: bool) -> None:
self._metadata: Dict[str, dict] = dict()

def create_bucket(
self, bucket_id, type_id, client, hostname, created, name=None
self,
bucket_id,
type_id,
client,
hostname,
created,
name=None,
data=None,
) -> None:
if not name:
name = bucket_id
Expand All @@ -32,9 +39,33 @@ def create_bucket(
"client": client,
"hostname": hostname,
"created": created,
"data": data or {},
}
self.db[bucket_id] = []

def update_bucket(
self,
bucket_id: str,
type_id: Optional[str] = None,
client: Optional[str] = None,
hostname: Optional[str] = None,
name: Optional[str] = None,
data: Optional[dict] = None,
) -> None:
if bucket_id in self._metadata:
if type_id:
self._metadata[bucket_id]["type"] = type_id
if client:
self._metadata[bucket_id]["client"] = client
if hostname:
self._metadata[bucket_id]["hostname"] = hostname
if name:
self._metadata[bucket_id]["name"] = name
if data:
self._metadata[bucket_id]["data"] = data
else:
raise Exception("Bucket did not exist, could not update")

def delete_bucket(self, bucket_id: str) -> None:
if bucket_id in self.db:
del self.db[bucket_id]
Expand Down
88 changes: 73 additions & 15 deletions aw_datastore/storages/peewee.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
from typing import Optional, List, Dict, Any
from datetime import datetime, timezone, timedelta
import json
import os
import logging
import os
from datetime import datetime, timedelta, timezone
from typing import (
Any,
Dict,
List,
Optional,
)

import iso8601
from aw_core.dirs import get_data_dir
from aw_core.models import Event
from playhouse.sqlite_ext import SqliteExtDatabase

import peewee
from peewee import (
Model,
AutoField,
CharField,
IntegerField,
DecimalField,
DateTimeField,
DecimalField,
ForeignKeyField,
AutoField,
IntegerField,
Model,
)
from playhouse.sqlite_ext import SqliteExtDatabase

from aw_core.models import Event
from aw_core.dirs import get_data_dir

from .abstract import AbstractStorage

Expand All @@ -38,6 +43,24 @@
LATEST_VERSION = 2


def auto_migrate(path: str) -> None:
from playhouse.migrate import SqliteMigrator, migrate

db = SqliteExtDatabase(path)
migrator = SqliteMigrator(db)

# check if bucketmodel has datastr field
info = db.execute_sql("PRAGMA table_info(bucketmodel)")
has_datastr = any(row[1] == "datastr" for row in info)

if not has_datastr:
datastr_field = CharField(default="{}")
with db.atomic():
migrate(migrator.add_column("bucketmodel", "datastr", datastr_field))

db.close()


def chunks(ls, n):
"""Yield successive n-sized chunks from ls.
From: https://stackoverflow.com/a/312464/965332"""
Expand Down Expand Up @@ -67,6 +90,7 @@ class BucketModel(BaseModel):
type = CharField()
client = CharField()
hostname = CharField()
datastr = CharField(null=True) # JSON-encoded object

def json(self):
return {
Expand All @@ -78,6 +102,7 @@ def json(self):
"type": self.type,
"client": self.client,
"hostname": self.hostname,
"data": json.loads(self.datastr) if self.datastr else {},
}


Expand Down Expand Up @@ -124,21 +149,24 @@ def __init__(self, testing: bool = True, filepath: Optional[str] = None) -> None
self.db = _db
self.db.init(filepath)
logger.info(f"Using database file: {filepath}")

self.db.connect()

self.bucket_keys: Dict[str, int] = {}
BucketModel.create_table(safe=True)
EventModel.create_table(safe=True)
self.update_bucket_keys()

# Migrate database if needed, requires closing the connection first
self.db.close()
auto_migrate(filepath)
self.db.connect()

def update_bucket_keys(self) -> None:
buckets = BucketModel.select()
self.bucket_keys = {bucket.id: bucket.key for bucket in buckets}

def buckets(self) -> Dict[str, Dict[str, Any]]:
buckets = {bucket.id: bucket.json() for bucket in BucketModel.select()}
return buckets
return {bucket.id: bucket.json() for bucket in BucketModel.select()}

def create_bucket(
self,
Expand All @@ -148,6 +176,7 @@ def create_bucket(
hostname: str,
created: str,
name: Optional[str] = None,
data: Optional[Dict[str, Any]] = None,
):
BucketModel.create(
id=bucket_id,
Expand All @@ -156,9 +185,37 @@ def create_bucket(
hostname=hostname,
created=created,
name=name,
datastr=json.dumps(data) if data else "{}",
ErikBjare marked this conversation as resolved.
Show resolved Hide resolved
)
self.update_bucket_keys()

def update_bucket(
self,
bucket_id: str,
type_id: Optional[str] = None,
client: Optional[str] = None,
hostname: Optional[str] = None,
name: Optional[str] = None,
data: Optional[dict] = None,
) -> None:
if bucket_id in self.bucket_keys:
bucket = BucketModel.get(BucketModel.key == self.bucket_keys[bucket_id])

if type_id is not None:
bucket.type = type_id
if client is not None:
bucket.client = client
if hostname is not None:
bucket.hostname = hostname
if name is not None:
bucket.name = name
if data is not None:
bucket.datastr = json.dumps(data) # Encoding data dictionary to JSON

bucket.save()
else:
raise Exception("Bucket did not exist, could not update")

def delete_bucket(self, bucket_id: str) -> None:
if bucket_id in self.bucket_keys:
EventModel.delete().where(
Expand All @@ -173,9 +230,10 @@ def delete_bucket(self, bucket_id: str) -> None:

def get_metadata(self, bucket_id: str):
if bucket_id in self.bucket_keys:
return BucketModel.get(
bucket = BucketModel.get(
BucketModel.key == self.bucket_keys[bucket_id]
).json()
return bucket
else:
raise Exception("Bucket did not exist, could not get metadata")

Expand Down
55 changes: 47 additions & 8 deletions aw_datastore/storages/sqlite.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional, List, Iterable
from datetime import datetime, timezone, timedelta
import json
import os
import logging

import os
import sqlite3
from datetime import datetime, timedelta, timezone
from typing import Iterable, List, Optional

from aw_core.models import Event
from aw_core.dirs import get_data_dir
from aw_core.models import Event

from .abstract import AbstractStorage

Expand Down Expand Up @@ -136,7 +135,7 @@ def buckets(self):
buckets = {}
c = self.conn.cursor()
for row in c.execute(
"SELECT id, name, type, client, hostname, created FROM buckets"
"SELECT id, name, type, client, hostname, created, datastr FROM buckets"
):
buckets[row[0]] = {
"id": row[0],
Expand All @@ -145,6 +144,7 @@ def buckets(self):
"client": row[3],
"hostname": row[4],
"created": row[5],
"data": json.loads(row[6] or "{}"),
}
return buckets

Expand All @@ -156,12 +156,50 @@ def create_bucket(
hostname: str,
created: str,
name: Optional[str] = None,
data: Optional[dict] = None,
):
self.conn.execute(
"INSERT INTO buckets(id, name, type, client, hostname, created, datastr) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?)",
[bucket_id, name, type_id, client, hostname, created, str({})],
[
bucket_id,
name,
type_id,
client,
hostname,
created,
json.dumps(data or {}),
],
)
self.commit()
return self.get_metadata(bucket_id)

def update_bucket(
self,
bucket_id: str,
type_id: Optional[str] = None,
client: Optional[str] = None,
hostname: Optional[str] = None,
name: Optional[str] = None,
data: Optional[dict] = None,
):
update_values = [
("type", type_id),
("client", client),
("hostname", hostname),
("name", name),
("datastr", json.dumps(data) if data is not None else None),
]
updates, values = zip(*[(k, v) for k, v in update_values if v is not None])
if not updates:
raise ValueError("At least one field must be updated.")

sql = (
"UPDATE buckets SET "
+ ", ".join(f"{u} = ?" for u in updates)
+ " WHERE id = ?"
)
self.conn.execute(sql, (*values, bucket_id))
self.commit()
return self.get_metadata(bucket_id)

Expand All @@ -178,7 +216,7 @@ def delete_bucket(self, bucket_id: str):
def get_metadata(self, bucket_id: str):
c = self.conn.cursor()
res = c.execute(
"SELECT id, name, type, client, hostname, created FROM buckets WHERE id = ?",
"SELECT id, name, type, client, hostname, created, datastr FROM buckets WHERE id = ?",
[bucket_id],
)
row = res.fetchone()
Expand All @@ -190,6 +228,7 @@ def get_metadata(self, bucket_id: str):
"client": row[3],
"hostname": row[4],
"created": row[5],
"data": json.loads(row[6] or "{}"),
}
else:
raise Exception("Bucket did not exist, could not get metadata")
Expand Down
Loading