Skip to content

Commit

Permalink
Merge pull request #39 from bigbernnn/main
Browse files Browse the repository at this point in the history
Determine embedding size with Titan Embedding v2 model
  • Loading branch information
3coins authored May 17, 2024
2 parents 3dfbad4 + b867d08 commit 622756c
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
20 changes: 15 additions & 5 deletions libs/aws/langchain_aws/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def validate_environment(cls, values: Dict) -> Dict:

return values

def _embedding_func(self, text: str) -> List[float]:
def _embedding_func(
self, text: str, dim: int = 1024, norm: bool = True
) -> List[float]:
"""Call out to Bedrock embedding endpoint."""
# replace newlines, which can negatively affect performance.
text = text.replace(os.linesep, " ")
Expand All @@ -128,6 +130,12 @@ def _embedding_func(self, text: str) -> List[float]:
else:
# includes common provider == "amazon"
input_body["inputText"] = text

# v2 and beyond titan embeddings with changing dimensions
if "v1" not in self.model_id:
input_body["dimensions"] = dim
input_body["normalize"] = norm

body = json.dumps(input_body)

try:
Expand Down Expand Up @@ -155,7 +163,9 @@ def _normalize_vector(self, embeddings: List[float]) -> List[float]:
norm_emb = emb / np.linalg.norm(emb)
return norm_emb.tolist()

def embed_documents(self, texts: List[str]) -> List[List[float]]:
def embed_documents(
self, texts: List[str], dim: int = 1024, norm: bool = True
) -> List[List[float]]:
"""Compute doc embeddings using a Bedrock model.
Args:
Expand All @@ -166,7 +176,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
results = []
for text in texts:
response = self._embedding_func(text)
response = self._embedding_func(text, dim, norm)

if self.normalize:
response = self._normalize_vector(response)
Expand All @@ -175,7 +185,7 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:

return results

def embed_query(self, text: str) -> List[float]:
def embed_query(self, text: str, dim: int = 1024, norm: bool = True) -> List[float]:
"""Compute query embeddings using a Bedrock model.
Args:
Expand All @@ -184,7 +194,7 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embeddings for the text.
"""
embedding = self._embedding_func(text)
embedding = self._embedding_func(text, dim, norm)

if self.normalize:
return self._normalize_vector(embedding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,18 @@ def test_embed_query_normalized(bedrock_embeddings) -> None:
bedrock_embeddings.normalize = True
output = bedrock_embeddings.embed_query("foo walked to the market")
assert np.isclose(np.linalg.norm(output), 1.0)


@pytest.mark.scheduled
def test_embed_query_with_size(bedrock_embeddings) -> None:
prompt_data = """Priority should be funding retirement through ROTH/IRA/401K
over HSA extra. You need to fund your HSA for reasonable and expected medical
expenses.
"""
embed_size = 256
normalize = True
embed_model = BedrockEmbeddings(model_id="amazon.titan-embed-text-v2:0")
response = embed_model.embed_documents([prompt_data], embed_size, normalize)
output = embed_model.embed_query(prompt_data, embed_size, False)
assert len(response[0]) == 256
assert len(output) == 256

0 comments on commit 622756c

Please sign in to comment.