Skip to content

Commit 8290134

Browse files
authored
Merge branch 'main' into improve_update_context
2 parents 8e4021a + bf7e4d6 commit 8290134

File tree

113 files changed

+4889
-615
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

113 files changed

+4889
-615
lines changed

OAI_CONFIG_LIST_sample

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
"api_key": "<your Azure OpenAI API key here>",
1414
"base_url": "<your Azure OpenAI API base here>",
1515
"api_type": "azure",
16-
"api_version": "2024-02-15-preview"
16+
"api_version": "2024-02-01"
1717
},
1818
{
1919
"model": "<your Azure OpenAI deployment name>",
2020
"api_key": "<your Azure OpenAI API key here>",
2121
"base_url": "<your Azure OpenAI API base here>",
2222
"api_type": "azure",
23-
"api_version": "2024-02-15-preview"
23+
"api_version": "2024-02-01"
2424
}
2525
]

autogen/agentchat/contrib/vectordb/pgvectordb.py

+38-46
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ class Collection:
3232
client: The PGVector client.
3333
collection_name (str): The name of the collection. Default is "documents".
3434
embedding_function (Callable): The embedding function used to generate the vector representation.
35+
Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
36+
Models can be chosen from:
37+
https://huggingface.co/models?library=sentence-transformers
3538
metadata (Optional[dict]): The metadata of the collection.
3639
get_or_create (Optional): The flag indicating whether to get or create the collection.
37-
model_name: (Optional str) | Sentence embedding model to use. Models can be chosen from:
38-
https://huggingface.co/models?library=sentence-transformers
3940
"""
4041

4142
def __init__(
@@ -45,7 +46,6 @@ def __init__(
4546
embedding_function: Callable = None,
4647
metadata=None,
4748
get_or_create=None,
48-
model_name="all-MiniLM-L6-v2",
4949
):
5050
"""
5151
Initialize the Collection object.
@@ -56,30 +56,26 @@ def __init__(
5656
embedding_function: The embedding function used to generate the vector representation.
5757
metadata: The metadata of the collection.
5858
get_or_create: The flag indicating whether to get or create the collection.
59-
model_name: | Sentence embedding model to use. Models can be chosen from:
60-
https://huggingface.co/models?library=sentence-transformers
6159
Returns:
6260
None
6361
"""
6462
self.client = client
65-
self.embedding_function = embedding_function
66-
self.model_name = model_name
6763
self.name = self.set_collection_name(collection_name)
6864
self.require_embeddings_or_documents = False
6965
self.ids = []
70-
try:
71-
self.embedding_function = (
72-
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
73-
)
74-
except Exception as e:
75-
logger.error(
76-
f"Validate the model name entered: {self.model_name} "
77-
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
78-
)
79-
raise e
66+
if embedding_function:
67+
self.embedding_function = embedding_function
68+
else:
69+
self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
8070
self.metadata = metadata if metadata else {"hnsw:space": "ip", "hnsw:construction_ef": 32, "hnsw:M": 16}
8171
self.documents = ""
8272
self.get_or_create = get_or_create
73+
# This will get the model dimension size by computing the embeddings dimensions
74+
sentences = [
75+
"The weather is lovely today in paradise.",
76+
]
77+
embeddings = self.embedding_function(sentences)
78+
self.dimension = len(embeddings[0])
8379

8480
def set_collection_name(self, collection_name) -> str:
8581
name = re.sub("-", "_", collection_name)
@@ -115,14 +111,14 @@ def add(self, ids: List[ItemID], documents: List, embeddings: List = None, metad
115111
elif metadatas is not None:
116112
for doc_id, metadata, document in zip(ids, metadatas, documents):
117113
metadata = re.sub("'", '"', str(metadata))
118-
embedding = self.embedding_function.encode(document)
114+
embedding = self.embedding_function(document)
119115
sql_values.append((doc_id, metadata, embedding, document))
120116
sql_string = (
121117
f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n" f"VALUES (%s, %s, %s, %s);\n"
122118
)
123119
else:
124120
for doc_id, document in zip(ids, documents):
125-
embedding = self.embedding_function.encode(document)
121+
embedding = self.embedding_function(document)
126122
sql_values.append((doc_id, document, embedding))
127123
sql_string = f"INSERT INTO {self.name} (id, documents, embedding)\n" f"VALUES (%s, %s, %s);\n"
128124
logger.debug(f"Add SQL String:\n{sql_string}\n{sql_values}")
@@ -166,7 +162,7 @@ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, me
166162
elif metadatas is not None:
167163
for doc_id, metadata, document in zip(ids, metadatas, documents):
168164
metadata = re.sub("'", '"', str(metadata))
169-
embedding = self.embedding_function.encode(document)
165+
embedding = self.embedding_function(document)
170166
sql_values.append((doc_id, metadata, embedding, document, metadata, document, embedding))
171167
sql_string = (
172168
f"INSERT INTO {self.name} (id, metadatas, embedding, documents)\n"
@@ -176,7 +172,7 @@ def upsert(self, ids: List[ItemID], documents: List, embeddings: List = None, me
176172
)
177173
else:
178174
for doc_id, document in zip(ids, documents):
179-
embedding = self.embedding_function.encode(document)
175+
embedding = self.embedding_function(document)
180176
sql_values.append((doc_id, document, embedding, document))
181177
sql_string = (
182178
f"INSERT INTO {self.name} (id, documents, embedding)\n"
@@ -304,7 +300,7 @@ def get(
304300
)
305301
except (psycopg.errors.UndefinedTable, psycopg.errors.UndefinedColumn) as e:
306302
logger.info(f"Error executing select on non-existent table: {self.name}. Creating it instead. Error: {e}")
307-
self.create_collection(collection_name=self.name)
303+
self.create_collection(collection_name=self.name, dimension=self.dimension)
308304
logger.info(f"Created table {self.name}")
309305

310306
cursor.close()
@@ -419,7 +415,7 @@ def query(
419415
cursor = self.client.cursor()
420416
results = []
421417
for query_text in query_texts:
422-
vector = self.embedding_function.encode(query_text, convert_to_tensor=False).tolist()
418+
vector = self.embedding_function(query_text, convert_to_tensor=False).tolist()
423419
if distance_type.lower() == "cosine":
424420
index_function = "<=>"
425421
elif distance_type.lower() == "euclidean":
@@ -526,22 +522,31 @@ def delete_collection(self, collection_name: Optional[str] = None) -> None:
526522
cursor.execute(f"DROP TABLE IF EXISTS {self.name}")
527523
cursor.close()
528524

529-
def create_collection(self, collection_name: Optional[str] = None) -> None:
525+
def create_collection(
526+
self, collection_name: Optional[str] = None, dimension: Optional[Union[str, int]] = None
527+
) -> None:
530528
"""
531529
Create a new collection.
532530
533531
Args:
534532
collection_name (Optional[str]): The name of the new collection.
533+
dimension (Optional[Union[str, int]]): The dimension size of the sentence embedding model
535534
536535
Returns:
537536
None
538537
"""
539538
if collection_name:
540539
self.name = collection_name
540+
541+
if dimension:
542+
self.dimension = dimension
543+
elif self.dimension is None:
544+
self.dimension = 384
545+
541546
cursor = self.client.cursor()
542547
cursor.execute(
543548
f"CREATE TABLE {self.name} ("
544-
f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector(384));"
549+
f"documents text, id CHAR(8) PRIMARY KEY, metadatas JSONB, embedding vector({self.dimension}));"
545550
f"CREATE INDEX "
546551
f'ON {self.name} USING hnsw (embedding vector_l2_ops) WITH (m = {self.metadata["hnsw:M"]}, '
547552
f'ef_construction = {self.metadata["hnsw:construction_ef"]});'
@@ -573,7 +578,6 @@ def __init__(
573578
connect_timeout: Optional[int] = 10,
574579
embedding_function: Callable = None,
575580
metadata: Optional[dict] = None,
576-
model_name: Optional[str] = "all-MiniLM-L6-v2",
577581
) -> None:
578582
"""
579583
Initialize the vector database.
@@ -591,15 +595,14 @@ def __init__(
591595
username: str | The database username to use. Default is None.
592596
password: str | The database user password to use. Default is None.
593597
connect_timeout: int | The timeout to set for the connection. Default is 10.
594-
embedding_function: Callable | The embedding function used to generate the vector representation
595-
of the documents. Default is None.
598+
embedding_function: Callable | The embedding function used to generate the vector representation.
599+
Default is None. SentenceTransformer("all-MiniLM-L6-v2").encode will be used when None.
600+
Models can be chosen from:
601+
https://huggingface.co/models?library=sentence-transformers
596602
metadata: dict | The metadata of the vector database. Default is None. If None, it will use this
597603
setting: {"hnsw:space": "ip", "hnsw:construction_ef": 30, "hnsw:M": 16}. Creates Index on table
598604
using hnsw (embedding vector_l2_ops) WITH (m = hnsw:M) ef_construction = "hnsw:construction_ef".
599605
For more info: https://github.com/pgvector/pgvector?tab=readme-ov-file#hnsw
600-
model_name: str | Sentence embedding model to use. Models can be chosen from:
601-
https://huggingface.co/models?library=sentence-transformers
602-
603606
Returns:
604607
None
605608
"""
@@ -613,17 +616,10 @@ def __init__(
613616
password=password,
614617
connect_timeout=connect_timeout,
615618
)
616-
self.model_name = model_name
617-
try:
618-
self.embedding_function = (
619-
SentenceTransformer(self.model_name) if embedding_function is None else embedding_function
620-
)
621-
except Exception as e:
622-
logger.error(
623-
f"Validate the model name entered: {self.model_name} "
624-
f"from https://huggingface.co/models?library=sentence-transformers\nError: {e}"
625-
)
626-
raise e
619+
if embedding_function:
620+
self.embedding_function = embedding_function
621+
else:
622+
self.embedding_function = SentenceTransformer("all-MiniLM-L6-v2").encode
627623
self.metadata = metadata
628624
register_vector(self.client)
629625
self.active_collection = None
@@ -738,7 +734,6 @@ def create_collection(
738734
embedding_function=self.embedding_function,
739735
get_or_create=get_or_create,
740736
metadata=self.metadata,
741-
model_name=self.model_name,
742737
)
743738
collection.set_collection_name(collection_name=collection_name)
744739
collection.create_collection(collection_name=collection_name)
@@ -751,7 +746,6 @@ def create_collection(
751746
embedding_function=self.embedding_function,
752747
get_or_create=get_or_create,
753748
metadata=self.metadata,
754-
model_name=self.model_name,
755749
)
756750
collection.set_collection_name(collection_name=collection_name)
757751
collection.create_collection(collection_name=collection_name)
@@ -765,7 +759,6 @@ def create_collection(
765759
embedding_function=self.embedding_function,
766760
get_or_create=get_or_create,
767761
metadata=self.metadata,
768-
model_name=self.model_name,
769762
)
770763
collection.set_collection_name(collection_name=collection_name)
771764
collection.create_collection(collection_name=collection_name)
@@ -797,7 +790,6 @@ def get_collection(self, collection_name: str = None) -> Collection:
797790
client=self.client,
798791
collection_name=collection_name,
799792
embedding_function=self.embedding_function,
800-
model_name=self.model_name,
801793
)
802794
return self.active_collection
803795

autogen/agentchat/conversable_agent.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
3232
from ..io.base import IOStream
3333
from ..oai.client import ModelClient, OpenAIWrapper
34-
from ..runtime_logging import log_event, log_new_agent, logging_enabled
34+
from ..runtime_logging import log_event, log_function_use, log_new_agent, logging_enabled
3535
from .agent import Agent, LLMAgent
3636
from .chat import ChatResult, a_initiate_chats, initiate_chats
3737
from .utils import consolidate_chat_info, gather_usage_summary
@@ -1357,9 +1357,7 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[
13571357

13581358
# TODO: #1143 handle token limit exceeded error
13591359
response = llm_client.create(
1360-
context=messages[-1].pop("context", None),
1361-
messages=all_messages,
1362-
cache=cache,
1360+
context=messages[-1].pop("context", None), messages=all_messages, cache=cache, agent=self
13631361
)
13641362
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
13651363

@@ -2528,13 +2526,14 @@ def _wrap_function(self, func: F) -> F:
25282526
@functools.wraps(func)
25292527
def _wrapped_func(*args, **kwargs):
25302528
retval = func(*args, **kwargs)
2531-
2529+
log_function_use(self, func, kwargs, retval)
25322530
return serialize_to_str(retval)
25332531

25342532
@load_basemodels_if_needed
25352533
@functools.wraps(func)
25362534
async def _a_wrapped_func(*args, **kwargs):
25372535
retval = await func(*args, **kwargs)
2536+
log_function_use(self, func, kwargs, retval)
25382537
return serialize_to_str(retval)
25392538

25402539
wrapped_func = _a_wrapped_func if inspect.iscoroutinefunction(func) else _wrapped_func

autogen/logger/base_logger.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
import sqlite3
44
import uuid
55
from abc import ABC, abstractmethod
6-
from typing import TYPE_CHECKING, Any, Dict, List, Union
6+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, TypeVar, Union
77

88
from openai import AzureOpenAI, OpenAI
99
from openai.types.chat import ChatCompletion
1010

1111
if TYPE_CHECKING:
1212
from autogen import Agent, ConversableAgent, OpenAIWrapper
1313

14+
F = TypeVar("F", bound=Callable[..., Any])
1415
ConfigItem = Dict[str, Union[str, List[str]]]
1516
LLMConfig = Dict[str, Union[None, float, int, ConfigItem, List[ConfigItem]]]
1617

@@ -32,6 +33,7 @@ def log_chat_completion(
3233
invocation_id: uuid.UUID,
3334
client_id: int,
3435
wrapper_id: int,
36+
source: Union[str, Agent],
3537
request: Dict[str, Union[float, str, List[Dict[str, str]]]],
3638
response: Union[str, ChatCompletion],
3739
is_cached: int,
@@ -49,9 +51,10 @@ def log_chat_completion(
4951
invocation_id (uuid): A unique identifier for the invocation to the OpenAIWrapper.create method call
5052
client_id (int): A unique identifier for the underlying OpenAI client instance
5153
wrapper_id (int): A unique identifier for the OpenAIWrapper instance
52-
request (dict): A dictionary representing the the request or call to the OpenAI client endpoint
54+
source (str or Agent): The source/creator of the event as a string name or an Agent instance
55+
request (dict): A dictionary representing the request or call to the OpenAI client endpoint
5356
response (str or ChatCompletion): The response from OpenAI
54-
is_chached (int): 1 if the response was a cache hit, 0 otherwise
57+
is_cached (int): 1 if the response was a cache hit, 0 otherwise
5558
cost(float): The cost for OpenAI response
5659
start_time (str): A string representing the moment the request was initiated
5760
"""
@@ -104,6 +107,18 @@ def log_new_client(
104107
"""
105108
...
106109

110+
@abstractmethod
111+
def log_function_use(self, source: Union[str, Agent], function: F, args: Dict[str, Any], returns: Any) -> None:
112+
"""
113+
Log the use of a registered function (could be a tool)
114+
115+
Args:
116+
source (str or Agent): The source/creator of the event as a string name or an Agent instance
117+
function (F): The function information
118+
args (dict): The function args to log
119+
returns (any): The return
120+
"""
121+
107122
@abstractmethod
108123
def stop(self) -> None:
109124
"""

0 commit comments

Comments
 (0)