diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb
index 4bd7a95e..2ca15979 100644
--- a/doc/source/notebooks/embedding.ipynb
+++ b/doc/source/notebooks/embedding.ipynb
@@ -12,7 +12,12 @@
{
"cell_type": "code",
"execution_count": 1,
- "metadata": {},
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-25T12:57:01.715707Z",
+ "start_time": "2023-08-25T12:56:54.919200Z"
+ }
+ },
"outputs": [
{
"name": "stdout",
@@ -51,7 +56,12 @@
{
"cell_type": "code",
"execution_count": 2,
- "metadata": {},
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-25T12:57:02.644919Z",
+ "start_time": "2023-08-25T12:57:01.723149Z"
+ }
+ },
"outputs": [],
"source": [
"content = [\"I have a dog.\", \"I like eating apples.\"]\n",
@@ -81,35 +91,17 @@
{
"cell_type": "code",
"execution_count": 3,
- "metadata": {},
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-25T12:57:08.645604Z",
+ "start_time": "2023-08-25T12:57:02.646625Z"
+ }
+ },
"outputs": [
{
"data": {
- "text/html": [
- "
\n",
- "\t\n",
- "\t\t| id | \n",
- "\t\tcontent | \n",
- "\t
\n",
- "\t\n",
- "\t\t| 0 | \n",
- "\t\tI have a dog. | \n",
- "\t
\n",
- "\t\n",
- "\t\t| 1 | \n",
- "\t\tI like eating apples. | \n",
- "\t
\n",
- "
"
- ],
- "text/plain": [
- "----------------------------\n",
- " id | content \n",
- "----+-----------------------\n",
- " 0 | I have a dog. \n",
- " 1 | I like eating apples. \n",
- "----------------------------\n",
- "(2 rows)"
- ]
+ "text/plain": "----------------------------\n id | content \n----+-----------------------\n 0 | I have a dog. \n 1 | I like eating apples. \n----------------------------\n(2 rows)",
+ "text/html": "\n\t\n\t\t| id | \n\t\tcontent | \n\t
\n\t\n\t\t| 0 | \n\t\tI have a dog. | \n\t
\n\t\n\t\t| 1 | \n\t\tI like eating apples. | \n\t
\n
"
},
"execution_count": 3,
"metadata": {},
@@ -133,30 +125,17 @@
{
"cell_type": "code",
"execution_count": 4,
- "metadata": {},
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2023-08-25T12:57:14.069009Z",
+ "start_time": "2023-08-25T12:57:08.643273Z"
+ }
+ },
"outputs": [
{
"data": {
- "text/html": [
- "\n",
- "\t\n",
- "\t\t| id | \n",
- "\t\tcontent | \n",
- "\t
\n",
- "\t\n",
- "\t\t| 1 | \n",
- "\t\tI like eating apples. | \n",
- "\t
\n",
- "
"
- ],
- "text/plain": [
- "----------------------------\n",
- " id | content \n",
- "----+-----------------------\n",
- " 1 | I like eating apples. \n",
- "----------------------------\n",
- "(1 row)"
- ]
+ "text/plain": "----------------------------\n id | content \n----+-----------------------\n 1 | I like eating apples. \n----------------------------\n(1 row)",
+ "text/html": "\n\t\n\t\t| id | \n\t\tcontent | \n\t
\n\t\n\t\t| 1 | \n\t\tI like eating apples. | \n\t
\n
"
},
"execution_count": 4,
"metadata": {},
@@ -169,39 +148,81 @@
},
{
"cell_type": "markdown",
- "metadata": {},
"source": [
- "## Cleaning All at Once"
- ]
+ "## Batched k-NN search"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "outputs": [],
+ "source": [
+ "query = (\n",
+ " db.create_dataframe(columns={\"idd\": range(3), \"query\": [\"apple\", \"dog\", \"banana\"]})\n",
+ " .save_as(\n",
+ " table_name=\"query_sample\",\n",
+ " column_names=[\"idd\", \"query\"],\n",
+ " distribution_key={\"idd\"},\n",
+ " distribution_type=\"hash\",\n",
+ " )\n",
+ " .check_unique(columns={\"idd\"})\n",
+ " .embedding()\n",
+ " .create_index(column=\"query\", model=\"all-MiniLM-L6-v2\")\n",
+ ")"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-08-25T12:57:17.400047Z",
+ "start_time": "2023-08-25T12:57:14.059315Z"
+ }
+ }
},
{
"cell_type": "code",
"execution_count": 6,
- "metadata": {},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- " * postgresql://localhost:7000\n",
- "Done.\n"
- ]
- },
{
"data": {
- "text/plain": [
- "[]"
- ]
+ "text/plain": "-------------------------------------------\n idd | id | query | content \n-----+----+--------+-----------------------\n 1 | 0 | dog | I have a dog. \n 2 | 1 | banana | I like eating apples. \n 0 | 1 | apple | I like eating apples. \n-------------------------------------------\n(3 rows)",
+ "text/html": "\n\t\n\t\t| idd | \n\t\tid | \n\t\tquery | \n\t\tcontent | \n\t
\n\t\n\t\t| 1 | \n\t\t0 | \n\t\tdog | \n\t\tI have a dog. | \n\t
\n\t\n\t\t| 2 | \n\t\t1 | \n\t\tbanana | \n\t\tI like eating apples. | \n\t
\n\t\n\t\t| 0 | \n\t\t1 | \n\t\tapple | \n\t\tI like eating apples. | \n\t
\n
"
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
+ "source": [
+ "t.embedding().search(column=\"content\", query=query[\"query\"], top_k=1)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-08-25T12:57:18.305871Z",
+ "start_time": "2023-08-25T12:57:17.402679Z"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Cleaning All at Once"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"%reload_ext sql\n",
"%sql postgresql://localhost:7000\n",
"%sql DROP TABLE text_sample CASCADE;"
+ "%sql DROP TABLE query_sample CASCADE;"
]
}
],
diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py
index 04d8dff4..5cedeb36 100644
--- a/greenplumpython/experimental/embedding.py
+++ b/greenplumpython/experimental/embedding.py
@@ -1,8 +1,9 @@
from collections import abc
-from typing import Any, Callable, cast
+from typing import Any, Callable, Dict, List, Optional, Set, Union, cast
from uuid import uuid4
import greenplumpython as gp
+from greenplumpython.col import Column
from greenplumpython.row import Row
from greenplumpython.type import TypeCast
@@ -144,7 +145,13 @@ def create_index(self, column: str, model: str) -> gp.DataFrame:
)
return self._dataframe
- def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame:
+ def search(
+ self,
+ column: str,
+ query: Any,
+ top_k: int,
+ query_unique_key_columns: Optional[Union[Dict[str, Optional[str]], Set[str]]] = None,
+ ) -> gp.DataFrame:
"""
Searche unstructured data based on semantic similarity on embeddings.
@@ -155,54 +162,122 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame:
Returns:
Dataframe with the top k most similar results in the `column` of `query`.
-
- Example:
- Please refer to :ref:`embedding-example` for more details.
"""
- assert self._dataframe._db is not None
- embdedding_info = self._dataframe._db._execute(
- f"""
- WITH indexed_col_info AS (
- SELECT attrelid, attnum
- FROM pg_attribute
+
+ def find_embedding_df(df: gp.DataFrame, column_c: str):
+ assert df._db is not None
+
+ embdedding_info = df._db._execute(
+ f"""
+ WITH indexed_col_info AS (
+ SELECT attrelid, attnum
+ FROM pg_attribute
+ WHERE
+ attrelid = '{df._qualified_table_name}'::regclass::oid AND
+ attname = '{column_c}'
+ ), reloptions AS (
+ SELECT unnest(reloptions) AS option
+ FROM pg_class, indexed_col_info
+ WHERE pg_class.oid = attrelid
+ ), embedding_info_json AS (
+ SELECT split_part(option, '=', 2)::json AS val
+ FROM reloptions, indexed_col_info
+ WHERE option LIKE format('_pygp_emb_%s=%%', attnum)
+ ), embedding_info AS (
+ SELECT *
+ FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text)
+ )
+ SELECT nspname, relname, attname, model
+ FROM embedding_info, pg_class, pg_namespace, pg_attribute
WHERE
- attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND
- attname = '{column}'
- ), reloptions AS (
- SELECT unnest(reloptions) AS option
- FROM pg_class, indexed_col_info
- WHERE pg_class.oid = attrelid
- ), embedding_info_json AS (
- SELECT split_part(option, '=', 2)::json AS val
- FROM reloptions, indexed_col_info
- WHERE option LIKE format('_pygp_emb_%s=%%', attnum)
- ), embedding_info AS (
- SELECT *
- FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text)
+ pg_class.oid = embedding_relid AND
+ relnamespace = pg_namespace.oid AND
+ embedding_relid = attrelid AND
+ pg_attribute.attnum = 2;
+ """
)
- SELECT nspname, relname, attname, model
- FROM embedding_info, pg_class, pg_namespace, pg_attribute
- WHERE
- pg_class.oid = embedding_relid AND
- relnamespace = pg_namespace.oid AND
- embedding_relid = attrelid AND
- pg_attribute.attnum = 2;
- """
- )
- row: Row = embdedding_info[0]
- schema: str = row["nspname"]
- embedding_table_name: str = row["relname"]
- model = row["model"]
- embedding_col_name = row["attname"]
- embedding_df = self._dataframe._db.create_dataframe(
- table_name=embedding_table_name, schema=schema
- )
+ row: Row = embdedding_info[0]
+ schema: str = row["nspname"]
+ embedding_table_name: str = row["relname"]
+ model = row["model"]
+ embedding_col_name = row["attname"]
+ embedding_df = df._db.create_dataframe(table_name=embedding_table_name, schema=schema)
+ return embedding_df, embedding_table_name, embedding_col_name, model
+
+ def _bind(t: str, columns: Union[Dict[str, Optional[str]], Set[str]]) -> List[str]:
+ target_list: List[str] = []
+ for k in columns:
+ v = columns[k] if isinstance(columns, dict) else None
+ col_serialize = t + "." + k + (f' AS "{v}"' if v is not None else "")
+ target_list.append(col_serialize)
+ return target_list
+
+ (
+ self_embedding_df,
+ self_embedding_table_name,
+ self_embedding_col_name,
+ self_model,
+ ) = find_embedding_df(self._dataframe, column)
assert self._dataframe.unique_key is not None
distance = gp.operator("<->") # L2 distance is the default operator class in pgvector
+ if isinstance(query, Column):
+ assert query._dataframe is not None
+ (_, query_embedding_table_name, query_embedding_col_name, _,) = find_embedding_df(
+ query._dataframe.embedding()._dataframe, query._name # type: ignore reportUnknownArgumentType
+ )
+ assert query._dataframe.unique_key is not None
+ joint_table_name = "cte_" + uuid4().hex
+ right_join_table_name = "cte_" + uuid4().hex
+ query_df_unique_keys: List[str] = list(query._dataframe.unique_key)
+ self_df_unique_keys: List[str] = list(self._dataframe.unique_key)
+ assert query_df_unique_keys is not None
+ assert self_df_unique_keys is not None
+ lateral_join_df = gp.DataFrame(
+ query=f"""
+ WITH {joint_table_name} as (
+ SELECT
+ {",".join(_bind(query_embedding_table_name, columns=query_unique_key_columns))
+ if query_unique_key_columns is not None
+ else ",".join(
+ [(query_embedding_table_name+"."+key) for key in query_df_unique_keys]
+ )},
+ {",".join([(right_join_table_name+"."+key) for key in self_df_unique_keys])},
+ {query_embedding_table_name}.{query_embedding_col_name},
+ {right_join_table_name}.{self_embedding_col_name}
+ FROM {query_embedding_table_name} CROSS JOIN LATERAL (
+ SELECT * FROM {self_embedding_table_name}
+ ORDER BY {self_embedding_table_name}.{self_embedding_col_name} <-> {query_embedding_table_name}.{query_embedding_col_name}
+ LIMIT {top_k}
+ ) AS {right_join_table_name}
+ )
+
+ SELECT
+ {",".join(_bind(query._dataframe._qualified_table_name, columns=query_unique_key_columns))
+ if query_unique_key_columns is not None
+ else ",".join(
+ [(query._dataframe._qualified_table_name+"."+key) for key in query_df_unique_keys]
+ )},
+ {",".join([(self._dataframe._qualified_table_name+"."+key) for key in self_df_unique_keys])},
+ {query._dataframe._qualified_table_name}.{query._name},
+ {self._dataframe._qualified_table_name}.{column}
+ FROM {joint_table_name}
+ JOIN {query._dataframe._qualified_table_name}
+ ON {"AND".join([
+ (f"{query._dataframe._qualified_table_name}.{key} = {joint_table_name}.{query_unique_key_columns[key] if query_unique_key_columns is not None and key in query_unique_key_columns else key}")
+ for key in query_df_unique_keys
+ ])}
+ JOIN {self._dataframe._qualified_table_name}
+ ON {"AND".join([(self._dataframe._qualified_table_name+"."+key+" = "+joint_table_name+"." + key) for key in self_df_unique_keys])}
+ """,
+ db=self._dataframe._db,
+ )
+ return lateral_join_df
+
return self._dataframe.join(
- embedding_df.assign(
+ self_embedding_df.assign(
distance=lambda t: distance(
- embedding_df[embedding_col_name], _generate_embedding(query, model)
+ self_embedding_df[self_embedding_col_name],
+ _generate_embedding(query, self_model),
)
).order_by("distance")[:top_k],
how="inner",