Skip to content

Commit

Permalink
add materialize stage
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed Jan 1, 2025
1 parent 1547714 commit 7c4cee5
Show file tree
Hide file tree
Showing 6 changed files with 651 additions and 6 deletions.
1 change: 1 addition & 0 deletions fiftyone/__public__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@
MatchFrames,
MatchLabels,
MatchTags,
Materialize,
Mongo,
Select,
SelectBy,
Expand Down
32 changes: 32 additions & 0 deletions fiftyone/core/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ def _is_clips(self):
"""Whether this collection contains clips."""
raise NotImplementedError("Subclass must implement _is_clips")

@property
def _is_materialized(self):
"""Whether this collection contains a materialized view."""
raise NotImplementedError("Subclass must implement _is_materialized")

@property
def _is_dynamic_groups(self):
"""Whether this collection contains dynamic groups."""
Expand Down Expand Up @@ -6193,6 +6198,33 @@ def match_tags(self, tags, bool=None, all=False):
"""
return self._add_view_stage(fos.MatchTags(tags, bool=bool, all=all))

@view_stage
def materialize(self):
"""Materializes the current view into a temporary database collection.
Apply this stage to an expensive view (eg an unindexed filtering
operation on a large dataset) if you plan to perform multiple
downstream operations on the view.
Examples::
import fiftyone as fo
import fiftyone.zoo as foz
from fiftyone import ViewField as F
dataset = foz.load_zoo_dataset("quickstart")
view = dataset.filter_labels("ground_truth", F("label") == "cat")
materialized_view = view.materialize()
print(view.count("ground_truth.detections"))
print(materialized_view.count("ground_truth.detections"))
Returns:
a :class:`fiftyone.core.view.DatasetView`
"""
return self._add_view_stage(fos.Materialize())

@view_stage
def mongo(self, pipeline, _needs_frames=None, _group_slices=None):
"""Adds a view stage defined by a raw MongoDB aggregation pipeline.
Expand Down
41 changes: 35 additions & 6 deletions fiftyone/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,12 @@ def _root_dataset(self):

@property
def _is_generated(self):
return self._is_patches or self._is_frames or self._is_clips
return (
self._is_patches
or self._is_frames
or self._is_clips
or self._is_materialized
)

@property
def _is_patches(self):
Expand All @@ -442,6 +447,10 @@ def _is_frames(self):
def _is_clips(self):
return self._sample_collection_name.startswith("clips.")

@property
def _is_materialized(self):
return self._sample_collection_name.startswith("materialized.")

@property
def _is_dynamic_groups(self):
return False
Expand Down Expand Up @@ -4877,7 +4886,13 @@ def clone(self, name=None, persistent=False):
"""
return self._clone(name=name, persistent=persistent)

def _clone(self, name=None, persistent=False, view=None):
def _clone(
self,
name=None,
persistent=False,
view=None,
materialized=False,
):
if name is None:
name = get_default_dataset_name()

Expand All @@ -4886,7 +4901,12 @@ def _clone(self, name=None, persistent=False, view=None):
else:
sample_collection = self

return _clone_collection(sample_collection, name, persistent)
return _clone_collection(
sample_collection,
name,
persistent=persistent,
materialized=materialized,
)

def clear(self):
"""Removes all samples from the dataset.
Expand Down Expand Up @@ -8271,7 +8291,7 @@ def _clone_collection_indexes(


def _make_sample_collection_name(
dataset_id, patches=False, frames=False, clips=False
dataset_id, patches=False, frames=False, clips=False, materialized=False
):
if patches and frames:
prefix = "patches.frames"
Expand All @@ -8281,6 +8301,8 @@ def _make_sample_collection_name(
prefix = "frames"
elif clips:
prefix = "clips"
elif materialized:
prefix = "materialized"
else:
prefix = "samples"

Expand Down Expand Up @@ -8463,7 +8485,12 @@ def _delete_dataset_doc(dataset_doc):
dataset_doc.delete()


def _clone_collection(sample_collection, name, persistent):
def _clone_collection(
sample_collection,
name,
persistent=False,
materialized=False,
):
slug = _validate_dataset_name(name)

contains_videos = sample_collection._contains_videos(any_slice=True)
Expand All @@ -8490,7 +8517,9 @@ def _clone_collection(sample_collection, name, persistent):
_id = dataset_doc.id
now = datetime.utcnow()

sample_collection_name = _make_sample_collection_name(_id)
sample_collection_name = _make_sample_collection_name(
_id, materialized=materialized
)

if contains_videos:
frame_collection_name = _make_frame_collection_name(
Expand Down
Loading

0 comments on commit 7c4cee5

Please sign in to comment.