diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index 4bd7a95e..2ca15979 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -12,7 +12,12 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-25T12:57:01.715707Z", + "start_time": "2023-08-25T12:56:54.919200Z" + } + }, "outputs": [ { "name": "stdout", @@ -51,7 +56,12 @@ { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-25T12:57:02.644919Z", + "start_time": "2023-08-25T12:57:01.723149Z" + } + }, "outputs": [], "source": [ "content = [\"I have a dog.\", \"I like eating apples.\"]\n", @@ -81,35 +91,17 @@ { "cell_type": "code", "execution_count": 3, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-25T12:57:08.645604Z", + "start_time": "2023-08-25T12:57:02.646625Z" + } + }, "outputs": [ { "data": { - "text/html": [ - "\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "
idcontent
0I have a dog.
1I like eating apples.
" - ], - "text/plain": [ - "----------------------------\n", - " id | content \n", - "----+-----------------------\n", - " 0 | I have a dog. \n", - " 1 | I like eating apples. \n", - "----------------------------\n", - "(2 rows)" - ] + "text/plain": "----------------------------\n id | content \n----+-----------------------\n 0 | I have a dog. \n 1 | I like eating apples. \n----------------------------\n(2 rows)", + "text/html": "\n\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\n
idcontent
0I have a dog.
1I like eating apples.
" }, "execution_count": 3, "metadata": {}, @@ -133,30 +125,17 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-25T12:57:14.069009Z", + "start_time": "2023-08-25T12:57:08.643273Z" + } + }, "outputs": [ { "data": { - "text/html": [ - "\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "
idcontent
1I like eating apples.
" - ], - "text/plain": [ - "----------------------------\n", - " id | content \n", - "----+-----------------------\n", - " 1 | I like eating apples. \n", - "----------------------------\n", - "(1 row)" - ] + "text/plain": "----------------------------\n id | content \n----+-----------------------\n 1 | I like eating apples. \n----------------------------\n(1 row)", + "text/html": "\n\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\n
idcontent
1I like eating apples.
" }, "execution_count": 4, "metadata": {}, @@ -169,39 +148,81 @@ }, { "cell_type": "markdown", - "metadata": {}, "source": [ - "## Cleaning All at Once" - ] + "## Batched k-NN search" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "query = (\n", + " db.create_dataframe(columns={\"idd\": range(3), \"query\": [\"apple\", \"dog\", \"banana\"]})\n", + " .save_as(\n", + " table_name=\"query_sample\",\n", + " column_names=[\"idd\", \"query\"],\n", + " distribution_key={\"idd\"},\n", + " distribution_type=\"hash\",\n", + " )\n", + " .check_unique(columns={\"idd\"})\n", + " .embedding()\n", + " .create_index(column=\"query\", model=\"all-MiniLM-L6-v2\")\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-25T12:57:17.400047Z", + "start_time": "2023-08-25T12:57:14.059315Z" + } + } }, { "cell_type": "code", "execution_count": 6, - "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " * postgresql://localhost:7000\n", - "Done.\n" - ] - }, { "data": { - "text/plain": [ - "[]" - ] + "text/plain": "-------------------------------------------\n idd | id | query | content \n-----+----+--------+-----------------------\n 1 | 0 | dog | I have a dog. \n 2 | 1 | banana | I like eating apples. \n 0 | 1 | apple | I like eating apples. \n-------------------------------------------\n(3 rows)", + "text/html": "\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n
iddidquerycontent
10dogI have a dog.
21bananaI like eating apples.
01appleI like eating apples.
" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], + "source": [ + "t.embedding().search(column=\"content\", query=query[\"query\"], top_k=1)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-25T12:57:18.305871Z", + "start_time": "2023-08-25T12:57:17.402679Z" + } + } + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleaning All at Once" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "%reload_ext sql\n", "%sql postgresql://localhost:7000\n", "%sql DROP TABLE text_sample CASCADE;" + "%sql DROP TABLE query_sample CASCADE;" ] } ], diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 04d8dff4..5cedeb36 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -1,8 +1,9 @@ from collections import abc -from typing import Any, Callable, cast +from typing import Any, Callable, Dict, List, Optional, Set, Union, cast from uuid import uuid4 import greenplumpython as gp +from greenplumpython.col import Column from greenplumpython.row import Row from greenplumpython.type import TypeCast @@ -144,7 +145,13 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: ) return self._dataframe - def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: + def search( + self, + column: str, + query: Any, + top_k: int, + query_unique_key_columns: Optional[Union[Dict[str, Optional[str]], Set[str]]] = None, + ) -> gp.DataFrame: """ Searche unstructured data based on semantic similarity on embeddings. @@ -155,54 +162,122 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: Returns: Dataframe with the top k most similar results in the `column` of `query`. - - Example: - Please refer to :ref:`embedding-example` for more details. """ - assert self._dataframe._db is not None - embdedding_info = self._dataframe._db._execute( - f""" - WITH indexed_col_info AS ( - SELECT attrelid, attnum - FROM pg_attribute + + def find_embedding_df(df: gp.DataFrame, column_c: str): + assert df._db is not None + + embdedding_info = df._db._execute( + f""" + WITH indexed_col_info AS ( + SELECT attrelid, attnum + FROM pg_attribute + WHERE + attrelid = '{df._qualified_table_name}'::regclass::oid AND + attname = '{column_c}' + ), reloptions AS ( + SELECT unnest(reloptions) AS option + FROM pg_class, indexed_col_info + WHERE pg_class.oid = attrelid + ), embedding_info_json AS ( + SELECT split_part(option, '=', 2)::json AS val + FROM reloptions, indexed_col_info + WHERE option LIKE format('_pygp_emb_%s=%%', attnum) + ), embedding_info AS ( + SELECT * + FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text) + ) + SELECT nspname, relname, attname, model + FROM embedding_info, pg_class, pg_namespace, pg_attribute WHERE - attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND - attname = '{column}' - ), reloptions AS ( - SELECT unnest(reloptions) AS option - FROM pg_class, indexed_col_info - WHERE pg_class.oid = attrelid - ), embedding_info_json AS ( - SELECT split_part(option, '=', 2)::json AS val - FROM reloptions, indexed_col_info - WHERE option LIKE format('_pygp_emb_%s=%%', attnum) - ), embedding_info AS ( - SELECT * - FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text) + pg_class.oid = embedding_relid AND + relnamespace = pg_namespace.oid AND + embedding_relid = attrelid AND + pg_attribute.attnum = 2; + """ ) - SELECT nspname, relname, attname, model - FROM embedding_info, pg_class, pg_namespace, pg_attribute - WHERE - pg_class.oid = embedding_relid AND - relnamespace = pg_namespace.oid AND - embedding_relid = attrelid AND - pg_attribute.attnum = 2; - """ - ) - row: Row = embdedding_info[0] - schema: str = row["nspname"] - embedding_table_name: str = row["relname"] - model = row["model"] - embedding_col_name = row["attname"] - embedding_df = self._dataframe._db.create_dataframe( - table_name=embedding_table_name, schema=schema - ) + row: Row = embdedding_info[0] + schema: str = row["nspname"] + embedding_table_name: str = row["relname"] + model = row["model"] + embedding_col_name = row["attname"] + embedding_df = df._db.create_dataframe(table_name=embedding_table_name, schema=schema) + return embedding_df, embedding_table_name, embedding_col_name, model + + def _bind(t: str, columns: Union[Dict[str, Optional[str]], Set[str]]) -> List[str]: + target_list: List[str] = [] + for k in columns: + v = columns[k] if isinstance(columns, dict) else None + col_serialize = t + "." + k + (f' AS "{v}"' if v is not None else "") + target_list.append(col_serialize) + return target_list + + ( + self_embedding_df, + self_embedding_table_name, + self_embedding_col_name, + self_model, + ) = find_embedding_df(self._dataframe, column) assert self._dataframe.unique_key is not None distance = gp.operator("<->") # L2 distance is the default operator class in pgvector + if isinstance(query, Column): + assert query._dataframe is not None + (_, query_embedding_table_name, query_embedding_col_name, _,) = find_embedding_df( + query._dataframe.embedding()._dataframe, query._name # type: ignore reportUnknownArgumentType + ) + assert query._dataframe.unique_key is not None + joint_table_name = "cte_" + uuid4().hex + right_join_table_name = "cte_" + uuid4().hex + query_df_unique_keys: List[str] = list(query._dataframe.unique_key) + self_df_unique_keys: List[str] = list(self._dataframe.unique_key) + assert query_df_unique_keys is not None + assert self_df_unique_keys is not None + lateral_join_df = gp.DataFrame( + query=f""" + WITH {joint_table_name} as ( + SELECT + {",".join(_bind(query_embedding_table_name, columns=query_unique_key_columns)) + if query_unique_key_columns is not None + else ",".join( + [(query_embedding_table_name+"."+key) for key in query_df_unique_keys] + )}, + {",".join([(right_join_table_name+"."+key) for key in self_df_unique_keys])}, + {query_embedding_table_name}.{query_embedding_col_name}, + {right_join_table_name}.{self_embedding_col_name} + FROM {query_embedding_table_name} CROSS JOIN LATERAL ( + SELECT * FROM {self_embedding_table_name} + ORDER BY {self_embedding_table_name}.{self_embedding_col_name} <-> {query_embedding_table_name}.{query_embedding_col_name} + LIMIT {top_k} + ) AS {right_join_table_name} + ) + + SELECT + {",".join(_bind(query._dataframe._qualified_table_name, columns=query_unique_key_columns)) + if query_unique_key_columns is not None + else ",".join( + [(query._dataframe._qualified_table_name+"."+key) for key in query_df_unique_keys] + )}, + {",".join([(self._dataframe._qualified_table_name+"."+key) for key in self_df_unique_keys])}, + {query._dataframe._qualified_table_name}.{query._name}, + {self._dataframe._qualified_table_name}.{column} + FROM {joint_table_name} + JOIN {query._dataframe._qualified_table_name} + ON {"AND".join([ + (f"{query._dataframe._qualified_table_name}.{key} = {joint_table_name}.{query_unique_key_columns[key] if query_unique_key_columns is not None and key in query_unique_key_columns else key}") + for key in query_df_unique_keys + ])} + JOIN {self._dataframe._qualified_table_name} + ON {"AND".join([(self._dataframe._qualified_table_name+"."+key+" = "+joint_table_name+"." + key) for key in self_df_unique_keys])} + """, + db=self._dataframe._db, + ) + return lateral_join_df + return self._dataframe.join( - embedding_df.assign( + self_embedding_df.assign( distance=lambda t: distance( - embedding_df[embedding_col_name], _generate_embedding(query, model) + self_embedding_df[self_embedding_col_name], + _generate_embedding(query, self_model), ) ).order_by("distance")[:top_k], how="inner",