Skip to content

Commit

Permalink
Fixes specifying model_kwargs for embeddings (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
3coins authored Jul 3, 2024
1 parent 97ddea7 commit 4d5ac21
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 26 deletions.
30 changes: 10 additions & 20 deletions libs/aws/langchain_aws/embeddings/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,48 +112,40 @@ def validate_environment(cls, values: Dict) -> Dict:

return values

def _embedding_func(
self, text: str, dim: int = 1024, norm: bool = True
) -> List[float]:
def _embedding_func(self, text: str) -> List[float]:
"""Call out to Bedrock embedding endpoint."""
# replace newlines, which can negatively affect performance.
text = text.replace(os.linesep, " ")

# format input body for provider
provider = self.model_id.split(".")[0]
_model_kwargs = self.model_kwargs or {}
input_body = {**_model_kwargs}
input_body: Dict[str, Any] = {}
if provider == "cohere":
if "input_type" not in input_body.keys():
input_body["input_type"] = "search_document"
input_body["input_type"] = "search_document"
input_body["texts"] = [text]
else:
# includes common provider == "amazon"
input_body["inputText"] = text

# v2 and beyond titan embeddings with changing dimensions
if "v1" not in self.model_id:
input_body["dimensions"] = dim
input_body["normalize"] = norm
if self.model_kwargs:
input_body = {**input_body, **self.model_kwargs}

body = json.dumps(input_body)

try:
# invoke bedrock API
response = self.client.invoke_model(
body=body,
modelId=self.model_id,
accept="application/json",
contentType="application/json",
)

# format output based on provider
response_body = json.loads(response.get("body").read())
if provider == "cohere":
return response_body.get("embeddings")[0]
else:
# includes common provider == "amazon"
return response_body.get("embedding")

except Exception as e:
raise ValueError(f"Error raised by inference endpoint: {e}")

Expand All @@ -163,9 +155,7 @@ def _normalize_vector(self, embeddings: List[float]) -> List[float]:
norm_emb = emb / np.linalg.norm(emb)
return norm_emb.tolist()

def embed_documents(
self, texts: List[str], dim: int = 1024, norm: bool = True
) -> List[List[float]]:
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a Bedrock model.
Args:
Expand All @@ -176,7 +166,7 @@ def embed_documents(
"""
results = []
for text in texts:
response = self._embedding_func(text, dim, norm)
response = self._embedding_func(text)

if self.normalize:
response = self._normalize_vector(response)
Expand All @@ -185,7 +175,7 @@ def embed_documents(

return results

def embed_query(self, text: str, dim: int = 1024, norm: bool = True) -> List[float]:
def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a Bedrock model.
Args:
Expand All @@ -194,7 +184,7 @@ def embed_query(self, text: str, dim: int = 1024, norm: bool = True) -> List[flo
Returns:
Embeddings for the text.
"""
embedding = self._embedding_func(text, dim, norm)
embedding = self._embedding_func(text)

if self.normalize:
return self._normalize_vector(embedding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ def bedrock_embeddings() -> BedrockEmbeddings:
return BedrockEmbeddings(model_id="amazon.titan-embed-text-v1")


@pytest.fixture
def bedrock_embeddings_v2() -> BedrockEmbeddings:
return BedrockEmbeddings(
model_id="amazon.titan-embed-text-v2:0",
model_kwargs={"dimensions": 256, "normalize": True},
)


@pytest.mark.scheduled
def test_bedrock_embedding_documents(bedrock_embeddings) -> None:
documents = ["foo bar"]
Expand All @@ -18,6 +26,14 @@ def test_bedrock_embedding_documents(bedrock_embeddings) -> None:
assert len(output[0]) == 1536


@pytest.mark.scheduled
def test_bedrock_embedding_documents_with_v2(bedrock_embeddings_v2) -> None:
documents = ["foo bar"]
output = bedrock_embeddings_v2.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 256


@pytest.mark.scheduled
def test_bedrock_embedding_documents_multiple(bedrock_embeddings) -> None:
documents = ["foo bar", "bar foo", "foo"]
Expand Down Expand Up @@ -76,15 +92,12 @@ def test_embed_query_normalized(bedrock_embeddings) -> None:


@pytest.mark.scheduled
def test_embed_query_with_size(bedrock_embeddings) -> None:
def test_embed_query_with_size(bedrock_embeddings_v2) -> None:
prompt_data = """Priority should be funding retirement through ROTH/IRA/401K
over HSA extra. You need to fund your HSA for reasonable and expected medical
expenses.
"""
embed_size = 256
normalize = True
embed_model = BedrockEmbeddings(model_id="amazon.titan-embed-text-v2:0")
response = embed_model.embed_documents([prompt_data], embed_size, normalize)
output = embed_model.embed_query(prompt_data, embed_size, False)
response = bedrock_embeddings_v2.embed_documents([prompt_data])
output = bedrock_embeddings_v2.embed_query(prompt_data)
assert len(response[0]) == 256
assert len(output) == 256

0 comments on commit 4d5ac21

Please sign in to comment.