Skip to content

Commit

Permalink
fix: better gard nan value from numpy for issue langgenius#11827 (lan…
Browse files Browse the repository at this point in the history
…ggenius#11864)

Signed-off-by: yihong0618 <[email protected]>
  • Loading branch information
yihong0618 authored and 刘江波 committed Dec 20, 2024
1 parent bac03bf commit 869e20f
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def _invoke(
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding

# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def _invoke(
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding

# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def _invoke(
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding

# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,10 @@ def _invoke(
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding

usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

Expand Down
2 changes: 2 additions & 0 deletions api/core/rag/embedding/cached_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def embed_query(self, text: str) -> list[float]:

embedding_results = embedding_result.embeddings[0]
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
Expand Down

0 comments on commit 869e20f

Please sign in to comment.