Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sql store #125

Merged
merged 16 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions nomenklatura/db.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from contextlib import contextmanager
import os
from pathlib import Path
from contextlib import contextmanager
from functools import cache
from typing import Optional, Generator
from pathlib import Path
from typing import Generator, Optional, Union

from sqlalchemy import MetaData, create_engine
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as psql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.schema import Table

from nomenklatura.statement import make_statement_table

DB_PATH = Path("nomenklatura.db").resolve()
DB_URL = os.environ.get("NOMENKLATURA_DB_URL", f"sqlite:///{DB_PATH.as_posix()}")
DB_STORE_TABLE = os.environ.get("NOMENKLATURA_DB_STORE_TABLE", "nk_store")
POOL_SIZE = int(os.environ.get("NOMENKLATURA_DB_POOL_SIZE", "5"))
Conn = Connection
Connish = Optional[Connection]
Expand All @@ -23,11 +31,31 @@ def get_metadata() -> MetaData:
return MetaData()


@cache
def get_statement_table() -> Table:
return make_statement_table(get_metadata(), DB_STORE_TABLE)


@contextmanager
def ensure_tx(conn: Connish = None) -> Generator[Connection, None, None]:
if conn is not None:
yield conn
return
engine = get_engine()
with engine.begin() as conn:
yield conn
try:
if conn is not None:
yield conn
else:
engine = get_engine()
with engine.begin() as conn:
yield conn
finally:
if conn is not None:
conn.commit()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, are we committing after an exception here? That feels a bit unhealty.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it here: f968f90 and here: 2389094



@cache
def get_upsert_func(
engine: Engine,
) -> Union[sqlite_insert, mysql_insert, psql_insert]:
if engine.name == "sqlite":
return sqlite_insert
if engine.name == "mysql":
return mysql_insert
return psql_insert
8 changes: 4 additions & 4 deletions nomenklatura/statement/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import MetaData, Table, Column, DateTime, Unicode, Boolean
from sqlalchemy import Boolean, Column, DateTime, MetaData, Table, Unicode

KEY_LEN = 255
VALUE_LEN = 65535
Expand All @@ -11,9 +11,9 @@ def make_statement_table(metadata: MetaData, name: str = "statement") -> Table:
Column("id", Unicode(KEY_LEN), primary_key=True, unique=True),
Column("entity_id", Unicode(KEY_LEN), index=True, nullable=False),
Column("canonical_id", Unicode(KEY_LEN), index=True, nullable=True),
Column("prop", Unicode(KEY_LEN), nullable=False),
Column("prop_type", Unicode(KEY_LEN), nullable=False),
Column("schema", Unicode(KEY_LEN), nullable=False),
Column("prop", Unicode(KEY_LEN), index=True, nullable=False),
Column("prop_type", Unicode(KEY_LEN), index=True, nullable=False),
Column("schema", Unicode(KEY_LEN), index=True, nullable=False),
Column("value", Unicode(VALUE_LEN), nullable=False),
Column("original_value", Unicode(VALUE_LEN), nullable=True),
Column("dataset", Unicode(KEY_LEN), index=True),
Expand Down
13 changes: 9 additions & 4 deletions nomenklatura/store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import orjson
from pathlib import Path
from typing import Optional

from nomenklatura.store.base import Store, Writer, View
from nomenklatura.store.memory import MemoryStore
from nomenklatura.resolver import Resolver
import orjson

from nomenklatura.dataset import Dataset
from nomenklatura.entity import CompositeEntity
from nomenklatura.resolver import Resolver
from nomenklatura.store.base import Store, View, Writer
from nomenklatura.store.level import LevelDBStore
from nomenklatura.store.memory import MemoryStore
from nomenklatura.store.sql import SqlStore

SimpleMemoryStore = MemoryStore[Dataset, CompositeEntity]

__all__ = [
"Store",
"Writer",
"View",
"LevelDBStore",
"MemoryStore",
"SimpleMemoryStore",
"SqlStore",
"load_entity_file_store",
]

Expand Down
14 changes: 8 additions & 6 deletions nomenklatura/store/level.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import plyvel # type: ignore
from typing import Generator, Optional, Tuple, Any, List, Set
from pathlib import Path
from followthemoney.types import registry
from typing import Any, Generator, List, Optional, Set, Tuple

import plyvel # type: ignore
from followthemoney.property import Property
from followthemoney.types import registry

from nomenklatura.store.base import Store, View, Writer
from nomenklatura.store.util import pack_statement, unpack_statement
from nomenklatura.statement import Statement
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.resolver import Resolver
from nomenklatura.statement import Statement
from nomenklatura.store.base import Store, View, Writer
from nomenklatura.store.util import pack_statement, unpack_statement


class LevelDBStore(Store[DS, CE]):
Expand Down Expand Up @@ -90,6 +91,7 @@ def pop(self, entity_id: str) -> List[Statement]:
key = f"e:{entity_id}:{dataset}".encode("utf-8")
self.batch.delete(key)

self.flush()
return list(statements)


Expand Down
171 changes: 171 additions & 0 deletions nomenklatura/store/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from datetime import datetime
from typing import Any, Generator, List, Optional, Set, Tuple

from banal import as_bool
from followthemoney.property import Property
from sqlalchemy import Table, create_engine, delete, select
from sqlalchemy.sql.selectable import Select

from nomenklatura.dataset import DS
from nomenklatura.db import (
DB_URL,
POOL_SIZE,
ensure_tx,
get_metadata,
get_statement_table,
get_upsert_func,
)
from nomenklatura.entity import CE
from nomenklatura.resolver import Resolver
from nomenklatura.statement import Statement
from nomenklatura.store import Store, View, Writer


def pack_statement(stmt: Statement) -> dict[str, Any]:
data: dict[str, Any] = stmt.to_row()
data["target"] = as_bool(data["target"])
data["external"] = as_bool(data["external"])
data["first_seen"] = data["first_seen"] or datetime.utcnow()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should make up dates somewhere in the middle of the pipeline. That would lead to pretty random outcomes. Maybe we make those columns nullable instead?

Been struggling with the same issue: https://github.com/opensanctions/opensanctions/blob/main/zavod/zavod/tools/load_db.py#L42-L45

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it nullable: 2d0dd85

I don't know what side-effects this could have in other parts of your pipeline? Especially for existing statement tables... (migrations?)

data["last_seen"] = data["last_seen"] or datetime.utcnow()
return data


class SqlStore(Store[DS, CE]):
def __init__(
self,
dataset: DS,
resolver: Resolver[CE],
uri: str = DB_URL,
**engine_kwargs: Any,
):
super().__init__(dataset, resolver)
engine_kwargs["pool_size"] = engine_kwargs.pop("pool_size", POOL_SIZE)
self.metadata = get_metadata()
self.engine = create_engine(uri, **engine_kwargs)
self.table = get_statement_table()
self.metadata.create_all(self.engine, tables=[self.table], checkfirst=True)

def writer(self) -> Writer[DS, CE]:
return SqlWriter(self)

def view(self, scope: DS, external: bool = False) -> View[DS, CE]:
return SqlView(self, scope, external=external)

def _iterate_stmts(self, q: Select) -> Generator[Statement, None, None]:
with ensure_tx(self.engine.connect()) as conn:
conn = conn.execution_options(stream_results=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has quite a bit of overhead. I wonder if we should pass an option into this func that can disable it for get_entity.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cursor = conn.execute(q)
while rows := cursor.fetchmany(10_000):
for row in rows:
yield Statement.from_db_row(row)

def _iterate(self, q: Select) -> Generator[CE, None, None]:
current_id = None
current_stmts: list[Statement] = []
for stmt in self._iterate_stmts(q):
entity_id = stmt.entity_id
if current_id is None:
current_id = entity_id
if current_id != entity_id:
proxy = self.assemble(current_stmts)
if proxy is not None:
yield proxy
current_id = entity_id
current_stmts = []
current_stmts.append(stmt)
if len(current_stmts):
proxy = self.assemble(current_stmts)
if proxy is not None:
yield proxy


class SqlWriter(Writer[DS, CE]):
BATCH_STATEMENTS = 10_000

def __init__(self, store: SqlStore[DS, CE]):
self.store: SqlStore[DS, CE] = store
self.batch: Optional[Set[Statement]] = None
self.batch_size = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need batch_size if we have a batch array? I assume len(batch) is more precise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3d25cb4

(this originally came from my believe that int > int is more performant than len(foo) > int but gosh, this would never be the bottleneck in the overall sql store implementation) 😂 🙈

self.insert = get_upsert_func(self.store.engine)

def flush(self) -> None:
if self.batch:
values = [pack_statement(s) for s in self.batch]
istmt = self.insert(self.store.table).values(values)
stmt = istmt.on_conflict_do_update(
index_elements=["id"],
set_=dict(
canonical_id=istmt.excluded.canonical_id,
schema=istmt.excluded.schema,
prop_type=istmt.excluded.prop_type,
target=istmt.excluded.target,
lang=istmt.excluded.lang,
original_value=istmt.excluded.original_value,
last_seen=istmt.excluded.last_seen,
),
)
with ensure_tx(self.store.engine.connect()) as conn:
conn.execute(stmt)
self.batch = set()
self.batch_size = 0

def add_statement(self, stmt: Statement) -> None:
if self.batch is None:
self.batch = set()
if stmt.entity_id is None:
return
if self.batch_size >= self.BATCH_STATEMENTS:
self.flush()
canonical_id = self.store.resolver.get_canonical(stmt.entity_id)
stmt.canonical_id = canonical_id
self.batch.add(stmt)
self.batch_size += 1

def pop(self, entity_id: str) -> List[Statement]:
self.flush()
table = self.store.table
q = select(table).where(table.c.entity_id == entity_id)
q_delete = delete(table).where(table.c.entity_id == entity_id)
statements: List[Statement] = list(self.store._iterate_stmts(q))
with ensure_tx(self.store.engine.connect()) as conn:
conn.execute(q_delete)
return statements


class SqlView(View[DS, CE]):
def __init__(
self, store: SqlStore[DS, CE], scope: DS, external: bool = False
) -> None:
super().__init__(store, scope, external=external)
self.store: SqlStore[DS, CE] = store

def get_entity(self, id: str) -> Optional[CE]:
table = self.store.table
q = select(table).where(table.c.entity_id == id)
for proxy in self.store._iterate(q):
return proxy
return None

def get_inverted(self, id: str) -> Generator[Tuple[Property, CE], None, None]:
table = self.store.table
q = (
select(table)
.where(table.c.prop_type == "entity", table.c.value == id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to check self.store.resolver.connected(id) not just id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.distinct(table.c.value)
)
for stmt in self.store._iterate_stmts(q):
if stmt.canonical_id is not None:
entity = self.get_entity(stmt.canonical_id)
if entity is not None:
for prop, value in entity.itervalues():
if value == id and prop.reverse is not None:
yield prop.reverse, entity

def entities(self) -> Generator[CE, None, None]:
table: Table = self.store.table
q = (
select(table)
.where(table.c.dataset.in_(self.dataset_names))
.order_by("entity_id")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be canonical_id? Otherwise it'll return fragmented entities, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

totally right. fixed here: c1dc320

but, canonical_id column was nullable in the sql table, which should not, right?

)
yield from self.store._iterate(q)
80 changes: 80 additions & 0 deletions tests/store/test_stores.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Test if the different store implementations all behave the same.
"""

from pathlib import Path
from typing import Any, Dict, List

from nomenklatura.dataset import Dataset
from nomenklatura.entity import CompositeEntity
from nomenklatura.resolver import Resolver
from nomenklatura.store import LevelDBStore, SimpleMemoryStore, SqlStore, Store


def _run_store_test(
store: Store, dataset: Dataset, donations_json: List[Dict[str, Any]]
):
with store.writer() as bulk:
for data in donations_json:
proxy = CompositeEntity.from_data(dataset, data)
bulk.add_entity(proxy)

view = store.default_view()
proxies = [e for e in view.entities()]
assert len(proxies) == len(donations_json)

entity = view.get_entity("4e0bd810e1fcb49990a2b31709b6140c4c9139c5")
assert entity.caption == "Tchibo Holding AG"

tested = False
for prop, value in entity.itervalues():
if prop.type.name == "entity":
for iprop, ientity in view.get_inverted(value):
assert iprop.reverse == prop
assert ientity == entity
tested = True
assert tested

adjacent = list(view.get_adjacent(entity))
assert len(adjacent) == 2

writer = store.writer()
stmts = writer.pop(entity.id)
assert len(stmts) == len(list(entity.statements))
assert view.get_entity(entity.id) is None

# upsert
with store.writer() as bulk:
for data in donations_json:
proxy = CompositeEntity.from_data(dataset, data)
bulk.add_entity(proxy)

entity = view.get_entity(entity.id)
assert entity.caption == "Tchibo Holding AG"
return True


def test_store_sql(
tmp_path: Path, test_dataset: Dataset, donations_json: List[Dict[str, Any]]
):
resolver = Resolver[CompositeEntity]()
uri = f"sqlite:///{tmp_path / 'test.db'}"
store = SqlStore(dataset=test_dataset, resolver=resolver, uri=uri)
assert str(store.engine.url) == uri
assert _run_store_test(store, test_dataset, donations_json)


def test_store_memory(test_dataset: Dataset, donations_json: List[Dict[str, Any]]):
resolver = Resolver[CompositeEntity]()
store = SimpleMemoryStore(dataset=test_dataset, resolver=resolver)
assert _run_store_test(store, test_dataset, donations_json)


def test_store_level(
tmp_path: Path, test_dataset: Dataset, donations_json: List[Dict[str, Any]]
):
resolver = Resolver[CompositeEntity]()
store = LevelDBStore(
dataset=test_dataset, resolver=resolver, path=tmp_path / "level.db"
)
assert _run_store_test(store, test_dataset, donations_json)
Loading