Skip to content

Commit

Permalink
Merge pull request #41 from pbeasly/dev/prepopulated_descr_generator
Browse files Browse the repository at this point in the history
Add prepopulated descriptor generator
  • Loading branch information
Purg authored Jun 10, 2024
2 parents c52b243 + 32abcbd commit 0130332
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/release_notes/pending_release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Pending Release Notes
Updates / New Features
----------------------

* Add prepopulated descriptor generator

Fixes
-----

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ ipython = ">=7.16.3"
# DescriptorGenerator
"smqtk_descriptors.impls.descriptor_generator.caffe1" = "smqtk_descriptors.impls.descriptor_generator.caffe1"
"smqtk_descriptors.impls.descriptor_generator.pytorch" = "smqtk_descriptors.impls.descriptor_generator.pytorch"
"smqtk_descriptors.impls.descriptor_generator.prepopulated" = "smqtk_descriptors.impls.descriptor_generator.prepopulated"
# DescriptorSet
"smqtk_descriptors.impls.descriptor_set.memory" = "smqtk_descriptors.impls.descriptor_set.memory"
"smqtk_descriptors.impls.descriptor_set.postgres" = "smqtk_descriptors.impls.descriptor_set.postgres"
Expand Down
34 changes: 34 additions & 0 deletions smqtk_descriptors/impls/descriptor_generator/prepopulated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
import numpy as np
from typing import Any, Dict, Iterable, Set, TypeVar

from smqtk_dataprovider import DataElement
from smqtk_descriptors import DescriptorGenerator

LOG = logging.getLogger(__name__)

__all__ = ["PrePopulatedDescriptorGenerator"]
T = TypeVar("T", bound="PrePopulatedDescriptorGenerator")


class PrePopulatedDescriptorGenerator(DescriptorGenerator):
"""
This class is to be used in the config when the descriptor set is already
prepopulated. This allows, for example, an IQR process where the
descriptors are already known or have been previously generated using some
external process. Calling the _generate_arrays() method will not work and
will raise an AssertionError.
"""

def valid_content_types(self) -> Set:
return set()

def _generate_arrays(
self, data_iter: Iterable[DataElement]
) -> Iterable[np.ndarray]:
raise AssertionError(
"Method should not be called since descriptors are prepopulated."
)

def get_config(self) -> Dict[str, Any]:
return {}
30 changes: 30 additions & 0 deletions tests/impls/descriptor_generator/test_prepopulated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest
from smqtk_descriptors.impls.descriptor_generator.prepopulated import (
PrePopulatedDescriptorGenerator,
)


def test_valid_content_types() -> None:
"""
Tests that valid_content_types() returns an empty set.
"""
generator = PrePopulatedDescriptorGenerator()
assert generator.valid_content_types() == set()


def test_generate_arrays() -> None:
"""
Tests that _generate_arrays() method raises AssertionError.
"""
generator = PrePopulatedDescriptorGenerator()

with pytest.raises(AssertionError):
generator._generate_arrays([])


def test_get_config() -> None:
"""
Tests that get_config() returns an empty dictionary.
"""
generator = PrePopulatedDescriptorGenerator()
assert generator.get_config() == {}

0 comments on commit 0130332

Please sign in to comment.