Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sql store #125

Merged
merged 16 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions nomenklatura/db.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
from contextlib import contextmanager
import os
from pathlib import Path
from contextlib import contextmanager
from functools import cache
from typing import Optional, Generator
from pathlib import Path
from typing import Generator, Optional, Union

from sqlalchemy import MetaData, create_engine
from sqlalchemy.engine import Engine, Connection
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as psql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.schema import Table

from nomenklatura.statement import make_statement_table

DB_PATH = Path("nomenklatura.db").resolve()
DB_URL = os.environ.get("NOMENKLATURA_DB_URL", f"sqlite:///{DB_PATH.as_posix()}")
DB_STORE_TABLE = os.environ.get("NOMENKLATURA_DB_STORE_TABLE", "nk_store")
POOL_SIZE = int(os.environ.get("NOMENKLATURA_DB_POOL_SIZE", "5"))
Conn = Connection
Connish = Optional[Connection]
Expand All @@ -23,6 +31,11 @@ def get_metadata() -> MetaData:
return MetaData()


@cache
def get_statement_table() -> Table:
return make_statement_table(get_metadata(), DB_STORE_TABLE)


@contextmanager
def ensure_tx(conn: Connish = None) -> Generator[Connection, None, None]:
if conn is not None:
Expand All @@ -31,3 +44,14 @@ def ensure_tx(conn: Connish = None) -> Generator[Connection, None, None]:
engine = get_engine()
with engine.begin() as conn:
yield conn


@cache
def get_upsert_func(
engine: Engine,
) -> Union[sqlite_insert, mysql_insert, psql_insert]:
if engine.name == "sqlite":
return sqlite_insert
if engine.name == "mysql":
return mysql_insert
return psql_insert
14 changes: 7 additions & 7 deletions nomenklatura/statement/db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from sqlalchemy import MetaData, Table, Column, DateTime, Unicode, Boolean
from sqlalchemy import Boolean, Column, DateTime, MetaData, Table, Unicode

KEY_LEN = 255
VALUE_LEN = 65535
Expand All @@ -10,16 +10,16 @@ def make_statement_table(metadata: MetaData, name: str = "statement") -> Table:
metadata,
Column("id", Unicode(KEY_LEN), primary_key=True, unique=True),
Column("entity_id", Unicode(KEY_LEN), index=True, nullable=False),
Column("canonical_id", Unicode(KEY_LEN), index=True, nullable=True),
Column("prop", Unicode(KEY_LEN), nullable=False),
Column("prop_type", Unicode(KEY_LEN), nullable=False),
Column("schema", Unicode(KEY_LEN), nullable=False),
Column("canonical_id", Unicode(KEY_LEN), index=True, nullable=False),
Column("prop", Unicode(KEY_LEN), index=True, nullable=False),
Column("prop_type", Unicode(KEY_LEN), index=True, nullable=False),
Column("schema", Unicode(KEY_LEN), index=True, nullable=False),
Column("value", Unicode(VALUE_LEN), nullable=False),
Column("original_value", Unicode(VALUE_LEN), nullable=True),
Column("dataset", Unicode(KEY_LEN), index=True),
Column("lang", Unicode(KEY_LEN), nullable=True),
Column("target", Boolean, default=False, nullable=False),
Column("external", Boolean, default=False, nullable=False),
Column("first_seen", DateTime, nullable=False),
Column("last_seen", DateTime, index=True),
Column("first_seen", DateTime, nullable=True),
Column("last_seen", DateTime, index=True, nullable=True),
)
22 changes: 16 additions & 6 deletions nomenklatura/statement/serialize.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import csv
import click
import orjson
from pathlib import Path
from io import TextIOWrapper
from pathlib import Path
from types import TracebackType
from typing import Optional, Dict, Any, List
from typing import BinaryIO, Generator, Iterable, Type
from typing import Any, BinaryIO, Dict, Generator, Iterable, List, Optional, Type

import click
import orjson
from banal import as_bool
from followthemoney.cli.util import MAX_LINE

from nomenklatura.statement.statement import S
from nomenklatura.util import pack_prop, unpack_prop
from nomenklatura.util import iso_datetime, pack_prop, unpack_prop

JSON = "json"
CSV = "csv"
Expand Down Expand Up @@ -138,6 +139,15 @@ def pack_statement(stmt: S) -> Dict[str, Any]:
return row


def pack_sql_statement(stmt: S) -> Dict[str, Any]:
data: Dict[str, Any] = stmt.to_row()
data["target"] = as_bool(data["target"])
data["external"] = as_bool(data["external"])
data["first_seen"] = iso_datetime(data.get("first_seen"))
data["last_seen"] = iso_datetime(data.get("last_seen"))
return data


def get_statement_writer(fh: BinaryIO, format: str) -> "StatementWriter":
if format == CSV:
return CSVStatementWriter(fh)
Expand Down
13 changes: 9 additions & 4 deletions nomenklatura/store/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import orjson
from pathlib import Path
from typing import Optional

from nomenklatura.store.base import Store, Writer, View
from nomenklatura.store.memory import MemoryStore
from nomenklatura.resolver import Resolver
import orjson

from nomenklatura.dataset import Dataset
from nomenklatura.entity import CompositeEntity
from nomenklatura.resolver import Resolver
from nomenklatura.store.base import Store, View, Writer
from nomenklatura.store.level import LevelDBStore
from nomenklatura.store.memory import MemoryStore
from nomenklatura.store.sql import SqlStore

SimpleMemoryStore = MemoryStore[Dataset, CompositeEntity]

__all__ = [
"Store",
"Writer",
"View",
"LevelDBStore",
"MemoryStore",
"SimpleMemoryStore",
"SqlStore",
"load_entity_file_store",
]

Expand Down
14 changes: 8 additions & 6 deletions nomenklatura/store/level.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import plyvel # type: ignore
from typing import Generator, Optional, Tuple, Any, List, Set
from pathlib import Path
from followthemoney.types import registry
from typing import Any, Generator, List, Optional, Set, Tuple

import plyvel # type: ignore
from followthemoney.property import Property
from followthemoney.types import registry

from nomenklatura.store.base import Store, View, Writer
from nomenklatura.store.util import pack_statement, unpack_statement
from nomenklatura.statement import Statement
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.resolver import Resolver
from nomenklatura.statement import Statement
from nomenklatura.store.base import Store, View, Writer
from nomenklatura.store.util import pack_statement, unpack_statement


class LevelDBStore(Store[DS, CE]):
Expand Down Expand Up @@ -90,6 +91,7 @@ def pop(self, entity_id: str) -> List[Statement]:
key = f"e:{entity_id}:{dataset}".encode("utf-8")
self.batch.delete(key)

self.flush()
return list(statements)


Expand Down
173 changes: 173 additions & 0 deletions nomenklatura/store/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import Any, Generator, List, Optional, Set, Tuple

from followthemoney.property import Property
from sqlalchemy import Table, create_engine, delete, select
from sqlalchemy.sql.selectable import Select

from nomenklatura.dataset import DS
from nomenklatura.db import (
DB_URL,
POOL_SIZE,
get_metadata,
get_statement_table,
get_upsert_func,
)
from nomenklatura.entity import CE
from nomenklatura.resolver import Resolver
from nomenklatura.statement import Statement
from nomenklatura.statement.serialize import pack_sql_statement
from nomenklatura.store import Store, View, Writer


class SqlStore(Store[DS, CE]):
def __init__(
self,
dataset: DS,
resolver: Resolver[CE],
uri: str = DB_URL,
**engine_kwargs: Any,
):
super().__init__(dataset, resolver)
engine_kwargs["pool_size"] = engine_kwargs.pop("pool_size", POOL_SIZE)
self.metadata = get_metadata()
self.engine = create_engine(uri, **engine_kwargs)
self.table = get_statement_table()
self.metadata.create_all(self.engine, tables=[self.table], checkfirst=True)

def writer(self) -> Writer[DS, CE]:
return SqlWriter(self)

def view(self, scope: DS, external: bool = False) -> View[DS, CE]:
return SqlView(self, scope, external=external)

def _execute(self, q: Select, many: bool = True) -> Generator[Any, None, None]:
# execute any read query against sql backend
with self.engine.connect() as conn:
if many:
conn = conn.execution_options(stream_results=True)
cursor = conn.execute(q)
while rows := cursor.fetchmany(10_000):
yield from rows
else:
yield from conn.execute(q)

def _iterate_stmts(
self, q: Select, many: bool = True
) -> Generator[Statement, None, None]:
for row in self._execute(q, many=many):
yield Statement.from_db_row(row)

def _iterate(self, q: Select, many: bool = True) -> Generator[CE, None, None]:
current_id = None
current_stmts: list[Statement] = []
for stmt in self._iterate_stmts(q, many=many):
entity_id = stmt.entity_id
if current_id is None:
current_id = entity_id
if current_id != entity_id:
proxy = self.assemble(current_stmts)
if proxy is not None:
yield proxy
current_id = entity_id
current_stmts = []
current_stmts.append(stmt)
if len(current_stmts):
proxy = self.assemble(current_stmts)
if proxy is not None:
yield proxy


class SqlWriter(Writer[DS, CE]):
BATCH_STATEMENTS = 10_000

def __init__(self, store: SqlStore[DS, CE]):
self.store: SqlStore[DS, CE] = store
self.batch: Optional[Set[Statement]] = None
self.insert = get_upsert_func(self.store.engine)

def flush(self) -> None:
if self.batch:
values = [pack_sql_statement(s) for s in self.batch]
istmt = self.insert(self.store.table).values(values)
stmt = istmt.on_conflict_do_update(
index_elements=["id"],
set_=dict(
canonical_id=istmt.excluded.canonical_id,
schema=istmt.excluded.schema,
prop_type=istmt.excluded.prop_type,
target=istmt.excluded.target,
lang=istmt.excluded.lang,
original_value=istmt.excluded.original_value,
last_seen=istmt.excluded.last_seen,
),
)
with self.store.engine.connect() as conn:
conn.begin()
conn.execute(stmt)
conn.commit()
self.batch = set()

def add_statement(self, stmt: Statement) -> None:
if self.batch is None:
self.batch = set()
if stmt.entity_id is None:
return
if len(self.batch) >= self.BATCH_STATEMENTS:
self.flush()
canonical_id = self.store.resolver.get_canonical(stmt.entity_id)
stmt.canonical_id = canonical_id
self.batch.add(stmt)

def pop(self, entity_id: str) -> List[Statement]:
self.flush()
table = self.store.table
q = select(table).where(table.c.entity_id == entity_id)
q_delete = delete(table).where(table.c.entity_id == entity_id)
statements: List[Statement] = list(self.store._iterate_stmts(q))
with self.store.engine.connect() as conn:
conn.begin()
conn.execute(q_delete)
conn.commit()
return statements


class SqlView(View[DS, CE]):
def __init__(
self, store: SqlStore[DS, CE], scope: DS, external: bool = False
) -> None:
super().__init__(store, scope, external=external)
self.store: SqlStore[DS, CE] = store

def get_entity(self, id: str) -> Optional[CE]:
table = self.store.table
ids = [str(i) for i in self.store.resolver.connected(id)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, interesting, so the idea here is that we're not believing the stmt.canonical_id in the table? While that works, it's a a bit different from all the other store implementations we have now....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, no, i guess i just misunderstood your comment here: #125 (comment)

q = select(table).where(
table.c.entity_id.in_(ids), table.c.dataset.in_(self.dataset_names)
)
for proxy in self.store._iterate(q, many=False):
return proxy
return None

def get_inverted(self, id: str) -> Generator[Tuple[Property, CE], None, None]:
table = self.store.table
q = (
select(table)
.where(table.c.prop_type == "entity", table.c.value == id)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to check self.store.resolver.connected(id) not just id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.distinct(table.c.value)
)
for stmt in self.store._iterate_stmts(q):
if stmt.canonical_id is not None:
entity = self.get_entity(stmt.canonical_id)
if entity is not None:
for prop, value in entity.itervalues():
if value == id and prop.reverse is not None:
yield prop.reverse, entity

def entities(self) -> Generator[CE, None, None]:
table: Table = self.store.table
q = (
select(table)
.where(table.c.dataset.in_(self.dataset_names))
.order_by("canonical_id")
)
yield from self.store._iterate(q)
Loading