Skip to content

Commit

Permalink
Create prepopulated generator module and class
Browse files Browse the repository at this point in the history
Update prepopulated module to be an implementation

Space format

Cleanup
  • Loading branch information
pbeasly committed May 21, 2024
1 parent c52b243 commit d4701df
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/release_notes/pending_release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Pending Release Notes
Updates / New Features
----------------------

* Add prepopulated descriptor generator

Fixes
-----

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ ipython = ">=7.16.3"
# DescriptorGenerator
"smqtk_descriptors.impls.descriptor_generator.caffe1" = "smqtk_descriptors.impls.descriptor_generator.caffe1"
"smqtk_descriptors.impls.descriptor_generator.pytorch" = "smqtk_descriptors.impls.descriptor_generator.pytorch"
"smqtk_descriptors.impls.descriptor_generator.prepopulated" = "smqtk_descriptors.impls.descriptor_generator.prepopulated"
# DescriptorSet
"smqtk_descriptors.impls.descriptor_set.memory" = "smqtk_descriptors.impls.descriptor_set.memory"
"smqtk_descriptors.impls.descriptor_set.postgres" = "smqtk_descriptors.impls.descriptor_set.postgres"
Expand Down
103 changes: 103 additions & 0 deletions smqtk_descriptors/impls/descriptor_generator/prepopulated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import copy
import logging
from typing import (
Any,
Callable,
Dict,
Iterable,
Mapping,
Set,
Type,
TypeVar,
Union,
)

import numpy as np

from smqtk_dataprovider import DataElement
from smqtk_descriptors import DescriptorGenerator
from smqtk_image_io import ImageReader
from smqtk_core.configuration import (
from_config_dict,
make_default_config,
to_config_dict,
)
from smqtk_descriptors.utils.parallel import parallel_map

LOG = logging.getLogger(__name__)

__all__ = ["PrePopulatedDescriptorGenerator"]
T = TypeVar("T", bound="PrePopulatedDescriptorGenerator")


class PrePopulatedDescriptorGenerator(DescriptorGenerator):

@classmethod
def is_usable(cls) -> bool:
return True

@classmethod
def get_default_config(cls) -> Dict[str, Any]:
c = super().get_default_config()
return c

@classmethod
def from_config(cls: Type[T], config_dict: Dict, merge_default: bool = True) -> T:
# Copy config to prevent input modification
config_dict = copy.deepcopy(config_dict)
return super().from_config(config_dict, merge_default)

def __init__(
self,
):
super().__init__()

def __getstate__(self) -> Dict[str, Any]:
return self.get_config()

def __setstate__(self, state: Mapping[str, Any]) -> None:
# This ``__dict__.update`` works because configuration parameters
# exactly match up with instance attributes currently.
self.__dict__.update(state)

def valid_content_types(self) -> Set:
return set()

def is_valid_element(self, data_element: DataElement) -> bool:
# Check element validity though the ImageReader algorithm instance
return False

def _generate_arrays(
self, data_iter: Iterable[DataElement]
) -> Iterable[np.ndarray]:
# Generically load image data [in parallel], iterating results into
# template method.
ir_load: Callable[..., Any] = self.image_reader.load_as_matrix
i_load_threads = self.image_load_threads

gen_fn = (
self.generate_arrays_from_images_naive, # False
self.generate_arrays_from_images_iter, # True
)[self.iter_runtime]

if i_load_threads is None or i_load_threads > 1:
return gen_fn(parallel_map(ir_load, data_iter, cores=i_load_threads))
else:
return gen_fn(ir_load(d) for d in data_iter)

# NOTE: may need to create wrapper function around _make_transform
# that adds a PIL.Image.from_array and convert transformation to
# ensure being in the expected image format for the network.

def generate_arrays_from_images_naive(
self, img_mat_iter: Iterable[np.ndarray]
) -> Iterable[np.ndarray]:
raise NotImplementedError("This method is purposefully not implemented.")

def generate_arrays_from_images_iter(
self, img_mat_iter: Iterable[DataElement]
) -> Iterable[np.ndarray]:
raise NotImplementedError("This method is purposefully not implemented.")

def get_config(self) -> Dict[str, Any]:
return {}
37 changes: 37 additions & 0 deletions tests/impls/descriptor_generator/test_prepopulated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
def test_prepopulated_descriptor_generator_plugin():
"""
test that plugin and implementation is discoverable
"""

try:
from smqtk_descriptors.impls.descriptor_generator.prepopulated import (
PrePopulatedDescriptorGenerator,
)

print("Module imported successfully")
except ImportError as e:
print(f"ImportError: {e}")

# print("Class info: ", PrePopulatedDescriptorGenerator)

import pkg_resources

entry_points = pkg_resources.iter_entry_points("smqtk_plugins")

for entry_point in entry_points:
if entry_point.name.startswith("smqtk_descriptors.impls.descriptor_generator."):
print(f"Entry point: {entry_point.name}")
plugin = entry_point.load()
print(f"Loaded plugin: {plugin}")

from smqtk_descriptors import DescriptorGenerator

# List all available implementations
print("\n Begin check for available implementations")

available_impls = DescriptorGenerator.get_impls()
print("Available DescriptorGenerator implementations:")
for impl in available_impls:
print(impl)

# TODO: add assertion statements

0 comments on commit d4701df

Please sign in to comment.