diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 6d5bd97544..6ef7411332 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -27,8 +27,7 @@ import numpy as np from api.utils.file_utils import get_project_base_directory, get_home_cache_dir -from rag.utils import num_tokens_from_string - +from rag.utils import num_tokens_from_string, truncate try: flag_model = FlagModel(os.path.join(get_home_cache_dir(), "bge-large-zh-v1.5"), @@ -70,7 +69,7 @@ def __init__(self, *args, **kwargs): self.model = flag_model def encode(self, texts: list, batch_size=32): - texts = [t[:2000] for t in texts] + texts = [truncate(t, 2048) for t in texts] token_count = 0 for t in texts: token_count += num_tokens_from_string(t) @@ -93,12 +92,14 @@ def __init__(self, key, model_name="text-embedding-ada-002", self.model_name = model_name def encode(self, texts: list, batch_size=32): + texts = [truncate(t, 8196) for t in texts] res = self.client.embeddings.create(input=texts, model=self.model_name) - return np.array([d.embedding for d in res.data]), res.usage.total_tokens + return np.array([d.embedding for d in res.data] + ), res.usage.total_tokens def encode_queries(self, text): - res = self.client.embeddings.create(input=[text], + res = self.client.embeddings.create(input=[truncate(text, 8196)], model=self.model_name) return np.array(res.data[0].embedding), res.usage.total_tokens @@ -112,7 +113,7 @@ def encode(self, texts: list, batch_size=10): import dashscope res = [] token_count = 0 - texts = [txt[:2048] for txt in texts] + texts = [truncate(t, 2048) for t in texts] for i in range(0, len(texts), batch_size): resp = dashscope.TextEmbedding.call( model=self.model_name, diff --git a/rag/utils/__init__.py b/rag/utils/__init__.py index 8536111000..f1db2e409d 100644 --- a/rag/utils/__init__.py +++ b/rag/utils/__init__.py @@ -63,3 +63,7 @@ def num_tokens_from_string(string: str) -> int: num_tokens = len(encoder.encode(string)) return num_tokens + +def truncate(string: str, max_len: int) -> int: + """Returns truncated text if the length of text exceed max_len.""" + return encoder.decode(encoder.encode(string)[:max_len])