diff --git a/.github/workflows/test_examples.yaml b/.github/workflows/test_examples.yaml index 5fc3d59b..b3049f1a 100644 --- a/.github/workflows/test_examples.yaml +++ b/.github/workflows/test_examples.yaml @@ -25,6 +25,8 @@ jobs: - many_to_zero - model_inference - one_to_many_pipeline + - batch_transform_with_filters/simple-example + - batch_transform_with_filters/filters-as-function executor: - SingleThreadExecutor - RayExecutor diff --git a/CHANGELOG.md b/CHANGELOG.md index 01657a3c..8f158e98 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ Major changes: See "Migration from v0.13 to v0.14" for more details +* `BatchTransform` has new argument `filters`. It's using to filter processing transform indexes using only that indexes that as indicated in `filters`. See `docs/concepts.md` for more details. + # 0.13.14 * Fix [#334](https://github.com/epoch8/datapipe/issues/334) diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..9f31f38a --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +black: + autoflake -r --in-place --remove-all-unused-imports steps/ *.py brandlink_utils/ + black --verbose --config black.toml steps/ alembic *.py brandlink_utils/ + +mypy: + mypy -p datapipe --ignore-missing-imports --follow-imports=silent --namespace-packages \ No newline at end of file diff --git a/black.toml b/black.toml new file mode 100644 index 00000000..55ec8d78 --- /dev/null +++ b/black.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 120 diff --git a/datapipe/cli.py b/datapipe/cli.py index 79e678b4..664b7c53 100644 --- a/datapipe/cli.py +++ b/datapipe/cli.py @@ -5,6 +5,7 @@ from typing import Dict, List, Optional, cast import click +from datapipe.run_config import RunConfig import pandas as pd import rich from opentelemetry import trace @@ -318,6 +319,7 @@ def step( steps = filter_steps_by_labels_and_name(app, labels=labels_list, name_prefix=name) ctx.obj["steps"] = steps + ctx.obj["labels"] = labels_list def to_human_repr(step: ComputeStep, extra_args: Optional[Dict] = None) -> str: @@ -379,6 +381,7 @@ def step_list(ctx: click.Context, status: bool) -> None: # noqa def step_run(ctx: click.Context, loop: bool, loop_delay: int) -> None: app: DatapipeApp = ctx.obj["pipeline"] steps_to_run: List[ComputeStep] = ctx.obj["steps"] + run_config = RunConfig(labels={k: v for k, v in ctx.obj["labels"]}) executor: Executor = ctx.obj["executor"] @@ -387,7 +390,7 @@ def step_run(ctx: click.Context, loop: bool, loop_delay: int) -> None: while True: if len(steps_to_run) > 0: - run_steps(app.ds, steps_to_run, executor=executor) + run_steps(app.ds, steps_to_run, run_config=run_config, executor=executor) if not loop: break diff --git a/datapipe/datatable.py b/datapipe/datatable.py index 7efa6e78..f42f5fa3 100644 --- a/datapipe/datatable.py +++ b/datapipe/datatable.py @@ -7,6 +7,11 @@ from datapipe.event_logger import EventLogger from datapipe.meta.sql_meta import MetaTable from datapipe.run_config import RunConfig +from datapipe.sql_util import ( + sql_apply_idx_filter_to_table, + sql_apply_runconfig_filters, +) +from datapipe.store.database import DBConn, MetaKey from datapipe.store.database import DBConn from datapipe.store.table_store import TableStore from datapipe.types import DataDF, IndexDF, MetadataDF, data_to_index, index_difference diff --git a/datapipe/meta/sql_meta.py b/datapipe/meta/sql_meta.py index 7374226e..4ed8ff06 100644 --- a/datapipe/meta/sql_meta.py +++ b/datapipe/meta/sql_meta.py @@ -20,7 +20,7 @@ import sqlalchemy as sa from datapipe.run_config import RunConfig -from datapipe.sql_util import sql_apply_idx_filter_to_table, sql_apply_runconfig_filter +from datapipe.sql_util import sql_apply_idx_filter_to_table, sql_apply_runconfig_filters from datapipe.store.database import DBConn, MetaKey from datapipe.types import ( DataDF, @@ -389,7 +389,7 @@ def get_stale_idx( ) ) - sql = sql_apply_runconfig_filter( + sql = sql_apply_runconfig_filters( sql, self.sql_table, self.primary_keys, run_config ) @@ -447,7 +447,7 @@ def get_agg_cte( sql = sql.group_by(*key_cols) sql = sql_apply_filters_idx_to_subquery(sql, keys, filters_idx) - sql = sql_apply_runconfig_filter(sql, tbl, self.primary_keys, run_config) + sql = sql_apply_runconfig_filters_to_subquery(sql, self.primary_keys, run_config) return (keys, sql.cte(name=f"{tbl.name}__update")) @@ -649,7 +649,7 @@ def mark_all_rows_unprocessed( .where(self.sql_table.c.is_success == True) ) - sql = sql_apply_runconfig_filter( + sql = sql_apply_runconfig_filters( update_sql, self.sql_table, self.primary_keys, run_config ) @@ -680,6 +680,18 @@ def sql_apply_filters_idx_to_subquery( return sql +def sql_apply_runconfig_filters_to_subquery( + sql: Any, + keys: List[str], + run_config: Optional[RunConfig] = None, +) -> Any: + if run_config is not None: + filters_idx = pd.DataFrame(run_config.filters) + sql = sql_apply_filters_idx_to_subquery(sql, keys, filters_idx) + + return sql + + @dataclass class ComputeInputCTE: cte: Any @@ -795,6 +807,7 @@ def build_changed_idx_sql( ) out = sql_apply_filters_idx_to_subquery(out, transform_keys, filters_idx) + out = sql_apply_runconfig_filters_to_subquery(out, transform_keys, run_config) out = out.cte(name="transform") diff --git a/datapipe/run_config.py b/datapipe/run_config.py index 4ee9b327..6998f973 100644 --- a/datapipe/run_config.py +++ b/datapipe/run_config.py @@ -1,7 +1,8 @@ from dataclasses import dataclass, field -from typing import Any, Dict, Optional +from typing import Any, List, Dict, Optional -LabelDict = Dict[str, Any] +import pandas as pd +from datapipe.types import LabelDict @dataclass @@ -10,7 +11,7 @@ class RunConfig: # если не пуст, то во время запуска обрабатываются только те строки, # которые строго соответствуют фильтру # (в случае, если у таблицы есть идентификатор с совпадающим именем). - filters: LabelDict = field(default_factory=dict) + filters: List[LabelDict] = field(default_factory=list) labels: LabelDict = field(default_factory=dict) @classmethod @@ -22,3 +23,16 @@ def add_labels(cls, rc: Optional["RunConfig"], labels: LabelDict) -> "RunConfig" ) else: return RunConfig(labels=labels) + + @classmethod + def add_filters(cls, rc: Optional["RunConfig"], filters: List[LabelDict]) -> "RunConfig": + if rc is not None: + return RunConfig( + filters=list( + pd.concat([pd.DataFrame(rc.filters), pd.DataFrame(filters)], ignore_index=True) + .drop_duplicates() + .apply(lambda row : row.dropna().to_dict(), axis=1) + ), + ) + else: + return RunConfig(filters=filters) \ No newline at end of file diff --git a/datapipe/sql_util.py b/datapipe/sql_util.py index 1a7495a1..26ffd806 100644 --- a/datapipe/sql_util.py +++ b/datapipe/sql_util.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, List, Optional +from collections import defaultdict +from typing import Any, Dict, List, Optional, cast -from sqlalchemy import Column, Integer, String, Table, tuple_ +import pandas as pd +from sqlalchemy import Column, Integer, String, Table, column, tuple_ +from sqlalchemy.sql.expression import and_, or_ from datapipe.run_config import RunConfig -from datapipe.types import IndexDF - +from datapipe.types import IndexDF, LabelDict def sql_apply_idx_filter_to_table( sql: Any, @@ -22,22 +24,28 @@ def sql_apply_idx_filter_to_table( keys = tuple_(*[table.c[key] for key in primary_keys]) # type: ignore sql = sql.where( - keys.in_([tuple([r[key] for key in primary_keys]) for r in idx.to_dict(orient="records")]) # type: ignore + keys.in_( + [ + tuple([r[key] for key in primary_keys]) # type: ignore + for r in idx.to_dict(orient="records") + ] + ) ) return sql -def sql_apply_runconfig_filter( +def sql_apply_runconfig_filters( sql: Any, table: Table, - primary_keys: List[str], + keys: List[str], run_config: Optional[RunConfig] = None, ) -> Any: if run_config is not None: - for k, v in run_config.filters.items(): - if k in primary_keys: - sql = sql.where(table.c[k] == v) + filters_idx = pd.DataFrame(run_config.filters) + primary_keys = [key for key in keys if key in table.c and key in filters_idx.columns] + if len(filters_idx) > 0 and len(primary_keys) > 0: + sql = sql_apply_idx_filter_to_table(sql, table, primary_keys, cast(IndexDF, filters_idx)) return sql @@ -49,4 +57,4 @@ def sql_apply_runconfig_filter( def sql_schema_to_dtype(schema: List[Column]) -> Dict[str, Any]: - return {i.name: SCHEMA_TO_DTYPE_LOOKUP[i.type.__class__] for i in schema} # type: ignore + return {i.name: SCHEMA_TO_DTYPE_LOOKUP[i.type.__class__] for i in schema} \ No newline at end of file diff --git a/datapipe/step/batch_transform.py b/datapipe/step/batch_transform.py index 46aad4c7..88b3e82b 100644 --- a/datapipe/step/batch_transform.py +++ b/datapipe/step/batch_transform.py @@ -34,10 +34,14 @@ ) from datapipe.datatable import DataStore, DataTable, MetaTable from datapipe.executor import Executor, ExecutorConfig, SingleThreadExecutor +from datapipe.run_config import RunConfig +from datapipe.store.database import DBConn from datapipe.meta.sql_meta import TransformMetaTable, build_changed_idx_sql from datapipe.run_config import LabelDict, RunConfig from datapipe.types import ( ChangeList, + LabelDict, + Filters, DataDF, IndexDF, JoinSpec, @@ -86,7 +90,7 @@ def __init__( chunk_size: int = 1000, labels: Optional[Labels] = None, executor_config: Optional[ExecutorConfig] = None, - filters: Optional[Union[LabelDict, Callable[[], LabelDict]]] = None, + filters: Optional[Filters] = None, order_by: Optional[List[str]] = None, order: Literal["asc", "desc"] = "asc", ) -> None: @@ -160,25 +164,50 @@ def compute_transform_schema( return (list(inp_out_p_keys), [all_keys[k] for k in inp_out_p_keys]) - def _apply_filters_to_run_config( - self, run_config: Optional[RunConfig] = None - ) -> Optional[RunConfig]: + def _get_filters( + self, + ds: DataStore, + run_config: Optional[RunConfig] = None + ) -> List[LabelDict]: if self.filters is None: - return run_config + return [] + + filters: List[LabelDict] + if isinstance(self.filters, str): + dt = ds.get_table(self.filters) + df = dt.get_data() + filters = cast(List[LabelDict], df[dt.primary_keys].to_dict(orient="records")) + elif isinstance(self.filters, pd.DataFrame): + filters = cast(List[LabelDict], self.filters.to_dict(orient="records")) + elif isinstance(self.filters, list) and all([isinstance(x, dict) for x in self.filters]): + filters = self.filters + elif isinstance(self.filters, Callable): # type: ignore + filters_func = cast(Callable[..., Union[List[LabelDict], IndexDF]], self.filters) + parameters = inspect.signature(filters_func).parameters + kwargs = { + **({"ds": ds} if "ds" in parameters else {}), + **({"run_config": run_config} if "run_config" in parameters else {}) + } + filters_res = filters_func(**kwargs) + if isinstance(filters_res, pd.DataFrame): + filters = cast(List[LabelDict], filters_res.to_dict(orient="records")) + elif isinstance(filters_res, list) and all([isinstance(x, dict) for x in filters_res]): + filters = filters_res + else: + raise TypeError( + "Function filters must return pd.Dataframe or list of key:values." + f" Returned type: {type(filters_res)}" + ) else: - if isinstance(self.filters, dict): - filters = self.filters - elif isinstance(self.filters, Callable): # type: ignore - filters = self.filters() + raise TypeError( + "Argument filters must be pd.Dataframe, list of key:values or function." + f" Got type: {type(self.filters)}" + ) - if run_config is None: - return RunConfig(filters=filters) - else: - run_config = copy.deepcopy(run_config) - filters = copy.deepcopy(filters) - filters.update(run_config.filters) - run_config.filters = filters - return run_config + keys = set([key for keys in filters for key in keys]) + if not all(len(filter) == len(keys) for filter in filters): + raise ValueError("Size of keys from filters must have same length") + return filters def get_status(self, ds: DataStore) -> StepStatus: return StepStatus( @@ -192,7 +221,6 @@ def get_changed_idx_count( ds: DataStore, run_config: Optional[RunConfig] = None, ) -> int: - run_config = self._apply_filters_to_run_config(run_config) _, sql = build_changed_idx_sql( ds=ds, meta_table=self.meta_table, @@ -224,7 +252,6 @@ def get_full_process_ids( - idx_size - количество индексов требующих обработки - idx_df - датафрейм без колонок с данными, только индексная колонка """ - run_config = self._apply_filters_to_run_config(run_config) chunk_size = chunk_size or self.chunk_size with tracer.start_as_current_span("compute ids to process"): @@ -246,24 +273,21 @@ def get_full_process_ids( order=self.order, # type: ignore # pylance is stupid ) - # Список ключей из фильтров, которые нужно добавить в результат - extra_filters: LabelDict if run_config is not None: - extra_filters = { - k: v for k, v in run_config.filters.items() if k not in join_keys - } + extra_filters = pd.DataFrame(run_config.filters) else: - extra_filters = {} + extra_filters = None def alter_res_df(): with ds.meta_dbconn.con.begin() as con: for df in pd.read_sql_query(u1, con=con, chunksize=chunk_size): df = df[self.transform_keys] - - for k, v in extra_filters.items(): - df[k] = v - - yield cast(IndexDF, df) + if extra_filters is not None and len(extra_filters) > 0: + if len(set(df.columns).intersection(extra_filters.columns)) > 0: + df = pd.merge(df, extra_filters) + else: + df = pd.merge(df, extra_filters, how="cross") + yield df return math.ceil(idx_count / chunk_size), alter_res_df() @@ -273,7 +297,6 @@ def get_change_list_process_ids( change_list: ChangeList, run_config: Optional[RunConfig] = None, ) -> Tuple[int, Iterable[IndexDF]]: - run_config = self._apply_filters_to_run_config(run_config) with tracer.start_as_current_span("compute ids to process"): changes = [pd.DataFrame(columns=self.transform_keys)] @@ -324,8 +347,6 @@ def store_batch_result( process_ts: float, run_config: Optional[RunConfig] = None, ) -> ChangeList: - run_config = self._apply_filters_to_run_config(run_config) - changes = ChangeList() if output_dfs is not None: @@ -371,8 +392,6 @@ def store_batch_err( process_ts: float, run_config: Optional[RunConfig] = None, ) -> None: - run_config = self._apply_filters_to_run_config(run_config) - idx_records = idx.to_dict(orient="records") logger.error( @@ -475,8 +494,10 @@ def run_full( logger.info(f"Running: {self.name}") run_config = RunConfig.add_labels(run_config, {"step_name": self.name}) + filters = self._get_filters(ds, run_config) + run_config_with_filters = RunConfig.add_filters(run_config, filters) - (idx_count, idx_gen) = self.get_full_process_ids(ds=ds, run_config=run_config) + (idx_count, idx_gen) = self.get_full_process_ids(ds=ds, run_config=run_config_with_filters) logger.info(f"Batches to process {idx_count}") @@ -489,7 +510,7 @@ def run_full( idx_count=idx_count, idx_gen=idx_gen, process_fn=self.process_batch, - run_config=run_config, + run_config=run_config_with_filters, executor_config=self.executor_config, ) @@ -506,9 +527,11 @@ def run_changelist( executor = SingleThreadExecutor() run_config = RunConfig.add_labels(run_config, {"step_name": self.name}) + filters = self._get_filters(ds, run_config) + run_config_with_filters = RunConfig.add_filters(run_config, filters) (idx_count, idx_gen) = self.get_change_list_process_ids( - ds, change_list, run_config + ds, change_list, run_config_with_filters ) logger.info(f"Batches to process {idx_count}") @@ -524,7 +547,7 @@ def run_changelist( idx_count=idx_count, idx_gen=idx_gen, process_fn=self.process_batch, - run_config=run_config, + run_config=run_config_with_filters, executor_config=self.executor_config, ) @@ -559,6 +582,10 @@ class DatatableBatchTransform(PipelineStep): transform_keys: Optional[List[str]] = None kwargs: Optional[Dict] = None labels: Optional[Labels] = None + executor_config: Optional[ExecutorConfig] = None + filters: Optional[Filters] = None + order_by: Optional[List[str]] = None + order: Literal["asc", "desc"] = "asc" def build_compute(self, ds: DataStore, catalog: Catalog) -> List[ComputeStep]: input_dts = [catalog.get_datatable(ds, name) for name in self.inputs] @@ -575,6 +602,10 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> List[ComputeStep]: transform_keys=self.transform_keys, chunk_size=self.chunk_size, labels=self.labels, + executor_config=self.executor_config, + filters=self.filters, + order_by=self.order_by, + order=self.order ) ] @@ -591,6 +622,10 @@ def __init__( transform_keys: Optional[List[str]] = None, chunk_size: int = 1000, labels: Optional[Labels] = None, + executor_config: Optional[ExecutorConfig] = None, + filters: Optional[Filters] = None, + order_by: Optional[List[str]] = None, + order: Literal["asc", "desc"] = "asc", ) -> None: super().__init__( ds=ds, @@ -600,10 +635,14 @@ def __init__( transform_keys=transform_keys, chunk_size=chunk_size, labels=labels, + executor_config=executor_config, + filters=filters, + order_by=order_by, + order=order ) self.func = func - self.kwargs = kwargs + self.kwargs = kwargs or {} def process_batch_dts( self, @@ -630,7 +669,7 @@ class BatchTransform(PipelineStep): transform_keys: Optional[List[str]] = None labels: Optional[Labels] = None executor_config: Optional[ExecutorConfig] = None - filters: Optional[Union[LabelDict, Callable[[], LabelDict]]] = None + filters: Optional[Filters] = None order_by: Optional[List[str]] = None order: Literal["asc", "desc"] = "asc" @@ -690,7 +729,7 @@ def __init__( chunk_size: int = 1000, labels: Optional[Labels] = None, executor_config: Optional[ExecutorConfig] = None, - filters: Optional[Union[LabelDict, Callable[[], LabelDict]]] = None, + filters: Optional[Filters] = None, order_by: Optional[List[str]] = None, order: Literal["asc", "desc"] = "asc", ) -> None: diff --git a/datapipe/store/database.py b/datapipe/store/database.py index aab69d36..518a07bd 100644 --- a/datapipe/store/database.py +++ b/datapipe/store/database.py @@ -13,7 +13,7 @@ from sqlalchemy.sql.expression import delete, select from datapipe.run_config import RunConfig -from datapipe.sql_util import sql_apply_idx_filter_to_table, sql_apply_runconfig_filter +from datapipe.sql_util import sql_apply_idx_filter_to_table, sql_apply_runconfig_filters from datapipe.store.table_store import TableStore from datapipe.types import DataDF, DataSchema, IndexDF, MetaSchema, OrmTable, TAnyDF @@ -303,7 +303,7 @@ def read_rows_meta_pseudo_df( ) -> Iterator[DataDF]: sql = select(*self.data_table.c) - sql = sql_apply_runconfig_filter( + sql = sql_apply_runconfig_filters( sql, self.data_table, self.primary_keys, run_config ) diff --git a/datapipe/store/filedir.py b/datapipe/store/filedir.py index 9519e878..6b30bb02 100644 --- a/datapipe/store/filedir.py +++ b/datapipe/store/filedir.py @@ -5,7 +5,7 @@ import re from abc import ABC from pathlib import Path -from typing import IO, Any, Dict, Iterator, List, Optional, Union, cast +from typing import IO, Any, Dict, Iterator, List, Optional, Union, cast, Set import fsspec import numpy as np @@ -241,7 +241,7 @@ def __init__( self.attrnames = _pattern_to_attrnames(filename_pattern) self.filename_glob = [_pattern_to_glob(pat) for pat in self.filename_patterns] self.filename_match = _pattern_to_match(filename_pattern_for_match) - self.filename_match_first_suffix = _pattern_to_match(self.filename_patterns[0]) + self.filename_match_suffixes = [_pattern_to_match(pattern) for pattern in self.filename_patterns] # Any * and ** pattern check if "*" in path: @@ -492,15 +492,22 @@ def read_rows_meta_pseudo_df( ids: Dict[str, List[str]] = {attrname: [] for attrname in self.attrnames} ukeys = [] filepaths = [] - + looked_keys: Set[Any] = set() for f in files: - m = re.match(self.filename_match_first_suffix, f.path) - + for filemath_match_suffix in self.filename_match_suffixes: + m = re.match(filemath_match_suffix, f"{self.protocol_str}{f.path}") + if m is not None: + break if m is None: continue - for attrname in self.attrnames: - ids[attrname].append(m.group(attrname)) + keys_values = tuple(m.group(attrname) for attrname in self.attrnames) + if keys_values in looked_keys: + continue + looked_keys.add(keys_values) + + for attrname, key_value in zip(self.attrnames, keys_values): + ids[attrname].append(key_value) ukeys.append(files.fs.ukey(f.path)) # type: ignore filepaths.append(f"{self.protocol_str}{f.path}") diff --git a/datapipe/types.py b/datapipe/types.py index 65e8f2fa..620d8b56 100644 --- a/datapipe/types.py +++ b/datapipe/types.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from typing import ( TYPE_CHECKING, + Any, Callable, Dict, List, @@ -15,7 +16,6 @@ Union, cast, ) - import pandas as pd from sqlalchemy import Column @@ -41,6 +41,9 @@ TransformResult = Union[DataDF, List[DataDF], Tuple[DataDF, ...]] +LabelDict = Dict[str, Any] +Filters = Union[str, IndexDF, List[LabelDict], Callable[..., List[LabelDict]], Callable[..., IndexDF]] + from sqlalchemy.orm.decl_api import DeclarativeMeta OrmTable = Type[DeclarativeMeta] diff --git a/docs/source/SUMMARY.md b/docs/source/SUMMARY.md index 87a7f5a3..af4ccbe3 100644 --- a/docs/source/SUMMARY.md +++ b/docs/source/SUMMARY.md @@ -8,6 +8,10 @@ - [Transform](./concepts-transform.md) - [How merging works](./how-merging-works.md) +# Concepts + +- [Concepts](./concepts.md) + # Command Line Interface - [Command Line Interface](./cli.md) diff --git a/docs/source/concepts-transform.md b/docs/source/concepts-transform.md index 59b66aab..c9d125d6 100644 --- a/docs/source/concepts-transform.md +++ b/docs/source/concepts-transform.md @@ -50,9 +50,17 @@ обработки. Имеет magic injection: -- Если у функции `func` есть аргумент `ds`, то туда передатся используемый `DataStore`. -- Если у функции `func` есть аргумент `run_config`, то туда передатся используемый текущий `RunConfig`. -- Если у функции `func` есть аргумент `idx`, то туда передатся используемый `IndexDF` -- текущие индексы обработки. +* Если у функции `func` есть аргумент `ds`, то туда передатся используемый `DataStore`. +* Если у функции `func` есть аргумент `run_config`, то туда передатся используемый текущий `RunConfig`. +* Если у функции `func` есть аргумент `idx`, то туда передатся используемый `IndexDF` -- текущие индексы обработки. + +Имеет поддержку фильтров `filters`, аналогичный по свойству в `RunConfig`. При использовании фильтров при вычислении текущих индексов обработки берутся только те значения, которые записаны в `filters`. Индексы, не попавшие в фильтрацию, не считаются обработанными и они могут быть обработаны, убрав фильтрацию. +Параметр принимает одно из значений: +* список ключи:значения, например `[{"idx": "0"}, {"idx": "1"}]` +* `IndexDF` (датафрейм с индексами) +* вызываемую функцию, которая на выходе возвращает либо список ключи:значения, либо `IndexDF`. Для этой функции есть также поддержка magic injection на аргументы `ds` и `run_config`. С помощью динамической функции и `run_config` можно реализовать довольно сложную логику фильтрации. + +Примеры запуска трансформации с фильтров можно найти в `examples/batch_transform_with_filters` ## `BatchGenerate` diff --git a/examples/batch_transform_with_filters/filters-as-function/.gitignore b/examples/batch_transform_with_filters/filters-as-function/.gitignore new file mode 100644 index 00000000..74d0d3a3 --- /dev/null +++ b/examples/batch_transform_with_filters/filters-as-function/.gitignore @@ -0,0 +1 @@ +output.jsonline diff --git a/examples/batch_transform_with_filters/filters-as-function/app.py b/examples/batch_transform_with_filters/filters-as-function/app.py new file mode 100644 index 00000000..9ec7a3fc --- /dev/null +++ b/examples/batch_transform_with_filters/filters-as-function/app.py @@ -0,0 +1,67 @@ +import pandas as pd +import sqlalchemy as sa + +from datapipe.compute import Catalog, DatapipeApp, Pipeline, Table +from datapipe.datatable import DataStore +from datapipe.run_config import RunConfig +from datapipe.step.batch_transform import BatchTransform +from datapipe.step.update_external_table import UpdateExternalTable +from datapipe.store.database import DBConn +from datapipe.store.pandas import TableStoreJsonLine + +dbconn = DBConn("sqlite+pysqlite3:///db.sqlite") +ds = DataStore(dbconn) + + +def filter_cases(ds: DataStore, run_config: RunConfig): + label = run_config.labels.get("stage", None) + if label == "label1": + return pd.DataFrame({"input_id": [1, 3, 4, 6, 9]}) + elif label == "label2": + return pd.DataFrame({"input_id": [2, 6, 9]}) + else: + return pd.DataFrame({"input_id": [6, 9]}) + + +def apply_transformation(input_df: pd.DataFrame) -> pd.DataFrame: + input_df["text"] = "Yay! I have been transformed." + return input_df + + +input_tbl = Table( + name="input", + store=TableStoreJsonLine( + filename="input.jsonline", + primary_schema=[ + sa.Column("input_id", sa.Integer, primary_key=True), + ], + ), +) + +output_tbl = Table( + name="output", + store=TableStoreJsonLine( + filename="output.jsonline", + primary_schema=[ + sa.Column("input_id", sa.Integer, primary_key=True), + ], + ), +) + + +pipeline = Pipeline( + [ + UpdateExternalTable( + output=input_tbl, + ), + BatchTransform( + apply_transformation, + inputs=[input_tbl], + outputs=[output_tbl], + labels=[("stage", "label1"), ("stage", "label2")], + filters=filter_cases + ), + ] +) + +app = DatapipeApp(ds, Catalog({}), pipeline) diff --git a/examples/batch_transform_with_filters/filters-as-function/input.jsonline b/examples/batch_transform_with_filters/filters-as-function/input.jsonline new file mode 100644 index 00000000..7108cc6a --- /dev/null +++ b/examples/batch_transform_with_filters/filters-as-function/input.jsonline @@ -0,0 +1,9 @@ +{"input_id": 1, "text": "I need to be transformed when labels=stage=label1."} +{"input_id": 2, "text": "I need to be transformed when labels=stage=label2."} +{"input_id": 3, "text": "I need to be transformed when labels=stage=label1."} +{"input_id": 4, "text": "I need to be transformed when labels=stage=label1."} +{"input_id": 5, "text": "I need to be ignored."} +{"input_id": 6, "text": "I need to be transformed anytime."} +{"input_id": 7, "text": "I need to be ignored."} +{"input_id": 8, "text": "I need to be ignored."} +{"input_id": 9, "text": "I need to be transformed anytime."} diff --git a/examples/batch_transform_with_filters/simple-example/.gitignore b/examples/batch_transform_with_filters/simple-example/.gitignore new file mode 100644 index 00000000..74d0d3a3 --- /dev/null +++ b/examples/batch_transform_with_filters/simple-example/.gitignore @@ -0,0 +1 @@ +output.jsonline diff --git a/examples/batch_transform_with_filters/simple-example/app.py b/examples/batch_transform_with_filters/simple-example/app.py new file mode 100644 index 00000000..6dda3ae0 --- /dev/null +++ b/examples/batch_transform_with_filters/simple-example/app.py @@ -0,0 +1,61 @@ +import pandas as pd +import sqlalchemy as sa + +from datapipe.compute import Catalog, DatapipeApp, Pipeline, Table +from datapipe.datatable import DataStore +from datapipe.step.batch_transform import BatchTransform +from datapipe.step.update_external_table import UpdateExternalTable +from datapipe.store.database import DBConn +from datapipe.store.pandas import TableStoreJsonLine + +dbconn = DBConn("sqlite+pysqlite3:///db.sqlite") +ds = DataStore(dbconn) + + +def apply_transformation(input_df: pd.DataFrame) -> pd.DataFrame: + input_df["text"] = "Yay! I have been transformed." + return input_df + + +input_tbl = Table( + name="input", + store=TableStoreJsonLine( + filename="input.jsonline", + primary_schema=[ + sa.Column("input_id", sa.Integer, primary_key=True), + ], + ), +) + +output_tbl = Table( + name="output", + store=TableStoreJsonLine( + filename="output.jsonline", + primary_schema=[ + sa.Column("input_id", sa.Integer, primary_key=True), + ], + ), +) + + +pipeline = Pipeline( + [ + UpdateExternalTable( + output=input_tbl, + ), + BatchTransform( + apply_transformation, + inputs=[input_tbl], + outputs=[output_tbl], + filters=[ + {"input_id": 1}, + {"input_id": 3}, + {"input_id": 4}, + {"input_id": 6}, + {"input_id": 9} + ] + ), + ] +) + +app = DatapipeApp(ds, Catalog({}), pipeline) diff --git a/examples/batch_transform_with_filters/simple-example/input.jsonline b/examples/batch_transform_with_filters/simple-example/input.jsonline new file mode 100644 index 00000000..77185c50 --- /dev/null +++ b/examples/batch_transform_with_filters/simple-example/input.jsonline @@ -0,0 +1,9 @@ +{"input_id": 1, "text": "I need to be transformed."} +{"input_id": 2, "text": "I need to be ignored."} +{"input_id": 3, "text": "I need to be transformed."} +{"input_id": 4, "text": "I need to be transformed."} +{"input_id": 5, "text": "I need to be ignored."} +{"input_id": 6, "text": "I need to be transformed."} +{"input_id": 7, "text": "I need to be ignored."} +{"input_id": 8, "text": "I need to be ignored."} +{"input_id": 9, "text": "I need to be transformed"} diff --git a/examples/model_inference/app.py b/examples/model_inference/app.py index 26b03c3b..237ba1ac 100644 --- a/examples/model_inference/app.py +++ b/examples/model_inference/app.py @@ -35,8 +35,8 @@ def apply_model(input_df: pd.DataFrame, model_df: pd.DataFrame) -> pd.DataFrame: store=TableStoreJsonLine( filename="input.jsonline", primary_schema=[ - sa.Column("pipeline_id", sa.String, primary_key=True), sa.Column("input_id", sa.Integer, primary_key=True), + sa.Column("pipeline_id", sa.String, primary_key=True) ], ), ) @@ -57,8 +57,8 @@ def apply_model(input_df: pd.DataFrame, model_df: pd.DataFrame) -> pd.DataFrame: store=TableStoreJsonLine( filename="output.jsonline", primary_schema=[ - sa.Column("pipeline_id", sa.String, primary_key=True), sa.Column("input_id", sa.Integer, primary_key=True), + sa.Column("pipeline_id", sa.String, primary_key=True), sa.Column("model_id", sa.String, primary_key=True), ], ), diff --git a/examples/model_inference/playground.ipynb b/examples/model_inference/playground.ipynb deleted file mode 100644 index 0196f2e5..00000000 --- a/examples/model_inference/playground.ipynb +++ /dev/null @@ -1,76 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import app\n", - "step = app.app.steps[-1]" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WITH input_meta__update_ts AS \n", - "(SELECT input_id, max(input_meta.update_ts) AS update_ts \n", - "FROM input_meta GROUP BY input_id), \n", - "models_meta__update_ts AS \n", - "(SELECT model_id, max(models_meta.update_ts) AS update_ts \n", - "FROM models_meta GROUP BY model_id), \n", - "all__update_ts AS \n", - "(SELECT input_meta__update_ts.input_id AS input_id, models_meta__update_ts.model_id AS model_id, max(input_meta__update_ts.update_ts, models_meta__update_ts.update_ts) AS update_ts \n", - "FROM input_meta__update_ts FULL OUTER JOIN models_meta__update_ts ON :param_1), \n", - "transform AS \n", - "(SELECT input_id, model_id, apply_functions_634cbbc660_meta.process_ts AS process_ts \n", - "FROM apply_functions_634cbbc660_meta \n", - "WHERE apply_functions_634cbbc660_meta.is_success = true GROUP BY input_id, model_id)\n", - " SELECT coalesce(all__update_ts.input_id, transform.input_id) AS input_id, coalesce(all__update_ts.model_id, transform.model_id) AS model_id \n", - "FROM all__update_ts FULL OUTER JOIN transform ON all__update_ts.input_id = transform.input_id AND all__update_ts.model_id = transform.model_id \n", - "WHERE transform.process_ts < all__update_ts.update_ts OR all__update_ts.update_ts IS NULL OR transform.process_ts IS NULL\n" - ] - } - ], - "source": [ - "_, sql = step._build_changed_idx_sql(ds=app.app.ds)\n", - "print(str(sql))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "venv", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index 8368a45a..f9dc079b 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -293,6 +293,227 @@ def train( ) == len(TEST__FROZEN_DATASET) * len(TEST__TRAIN_CONFIG) +def test_complex_transform_with_filters(dbconn): + ds = DataStore(dbconn, create_meta_table=True) + catalog = Catalog({ + "tbl_image": Table( + store=TableStoreDB( + dbconn, + "tbl_image", + [ + Column("image_id", Integer, primary_key=True), + ], + True + ) + ), + "tbl_subset__has__image": Table( + store=TableStoreDB( + dbconn, + "tbl_subset__has__image", + [ + Column("image_id", Integer, primary_key=True), + Column("subset_id", Integer, primary_key=True), + ], + True + ) + ), + "tbl_prediction": Table( + store=TableStoreDB( + dbconn, + "tbl_prediction", + [ + Column("image_id", Integer, primary_key=True), + Column("model_id", Integer, primary_key=True), + ], + True + ) + ), + "tbl_output": Table( + store=TableStoreDB( + dbconn, + "tbl_output", + [ + Column("subset_id", Integer, primary_key=True), + Column("model_id", Integer, primary_key=True), + Column("count", Integer), + ], + True + ) + ) + }) + + def gen_tbl(df): + yield df + + test_df__image = pd.DataFrame({ + "image_id": range(1000) + }) + test_df__subset__has__image = pd.DataFrame({ + "image_id": range(1000), + "subset_id": [0 for _ in range(200)] + [1 for _ in range(200)] + [2 for _ in range(600)] + }) + test_df__prediction = pd.DataFrame({ + "image_id": list(range(1000)) + list(range(1000)), + "model_id": [0] * 1000 + [1] * 1000 + }) + + def count_func( + df__image: pd.DataFrame, + df__subset__has__image: pd.DataFrame, + df__prediction: pd.DataFrame, + ): + df__image = pd.merge(df__image, df__subset__has__image, on=["image_id"]) + df__image = pd.merge(df__image, df__prediction, on=["image_id"]) + df__output = df__image.groupby(["subset_id", "model_id"]).agg(len).reset_index().rename(columns={"image_id": "count"}) + return df__output + + pipeline = Pipeline( + [ + BatchGenerate( + func=gen_tbl, + outputs=["tbl_image"], + kwargs=dict(df=test_df__image), + ), + BatchGenerate( + func=gen_tbl, + outputs=["tbl_subset__has__image"], + kwargs=dict(df=test_df__subset__has__image), + ), + BatchGenerate( + func=gen_tbl, + outputs=["tbl_prediction"], + kwargs=dict(df=test_df__prediction), + ), + BatchTransform( + func=count_func, + inputs=["tbl_image", "tbl_subset__has__image", "tbl_prediction"], + outputs=["tbl_output"], + transform_keys=["subset_id", "model_id"], + chunk_size=100, + filters=[ + {"subset_id": 0}, + {"subset_id": 1}, + ] + ), + ] + ) + steps = build_compute(ds, catalog, pipeline) + run_steps(ds, steps) + test__df_output = count_func( + df__image=test_df__image, + df__subset__has__image=test_df__subset__has__image[ + test_df__subset__has__image["subset_id"].isin([0, 1]) + ], + df__prediction=test_df__prediction + ) + assert_df_equal( + ds.get_table("tbl_output").get_data(), + test__df_output, + index_cols=["model_id"] + ) + + +def complex_transform_with_filters2_by_N(dbconn, N): + ds = DataStore(dbconn, create_meta_table=True) + catalog = Catalog({ + "tbl_image": Table( + store=TableStoreDB( + dbconn, + "tbl_image", + [ + Column("image_id", Integer, primary_key=True), + ], + True + ) + ), + "tbl_model": Table( + store=TableStoreDB( + dbconn, + "tbl_model", + [ + Column("model_id", Integer, primary_key=True), + ], + True + ) + ), + "tbl_prediction": Table( + store=TableStoreDB( + dbconn, + "tbl_prediction", + [ + Column("image_id", Integer, primary_key=True), + Column("model_id", Integer, primary_key=True), + ], + True + ) + ) + }) + + def gen_tbl(df): + yield df + + test_df__image = pd.DataFrame({ + "image_id": range(N) + }) + test_df__model = pd.DataFrame({ + "model_id": [0, 1, 2, 3, 4] + }) + + def filters_images(): + return [{"image_id": i} for i in range(N // 2)] + + def make_prediction( + df__image: pd.DataFrame, + df__model: pd.DataFrame, + ): + df__prediction = pd.merge(df__image, df__model, how="cross") + return df__prediction + + pipeline = Pipeline( + [ + BatchGenerate( + func=gen_tbl, + outputs=["tbl_image"], + kwargs=dict(df=test_df__image), + ), + BatchGenerate( + func=gen_tbl, + outputs=["tbl_model"], + kwargs=dict(df=test_df__model), + ), + BatchTransform( + func=make_prediction, + inputs=["tbl_image", "tbl_model"], + outputs=["tbl_prediction"], + transform_keys=["image_id", "model_id"], + chunk_size=1000, + filters=filters_images + ), + ] + ) + steps = build_compute(ds, catalog, pipeline) + run_steps(ds, steps) + test__df_output = make_prediction( + df__image=test_df__image[ + test_df__image["image_id"].isin([r["image_id"] for r in filters_images()]) + ], + df__model=test_df__model + ) + assert_df_equal( + ds.get_table("tbl_prediction").get_data(), + test__df_output, + index_cols=["model_id"] + ) + +def test_complex_transform_with_filters2_N100(dbconn): + complex_transform_with_filters2_by_N(dbconn, N=100) + + +@pytest.mark.skip(reason="big filters not supported yet") +def test_complex_transform_with_filters2_N10000(dbconn): + complex_transform_with_filters2_by_N(dbconn, N=10000) + + def complex_transform_with_many_recordings(dbconn, N: int): ds = DataStore(dbconn, create_meta_table=True) catalog = Catalog( diff --git a/tests/test_core_steps2.py b/tests/test_core_steps2.py index 26b54022..90237339 100644 --- a/tests/test_core_steps2.py +++ b/tests/test_core_steps2.py @@ -4,6 +4,7 @@ # import pytest import time +from typing import Optional, cast import pandas as pd from sqlalchemy import Column, String @@ -15,7 +16,7 @@ from datapipe.step.batch_generate import do_batch_generate from datapipe.step.batch_transform import BatchTransformStep from datapipe.store.database import MetaKey, TableStoreDB -from datapipe.types import ChangeList, IndexDF +from datapipe.types import ChangeList, Filters, IndexDF from .util import assert_datatable_equal, assert_df_equal @@ -116,7 +117,7 @@ def test_batch_transform(dbconn): assert all(meta_df["process_ts"] == process_ts) -def test_batch_transform_with_filter(dbconn): +def test_batch_transform_with_filter_in_run_config(dbconn): ds = DataStore(dbconn, create_meta_table=True) tbl1 = ds.create_table( @@ -139,13 +140,42 @@ def test_batch_transform_with_filter(dbconn): step.run_full( ds, run_config=RunConfig( - filters={"pipeline_id": 0}, + filters=[{"pipeline_id": 0}], ), ) assert_datatable_equal(tbl2, TEST_DF1_1.query("pipeline_id == 0")) +def test_batch_transform_with_filter_in_run_config_not_in_transform_index(dbconn): + ds = DataStore(dbconn, create_meta_table=True) + + tbl1 = ds.create_table( + "tbl1", table_store=TableStoreDB(dbconn, "tbl1_data", TEST_SCHEMA1, True) + ) + + tbl2 = ds.create_table( + "tbl2", table_store=TableStoreDB(dbconn, "tbl2_data", TEST_SCHEMA2, True) + ) + + tbl1.store_chunk(TEST_DF1_2, now=0) + + step = BatchTransformStep( + ds=ds, + name="test", + func=lambda df: df[["item_id", "a"]], + input_dts=[ComputeInput(dt=tbl1, join_type="full")], + output_dts=[tbl2], + ) + + step.run_full( + ds, + run_config=RunConfig(filters=[{"pipeline_id": 0}]), + ) + + assert_datatable_equal(tbl2, TEST_DF1_2.query("pipeline_id == 0")[["item_id", "a"]]) + + def test_batch_transform_with_filter_not_in_transform_index(dbconn): ds = DataStore(dbconn, create_meta_table=True) @@ -165,11 +195,11 @@ def test_batch_transform_with_filter_not_in_transform_index(dbconn): func=lambda df: df[["item_id", "a"]], input_dts=[ComputeInput(dt=tbl1, join_type="full")], output_dts=[tbl2], + filters=[{"pipeline_id": 0}] ) step.run_full( ds, - run_config=RunConfig(filters={"pipeline_id": 0}), ) assert_datatable_equal(tbl2, TEST_DF1_2.query("pipeline_id == 0")[["item_id", "a"]]) @@ -301,7 +331,7 @@ def gen_func(): func=gen_func, ds=ds, output_dts=[tbl], - run_config=RunConfig(filters={"pipeline_id": 0}), + run_config=RunConfig(filters=[{"pipeline_id": 0}]), ) assert_datatable_equal( @@ -394,3 +424,115 @@ def update_df(products: pd.DataFrame, items: pd.DataFrame): items2_df = merged_df[["item_id", "pipeline_id", "product_id", "a"]] assert_df_equal(items2.get_data(), items2_df, index_cols=["item_id", "pipeline_id"]) + + +def batch_transform_with_filters(dbconn, filters: Filters, ds: Optional[DataStore] = None): + if ds is None: + ds = DataStore(dbconn, create_meta_table=True) + + item = ds.create_table( + "item", + table_store=TableStoreDB( + dbconn, + "item", + [Column("item_id", Integer, primary_key=True)], + True + ), + ) + + inner_item = ds.create_table( + "inner_item", table_store=TableStoreDB( + dbconn, + "inner_item", + [ + Column("item_id", Integer, primary_key=True), + Column("inner_item_id", Integer, primary_key=True) + ], + True + ) + ) + + output = ds.create_table( + "output", table_store=TableStoreDB( + dbconn, + "output", + [ + Column("item_id", Integer, primary_key=True), + Column("inner_item_id", Integer, primary_key=True) + ], + True + ) + ) + + test_df__item = pd.DataFrame( + { + "item_id": list(range(10)), + } + ) + + test_df__inner_item = pd.DataFrame( + { + "item_id": list(range(10)) * 10, + "inner_item_id": list(range(100)), + } + ) + item.store_chunk(test_df__item, now=0) + inner_item.store_chunk(test_df__inner_item, now=0) + + def update_df(df__item: pd.DataFrame, df__inner_item: pd.DataFrame): + merged_df = pd.merge(df__item, df__inner_item, on=["item_id"]) + return merged_df + + step = BatchTransformStep( + ds=ds, + name="test", + func=update_df, + input_dts=[ComputeInput(dt=item, join_type="full"), ComputeInput(dt=inner_item, join_type="full")], + output_dts=[output], + filters=filters + ) + + step.run_full(ds) + + test_df__output = update_df( + df__item=test_df__item[test_df__item["item_id"].isin([0, 1, 2])], + df__inner_item=test_df__inner_item[test_df__inner_item["item_id"].isin([0, 1, 2])] + ) + + assert_df_equal(output.get_data(), test_df__output, index_cols=["item_id", "inner_item_id"]) + + +def test_batch_transform_with_filters_as_str(dbconn): + ds = DataStore(dbconn, create_meta_table=True) + filters_data = pd.DataFrame([{"item_id": 0}, {"item_id": 1}, {"item_id": 2}]) + filters = ds.create_table( + "filters_data", table_store=TableStoreDB( + dbconn, "filters_data", [Column("item_id", Integer, primary_key=True)], True + ) + ) + filters.store_chunk(filters_data, now=0) + batch_transform_with_filters(dbconn, filters="filters_data", ds=ds) + + +def test_batch_transform_with_filters_as_IndexDF(dbconn): + batch_transform_with_filters( + dbconn, filters=cast(IndexDF, pd.DataFrame([{"item_id": 0}, {"item_id": 1}, {"item_id": 2}])) + ) + + +def test_batch_transform_with_filters_as_list_of_dict(dbconn): + batch_transform_with_filters(dbconn, filters=[{"item_id": 0}, {"item_id": 1}, {"item_id": 2}]) + + +def test_batch_transform_with_filters_as_callable_IndexDF(dbconn): + def callable(ds: DataStore, run_config: Optional[RunConfig]): + return cast(IndexDF, pd.DataFrame([{"item_id": 0}, {"item_id": 1}, {"item_id": 2}])) + + batch_transform_with_filters(dbconn, filters=callable) + + +def test_batch_transform_with_filters_as_callable_list_of_dict(dbconn): + def callable(ds: DataStore, run_config: Optional[RunConfig]): + return [{"item_id": 0}, {"item_id": 1}, {"item_id": 2}] + + batch_transform_with_filters(dbconn, filters=callable) diff --git a/tests/test_table_store.py b/tests/test_table_store.py index c0a0d84c..2929b938 100644 --- a/tests/test_table_store.py +++ b/tests/test_table_store.py @@ -472,7 +472,7 @@ def test_read_rows_meta_pseudo_df_with_runconfig(store: TableStore, test_df: pd. assert_ts_contains(store, test_df) # TODO проверять, что runconfig реально влияет на результирующие данные - pseudo_df_iter = store.read_rows_meta_pseudo_df(run_config=RunConfig(filters={"a": 1})) + pseudo_df_iter = store.read_rows_meta_pseudo_df(run_config=RunConfig(filters=[{"a": 1}])) assert isinstance(pseudo_df_iter, Iterable) for pseudo_df in pseudo_df_iter: assert isinstance(pseudo_df, DataDF)