Skip to content

Commit 78ab66a

Browse files
authored
feat: ability to filter dataset when downloading (#762)
* draft * add integ test * use backward compatible type * tidy up docs * set numpy upper bound in semantic segmentation example due to issue with opencv
1 parent d2bc1c5 commit 78ab66a

File tree

6 files changed

+118
-2
lines changed

6 files changed

+118
-2
lines changed

docs/reference/dataset/index.md

+4
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@
1111
options:
1212
members: ["upload_dataset_embeddings"]
1313
show_root_heading: false
14+
::: kolena._api.v2.dataset
15+
options:
16+
members: ["Filters", "GeneralFieldFilter"]
17+
show_root_heading: false

examples/dataset/semantic_segmentation/pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"boto3>=1.25,<2",
1717
"scikit-learn>=1.1.2,<2",
1818
"scikit-image>=0.19.3,<1",
19+
"numpy<2",
1920
]
2021

2122
[tool.uv]

kolena/_api/v2/dataset.py

+35
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from dataclasses import field
1415
from enum import Enum
1516
from typing import Dict
1617
from typing import List
1718
from typing import Optional
19+
from typing import Union
20+
21+
from typing_extensions import Literal
1822

1923
from kolena._api.v1.batched_load import BatchedLoad
2024
from kolena._utils.pydantic_v1 import conint
25+
from kolena._utils.pydantic_v1 import StrictBool
26+
from kolena._utils.pydantic_v1 import StrictStr
2127
from kolena._utils.pydantic_v1.dataclasses import dataclass
2228

2329

@@ -41,11 +47,40 @@ class RegisterRequest:
4147
description: Optional[str] = None
4248

4349

50+
@dataclass(frozen=True)
51+
class GeneralFieldFilter:
52+
"""
53+
Generic representation of a filter on Kolena.
54+
"""
55+
56+
value_in: Optional[List[Union[StrictStr, StrictBool]]] = None
57+
"""A list of desired categorical values."""
58+
null_value: Optional[Literal[True]] = None
59+
"""Whether to filter for cases where the field has null value or the field name does not exist."""
60+
61+
62+
@dataclass(frozen=True)
63+
class Filters:
64+
"""
65+
Filters to be applied on the dataset during the operation. Currently only used as an optional argument
66+
in [`download_dataset`][kolena.dataset.download_dataset].
67+
"""
68+
69+
datapoint: Dict[str, GeneralFieldFilter] = field(default_factory=dict)
70+
"""
71+
Dictionary of a field name of the datapoint to the [`GeneralFieldFilter`][kolena.dataset.GeneralFieldFilter] to be
72+
applied on the field. In case of nested objects, use `.` as the delimiter to separate the keys. For example, if you
73+
have a `ground_truth` column of [`Label`][kolena.annotation.Label] type, you can use `ground_truth.label` as the key
74+
to query for the class label.
75+
"""
76+
77+
4478
@dataclass(frozen=True)
4579
class LoadDatapointsRequest(BatchedLoad.BaseInitDownloadRequest):
4680
name: str
4781
commit: Optional[str] = None
4882
include_extracted_properties: bool = False
83+
filters: Optional[Filters] = None
4984

5085

5186
@dataclass(frozen=True)

kolena/dataset/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,13 @@
2121
from kolena.dataset.evaluation import ModelEntity
2222
from kolena.dataset.evaluation import get_models
2323
from kolena.dataset.embeddings import upload_dataset_embeddings
24+
from kolena._api.v2.dataset import Filters
25+
from kolena._api.v2.dataset import GeneralFieldFilter
2426

2527
__all__ = [
2628
"upload_dataset",
29+
"Filters",
30+
"GeneralFieldFilter",
2731
"download_dataset",
2832
"upload_results",
2933
"download_results",

kolena/dataset/dataset.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from kolena._api.v1.event import EventAPI
3232
from kolena._api.v2.dataset import CommitData
3333
from kolena._api.v2.dataset import EntityData
34+
from kolena._api.v2.dataset import Filters
3435
from kolena._api.v2.dataset import ListCommitHistoryRequest
3536
from kolena._api.v2.dataset import ListCommitHistoryResponse
3637
from kolena._api.v2.dataset import ListDatasetsResponse
@@ -368,13 +369,15 @@ def _iter_dataset_raw(
368369
commit: Optional[str] = None,
369370
batch_size: int = BatchSize.LOAD_SAMPLES.value,
370371
include_extracted_properties: bool = False,
372+
filters: Optional[Filters] = None,
371373
) -> Iterator[pd.DataFrame]:
372374
validate_batch_size(batch_size)
373375
init_request = LoadDatapointsRequest(
374376
name=name,
375377
commit=commit,
376378
batch_size=batch_size,
377379
include_extracted_properties=include_extracted_properties,
380+
filters=filters,
378381
)
379382
yield from _BatchedLoader.iter_data(
380383
init_request=init_request,
@@ -389,11 +392,12 @@ def _iter_dataset(
389392
commit: Optional[str] = None,
390393
batch_size: int = BatchSize.LOAD_SAMPLES.value,
391394
include_extracted_properties: bool = False,
395+
filters: Optional[Filters] = None,
392396
) -> Iterator[pd.DataFrame]:
393397
"""
394398
Get an iterator over datapoints in the dataset.
395399
"""
396-
for df_batch in _iter_dataset_raw(name, commit, batch_size, include_extracted_properties):
400+
for df_batch in _iter_dataset_raw(name, commit, batch_size, include_extracted_properties, filters):
397401
yield _to_deserialized_dataframe(df_batch, column=COL_DATAPOINT)
398402

399403

@@ -403,6 +407,7 @@ def download_dataset(
403407
*,
404408
commit: Optional[str] = None,
405409
include_extracted_properties: bool = False,
410+
filters: Optional[Filters] = None,
406411
) -> pd.DataFrame:
407412
"""
408413
Download an entire dataset given its name.
@@ -411,9 +416,10 @@ def download_dataset(
411416
:param commit: The commit hash for version control. Get the latest commit when this value is `None`.
412417
:param include_extracted_properties: If True, include kolena extracted properties from automated extractions
413418
in the dataset as separate columns
419+
:param filters: [Experimental] Optional filter to specify which datapoints should be downloaded.
414420
:return: A DataFrame containing the specified dataset.
415421
"""
416-
df_batches = list(_iter_dataset(name, commit, BatchSize.LOAD_SAMPLES.value, include_extracted_properties))
422+
df_batches = list(_iter_dataset(name, commit, BatchSize.LOAD_SAMPLES.value, include_extracted_properties, filters))
417423
log.info(f"downloaded dataset '{name}'")
418424
df_dataset = pd.concat(df_batches, ignore_index=True) if df_batches else pd.DataFrame()
419425
return df_dataset

tests/integration/dataset/test_dataset.py

+66
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import random
15+
from typing import Dict
1516
from typing import Iterator
1617
from typing import List
18+
from typing import Optional
1719
from typing import Tuple
1820

1921
import numpy as np
@@ -25,6 +27,8 @@
2527
from kolena.annotation import BoundingBox
2628
from kolena.annotation import LabeledBoundingBox
2729
from kolena.dataset import download_dataset
30+
from kolena.dataset import Filters
31+
from kolena.dataset import GeneralFieldFilter
2832
from kolena.dataset import list_datasets
2933
from kolena.dataset import upload_dataset
3034
from kolena.dataset.dataset import _fetch_dataset_history
@@ -422,3 +426,65 @@ def test__upload_dataset__with_description() -> None:
422426
)
423427
dataset = _load_dataset_metadata(name)
424428
assert dataset.description == description_v2
429+
430+
431+
@pytest.fixture(scope="module")
432+
def download_datapoints_with_filters_data() -> Tuple[str, List[str], List[Dict]]:
433+
name = with_test_prefix(f"{__file__}::test__download_dataset__with_filters")
434+
id_fields = ["value"]
435+
n_datapoints = 10
436+
columns = ["value", "str", "nested"]
437+
datapoints = [
438+
dict(
439+
value=i,
440+
str=f"str-{i}",
441+
nested={
442+
"bool ean": i % 2 == 0,
443+
"optional_col": str(i) if i % 5 > 0 else None,
444+
},
445+
)
446+
for i in range(n_datapoints)
447+
]
448+
df_datapoints = pd.DataFrame(datapoints, columns=["value", "str", "nested"])
449+
450+
upload_dataset(
451+
name,
452+
df_datapoints,
453+
id_fields=id_fields,
454+
)
455+
return name, columns, datapoints
456+
457+
458+
@pytest.mark.parametrize(
459+
"filters, expected_datapoint_inds",
460+
[
461+
(None, list(range(10))),
462+
(Filters(datapoint={"str": GeneralFieldFilter(value_in=["str-0", "str-1"])}), [0, 1]),
463+
(Filters(datapoint={"value": GeneralFieldFilter(value_in=["2", "3"])}), [2, 3]),
464+
(Filters(datapoint={'nested."bool ean"': GeneralFieldFilter(value_in=[True])}), [0, 2, 4, 6, 8]),
465+
(Filters(datapoint={"nested.optional_col": GeneralFieldFilter(value_in=["7"], null_value=True)}), [0, 5, 7]),
466+
(
467+
Filters(
468+
datapoint={
469+
"str": GeneralFieldFilter(value_in=["str-0", "str-1", "str-2", "str-5"]),
470+
"nested.optional_col": GeneralFieldFilter(value_in=["0", "3"], null_value=True),
471+
},
472+
),
473+
[0, 5],
474+
),
475+
],
476+
)
477+
def test__download_dataset__with_filters(
478+
download_datapoints_with_filters_data: Tuple[str, List[str], List[Dict]],
479+
filters: Optional[Filters],
480+
expected_datapoint_inds: List[int],
481+
) -> None:
482+
name, columns, datapoints = download_datapoints_with_filters_data
483+
expected_datapoints = (
484+
pd.DataFrame([datapoints[ind] for ind in expected_datapoint_inds], columns=columns)
485+
.sort_values(by="value")
486+
.reset_index(drop=True)
487+
)
488+
loaded_datapoints = download_dataset(name, filters=filters)
489+
loaded_datapoints = loaded_datapoints.sort_values(by="value").reset_index(drop=True)
490+
assert_frame_equal(loaded_datapoints, expected_datapoints, columns)

0 commit comments

Comments
 (0)