-
-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/sql store #125
Feature/sql store #125
Changes from 5 commits
c4f7503
33e5456
84600b6
ea0c427
cebd4d7
20d5daa
5d9dd5a
f968f90
2d0dd85
f2c60cf
2389094
565dc10
86f288d
8ed0544
c1dc320
3d25cb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from datetime import datetime | ||
from typing import Any, Generator, List, Optional, Set, Tuple | ||
|
||
from banal import as_bool | ||
from followthemoney.property import Property | ||
from sqlalchemy import Table, create_engine, delete, select | ||
from sqlalchemy.sql.selectable import Select | ||
|
||
from nomenklatura.dataset import DS | ||
from nomenklatura.db import ( | ||
DB_URL, | ||
POOL_SIZE, | ||
ensure_tx, | ||
get_metadata, | ||
get_statement_table, | ||
get_upsert_func, | ||
) | ||
from nomenklatura.entity import CE | ||
from nomenklatura.resolver import Resolver | ||
from nomenklatura.statement import Statement | ||
from nomenklatura.store import Store, View, Writer | ||
|
||
|
||
def pack_statement(stmt: Statement) -> dict[str, Any]: | ||
data: dict[str, Any] = stmt.to_row() | ||
data["target"] = as_bool(data["target"]) | ||
data["external"] = as_bool(data["external"]) | ||
data["first_seen"] = data["first_seen"] or datetime.utcnow() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should make up dates somewhere in the middle of the pipeline. That would lead to pretty random outcomes. Maybe we make those columns nullable instead? Been struggling with the same issue: https://github.com/opensanctions/opensanctions/blob/main/zavod/zavod/tools/load_db.py#L42-L45 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made it nullable: 2d0dd85 I don't know what side-effects this could have in other parts of your pipeline? Especially for existing statement tables... (migrations?) |
||
data["last_seen"] = data["last_seen"] or datetime.utcnow() | ||
return data | ||
|
||
|
||
class SqlStore(Store[DS, CE]): | ||
def __init__( | ||
self, | ||
dataset: DS, | ||
resolver: Resolver[CE], | ||
uri: str = DB_URL, | ||
**engine_kwargs: Any, | ||
): | ||
super().__init__(dataset, resolver) | ||
engine_kwargs["pool_size"] = engine_kwargs.pop("pool_size", POOL_SIZE) | ||
self.metadata = get_metadata() | ||
self.engine = create_engine(uri, **engine_kwargs) | ||
self.table = get_statement_table() | ||
self.metadata.create_all(self.engine, tables=[self.table], checkfirst=True) | ||
|
||
def writer(self) -> Writer[DS, CE]: | ||
return SqlWriter(self) | ||
|
||
def view(self, scope: DS, external: bool = False) -> View[DS, CE]: | ||
return SqlView(self, scope, external=external) | ||
|
||
def _iterate_stmts(self, q: Select) -> Generator[Statement, None, None]: | ||
with ensure_tx(self.engine.connect()) as conn: | ||
conn = conn.execution_options(stream_results=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has quite a bit of overhead. I wonder if we should pass an option into this func that can disable it for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
cursor = conn.execute(q) | ||
while rows := cursor.fetchmany(10_000): | ||
for row in rows: | ||
yield Statement.from_db_row(row) | ||
|
||
def _iterate(self, q: Select) -> Generator[CE, None, None]: | ||
current_id = None | ||
current_stmts: list[Statement] = [] | ||
for stmt in self._iterate_stmts(q): | ||
entity_id = stmt.entity_id | ||
if current_id is None: | ||
current_id = entity_id | ||
if current_id != entity_id: | ||
proxy = self.assemble(current_stmts) | ||
if proxy is not None: | ||
yield proxy | ||
current_id = entity_id | ||
current_stmts = [] | ||
current_stmts.append(stmt) | ||
if len(current_stmts): | ||
proxy = self.assemble(current_stmts) | ||
if proxy is not None: | ||
yield proxy | ||
|
||
|
||
class SqlWriter(Writer[DS, CE]): | ||
BATCH_STATEMENTS = 10_000 | ||
|
||
def __init__(self, store: SqlStore[DS, CE]): | ||
self.store: SqlStore[DS, CE] = store | ||
self.batch: Optional[Set[Statement]] = None | ||
self.batch_size = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (this originally came from my believe that |
||
self.insert = get_upsert_func(self.store.engine) | ||
|
||
def flush(self) -> None: | ||
if self.batch: | ||
values = [pack_statement(s) for s in self.batch] | ||
istmt = self.insert(self.store.table).values(values) | ||
stmt = istmt.on_conflict_do_update( | ||
index_elements=["id"], | ||
set_=dict( | ||
canonical_id=istmt.excluded.canonical_id, | ||
schema=istmt.excluded.schema, | ||
prop_type=istmt.excluded.prop_type, | ||
target=istmt.excluded.target, | ||
lang=istmt.excluded.lang, | ||
original_value=istmt.excluded.original_value, | ||
last_seen=istmt.excluded.last_seen, | ||
), | ||
) | ||
with ensure_tx(self.store.engine.connect()) as conn: | ||
conn.execute(stmt) | ||
self.batch = set() | ||
self.batch_size = 0 | ||
|
||
def add_statement(self, stmt: Statement) -> None: | ||
if self.batch is None: | ||
self.batch = set() | ||
if stmt.entity_id is None: | ||
return | ||
if self.batch_size >= self.BATCH_STATEMENTS: | ||
self.flush() | ||
canonical_id = self.store.resolver.get_canonical(stmt.entity_id) | ||
stmt.canonical_id = canonical_id | ||
self.batch.add(stmt) | ||
self.batch_size += 1 | ||
|
||
def pop(self, entity_id: str) -> List[Statement]: | ||
self.flush() | ||
table = self.store.table | ||
q = select(table).where(table.c.entity_id == entity_id) | ||
q_delete = delete(table).where(table.c.entity_id == entity_id) | ||
statements: List[Statement] = list(self.store._iterate_stmts(q)) | ||
with ensure_tx(self.store.engine.connect()) as conn: | ||
conn.execute(q_delete) | ||
return statements | ||
|
||
|
||
class SqlView(View[DS, CE]): | ||
def __init__( | ||
self, store: SqlStore[DS, CE], scope: DS, external: bool = False | ||
) -> None: | ||
super().__init__(store, scope, external=external) | ||
self.store: SqlStore[DS, CE] = store | ||
|
||
def get_entity(self, id: str) -> Optional[CE]: | ||
table = self.store.table | ||
q = select(table).where(table.c.entity_id == id) | ||
for proxy in self.store._iterate(q): | ||
return proxy | ||
return None | ||
|
||
def get_inverted(self, id: str) -> Generator[Tuple[Property, CE], None, None]: | ||
table = self.store.table | ||
q = ( | ||
select(table) | ||
.where(table.c.prop_type == "entity", table.c.value == id) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it needs to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
.distinct(table.c.value) | ||
) | ||
for stmt in self.store._iterate_stmts(q): | ||
if stmt.canonical_id is not None: | ||
entity = self.get_entity(stmt.canonical_id) | ||
if entity is not None: | ||
for prop, value in entity.itervalues(): | ||
if value == id and prop.reverse is not None: | ||
yield prop.reverse, entity | ||
|
||
def entities(self) -> Generator[CE, None, None]: | ||
table: Table = self.store.table | ||
q = ( | ||
select(table) | ||
.where(table.c.dataset.in_(self.dataset_names)) | ||
.order_by("entity_id") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. totally right. fixed here: c1dc320 but, |
||
) | ||
yield from self.store._iterate(q) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
""" | ||
Test if the different store implementations all behave the same. | ||
""" | ||
|
||
from pathlib import Path | ||
from typing import Any, Dict, List | ||
|
||
from nomenklatura.dataset import Dataset | ||
from nomenklatura.entity import CompositeEntity | ||
from nomenklatura.resolver import Resolver | ||
from nomenklatura.store import LevelDBStore, SimpleMemoryStore, SqlStore, Store | ||
|
||
|
||
def _run_store_test( | ||
store: Store, dataset: Dataset, donations_json: List[Dict[str, Any]] | ||
): | ||
with store.writer() as bulk: | ||
for data in donations_json: | ||
proxy = CompositeEntity.from_data(dataset, data) | ||
bulk.add_entity(proxy) | ||
|
||
view = store.default_view() | ||
proxies = [e for e in view.entities()] | ||
assert len(proxies) == len(donations_json) | ||
|
||
entity = view.get_entity("4e0bd810e1fcb49990a2b31709b6140c4c9139c5") | ||
assert entity.caption == "Tchibo Holding AG" | ||
|
||
tested = False | ||
for prop, value in entity.itervalues(): | ||
if prop.type.name == "entity": | ||
for iprop, ientity in view.get_inverted(value): | ||
assert iprop.reverse == prop | ||
assert ientity == entity | ||
tested = True | ||
assert tested | ||
|
||
adjacent = list(view.get_adjacent(entity)) | ||
assert len(adjacent) == 2 | ||
|
||
writer = store.writer() | ||
stmts = writer.pop(entity.id) | ||
assert len(stmts) == len(list(entity.statements)) | ||
assert view.get_entity(entity.id) is None | ||
|
||
# upsert | ||
with store.writer() as bulk: | ||
for data in donations_json: | ||
proxy = CompositeEntity.from_data(dataset, data) | ||
bulk.add_entity(proxy) | ||
|
||
entity = view.get_entity(entity.id) | ||
assert entity.caption == "Tchibo Holding AG" | ||
return True | ||
|
||
|
||
def test_store_sql( | ||
tmp_path: Path, test_dataset: Dataset, donations_json: List[Dict[str, Any]] | ||
): | ||
resolver = Resolver[CompositeEntity]() | ||
uri = f"sqlite:///{tmp_path / 'test.db'}" | ||
store = SqlStore(dataset=test_dataset, resolver=resolver, uri=uri) | ||
assert str(store.engine.url) == uri | ||
assert _run_store_test(store, test_dataset, donations_json) | ||
|
||
|
||
def test_store_memory(test_dataset: Dataset, donations_json: List[Dict[str, Any]]): | ||
resolver = Resolver[CompositeEntity]() | ||
store = SimpleMemoryStore(dataset=test_dataset, resolver=resolver) | ||
assert _run_store_test(store, test_dataset, donations_json) | ||
|
||
|
||
def test_store_level( | ||
tmp_path: Path, test_dataset: Dataset, donations_json: List[Dict[str, Any]] | ||
): | ||
resolver = Resolver[CompositeEntity]() | ||
store = LevelDBStore( | ||
dataset=test_dataset, resolver=resolver, path=tmp_path / "level.db" | ||
) | ||
assert _run_store_test(store, test_dataset, donations_json) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wait, are we committing after an exception here? That feels a bit unhealty.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed it here: f968f90 and here: 2389094