-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create prepopulated generator module and class
Update prepopulated module to be an implementation Space format Cleanup
- Loading branch information
Showing
4 changed files
with
143 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
103 changes: 103 additions & 0 deletions
103
smqtk_descriptors/impls/descriptor_generator/prepopulated.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import copy | ||
import logging | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Dict, | ||
Iterable, | ||
Mapping, | ||
Set, | ||
Type, | ||
TypeVar, | ||
Union, | ||
) | ||
|
||
import numpy as np | ||
|
||
from smqtk_dataprovider import DataElement | ||
from smqtk_descriptors import DescriptorGenerator | ||
from smqtk_image_io import ImageReader | ||
from smqtk_core.configuration import ( | ||
from_config_dict, | ||
make_default_config, | ||
to_config_dict, | ||
) | ||
from smqtk_descriptors.utils.parallel import parallel_map | ||
|
||
LOG = logging.getLogger(__name__) | ||
|
||
__all__ = ["PrePopulatedDescriptorGenerator"] | ||
T = TypeVar("T", bound="PrePopulatedDescriptorGenerator") | ||
|
||
|
||
class PrePopulatedDescriptorGenerator(DescriptorGenerator): | ||
|
||
@classmethod | ||
def is_usable(cls) -> bool: | ||
return True | ||
|
||
@classmethod | ||
def get_default_config(cls) -> Dict[str, Any]: | ||
c = super().get_default_config() | ||
return c | ||
|
||
@classmethod | ||
def from_config(cls: Type[T], config_dict: Dict, merge_default: bool = True) -> T: | ||
# Copy config to prevent input modification | ||
config_dict = copy.deepcopy(config_dict) | ||
return super().from_config(config_dict, merge_default) | ||
|
||
def __init__( | ||
self, | ||
): | ||
super().__init__() | ||
|
||
def __getstate__(self) -> Dict[str, Any]: | ||
return self.get_config() | ||
|
||
def __setstate__(self, state: Mapping[str, Any]) -> None: | ||
# This ``__dict__.update`` works because configuration parameters | ||
# exactly match up with instance attributes currently. | ||
self.__dict__.update(state) | ||
|
||
def valid_content_types(self) -> Set: | ||
return set() | ||
|
||
def is_valid_element(self, data_element: DataElement) -> bool: | ||
# Check element validity though the ImageReader algorithm instance | ||
return False | ||
|
||
def _generate_arrays( | ||
self, data_iter: Iterable[DataElement] | ||
) -> Iterable[np.ndarray]: | ||
# Generically load image data [in parallel], iterating results into | ||
# template method. | ||
ir_load: Callable[..., Any] = self.image_reader.load_as_matrix | ||
i_load_threads = self.image_load_threads | ||
|
||
gen_fn = ( | ||
self.generate_arrays_from_images_naive, # False | ||
self.generate_arrays_from_images_iter, # True | ||
)[self.iter_runtime] | ||
|
||
if i_load_threads is None or i_load_threads > 1: | ||
return gen_fn(parallel_map(ir_load, data_iter, cores=i_load_threads)) | ||
else: | ||
return gen_fn(ir_load(d) for d in data_iter) | ||
|
||
# NOTE: may need to create wrapper function around _make_transform | ||
# that adds a PIL.Image.from_array and convert transformation to | ||
# ensure being in the expected image format for the network. | ||
|
||
def generate_arrays_from_images_naive( | ||
self, img_mat_iter: Iterable[np.ndarray] | ||
) -> Iterable[np.ndarray]: | ||
raise NotImplementedError("This method is purposefully not implemented.") | ||
|
||
def generate_arrays_from_images_iter( | ||
self, img_mat_iter: Iterable[DataElement] | ||
) -> Iterable[np.ndarray]: | ||
raise NotImplementedError("This method is purposefully not implemented.") | ||
|
||
def get_config(self) -> Dict[str, Any]: | ||
return {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
def test_prepopulated_descriptor_generator_plugin(): | ||
""" | ||
test that plugin and implementation is discoverable | ||
""" | ||
|
||
try: | ||
from smqtk_descriptors.impls.descriptor_generator.prepopulated import ( | ||
PrePopulatedDescriptorGenerator, | ||
) | ||
|
||
print("Module imported successfully") | ||
except ImportError as e: | ||
print(f"ImportError: {e}") | ||
|
||
# print("Class info: ", PrePopulatedDescriptorGenerator) | ||
|
||
import pkg_resources | ||
|
||
entry_points = pkg_resources.iter_entry_points("smqtk_plugins") | ||
|
||
for entry_point in entry_points: | ||
if entry_point.name.startswith("smqtk_descriptors.impls.descriptor_generator."): | ||
print(f"Entry point: {entry_point.name}") | ||
plugin = entry_point.load() | ||
print(f"Loaded plugin: {plugin}") | ||
|
||
from smqtk_descriptors import DescriptorGenerator | ||
|
||
# List all available implementations | ||
print("\n Begin check for available implementations") | ||
|
||
available_impls = DescriptorGenerator.get_impls() | ||
print("Available DescriptorGenerator implementations:") | ||
for impl in available_impls: | ||
print(impl) | ||
|
||
# TODO: add assertion statements |