Skip to content

Commit

Permalink
error handling + documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jul 17, 2023
1 parent c9dc517 commit 7d803f7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 8 deletions.
77 changes: 70 additions & 7 deletions pegasus/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,99 @@
# from numba import njit
# from joblib import Memory
import numpy as np

from pegasus.embedding_functions import MultiModalEmbeddingFunction
import logging

# memory = Memory("PegasusStore", verbose=0)

# @memory.cache
# # @njit

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)


def optimized_embedding_function(modality, data):
return MultiModalEmbeddingFunction(modality)(data)
#creates and applies a MultiModalEmbeddingFunction for a specfic modality

"""
inputs:
modality: A string representing the modality in lower case 'text' 'vision' 'audio'
data: A numpy array representing the data'
Returns:
the embeddings generated by MultiModalEmbeddingFunction
"""
try:
return MultiModalEmbeddingFunction(modality)(data)
except Exception as e:
logger.error(f"Failed to generate embeddings: {str(e)}")
raise

class Pegasus:
"""
Pegasus is the main multi-modal embedding class
Inputs:
modality: A string representing the modality => "text' 'audio'
multi_process: A boolean indicating if multiprocessing will be enabled
n_processes: An integer indicating that the number of processes to use
"""
def __init__(self, modality, multi_process=False, n_processes=4):
if modality not in {"text", "audio", "vision", "sensor", "heatmap"}:
logger.error(f"Invalid modality: {modality}")
raise ValueError("Invalid modality")

self.modality = modality
self.multi_process = multi_process
self.n_processes = n_processes if multi_process else 1

def _embed_data(self, data):
"""
Embeds the data using MultiModalEmbeddingFunction
Args:
data: a numpy array representing the data
Returns:
The embeddings generated by MultiModalEmbeddingFunction
"""
if self.modality not in {"text", "audio", "vision", "sensor", "heatmap"}:
raise ValueError("Invalid modality")
return optimized_embedding_function(self.modality, data)

def embed_data(self, data):
"""
Embeds the data using MultiModalEmbeddingFunction
if multiprocessing is enabled, the data is split and processed in parallel
Inputs:
data: a numpy array or a list representing the data
Returns:
the embeddings generated by the MultiModalEmbeddingFunction
"""
if not isinstance(data, np.ndarray):
data = np.array(data)
try:
data = np.array(data)
except Exception as e:
logger.error(f"Failed to convert data to numpy array: {str(e)}")
raise

if not self.multi_process:
return self._embed_data(data)

with ProcessPoolExecutor(max_workers=self.n_processes) as executor:
future_to_data = {executor.submit(self._embed_data, d): d for d in data}
return {future_to_data[future]: future.result() for future in as_completed(future_to_data)}

try:
with ProcessPoolExecutor(max_workers=self.n_processes) as executor:
future_to_data = {executor.submit(self._embed_data, d): d for d in data}
return {future_to_data[future]: future.result() for future in as_completed(future_to_data)}
except Exception as e:
logger.error(f"Failed to embed data in parallel: {str(e)}")
raise


2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'pegasusX',
packages = find_packages(exclude=[]),
version = '0.3.7',
version = '0.3.8',
license='MIT',
description = 'pegasus - Pytorch',
author = 'Kye Gomez',
Expand Down

0 comments on commit 7d803f7

Please sign in to comment.