From 99daf516569abcdae74049f09e2da5df766248d5 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Tue, 25 Jul 2023 22:16:10 -0400 Subject: [PATCH 01/19] Add module for indexing and searching embeddings --- doc/source/notebooks/abalone.ipynb | 24 ++-- doc/source/notebooks/basic.ipynb | 12 +- doc/source/notebooks/embedding.ipynb | 76 ++++++++++ doc/source/notebooks/pandas.ipynb | 73 +++++----- greenplumpython/dataframe.py | 35 +++-- greenplumpython/experimental/__init__.py | 0 greenplumpython/experimental/embedding.py | 160 ++++++++++++++++++++++ greenplumpython/type.py | 7 +- 8 files changed, 318 insertions(+), 69 deletions(-) create mode 100644 doc/source/notebooks/embedding.ipynb create mode 100644 greenplumpython/experimental/__init__.py create mode 100644 greenplumpython/experimental/embedding.py diff --git a/doc/source/notebooks/abalone.ipynb b/doc/source/notebooks/abalone.ipynb index 78e87f4e..aaf2070b 100644 --- a/doc/source/notebooks/abalone.ipynb +++ b/doc/source/notebooks/abalone.ipynb @@ -504,6 +504,7 @@ "import numpy as np\n", "import pickle\n", "\n", + "\n", "@gp.create_column_function\n", "def linreg_func(length: List[float], shucked_weight: List[float], rings: List[int]) -> LinregType:\n", " X = np.array([length, shucked_weight]).T\n", @@ -560,9 +561,8 @@ "# ) a\n", "# ) DISTRIBUTED BY (sex);\n", "\n", - "linreg_fitted = (\n", - " abalone_train.group_by(\"sex\")\n", - " .apply(lambda t: linreg_func(t[\"length\"], t[\"shucked_weight\"], t[\"rings\"]), expand=True)\n", + "linreg_fitted = abalone_train.group_by(\"sex\").apply(\n", + " lambda t: linreg_func(t[\"length\"], t[\"shucked_weight\"], t[\"rings\"]), expand=True\n", ")" ] }, @@ -800,7 +800,7 @@ "linreg_test_fit = linreg_fitted.inner_join(\n", " abalone_test,\n", " cond=lambda t1, t2: t1[\"sex\"] == t2[\"sex\"],\n", - " self_columns=[\"col_nm\", \"coef\", \"intercept\", \"serialized_linreg_model\", \"created_dt\"]\n", + " self_columns=[\"col_nm\", \"coef\", \"intercept\", \"serialized_linreg_model\", \"created_dt\"],\n", ")" ] }, @@ -836,12 +836,11 @@ "\n", "\n", "linreg_pred = linreg_test_fit.assign(\n", - " rings_pred=lambda t:\n", - " linreg_pred_func(\n", - " t[\"serialized_linreg_model\"],\n", - " t[\"length\"],\n", - " t[\"shucked_weight\"],\n", - " ),\n", + " rings_pred=lambda t: linreg_pred_func(\n", + " t[\"serialized_linreg_model\"],\n", + " t[\"length\"],\n", + " t[\"shucked_weight\"],\n", + " ),\n", ")[[\"id\", \"sex\", \"rings\", \"rings_pred\"]]" ] }, @@ -946,6 +945,7 @@ "# , r2_score float8\n", "# );\n", "\n", + "\n", "@dataclasses.dataclass\n", "class linreg_eval_type:\n", " mae: float\n", @@ -974,6 +974,7 @@ "source": [ "from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score\n", "\n", + "\n", "@gp.create_column_function\n", "def linreg_eval(y_actual: List[float], y_pred: List[float]) -> linreg_eval_type:\n", " mae = mean_absolute_error(y_actual, y_pred)\n", @@ -1066,8 +1067,7 @@ "# ) a\n", "\n", "\n", - "linreg_pred.group_by(\"sex\").apply(\n", - " lambda t: linreg_eval(t[\"rings\"], t[\"rings_pred\"]), expand=True)" + "linreg_pred.group_by(\"sex\").apply(lambda t: linreg_eval(t[\"rings\"], t[\"rings_pred\"]), expand=True)" ] } ], diff --git a/doc/source/notebooks/basic.ipynb b/doc/source/notebooks/basic.ipynb index 839f71f9..c75257ae 100644 --- a/doc/source/notebooks/basic.ipynb +++ b/doc/source/notebooks/basic.ipynb @@ -787,14 +787,8 @@ "t_join = t1.join(\n", " t2,\n", " on=\"val\",\n", - " self_columns = {\n", - " \"id\": \"t1_id\",\n", - " \"val\": \"t1_val\"\n", - " },\n", - " other_columns = {\n", - " \"id\": \"t2_id\",\n", - " \"val\": \"t2_val\"\n", - " }\n", + " self_columns={\"id\": \"t1_id\", \"val\": \"t1_val\"},\n", + " other_columns={\"id\": \"t2_id\", \"val\": \"t2_val\"},\n", ")\n", "t_join" ] @@ -1075,7 +1069,7 @@ " numbers.assign(is_even=lambda t: t[\"val\"] % 2 == 0)\n", " .group_by(\"is_even\")\n", " .apply(lambda t: F.sum(t[\"val\"]))\n", - ")\n" + ")" ] } ], diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb new file mode 100644 index 00000000..12353331 --- /dev/null +++ b/doc/source/notebooks/embedding.ipynb @@ -0,0 +1,76 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%cd $(find ~ -name GreenplumPython)\n", + "!python3 -m pip install --upgrade ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "content = [\"I have a dog.\", \"I like eating apples.\"]\n", + "\n", + "import greenplumpython as gp\n", + "\n", + "db = gp.database(\"postgresql://localhost:7000\")\n", + "t = db.create_dataframe(columns={\"id\": range(len(content)), \"content\": content})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import greenplumpython.experimental.embedding\n", + "\n", + "gp.config.print_sql = True\n", + "\n", + "t = t.save_as(\n", + " column_names=[\"id\", \"content\"], distribution_key={\"id\"}, distribution_type=\"hash\"\n", + ").check_unique(columns={\"id\"})\n", + "t = t.embedding().create_index(column=\"content\", model=\"all-MiniLM-L6-v2\")\n", + "t" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t.embedding().search(column=\"content\", query=\"apple\", top_k=1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/doc/source/notebooks/pandas.ipynb b/doc/source/notebooks/pandas.ipynb index 64ee3c3b..61caa6a5 100644 --- a/doc/source/notebooks/pandas.ipynb +++ b/doc/source/notebooks/pandas.ipynb @@ -31,7 +31,7 @@ "import pandas as pd\n", "import greenplumpython as gp\n", "\n", - "gp\n" + "gp" ] }, { @@ -80,7 +80,7 @@ ], "source": [ "students = [(\"alice\", 18), (\"bob\", 19), (\"carol\", 19)]\n", - "students\n" + "students" ] }, { @@ -154,7 +154,7 @@ ], "source": [ "pd_df = pd.DataFrame.from_records(students, columns=[\"name\", \"age\"])\n", - "pd_df\n" + "pd_df" ] }, { @@ -210,7 +210,7 @@ "source": [ "db = gp.database(\"postgresql://localhost/gpadmin\")\n", "gp_df = gp.DataFrame.from_rows(students, column_names=[\"name\", \"age\"], db=db)\n", - "gp_df\n" + "gp_df" ] }, { @@ -271,7 +271,7 @@ } ], "source": [ - "gp_df.save_as(\"student\", column_names=[\"name\", \"age\"], temp=True)\n" + "gp_df.save_as(\"student\", column_names=[\"name\", \"age\"], temp=True)" ] }, { @@ -287,7 +287,7 @@ "metadata": {}, "outputs": [], "source": [ - "pd_df.to_csv(\"/tmp/student.csv\")\n" + "pd_df.to_csv(\"/tmp/student.csv\")" ] }, { @@ -352,7 +352,7 @@ ], "source": [ "student = db.create_dataframe(table_name=\"student\")\n", - "student\n" + "student" ] }, { @@ -429,7 +429,7 @@ } ], "source": [ - "pd.read_csv(\"/tmp/student.csv\")\n" + "pd.read_csv(\"/tmp/student.csv\")" ] }, { @@ -460,7 +460,7 @@ ], "source": [ "for row in gp_df:\n", - " print(row[\"name\"], row[\"age\"])\n" + " print(row[\"name\"], row[\"age\"])" ] }, { @@ -487,7 +487,7 @@ ], "source": [ "for row in pd_df.iterrows():\n", - " print(row[1][\"name\"], row[1][\"age\"])\n" + " print(row[1][\"name\"], row[1][\"age\"])" ] }, { @@ -584,7 +584,7 @@ } ], "source": [ - "pd_df[[\"name\", \"age\"]]\n" + "pd_df[[\"name\", \"age\"]]" ] }, { @@ -638,7 +638,7 @@ } ], "source": [ - "student[[\"name\", \"age\"]]\n" + "student[[\"name\", \"age\"]]" ] }, { @@ -679,7 +679,7 @@ } ], "source": [ - "pd_df[\"name\"]\n" + "pd_df[\"name\"]" ] }, { @@ -706,7 +706,7 @@ } ], "source": [ - "gp_df[\"name\"]\n" + "gp_df[\"name\"]" ] }, { @@ -788,7 +788,7 @@ } ], "source": [ - "pd_df[lambda df: df[\"name\"] == \"alice\"]\n" + "pd_df[lambda df: df[\"name\"] == \"alice\"]" ] }, { @@ -832,7 +832,7 @@ } ], "source": [ - "student[lambda t: t[\"name\"] == \"alice\"]\n" + "student[lambda t: t[\"name\"] == \"alice\"]" ] }, { @@ -898,7 +898,7 @@ } ], "source": [ - "student[:2]\n" + "student[:2]" ] }, { @@ -965,7 +965,7 @@ } ], "source": [ - "pd_df[:2]\n" + "pd_df[:2]" ] }, { @@ -1062,7 +1062,7 @@ } ], "source": [ - "pd_df.sort_values([\"age\", \"name\"], ascending=[False, False])\n" + "pd_df.sort_values([\"age\", \"name\"], ascending=[False, False])" ] }, { @@ -1116,7 +1116,7 @@ } ], "source": [ - "student.order_by(\"age\", ascending=False).order_by(\"name\", ascending=False)[:]\n" + "student.order_by(\"age\", ascending=False).order_by(\"name\", ascending=False)[:]" ] }, { @@ -1219,7 +1219,7 @@ "import datetime\n", "\n", "this_year = datetime.date.today().year\n", - "pd_df.assign(year_of_birth=lambda df: -df[\"age\"] + this_year)\n" + "pd_df.assign(year_of_birth=lambda df: -df[\"age\"] + this_year)" ] }, { @@ -1277,7 +1277,7 @@ } ], "source": [ - "student.assign(year_of_birth=lambda t: -t[\"age\"] + this_year)\n" + "student.assign(year_of_birth=lambda t: -t[\"age\"] + this_year)" ] }, { @@ -1297,9 +1297,10 @@ "source": [ "from hashlib import sha256\n", "\n", + "\n", "@gp.create_function\n", "def hash_name(name: str) -> str:\n", - " return sha256(name.encode(\"utf-8\")).hexdigest()\n" + " return sha256(name.encode(\"utf-8\")).hexdigest()" ] }, { @@ -1359,7 +1360,7 @@ } ], "source": [ - "student.assign(name_=lambda t: hash_name(t[\"name\"]))\n" + "student.assign(name_=lambda t: hash_name(t[\"name\"]))" ] }, { @@ -1430,7 +1431,7 @@ " return Student(name=sha256(name.encode(\"utf-8\")).hexdigest(), age=age)\n", "\n", "\n", - "student.apply(lambda t: gp.create_function(hide_name)(t[\"name\"], t[\"age\"]), expand=True)\n" + "student.apply(lambda t: gp.create_function(hide_name)(t[\"name\"], t[\"age\"]), expand=True)" ] }, { @@ -1503,11 +1504,7 @@ } ], "source": [ - "pd_df.apply(\n", - " lambda df: asdict(hide_name(df[\"name\"], df[\"age\"])),\n", - " axis=1, \n", - " result_type=\"expand\"\n", - ")\n" + "pd_df.apply(lambda df: asdict(hide_name(df[\"name\"], df[\"age\"])), axis=1, result_type=\"expand\")" ] }, { @@ -1555,7 +1552,7 @@ "source": [ "import numpy as np\n", "\n", - "pd_df.groupby(\"age\").apply(lambda df: np.count_nonzero(df[\"name\"]))\n" + "pd_df.groupby(\"age\").apply(lambda df: np.count_nonzero(df[\"name\"]))" ] }, { @@ -1606,7 +1603,7 @@ "source": [ "count = gp.aggregate_function(\"count\")\n", "\n", - "student.group_by(\"age\").apply(lambda t: count(t[\"name\"]))\n" + "student.group_by(\"age\").apply(lambda t: count(t[\"name\"]))" ] }, { @@ -1681,7 +1678,7 @@ } ], "source": [ - "pd_df.drop_duplicates(\"age\")\n" + "pd_df.drop_duplicates(\"age\")" ] }, { @@ -1730,7 +1727,7 @@ } ], "source": [ - "student.distinct_on(\"age\")\n" + "student.distinct_on(\"age\")" ] }, { @@ -1774,7 +1771,7 @@ } ], "source": [ - "student.apply(lambda t: count.distinct(t[\"age\"]))\n" + "student.apply(lambda t: count.distinct(t[\"age\"]))" ] }, { @@ -1873,7 +1870,7 @@ } ], "source": [ - "pd_df.merge(pd_df, on=\"age\", suffixes=(\"\", \"_2\"))\n" + "pd_df.merge(pd_df, on=\"age\", suffixes=(\"\", \"_2\"))" ] }, { @@ -1943,7 +1940,7 @@ } ], "source": [ - "student.join(student, on=\"age\", other_columns={\"name\": \"name_2\"})\n" + "student.join(student, on=\"age\", other_columns={\"name\": \"name_2\"})" ] }, { @@ -2029,7 +2026,7 @@ "num_1 = pd.DataFrame({\"val\": [1, 3, 5, 7, 9]})\n", "num_2 = pd.DataFrame({\"val\": [2, 4, 6, 8, 10]})\n", "\n", - "num_1[num_2[\"val\"] % 2 == 0] # Even numbers?\n" + "num_1[num_2[\"val\"] % 2 == 0] # Even numbers?" ] }, { diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 28aed86e..5a7469d7 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -941,6 +941,11 @@ def save_as( if table_name is None: table_name = self._name if not self.is_saved else "cte_" + uuid4().hex qualified_table_name = f'"{table_name}"' if schema is None else f'"{schema}"."{table_name}"' + if distribution_type is not None: + distribution_type = distribution_type.lower() + assert (distribution_key is not None and distribution_type == "hash") or ( + distribution_key is None and distribution_type == "randomly" or "replicated" + ), f"Distribution type '{distribution_type}' on key '{distribution_key}' is invalid." distribution_clause = ( f""" DISTRIBUTED {f"BY ({','.join(distribution_key)})" @@ -1018,10 +1023,6 @@ def group_by(self, *column_names: str) -> DataFrameGroupingSet: :class:`~dataframe.DataFrame`. Each group is identified by a different set of values of the columns in the arguments. """ - # State transition diagram: - # DataFrame --group_by()-> DataFrameGroupingSet --aggregate()-> FunctionExpr - # ^ | - # |------------------------- assign() or apply() ---------------| return DataFrameGroupingSet(self, [column_names]) def distinct_on(self, *column_names: str) -> "DataFrame": @@ -1063,6 +1064,24 @@ def distinct_on(self, *column_names: str) -> "DataFrame": parents=[self], ) + @property + def unique_key(self) -> List[str]: + return self._unique_key + + def check_unique(self, columns: set[str]) -> "DataFrame": + """ + Check whether a given set of columns, i.e. key, is unique. + """ + assert self.is_saved, "DataFrame must be saved before checking uniqueness." + assert self._db is not None, "Database is required to check uniqueness." + print(self) + self._db._execute( + f"CREATE UNIQUE INDEX ON {self._qualified_table_name} ({','.join(columns)})", + has_results=False, + ) + self._unique_key = columns + return self + # dataframe_name can be table/view name @classmethod def from_table(cls, table_name: str, db: Database, schema: Optional[str] = None) -> "DataFrame": @@ -1080,7 +1099,7 @@ def from_table(cls, table_name: str, db: Database, schema: Optional[str] = None) """ qualified_name = f'"{schema}"."{table_name}"' if schema is not None else f'"{table_name}"' - return DataFrame(f"TABLE {qualified_name}", db=db, qualified_table_name=qualified_name) + return cls(f"TABLE {qualified_name}", db=db, qualified_table_name=qualified_name) @classmethod def from_rows( @@ -1141,9 +1160,7 @@ def from_rows( column_names = [f'"{name}"' for name in column_names] columns_string = f"({','.join(column_names)})" table_name = "cte_" + uuid4().hex - return DataFrame( - f"SELECT * FROM (VALUES {rows_string}) AS {table_name} {columns_string}", db=db - ) + return cls(f"SELECT * FROM (VALUES {rows_string}) AS {table_name} {columns_string}", db=db) @classmethod def from_columns(cls, columns: Dict[str, Iterable[Any]], db: Database) -> "DataFrame": @@ -1176,4 +1193,4 @@ def from_columns(cls, columns: Dict[str, Iterable[Any]], db: Database) -> "DataF columns_string = ",".join( [f'unnest({_serialize(list(v))}) AS "{k}"' for k, v in columns.items()] ) - return DataFrame(f"SELECT {columns_string}", db=db) + return cls(f"SELECT {columns_string}", db=db) diff --git a/greenplumpython/experimental/__init__.py b/greenplumpython/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py new file mode 100644 index 00000000..8f7e366a --- /dev/null +++ b/greenplumpython/experimental/embedding.py @@ -0,0 +1,160 @@ +from collections import abc +from typing import Any +from uuid import uuid4 + +import greenplumpython as gp + +_vector_type = gp.type_("vector", modifier=384) + + +@gp.create_function +def _generate_embedding(content: str, model_name: str) -> _vector_type: + import sys + + SD = globals().get("SD", sys.modules["plpy"]._SD) + if "model" not in SD: + from sentence_transformers import SentenceTransformer + + model = SentenceTransformer(model_name) + SD["model"] = model + else: + model = SD["model"] + + # Sentences are encoded by calling model.encode() + emb = model.encode(content, normalize_embeddings=True) + return emb.tolist() + + +class Embedding: + def __init__(self, dataframe: gp.DataFrame) -> None: + self._dataframe = dataframe + + def create_index(self, column: str, model: str) -> gp.DataFrame: + """ + Generate embeddings and create index for a column of unstructured data. + This include + - texts, + - images, or + - videos, etc. + + This enables searching unstructured data based on semantic similarity, + That is, whether they mean or contain similar things. + + For better efficiency, the generated embeddings is stored in a + column-oriented approach, i.e., separated from the input DataFrame. The + input DataFrame must have a **unique key** to identify the tuples in the + search results. + """ + + assert self._dataframe.unique_key is not None, "Unique key is required to create index." + + embedding_col_name = "_emb_" + uuid4().hex + embedding_df_cols = list(self._dataframe.unique_key) + [embedding_col_name] + embedding_df: gp.DataFrame = ( + self._dataframe.assign( + **{ + embedding_col_name: lambda t: _vector_type( + _generate_embedding(t[column], model) + ) + }, + )[embedding_df_cols] + .save_as( + column_names=embedding_df_cols, + distribution_key=self._dataframe.unique_key, + distribution_type="hash", + ) + .check_unique(self._dataframe.unique_key) + .create_index(columns={embedding_col_name}, method="ivfflat") + ) + assert self._dataframe._db is not None + self._dataframe._db._execute( + f""" + DO $$ + BEGIN + SET LOCAL allow_system_table_mods TO ON; + WITH embedding_info AS ( + SELECT attrelid, attnum, `{model}` AS model + FROM pg_attribute + WHERE + attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND + attname = '{column}' + ), add_option AS ( + UPDATE pg_class + FROM embedding_info + SET reloptions = array_append( + reloptions, + format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info)) + ) + ), add_dependency AS ( + INSERT INTO pg_depend + SELECT + 'pg_class'::regclass::oid AS classid, + '{embedding_df._qualified_table_name}'::regclass::oid AS objid, + 0::oid AS objsubid, + 'pg_class'::regclass::oid AS refclassid, + embedding_info.attrelid AS refobjid, + embedding_info.attnum AS refobjsubid + FROM embedding_info + RETURNING * + ) + SELECT * FROM add_dependency; + END; + $$; + """ + ) + return self._dataframe + + def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: + assert self._dataframe._db is not None + embdedding_info = self._dataframe._db._execute( + f""" + WITH embedding_oid AS ( + SELECT attrelid, attnum + FROM pg_attribute + WHERE + attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND + attname = '{column}' + ), reloptions AS ( + SELECT unnest(reloptions) AS option + FROM pg_class, embedding_oid + WEHRE oid = attrelid + ), embedding_info AS ( + SELECT split_part(option, '=', 2)::jsonb AS info + FROM reloptions, embedding_oid + WHERE option LIKE format('_pygp_emb%s=%%', attnum) + ) embedding_table_qualified_name AS ( + SELECT nspname, relname, embedding.info->'model' AS model + FROM embedding_table, pg_class, pg_namespace + WHERE + pg_class.oid = embedding.info->'attrelid' AND + relnamespace = pg_namespace.oid + ) + SELECT * FROM embedding_table_qualified_name + """ + ) + assert isinstance(embdedding_info, abc.Mapping[str, Any]) + embedding_table_name = None + for row in embdedding_info: + embedding_table_name = f'"{row["nspname"]}"."{row["relname"]}"' + model = row["model"] + break + assert embedding_table_name is not None + embedding_df = self._dataframe._db.create_dataframe(embedding_table_name) + assert self._dataframe.unique_key is not None + distance = gp.operator("<#>") + return self._dataframe.join( + embedding_df.assign( + distance=lambda t: distance(t["_emb_"], _generate_embedding(query, model)) + ).order_by("distance")[:top_k], + how="inner", + on=self._dataframe.unique_key, + self_columns={"*"}, + other_columns={}, + ) + + +def _embedding(dataframe: gp.DataFrame) -> Embedding: + return Embedding(dataframe=dataframe) + + +setattr(gp.DataFrame, "embedding", _embedding) diff --git a/greenplumpython/type.py b/greenplumpython/type.py index d3ea1d32..cb3f4bdd 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -135,6 +135,9 @@ def _create_in_db(self, db: Database): """ if self._created_in_dbs is None or db in self._created_in_dbs: return + assert isinstance( + self._annotation, type + ), "Only composite data types can be created in database." schema = "pg_temp" members = get_type_hints(self._annotation) if len(members) == 0: @@ -183,7 +186,7 @@ def _qualified_name(self) -> Tuple[Optional[str], str]: bytes: Type(name="bytea"), } - +# FIXME: Change to data_type() to make it more clear. def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = None) -> Type: """ Get access to a type predefined in database. @@ -219,6 +222,8 @@ def to_pg_type( Returns: str: name of type in SQL """ + if isinstance(annotation, Type): + return annotation._qualified_name_str if annotation is not None and hasattr(annotation, "__origin__"): # The `or` here is to make the function work on Python 3.6. # Python 3.6 is the default Python version on CentOS 7 and Ubuntu 18.04 From 3c64cc134da6cc09f08f128324f9e6aa23751aa7 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Tue, 25 Jul 2023 22:16:10 -0400 Subject: [PATCH 02/19] Fix issues in SQL --- doc/source/notebooks/embedding.ipynb | 225 +++++++++++++++++++++- greenplumpython/experimental/embedding.py | 96 ++++----- greenplumpython/type.py | 4 +- 3 files changed, 271 insertions(+), 54 deletions(-) diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index 12353331..b112239a 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -2,17 +2,39 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/gpadmin/GreenplumPython\n", + "Defaulting to user installation because normal site-packages is not writeable\n", + "Processing /home/gpadmin/GreenplumPython\n", + " Installing build dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", + "\u001b[?25h Preparing wheel metadata ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied, skipping upgrade: dill==0.3.6 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\n", + "Requirement already satisfied, skipping upgrade: psycopg2-binary==2.9.5 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (2.9.5)\n", + "Building wheels for collected packages: greenplum-python\n", + " Building wheel for greenplum-python (PEP 517) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=70509 sha256=995d00c1fdf47e7721a42c1f1f1e0ffa3af7b02ea1403620ee8c2cebdacf69c6\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-34hu1ytc/wheels/bb/1f/99/ff8594e48ec11df99af6e0ee8611a5e560e9f44d1a3fefb351\n", + "Successfully built greenplum-python\n", + "Installing collected packages: greenplum-python\n", + "Successfully installed greenplum-python-1.0.1\n" + ] + } + ], "source": [ - "%cd $(find ~ -name GreenplumPython)\n", + "%cd ../../../\n", "!python3 -m pip install --upgrade ." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -26,9 +48,147 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " CREATE TABLE \"cte_7721c0fd66c041c4b7c5ff4074b3479b\"\n", + " (id,content)\n", + " \n", + " AS SELECT unnest(ARRAY[0,1]) AS \"id\",unnest(ARRAY['I have a dog.','I like eating apples.']) AS \"content\"\n", + " \n", + " DISTRIBUTED BY (id)\n", + " \n", + " \n", + "WITH cte_d0f82a0d4eee4d86a93fa03651ebcb8a AS (TABLE \"cte_7721c0fd66c041c4b7c5ff4074b3479b\")SELECT to_json(cte_1fd18725c25a4c9a877fd6a0aa180e0a)::TEXT FROM cte_d0f82a0d4eee4d86a93fa03651ebcb8a AS cte_1fd18725c25a4c9a877fd6a0aa180e0a\n", + "----------------------------\n", + " id | content \n", + "----+-----------------------\n", + " 0 | I have a dog. \n", + " 1 | I like eating apples. \n", + "----------------------------\n", + "(2 rows)\n", + "\n", + "CREATE UNIQUE INDEX ON \"cte_7721c0fd66c041c4b7c5ff4074b3479b\" (id)\n", + "CREATE FUNCTION \"pg_temp\".\"func_54569d267ed6412e9026f55a5bf7601b\" (content \"text\",model_name \"text\") RETURNS \"vector\"(384) AS $$\n", + "try:\n", + " return GD['__func_54569d267ed6412e9026f55a5bf7601b'](content=content,model_name=model_name)\n", + "except KeyError:\n", + " try:\n", + " import dill as __lib_64f0f94e2fe74186b2d943881574a343\n", + " import sysconfig as __lib_605eb3b464e84c5999150fadf59627c7\n", + " import base64 as __lib_bddbb4833d4c4ea7879fcdffa7ff8270\n", + " import sys as __lib_6f40e80834a64c8d91808b5226eda448\n", + " if __lib_605eb3b464e84c5999150fadf59627c7.get_python_version() != '3.9':\n", + " raise ModuleNotFoundError\n", + " setattr(__lib_6f40e80834a64c8d91808b5226eda448.modules['plpy'], '_SD', SD)\n", + " GD['__func_54569d267ed6412e9026f55a5bf7601b'] = __lib_64f0f94e2fe74186b2d943881574a343.loads(__lib_bddbb4833d4c4ea7879fcdffa7ff8270.b64decode(b'gASVNgMAAAAAAACMCmRpbGwuX2RpbGyUjBBfY3JlYXRlX2Z1bmN0aW9ulJOUKGgAjAxfY3JlYXRlX2NvZGWUk5QoSwJLAEsASwdLBUtDQ2JkAWQAbAB9AnQBgwCgAmQCfAJqA2QDGQBqBKECfQNkBHwDdgFyRGQBZAVsBW0GfQQBAHwEfAGDAX0FfAV8A2QEPABuCHwDZAQZAH0FfAVqB3wAZAZkB40CfQZ8BqAIoQBTAJQoTksAjAJTRJSMBHBscHmUjAVtb2RlbJSME1NlbnRlbmNlVHJhbnNmb3JtZXKUhZSIjBRub3JtYWxpemVfZW1iZWRkaW5nc5SFlHSUKIwDc3lzlIwHZ2xvYmFsc5SMA2dldJSMB21vZHVsZXOUjANfU0SUjBVzZW50ZW5jZV90cmFuc2Zvcm1lcnOUaAmMBmVuY29kZZSMBnRvbGlzdJR0lCiMB2NvbnRlbnSUjAptb2RlbF9uYW1llGgOaAZoCWgIjANlbWKUdJSMRy9ob21lL2dwYWRtaW4vR3JlZW5wbHVtUHl0aG9uL2dyZWVucGx1bXB5dGhvbi9leHBlcmltZW50YWwvZW1iZWRkaW5nLnB5lIwTX2dlbmVyYXRlX2VtYmVkZGluZ5RLCkMSAAIIAhYBCAEMAggBCgIIAw4BlCkpdJRSlH2UjAhfX25hbWVfX5SMJmdyZWVucGx1bXB5dGhvbi5leHBlcmltZW50YWwuZW1iZWRkaW5nlHNoHE5OdJRSlH2UfZSMD19fYW5ub3RhdGlvbnNfX5R9lChoF2gAjApfbG9hZF90eXBllJOUjANzdHKUhZRSlGgYaC2MBnJldHVybpSMFGdyZWVucGx1bXB5dGhvbi50eXBllIwEVHlwZZSTlCmBlH2UKIwFX25hbWWUjAZ2ZWN0b3KUjAtfYW5ub3RhdGlvbpROjA9fY3JlYXRlZF9pbl9kYnOUTowHX3NjaGVtYZROjAlfbW9kaWZpZXKUTYABjBNfcXVhbGlmaWVkX25hbWVfc3RylIwNInZlY3RvciIoMzg0KZR1YnVzhpRiaCCMB2dsb2JhbHOUjAhidWlsdGluc5SMB2dsb2JhbHOUk5RzMC4='))\n", + " except ModuleNotFoundError:\n", + " exec(\"def __func_54569d267ed6412e9026f55a5bf7601b(content, model_name):\\n import sys\\n SD = globals().get('SD', sys.modules['plpy']._SD)\\n if 'model' not in SD:\\n from sentence_transformers import SentenceTransformer\\n model = SentenceTransformer(model_name)\\n SD['model'] = model\\n else:\\n model = SD['model']\\n emb = model.encode(content, normalize_embeddings=True)\\n return emb.tolist()\", globals())\n", + " GD['__func_54569d267ed6412e9026f55a5bf7601b'] = globals()['__func_54569d267ed6412e9026f55a5bf7601b']\n", + " return GD['__func_54569d267ed6412e9026f55a5bf7601b'](content=content,model_name=model_name)\n", + "$$ LANGUAGE plpython3u;\n", + "\n", + " CREATE TABLE \"cte_25a74c5b96e64b4183b4b55256867459\"\n", + " (id,_emb_23dd83d7748d40cd8eda21e5f2129629)\n", + " \n", + " AS WITH cte_d0f82a0d4eee4d86a93fa03651ebcb8a AS (TABLE \"cte_7721c0fd66c041c4b7c5ff4074b3479b\"),cte_6a96d6faca4846f6ab4870e8dfc80e29 AS (SELECT *, (\"pg_temp\".\"func_54569d267ed6412e9026f55a5bf7601b\"( cte_d0f82a0d4eee4d86a93fa03651ebcb8a.\"content\",'all-MiniLM-L6-v2')::\"vector\"(384)) AS _emb_23dd83d7748d40cd8eda21e5f2129629 FROM cte_d0f82a0d4eee4d86a93fa03651ebcb8a)\n", + " SELECT cte_6a96d6faca4846f6ab4870e8dfc80e29.\"id\",cte_6a96d6faca4846f6ab4870e8dfc80e29.\"_emb_23dd83d7748d40cd8eda21e5f2129629\"\n", + " FROM cte_6a96d6faca4846f6ab4870e8dfc80e29\n", + " \n", + " \n", + " DISTRIBUTED BY (id)\n", + " \n", + " \n", + "WITH cte_14c1fce52f0e4af08972a2eb07df02a9 AS (TABLE \"cte_25a74c5b96e64b4183b4b55256867459\")SELECT to_json(cte_baadcbf448304ef7a6693f98b217a984)::TEXT FROM cte_14c1fce52f0e4af08972a2eb07df02a9 AS cte_baadcbf448304ef7a6693f98b217a984\n", + "-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", + " id | _emb_23dd83d7748d40cd8eda21e5f2129629 \n", + "----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", + " 0 | [-0.03659846,-0.012087755,0.08805456,0.061151367,-0.043457735,-0.015592895,0.07047544,-0.002039735,0.082576185,-0.027372142,0.0441429,-0.03269936,0.013636178,0.04161656,0.010410312,-0.0015930303,-0.06705982,-0.044098374,-0.0057846555,-0.06406435,-0.06587656,0.07500747,0.012162317,-0.005788625,-0.109905936,0.027304199,-0.039163478,-0.05016219,0.0029828846,-0.038391765,-0.015229778,-0.055909287,-0.011802607,-0.0048770546,-0.042732246,-0.041694522,0.0065653264,-0.013692718,0.103015944,0.08045541,0.047172282,0.014515043,0.06301976,-0.008371227,-0.00376406,0.037010953,-0.08730184,-0.019860007,0.116164975,-0.009175166,-0.02942207,0.05760907,-0.017986163,0.030363861,-0.018659858,-0.023920199,0.0075364825,0.030293372,-0.0017754125,-0.027292343,0.010452809,0.06776974,0.009428498,0.047237813,0.00020278565,0.02074401,-0.061770048,0.06334932,-0.06630364,0.055175032,0.03636048,0.033629246,0.0447184,0.09541598,-0.035489805,-0.10748742,0.06386299,-0.030471267,0.18042347,0.07920897,-0.08705959,-0.06174667,-0.042777605,0.04772073,0.04045522,0.011489479,0.07283991,0.06658613,-0.117522456,0.011569888,-0.022578657,-0.049202647,-0.03411386,0.01763423,-0.0032649997,-0.010033316,-0.022944551,-0.033948276,-0.021662703,0.089601316,0.008443466,0.028806083,0.07188612,0.045687683,0.09596068,0.023099585,-0.09392849,0.060846657,-0.010293214,0.0019619083,-0.012627958,0.009032658,-0.023953663,0.10200912,0.047290858,0.045499124,-0.07541447,-0.024221145,0.06080323,-0.09191942,0.011989021,0.021896897,-0.044340197,0.02122626,0.019848485,-0.058525886,0.03497772,-9.026044e-33,-0.0053451294,-0.0290533,0.014672238,0.04659998,-0.028272917,0.013217353,-0.038185634,0.030172182,-0.05259568,-0.016775027,0.0034630536,0.00057961704,-0.020373803,-0.034381036,-0.0033685626,0.0013990739,0.051134076,0.01848566,0.08034311,-0.00014362897,-0.013998878,-0.021286957,0.03914335,0.017298121,-0.017837863,-0.012515484,-0.013980058,-0.083431505,-0.02655956,0.024582446,0.028264284,0.020893436,0.045632925,-0.041542996,-0.105518915,-0.03636643,-0.05349343,-0.05543646,-0.043980815,0.052545346,0.08640961,-0.0042671263,0.017281698,-0.0003456163,0.0046999496,-0.034805853,0.008263829,0.020119112,-0.09260091,0.01470312,0.011787518,-0.03307292,0.0042901468,-0.08931934,-0.029248364,-0.041016944,0.059762184,-0.009189991,0.019669637,0.08591937,0.022527453,0.0075523625,-0.030852512,0.029306179,0.051727347,-0.090517536,-0.09521753,-0.041740306,-0.0011757809,0.014292587,-0.024682235,-0.0035219707,0.0077362237,-0.017399674,0.07142882,-0.0123587465,-0.00534226,-0.003308827,-0.01875911,-0.07966144,0.019006351,0.0018609086,0.00706818,0.057706404,0.07751448,0.059841618,-0.029955173,-0.0058063352,-0.023169437,0.0026582994,-0.065715685,-0.043993074,0.03394865,-0.027996289,0.04052676,5.238207e-33,0.010241467,0.03607309,0.046909,0.013635992,-0.005335444,0.0016521218,-0.020371608,0.04564495,-0.082175665,0.06402261,-0.001709206,0.044672277,0.10069537,0.00045673744,0.062299304,0.03769347,-0.039460346,-0.019606683,0.050265815,-0.05616924,-0.18455045,0.08040064,0.07426137,0.019323843,-0.026447829,0.040501535,-0.019648906,-0.02372921,-0.058951914,-0.0853744,-0.045682464,-0.12889871,-0.055900462,-0.068548314,-0.0058031273,0.066947535,-0.023167405,-0.1457526,-0.0123237185,-0.059538133,0.036701616,-0.0021032416,0.04832922,0.078937754,0.014486305,0.029141134,0.014654051,-0.06743171,0.00976345,0.033080045,-0.026131311,-0.008976268,-0.028050678,-0.062519975,-0.0033331455,-0.014157539,-0.07179509,-0.067832775,0.014238785,0.008521254,-0.031684905,0.09964349,-0.05202337,0.13799058,-0.019717641,-0.0868198,-0.0071095424,-0.0557247,0.011921498,-0.07336916,-0.0079654865,0.07029794,-0.031166447,-0.055607356,0.0108316,0.04010841,0.051589135,-0.0015768349,0.03786852,0.015498465,-0.06851167,-0.04085385,0.009224494,-0.010765805,-0.001525135,-0.03769954,-0.00508086,0.05028555,-0.0018060899,0.047179505,-0.032873716,0.0786257,0.0219288,-0.055561442,0.0068103974,-1.6011152e-08,-0.047843613,-0.0016648023,-0.0019612245,-0.0025547266,0.05134095,0.035634715,0.0084129,-0.06416773,-0.03193827,-0.019677935,0.03140499,-0.0173519,-0.043358684,0.02033876,0.10461025,0.025110237,0.017567858,8.451519e-06,0.034815624,0.1194926,-0.071207054,0.014109294,0.079820834,-0.006870605,-0.0052823867,-0.029617261,0.073567234,0.06555545,-0.09733238,0.06841361,-0.032084044,0.10998643,-0.031699374,0.018973608,0.02462254,-0.069597505,0.070999734,-0.050207775,0.044230375,0.021497803,0.05741905,0.12532367,-0.08883319,-0.01811394,0.0011768066,0.06459078,-0.0014821606,-0.09094167,-0.0075864797,-0.00019054905,-0.124157004,-0.064882055,0.09381429,0.051018275,-0.020306546,-0.004231254,-0.018098317,-0.07439528,0.056705363,0.036972076,0.03879501,0.044584196,-0.080352895,-0.030577179] \n", + " 1 | [0.021809116,-0.0155318845,0.011607787,0.08773645,-0.060896672,-0.035311002,0.1109756,-0.05388055,0.015478594,0.025643239,0.034682155,-0.09349968,0.018253846,0.003201303,0.043405153,-0.037074342,0.088959046,-0.0040923767,-0.010021047,0.005995185,-0.078318,0.066143945,0.042326793,-0.027101,0.017702201,0.04703828,0.06959306,-0.037545238,-0.08466894,-0.0149313845,-0.05919541,2.3302038e-05,0.013309361,0.012327695,-0.054391142,0.008196508,0.14044063,-0.07974372,-0.041333504,-0.02224858,0.01838698,0.066759095,0.060005356,0.040904347,-0.057686333,-0.008572934,-0.0006931544,-0.017934252,0.09348519,0.04610809,0.042312067,0.0042564836,-0.035399742,-0.031868268,0.05509771,0.03063401,0.017477227,0.007607798,0.002851456,-0.00848902,0.070586,-0.06596947,-0.003001832,0.017515425,0.03681233,-0.05101503,-0.05168192,-0.007240675,-0.056723353,-0.0003316012,-0.016689977,0.05097667,0.09232242,0.048701957,-0.0233264,0.014426018,0.09440483,-0.08410635,-0.065320976,0.010295285,-0.06000792,-0.0066203084,0.018760884,0.006218706,-0.016821042,-0.051536806,-0.019194003,0.019247968,-0.055921093,0.0744291,0.0011268512,-0.018572511,-0.033866387,0.04826349,0.0018755798,0.02145841,0.026700653,-0.07195235,-0.035215978,0.09375797,0.009641733,0.03153927,-0.006521103,0.059988238,0.02907713,0.006436106,-0.1688827,-0.0121928835,0.00831767,-0.0010369162,0.020289466,-0.015101374,-0.036400627,-0.0053182486,0.016343202,0.04836314,0.052492004,0.0022888575,0.013867832,-0.011067111,-0.0063246978,0.08962686,-0.05633277,-5.0772534e-05,0.0003743617,-0.043979187,0.030548064,-6.112132e-33,-0.100447245,-0.047969893,0.050677996,-0.031848874,0.017650874,0.00557819,0.035132963,0.09510477,0.09157566,-0.02606478,-0.0059387456,-0.023844875,-0.037891146,-0.0062694843,0.024072707,-0.06319935,-0.025684582,0.07265957,-0.04208773,-0.014134076,-0.017349897,-0.092400596,-0.006409122,0.09291194,-0.027069112,-0.08738224,0.042585004,-0.12305705,0.062073916,0.01713978,0.043850727,-0.0055547897,-0.035159733,-0.05796307,-0.0016850628,-0.029315371,0.0721106,0.04989424,-0.028748112,0.0011031058,-0.0070465477,0.02051565,0.0671912,0.021492135,0.06486443,0.006083919,0.025401652,0.07397291,-0.030965947,-0.007620959,-0.045778204,-0.048278432,0.09053187,0.032227647,-0.015725302,-0.010724716,0.013521858,-0.036038376,-0.092461206,0.013104383,-0.078536704,0.049683243,0.0088001145,-0.007872616,-0.11311237,0.11412768,-0.035817996,-0.047303315,0.014969718,0.023965022,-0.042791124,0.03148287,-0.022683943,0.0005804951,-0.11246335,-0.09787001,0.04521075,-0.031591777,-0.055069365,-0.023562724,0.052014757,-0.0024513777,0.0039027003,-0.010034752,0.033652805,0.122117504,-0.067184255,-0.0667508,0.1081975,-0.015414982,0.00400915,0.021052312,0.016455468,0.019499231,-0.12814389,5.5857056e-33,-0.0018572047,-0.080793984,-0.013305316,0.018411051,-0.037682965,-0.06759447,-0.08707167,0.013579468,-0.02803436,-0.03244555,-0.026130449,-0.0068652094,-0.02230584,-0.016416704,0.023153821,0.024428546,-0.011959954,0.09368941,-0.0325776,0.026465558,-0.046098772,0.008481753,-0.006716867,0.019120444,0.016167276,-0.023132937,-0.0042774444,0.04393394,-0.018111901,0.059962098,0.051095933,-0.07903501,-0.059705775,-0.13360043,0.04902078,0.03544222,-0.093780324,-0.056613877,-0.0022577501,0.03077088,0.015449573,0.0032539142,0.031303108,0.11281747,0.036288813,0.093467966,0.03139063,0.058778938,0.02215492,0.05777489,0.0009719842,-0.026091043,-0.06628838,0.015047404,0.03955509,0.05236228,0.0069718575,0.0009398838,-0.03959814,-0.075498044,-0.10264736,0.06432404,0.018766893,0.0139612565,0.060313296,-0.02941947,-0.030336095,-0.05356687,-0.07672768,0.012401418,-0.009276501,-0.054574188,-0.056601916,-0.024081068,-0.0397901,-0.035410695,0.01184497,0.036265045,-0.08490439,0.05896337,-0.030408576,0.10739633,0.0100452835,0.06581673,0.04995253,0.056139104,-0.018259417,0.023479536,-0.04595968,0.038907755,-0.005904816,-0.015094102,0.013457788,-0.039148435,0.01151064,-1.5212933e-08,-0.045827802,-0.029699314,0.03503022,-0.01087893,-0.0031904539,0.07422464,-0.07662781,0.054133236,0.021378785,-0.040636804,0.062867165,0.085515775,-0.08906479,0.056114767,0.048328143,0.008293789,0.08469364,-0.027762378,-0.015386821,0.06791649,-0.09377292,0.018911818,-0.013141011,0.04376481,-0.018527081,0.021828363,0.0024259256,0.020919835,0.10574034,0.063920565,0.05623135,0.053664792,-0.0830025,0.06855377,-0.005921287,-0.0768514,0.010081457,-0.01137772,-0.012504763,-0.10047469,-0.049601573,-0.002936133,0.015598559,-0.04278624,-0.0998226,0.02282328,0.06384441,0.011207105,0.0207268,0.08571721,0.04142781,0.026192738,0.096607804,0.08237023,0.03691293,-0.014799454,0.04348573,-0.07760758,0.01575173,0.078169346,0.11799159,0.05871559,0.021846,-0.016581286] \n", + "-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", + "(2 rows)\n", + "\n", + "CREATE UNIQUE INDEX ON \"cte_25a74c5b96e64b4183b4b55256867459\" (id)\n", + "CREATE INDEX \"idx_f820481a5497411a9b74d29d2fe98ea5\" ON \"cte_25a74c5b96e64b4183b4b55256867459\" USING \"ivfflat\" ( \"_emb_23dd83d7748d40cd8eda21e5f2129629\")\n", + "\n", + " DO $$\n", + " BEGIN\n", + " SET LOCAL allow_system_table_mods TO ON;\n", + "\n", + " WITH embedding_info AS (\n", + " SELECT attrelid, attnum, 'all-MiniLM-L6-v2' AS model\n", + " FROM pg_attribute\n", + " WHERE \n", + " attrelid = '\"cte_7721c0fd66c041c4b7c5ff4074b3479b\"'::regclass::oid AND\n", + " attname = 'content'\n", + " )\n", + " UPDATE pg_class\n", + " SET reloptions = array_append(\n", + " reloptions, \n", + " format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info))\n", + " )\n", + " FROM embedding_info;\n", + "\n", + " WITH embedding_info AS (\n", + " SELECT attrelid, attnum, 'all-MiniLM-L6-v2' AS model\n", + " FROM pg_attribute\n", + " WHERE \n", + " attrelid = '\"cte_7721c0fd66c041c4b7c5ff4074b3479b\"'::regclass::oid AND\n", + " attname = 'content'\n", + " )\n", + " INSERT INTO pg_depend\n", + " SELECT\n", + " 'pg_class'::regclass::oid AS classid,\n", + " '\"cte_25a74c5b96e64b4183b4b55256867459\"'::regclass::oid AS objid,\n", + " 0::oid AS objsubid,\n", + " 'pg_class'::regclass::oid AS refclassid,\n", + " embedding_info.attrelid AS refobjid,\n", + " embedding_info.attnum AS refobjsubid,\n", + " 'n' AS deptype\n", + " FROM embedding_info;\n", + " END;\n", + " $$;\n", + " \n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\t\n", + "\t\t\n", + "\t\t\n", + "\t\n", + "\t\n", + "\t\t\n", + "\t\t\n", + "\t\n", + "\t\n", + "\t\t\n", + "\t\t\n", + "\t\n", + "
idcontent
0I have a dog.
1I like eating apples.
" + ], + "text/plain": [ + "----------------------------\n", + " id | content \n", + "----+-----------------------\n", + " 0 | I have a dog. \n", + " 1 | I like eating apples. \n", + "----------------------------\n", + "(2 rows)" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import greenplumpython.experimental.embedding\n", "\n", @@ -43,9 +203,58 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " WITH embedding_oid AS (\n", + " SELECT attrelid, attnum\n", + " FROM pg_attribute\n", + " WHERE \n", + " attrelid = '\"cte_7721c0fd66c041c4b7c5ff4074b3479b\"'::regclass::oid AND\n", + " attname = 'content'\n", + " ), reloptions AS (\n", + " SELECT unnest(reloptions) AS option\n", + " FROM pg_class, embedding_oid\n", + " WHERE pg_class.oid = attrelid\n", + " ), embedding_info_json AS (\n", + " SELECT split_part(option, '=', 2)::json AS val\n", + " FROM reloptions, embedding_oid\n", + " WHERE option LIKE format('_pygp_emb_%s=%%', attnum)\n", + " ), embedding_info AS (\n", + " SELECT * \n", + " FROM embedding_info_json, json_to_record(val) AS (attnum int4, attrelid oid, model text)\n", + " )\n", + " SELECT nspname, relname, model\n", + " FROM embedding_info, pg_class, pg_namespace\n", + " WHERE \n", + " pg_class.oid = attrelid AND\n", + " relnamespace = pg_namespace.oid;\n", + " \n" + ] + }, + { + "ename": "AssertionError", + "evalue": "Database is required to create function.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m t\u001b[39m.\u001b[39;49membedding()\u001b[39m.\u001b[39;49msearch(column\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mcontent\u001b[39;49m\u001b[39m\"\u001b[39;49m, query\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mapple\u001b[39;49m\u001b[39m\"\u001b[39;49m, top_k\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m)\n", + "File \u001b[0;32m~/GreenplumPython/greenplumpython/experimental/embedding.py:154\u001b[0m, in \u001b[0;36mEmbedding.search\u001b[0;34m(self, column, query, top_k)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataframe\u001b[39m.\u001b[39munique_key \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 152\u001b[0m distance \u001b[39m=\u001b[39m gp\u001b[39m.\u001b[39moperator(\u001b[39m\"\u001b[39m\u001b[39m<#>\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 153\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataframe\u001b[39m.\u001b[39mjoin(\n\u001b[0;32m--> 154\u001b[0m embedding_df\u001b[39m.\u001b[39;49massign(\n\u001b[1;32m 155\u001b[0m distance\u001b[39m=\u001b[39;49m\u001b[39mlambda\u001b[39;49;00m t: distance(t[\u001b[39m\"\u001b[39;49m\u001b[39m_emb_\u001b[39;49m\u001b[39m\"\u001b[39;49m], _generate_embedding(query, model))\n\u001b[1;32m 156\u001b[0m )\u001b[39m.\u001b[39morder_by(\u001b[39m\"\u001b[39m\u001b[39mdistance\u001b[39m\u001b[39m\"\u001b[39m)[:top_k],\n\u001b[1;32m 157\u001b[0m how\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39minner\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 158\u001b[0m on\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataframe\u001b[39m.\u001b[39munique_key,\n\u001b[1;32m 159\u001b[0m self_columns\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39m*\u001b[39m\u001b[39m\"\u001b[39m},\n\u001b[1;32m 160\u001b[0m other_columns\u001b[39m=\u001b[39m{},\n\u001b[1;32m 161\u001b[0m )\n", + "File \u001b[0;32m~/GreenplumPython/greenplumpython/dataframe.py:492\u001b[0m, in \u001b[0;36mDataFrame.assign\u001b[0;34m(self, **new_columns)\u001b[0m\n\u001b[1;32m 490\u001b[0m other_parents[v\u001b[39m.\u001b[39m_other_dataframe\u001b[39m.\u001b[39m_name] \u001b[39m=\u001b[39m v\u001b[39m.\u001b[39m_other_dataframe\n\u001b[1;32m 491\u001b[0m v \u001b[39m=\u001b[39m v\u001b[39m.\u001b[39m_bind(db\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_db)\n\u001b[0;32m--> 492\u001b[0m targets\u001b[39m.\u001b[39mappend(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m_serialize(v)\u001b[39m}\u001b[39;00m\u001b[39m AS \u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 493\u001b[0m \u001b[39mreturn\u001b[39;00m DataFrame(\n\u001b[1;32m 494\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mSELECT *, \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m,\u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(targets)\u001b[39m}\u001b[39;00m\u001b[39m FROM \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_name\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 495\u001b[0m parents\u001b[39m=\u001b[39m[\u001b[39mself\u001b[39m] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(other_parents\u001b[39m.\u001b[39mvalues()),\n\u001b[1;32m 496\u001b[0m )\n", + "File \u001b[0;32m~/GreenplumPython/greenplumpython/expr.py:561\u001b[0m, in \u001b[0;36m_serialize\u001b[0;34m(value)\u001b[0m\n\u001b[1;32m 551\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 552\u001b[0m \u001b[39m:meta private:\u001b[39;00m\n\u001b[1;32m 553\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[39m in Python 3 and Python 2 is EOL officially.\u001b[39;00m\n\u001b[1;32m 559\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 560\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(value, Expr):\n\u001b[0;32m--> 561\u001b[0m \u001b[39mreturn\u001b[39;00m value\u001b[39m.\u001b[39;49m_serialize()\n\u001b[1;32m 562\u001b[0m \u001b[39mreturn\u001b[39;00m adapt(value)\u001b[39m.\u001b[39mgetquoted()\u001b[39m.\u001b[39mdecode(\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m)\n", + "File \u001b[0;32m~/GreenplumPython/greenplumpython/expr.py:636\u001b[0m, in \u001b[0;36mBinaryExpr._serialize\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 633\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mgreenplumpython\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mexpr\u001b[39;00m \u001b[39mimport\u001b[39;00m _serialize\n\u001b[1;32m 635\u001b[0m left_str \u001b[39m=\u001b[39m _serialize(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_left)\n\u001b[0;32m--> 636\u001b[0m right_str \u001b[39m=\u001b[39m _serialize(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_right)\n\u001b[1;32m 637\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m(\u001b[39m\u001b[39m{\u001b[39;00mleft_str\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_operator\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00mright_str\u001b[39m}\u001b[39;00m\u001b[39m)\u001b[39m\u001b[39m\"\u001b[39m\n", + "File \u001b[0;32m~/GreenplumPython/greenplumpython/expr.py:561\u001b[0m, in \u001b[0;36m_serialize\u001b[0;34m(value)\u001b[0m\n\u001b[1;32m 551\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 552\u001b[0m \u001b[39m:meta private:\u001b[39;00m\n\u001b[1;32m 553\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[39m in Python 3 and Python 2 is EOL officially.\u001b[39;00m\n\u001b[1;32m 559\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 560\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(value, Expr):\n\u001b[0;32m--> 561\u001b[0m \u001b[39mreturn\u001b[39;00m value\u001b[39m.\u001b[39;49m_serialize()\n\u001b[1;32m 562\u001b[0m \u001b[39mreturn\u001b[39;00m adapt(value)\u001b[39m.\u001b[39mgetquoted()\u001b[39m.\u001b[39mdecode(\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m)\n", + "File \u001b[0;32m~/GreenplumPython/greenplumpython/func.py:86\u001b[0m, in \u001b[0;36mFunctionExpr._serialize\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_serialize\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n\u001b[1;32m 84\u001b[0m \u001b[39m# noqa D400\u001b[39;00m\n\u001b[1;32m 85\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\":meta private:\"\"\"\u001b[39;00m\n\u001b[0;32m---> 86\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_db \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m, \u001b[39m\"\u001b[39m\u001b[39mDatabase is required to create function.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 87\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_function\u001b[39m.\u001b[39m_create_in_db(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_db)\n\u001b[1;32m 88\u001b[0m distinct \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mDISTINCT\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_distinct \u001b[39melse\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m\n", + "\u001b[0;31mAssertionError\u001b[0m: Database is required to create function." + ] + } + ], "source": [ "t.embedding().search(column=\"content\", query=\"apple\", top_k=1)" ] diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 8f7e366a..79028ec6 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -72,35 +72,42 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: DO $$ BEGIN SET LOCAL allow_system_table_mods TO ON; + WITH embedding_info AS ( - SELECT attrelid, attnum, `{model}` AS model + SELECT attrelid, attnum, '{model}' AS model + FROM pg_attribute + WHERE + attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND + attname = '{column}' + ) + UPDATE pg_class + SET reloptions = array_append( + reloptions, + format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info)) + ) + FROM embedding_info; + + WITH embedding_info AS ( + SELECT attrelid, attnum, '{model}' AS model FROM pg_attribute WHERE attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND attname = '{column}' - ), add_option AS ( - UPDATE pg_class - FROM embedding_info - SET reloptions = array_append( - reloptions, - format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info)) - ) - ), add_dependency AS ( - INSERT INTO pg_depend - SELECT - 'pg_class'::regclass::oid AS classid, - '{embedding_df._qualified_table_name}'::regclass::oid AS objid, - 0::oid AS objsubid, - 'pg_class'::regclass::oid AS refclassid, - embedding_info.attrelid AS refobjid, - embedding_info.attnum AS refobjsubid - FROM embedding_info - RETURNING * ) - SELECT * FROM add_dependency; + INSERT INTO pg_depend + SELECT + 'pg_class'::regclass::oid AS classid, + '{embedding_df._qualified_table_name}'::regclass::oid AS objid, + 0::oid AS objsubid, + 'pg_class'::regclass::oid AS refclassid, + embedding_info.attrelid AS refobjid, + embedding_info.attnum AS refobjsubid, + 'n' AS deptype + FROM embedding_info; END; $$; - """ + """, + has_results=False ) return self._dataframe @@ -108,31 +115,32 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: assert self._dataframe._db is not None embdedding_info = self._dataframe._db._execute( f""" - WITH embedding_oid AS ( - SELECT attrelid, attnum - FROM pg_attribute - WHERE - attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND - attname = '{column}' - ), reloptions AS ( - SELECT unnest(reloptions) AS option - FROM pg_class, embedding_oid - WEHRE oid = attrelid - ), embedding_info AS ( - SELECT split_part(option, '=', 2)::jsonb AS info - FROM reloptions, embedding_oid - WHERE option LIKE format('_pygp_emb%s=%%', attnum) - ) embedding_table_qualified_name AS ( - SELECT nspname, relname, embedding.info->'model' AS model - FROM embedding_table, pg_class, pg_namespace - WHERE - pg_class.oid = embedding.info->'attrelid' AND - relnamespace = pg_namespace.oid - ) - SELECT * FROM embedding_table_qualified_name + WITH embedding_oid AS ( + SELECT attrelid, attnum + FROM pg_attribute + WHERE + attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND + attname = '{column}' + ), reloptions AS ( + SELECT unnest(reloptions) AS option + FROM pg_class, embedding_oid + WHERE pg_class.oid = attrelid + ), embedding_info_json AS ( + SELECT split_part(option, '=', 2)::json AS val + FROM reloptions, embedding_oid + WHERE option LIKE format('_pygp_emb_%s=%%', attnum) + ), embedding_info AS ( + SELECT * + FROM embedding_info_json, json_to_record(val) AS (attnum int4, attrelid oid, model text) + ) + SELECT nspname, relname, model + FROM embedding_info, pg_class, pg_namespace + WHERE + pg_class.oid = attrelid AND + relnamespace = pg_namespace.oid; """ ) - assert isinstance(embdedding_info, abc.Mapping[str, Any]) + # assert isinstance(embdedding_info, abc.Mapping) embedding_table_name = None for row in embdedding_info: embedding_table_name = f'"{row["nspname"]}"."{row["relname"]}"' diff --git a/greenplumpython/type.py b/greenplumpython/type.py index cb3f4bdd..8956aa3b 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -222,8 +222,6 @@ def to_pg_type( Returns: str: name of type in SQL """ - if isinstance(annotation, Type): - return annotation._qualified_name_str if annotation is not None and hasattr(annotation, "__origin__"): # The `or` here is to make the function work on Python 3.6. # Python 3.6 is the default Python version on CentOS 7 and Ubuntu 18.04 @@ -235,6 +233,8 @@ def to_pg_type( return f"{to_pg_type(args[0], db)}[]" # type: ignore raise NotImplementedError() else: + if isinstance(annotation, Type): + return annotation._qualified_name_str assert db is not None, "Database is required to create type" if annotation not in _defined_types: type_name = "type_" + uuid4().hex From bd35e3171cde050b3c408ae6e56e06ba92f5ac6a Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Tue, 25 Jul 2023 23:56:08 -0400 Subject: [PATCH 03/19] Make the example work --- doc/source/notebooks/embedding.ipynb | 223 ++++++---------------- greenplumpython/dataframe.py | 1 - greenplumpython/experimental/embedding.py | 43 +++-- greenplumpython/expr.py | 5 +- greenplumpython/type.py | 1 + 5 files changed, 90 insertions(+), 183 deletions(-) diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index b112239a..feb4661d 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -1,5 +1,12 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Installing the Package" + ] + }, { "cell_type": "code", "execution_count": 1, @@ -15,12 +22,12 @@ " Installing build dependencies ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", "\u001b[?25h Preparing wheel metadata ... \u001b[?25ldone\n", - "\u001b[?25hRequirement already satisfied, skipping upgrade: dill==0.3.6 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\n", - "Requirement already satisfied, skipping upgrade: psycopg2-binary==2.9.5 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (2.9.5)\n", + "\u001b[?25hRequirement already satisfied, skipping upgrade: psycopg2-binary==2.9.5 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (2.9.5)\n", + "Requirement already satisfied, skipping upgrade: dill==0.3.6 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\n", "Building wheels for collected packages: greenplum-python\n", " Building wheel for greenplum-python (PEP 517) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=70509 sha256=995d00c1fdf47e7721a42c1f1f1e0ffa3af7b02ea1403620ee8c2cebdacf69c6\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-34hu1ytc/wheels/bb/1f/99/ff8594e48ec11df99af6e0ee8611a5e560e9f44d1a3fefb351\n", + "\u001b[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=70614 sha256=4ee428916d3690ae05c591dff7e48e84e41f04d37fba06dfd6bf9543791e4d4f\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-2iw1rkm7/wheels/bb/1f/99/ff8594e48ec11df99af6e0ee8611a5e560e9f44d1a3fefb351\n", "Successfully built greenplum-python\n", "Installing collected packages: greenplum-python\n", "Successfully installed greenplum-python-1.0.1\n" @@ -32,6 +39,13 @@ "!python3 -m pip install --upgrade ." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Preparing Data" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -43,7 +57,18 @@ "import greenplumpython as gp\n", "\n", "db = gp.database(\"postgresql://localhost:7000\")\n", - "t = db.create_dataframe(columns={\"id\": range(len(content)), \"content\": content})" + "t = (\n", + " db.create_dataframe(columns={\"id\": range(len(content)), \"content\": content})\n", + " .save_as(column_names=[\"id\", \"content\"], distribution_key={\"id\"}, distribution_type=\"hash\")\n", + " .check_unique(columns={\"id\"})\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Generating and Indexing Embeddings" ] }, { @@ -51,111 +76,6 @@ "execution_count": 3, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " CREATE TABLE \"cte_7721c0fd66c041c4b7c5ff4074b3479b\"\n", - " (id,content)\n", - " \n", - " AS SELECT unnest(ARRAY[0,1]) AS \"id\",unnest(ARRAY['I have a dog.','I like eating apples.']) AS \"content\"\n", - " \n", - " DISTRIBUTED BY (id)\n", - " \n", - " \n", - "WITH cte_d0f82a0d4eee4d86a93fa03651ebcb8a AS (TABLE \"cte_7721c0fd66c041c4b7c5ff4074b3479b\")SELECT to_json(cte_1fd18725c25a4c9a877fd6a0aa180e0a)::TEXT FROM cte_d0f82a0d4eee4d86a93fa03651ebcb8a AS cte_1fd18725c25a4c9a877fd6a0aa180e0a\n", - "----------------------------\n", - " id | content \n", - "----+-----------------------\n", - " 0 | I have a dog. \n", - " 1 | I like eating apples. \n", - "----------------------------\n", - "(2 rows)\n", - "\n", - "CREATE UNIQUE INDEX ON \"cte_7721c0fd66c041c4b7c5ff4074b3479b\" (id)\n", - "CREATE FUNCTION \"pg_temp\".\"func_54569d267ed6412e9026f55a5bf7601b\" (content \"text\",model_name \"text\") RETURNS \"vector\"(384) AS $$\n", - "try:\n", - " return GD['__func_54569d267ed6412e9026f55a5bf7601b'](content=content,model_name=model_name)\n", - "except KeyError:\n", - " try:\n", - " import dill as __lib_64f0f94e2fe74186b2d943881574a343\n", - " import sysconfig as __lib_605eb3b464e84c5999150fadf59627c7\n", - " import base64 as __lib_bddbb4833d4c4ea7879fcdffa7ff8270\n", - " import sys as __lib_6f40e80834a64c8d91808b5226eda448\n", - " if __lib_605eb3b464e84c5999150fadf59627c7.get_python_version() != '3.9':\n", - " raise ModuleNotFoundError\n", - " setattr(__lib_6f40e80834a64c8d91808b5226eda448.modules['plpy'], '_SD', SD)\n", - " GD['__func_54569d267ed6412e9026f55a5bf7601b'] = __lib_64f0f94e2fe74186b2d943881574a343.loads(__lib_bddbb4833d4c4ea7879fcdffa7ff8270.b64decode(b'gASVNgMAAAAAAACMCmRpbGwuX2RpbGyUjBBfY3JlYXRlX2Z1bmN0aW9ulJOUKGgAjAxfY3JlYXRlX2NvZGWUk5QoSwJLAEsASwdLBUtDQ2JkAWQAbAB9AnQBgwCgAmQCfAJqA2QDGQBqBKECfQNkBHwDdgFyRGQBZAVsBW0GfQQBAHwEfAGDAX0FfAV8A2QEPABuCHwDZAQZAH0FfAVqB3wAZAZkB40CfQZ8BqAIoQBTAJQoTksAjAJTRJSMBHBscHmUjAVtb2RlbJSME1NlbnRlbmNlVHJhbnNmb3JtZXKUhZSIjBRub3JtYWxpemVfZW1iZWRkaW5nc5SFlHSUKIwDc3lzlIwHZ2xvYmFsc5SMA2dldJSMB21vZHVsZXOUjANfU0SUjBVzZW50ZW5jZV90cmFuc2Zvcm1lcnOUaAmMBmVuY29kZZSMBnRvbGlzdJR0lCiMB2NvbnRlbnSUjAptb2RlbF9uYW1llGgOaAZoCWgIjANlbWKUdJSMRy9ob21lL2dwYWRtaW4vR3JlZW5wbHVtUHl0aG9uL2dyZWVucGx1bXB5dGhvbi9leHBlcmltZW50YWwvZW1iZWRkaW5nLnB5lIwTX2dlbmVyYXRlX2VtYmVkZGluZ5RLCkMSAAIIAhYBCAEMAggBCgIIAw4BlCkpdJRSlH2UjAhfX25hbWVfX5SMJmdyZWVucGx1bXB5dGhvbi5leHBlcmltZW50YWwuZW1iZWRkaW5nlHNoHE5OdJRSlH2UfZSMD19fYW5ub3RhdGlvbnNfX5R9lChoF2gAjApfbG9hZF90eXBllJOUjANzdHKUhZRSlGgYaC2MBnJldHVybpSMFGdyZWVucGx1bXB5dGhvbi50eXBllIwEVHlwZZSTlCmBlH2UKIwFX25hbWWUjAZ2ZWN0b3KUjAtfYW5ub3RhdGlvbpROjA9fY3JlYXRlZF9pbl9kYnOUTowHX3NjaGVtYZROjAlfbW9kaWZpZXKUTYABjBNfcXVhbGlmaWVkX25hbWVfc3RylIwNInZlY3RvciIoMzg0KZR1YnVzhpRiaCCMB2dsb2JhbHOUjAhidWlsdGluc5SMB2dsb2JhbHOUk5RzMC4='))\n", - " except ModuleNotFoundError:\n", - " exec(\"def __func_54569d267ed6412e9026f55a5bf7601b(content, model_name):\\n import sys\\n SD = globals().get('SD', sys.modules['plpy']._SD)\\n if 'model' not in SD:\\n from sentence_transformers import SentenceTransformer\\n model = SentenceTransformer(model_name)\\n SD['model'] = model\\n else:\\n model = SD['model']\\n emb = model.encode(content, normalize_embeddings=True)\\n return emb.tolist()\", globals())\n", - " GD['__func_54569d267ed6412e9026f55a5bf7601b'] = globals()['__func_54569d267ed6412e9026f55a5bf7601b']\n", - " return GD['__func_54569d267ed6412e9026f55a5bf7601b'](content=content,model_name=model_name)\n", - "$$ LANGUAGE plpython3u;\n", - "\n", - " CREATE TABLE \"cte_25a74c5b96e64b4183b4b55256867459\"\n", - " (id,_emb_23dd83d7748d40cd8eda21e5f2129629)\n", - " \n", - " AS WITH cte_d0f82a0d4eee4d86a93fa03651ebcb8a AS (TABLE \"cte_7721c0fd66c041c4b7c5ff4074b3479b\"),cte_6a96d6faca4846f6ab4870e8dfc80e29 AS (SELECT *, (\"pg_temp\".\"func_54569d267ed6412e9026f55a5bf7601b\"( cte_d0f82a0d4eee4d86a93fa03651ebcb8a.\"content\",'all-MiniLM-L6-v2')::\"vector\"(384)) AS _emb_23dd83d7748d40cd8eda21e5f2129629 FROM cte_d0f82a0d4eee4d86a93fa03651ebcb8a)\n", - " SELECT cte_6a96d6faca4846f6ab4870e8dfc80e29.\"id\",cte_6a96d6faca4846f6ab4870e8dfc80e29.\"_emb_23dd83d7748d40cd8eda21e5f2129629\"\n", - " FROM cte_6a96d6faca4846f6ab4870e8dfc80e29\n", - " \n", - " \n", - " DISTRIBUTED BY (id)\n", - " \n", - " \n", - "WITH cte_14c1fce52f0e4af08972a2eb07df02a9 AS (TABLE \"cte_25a74c5b96e64b4183b4b55256867459\")SELECT to_json(cte_baadcbf448304ef7a6693f98b217a984)::TEXT FROM cte_14c1fce52f0e4af08972a2eb07df02a9 AS cte_baadcbf448304ef7a6693f98b217a984\n", - "-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - " id | _emb_23dd83d7748d40cd8eda21e5f2129629 \n", - "----+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - " 0 | [-0.03659846,-0.012087755,0.08805456,0.061151367,-0.043457735,-0.015592895,0.07047544,-0.002039735,0.082576185,-0.027372142,0.0441429,-0.03269936,0.013636178,0.04161656,0.010410312,-0.0015930303,-0.06705982,-0.044098374,-0.0057846555,-0.06406435,-0.06587656,0.07500747,0.012162317,-0.005788625,-0.109905936,0.027304199,-0.039163478,-0.05016219,0.0029828846,-0.038391765,-0.015229778,-0.055909287,-0.011802607,-0.0048770546,-0.042732246,-0.041694522,0.0065653264,-0.013692718,0.103015944,0.08045541,0.047172282,0.014515043,0.06301976,-0.008371227,-0.00376406,0.037010953,-0.08730184,-0.019860007,0.116164975,-0.009175166,-0.02942207,0.05760907,-0.017986163,0.030363861,-0.018659858,-0.023920199,0.0075364825,0.030293372,-0.0017754125,-0.027292343,0.010452809,0.06776974,0.009428498,0.047237813,0.00020278565,0.02074401,-0.061770048,0.06334932,-0.06630364,0.055175032,0.03636048,0.033629246,0.0447184,0.09541598,-0.035489805,-0.10748742,0.06386299,-0.030471267,0.18042347,0.07920897,-0.08705959,-0.06174667,-0.042777605,0.04772073,0.04045522,0.011489479,0.07283991,0.06658613,-0.117522456,0.011569888,-0.022578657,-0.049202647,-0.03411386,0.01763423,-0.0032649997,-0.010033316,-0.022944551,-0.033948276,-0.021662703,0.089601316,0.008443466,0.028806083,0.07188612,0.045687683,0.09596068,0.023099585,-0.09392849,0.060846657,-0.010293214,0.0019619083,-0.012627958,0.009032658,-0.023953663,0.10200912,0.047290858,0.045499124,-0.07541447,-0.024221145,0.06080323,-0.09191942,0.011989021,0.021896897,-0.044340197,0.02122626,0.019848485,-0.058525886,0.03497772,-9.026044e-33,-0.0053451294,-0.0290533,0.014672238,0.04659998,-0.028272917,0.013217353,-0.038185634,0.030172182,-0.05259568,-0.016775027,0.0034630536,0.00057961704,-0.020373803,-0.034381036,-0.0033685626,0.0013990739,0.051134076,0.01848566,0.08034311,-0.00014362897,-0.013998878,-0.021286957,0.03914335,0.017298121,-0.017837863,-0.012515484,-0.013980058,-0.083431505,-0.02655956,0.024582446,0.028264284,0.020893436,0.045632925,-0.041542996,-0.105518915,-0.03636643,-0.05349343,-0.05543646,-0.043980815,0.052545346,0.08640961,-0.0042671263,0.017281698,-0.0003456163,0.0046999496,-0.034805853,0.008263829,0.020119112,-0.09260091,0.01470312,0.011787518,-0.03307292,0.0042901468,-0.08931934,-0.029248364,-0.041016944,0.059762184,-0.009189991,0.019669637,0.08591937,0.022527453,0.0075523625,-0.030852512,0.029306179,0.051727347,-0.090517536,-0.09521753,-0.041740306,-0.0011757809,0.014292587,-0.024682235,-0.0035219707,0.0077362237,-0.017399674,0.07142882,-0.0123587465,-0.00534226,-0.003308827,-0.01875911,-0.07966144,0.019006351,0.0018609086,0.00706818,0.057706404,0.07751448,0.059841618,-0.029955173,-0.0058063352,-0.023169437,0.0026582994,-0.065715685,-0.043993074,0.03394865,-0.027996289,0.04052676,5.238207e-33,0.010241467,0.03607309,0.046909,0.013635992,-0.005335444,0.0016521218,-0.020371608,0.04564495,-0.082175665,0.06402261,-0.001709206,0.044672277,0.10069537,0.00045673744,0.062299304,0.03769347,-0.039460346,-0.019606683,0.050265815,-0.05616924,-0.18455045,0.08040064,0.07426137,0.019323843,-0.026447829,0.040501535,-0.019648906,-0.02372921,-0.058951914,-0.0853744,-0.045682464,-0.12889871,-0.055900462,-0.068548314,-0.0058031273,0.066947535,-0.023167405,-0.1457526,-0.0123237185,-0.059538133,0.036701616,-0.0021032416,0.04832922,0.078937754,0.014486305,0.029141134,0.014654051,-0.06743171,0.00976345,0.033080045,-0.026131311,-0.008976268,-0.028050678,-0.062519975,-0.0033331455,-0.014157539,-0.07179509,-0.067832775,0.014238785,0.008521254,-0.031684905,0.09964349,-0.05202337,0.13799058,-0.019717641,-0.0868198,-0.0071095424,-0.0557247,0.011921498,-0.07336916,-0.0079654865,0.07029794,-0.031166447,-0.055607356,0.0108316,0.04010841,0.051589135,-0.0015768349,0.03786852,0.015498465,-0.06851167,-0.04085385,0.009224494,-0.010765805,-0.001525135,-0.03769954,-0.00508086,0.05028555,-0.0018060899,0.047179505,-0.032873716,0.0786257,0.0219288,-0.055561442,0.0068103974,-1.6011152e-08,-0.047843613,-0.0016648023,-0.0019612245,-0.0025547266,0.05134095,0.035634715,0.0084129,-0.06416773,-0.03193827,-0.019677935,0.03140499,-0.0173519,-0.043358684,0.02033876,0.10461025,0.025110237,0.017567858,8.451519e-06,0.034815624,0.1194926,-0.071207054,0.014109294,0.079820834,-0.006870605,-0.0052823867,-0.029617261,0.073567234,0.06555545,-0.09733238,0.06841361,-0.032084044,0.10998643,-0.031699374,0.018973608,0.02462254,-0.069597505,0.070999734,-0.050207775,0.044230375,0.021497803,0.05741905,0.12532367,-0.08883319,-0.01811394,0.0011768066,0.06459078,-0.0014821606,-0.09094167,-0.0075864797,-0.00019054905,-0.124157004,-0.064882055,0.09381429,0.051018275,-0.020306546,-0.004231254,-0.018098317,-0.07439528,0.056705363,0.036972076,0.03879501,0.044584196,-0.080352895,-0.030577179] \n", - " 1 | [0.021809116,-0.0155318845,0.011607787,0.08773645,-0.060896672,-0.035311002,0.1109756,-0.05388055,0.015478594,0.025643239,0.034682155,-0.09349968,0.018253846,0.003201303,0.043405153,-0.037074342,0.088959046,-0.0040923767,-0.010021047,0.005995185,-0.078318,0.066143945,0.042326793,-0.027101,0.017702201,0.04703828,0.06959306,-0.037545238,-0.08466894,-0.0149313845,-0.05919541,2.3302038e-05,0.013309361,0.012327695,-0.054391142,0.008196508,0.14044063,-0.07974372,-0.041333504,-0.02224858,0.01838698,0.066759095,0.060005356,0.040904347,-0.057686333,-0.008572934,-0.0006931544,-0.017934252,0.09348519,0.04610809,0.042312067,0.0042564836,-0.035399742,-0.031868268,0.05509771,0.03063401,0.017477227,0.007607798,0.002851456,-0.00848902,0.070586,-0.06596947,-0.003001832,0.017515425,0.03681233,-0.05101503,-0.05168192,-0.007240675,-0.056723353,-0.0003316012,-0.016689977,0.05097667,0.09232242,0.048701957,-0.0233264,0.014426018,0.09440483,-0.08410635,-0.065320976,0.010295285,-0.06000792,-0.0066203084,0.018760884,0.006218706,-0.016821042,-0.051536806,-0.019194003,0.019247968,-0.055921093,0.0744291,0.0011268512,-0.018572511,-0.033866387,0.04826349,0.0018755798,0.02145841,0.026700653,-0.07195235,-0.035215978,0.09375797,0.009641733,0.03153927,-0.006521103,0.059988238,0.02907713,0.006436106,-0.1688827,-0.0121928835,0.00831767,-0.0010369162,0.020289466,-0.015101374,-0.036400627,-0.0053182486,0.016343202,0.04836314,0.052492004,0.0022888575,0.013867832,-0.011067111,-0.0063246978,0.08962686,-0.05633277,-5.0772534e-05,0.0003743617,-0.043979187,0.030548064,-6.112132e-33,-0.100447245,-0.047969893,0.050677996,-0.031848874,0.017650874,0.00557819,0.035132963,0.09510477,0.09157566,-0.02606478,-0.0059387456,-0.023844875,-0.037891146,-0.0062694843,0.024072707,-0.06319935,-0.025684582,0.07265957,-0.04208773,-0.014134076,-0.017349897,-0.092400596,-0.006409122,0.09291194,-0.027069112,-0.08738224,0.042585004,-0.12305705,0.062073916,0.01713978,0.043850727,-0.0055547897,-0.035159733,-0.05796307,-0.0016850628,-0.029315371,0.0721106,0.04989424,-0.028748112,0.0011031058,-0.0070465477,0.02051565,0.0671912,0.021492135,0.06486443,0.006083919,0.025401652,0.07397291,-0.030965947,-0.007620959,-0.045778204,-0.048278432,0.09053187,0.032227647,-0.015725302,-0.010724716,0.013521858,-0.036038376,-0.092461206,0.013104383,-0.078536704,0.049683243,0.0088001145,-0.007872616,-0.11311237,0.11412768,-0.035817996,-0.047303315,0.014969718,0.023965022,-0.042791124,0.03148287,-0.022683943,0.0005804951,-0.11246335,-0.09787001,0.04521075,-0.031591777,-0.055069365,-0.023562724,0.052014757,-0.0024513777,0.0039027003,-0.010034752,0.033652805,0.122117504,-0.067184255,-0.0667508,0.1081975,-0.015414982,0.00400915,0.021052312,0.016455468,0.019499231,-0.12814389,5.5857056e-33,-0.0018572047,-0.080793984,-0.013305316,0.018411051,-0.037682965,-0.06759447,-0.08707167,0.013579468,-0.02803436,-0.03244555,-0.026130449,-0.0068652094,-0.02230584,-0.016416704,0.023153821,0.024428546,-0.011959954,0.09368941,-0.0325776,0.026465558,-0.046098772,0.008481753,-0.006716867,0.019120444,0.016167276,-0.023132937,-0.0042774444,0.04393394,-0.018111901,0.059962098,0.051095933,-0.07903501,-0.059705775,-0.13360043,0.04902078,0.03544222,-0.093780324,-0.056613877,-0.0022577501,0.03077088,0.015449573,0.0032539142,0.031303108,0.11281747,0.036288813,0.093467966,0.03139063,0.058778938,0.02215492,0.05777489,0.0009719842,-0.026091043,-0.06628838,0.015047404,0.03955509,0.05236228,0.0069718575,0.0009398838,-0.03959814,-0.075498044,-0.10264736,0.06432404,0.018766893,0.0139612565,0.060313296,-0.02941947,-0.030336095,-0.05356687,-0.07672768,0.012401418,-0.009276501,-0.054574188,-0.056601916,-0.024081068,-0.0397901,-0.035410695,0.01184497,0.036265045,-0.08490439,0.05896337,-0.030408576,0.10739633,0.0100452835,0.06581673,0.04995253,0.056139104,-0.018259417,0.023479536,-0.04595968,0.038907755,-0.005904816,-0.015094102,0.013457788,-0.039148435,0.01151064,-1.5212933e-08,-0.045827802,-0.029699314,0.03503022,-0.01087893,-0.0031904539,0.07422464,-0.07662781,0.054133236,0.021378785,-0.040636804,0.062867165,0.085515775,-0.08906479,0.056114767,0.048328143,0.008293789,0.08469364,-0.027762378,-0.015386821,0.06791649,-0.09377292,0.018911818,-0.013141011,0.04376481,-0.018527081,0.021828363,0.0024259256,0.020919835,0.10574034,0.063920565,0.05623135,0.053664792,-0.0830025,0.06855377,-0.005921287,-0.0768514,0.010081457,-0.01137772,-0.012504763,-0.10047469,-0.049601573,-0.002936133,0.015598559,-0.04278624,-0.0998226,0.02282328,0.06384441,0.011207105,0.0207268,0.08571721,0.04142781,0.026192738,0.096607804,0.08237023,0.03691293,-0.014799454,0.04348573,-0.07760758,0.01575173,0.078169346,0.11799159,0.05871559,0.021846,-0.016581286] \n", - "-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------\n", - "(2 rows)\n", - "\n", - "CREATE UNIQUE INDEX ON \"cte_25a74c5b96e64b4183b4b55256867459\" (id)\n", - "CREATE INDEX \"idx_f820481a5497411a9b74d29d2fe98ea5\" ON \"cte_25a74c5b96e64b4183b4b55256867459\" USING \"ivfflat\" ( \"_emb_23dd83d7748d40cd8eda21e5f2129629\")\n", - "\n", - " DO $$\n", - " BEGIN\n", - " SET LOCAL allow_system_table_mods TO ON;\n", - "\n", - " WITH embedding_info AS (\n", - " SELECT attrelid, attnum, 'all-MiniLM-L6-v2' AS model\n", - " FROM pg_attribute\n", - " WHERE \n", - " attrelid = '\"cte_7721c0fd66c041c4b7c5ff4074b3479b\"'::regclass::oid AND\n", - " attname = 'content'\n", - " )\n", - " UPDATE pg_class\n", - " SET reloptions = array_append(\n", - " reloptions, \n", - " format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info))\n", - " )\n", - " FROM embedding_info;\n", - "\n", - " WITH embedding_info AS (\n", - " SELECT attrelid, attnum, 'all-MiniLM-L6-v2' AS model\n", - " FROM pg_attribute\n", - " WHERE \n", - " attrelid = '\"cte_7721c0fd66c041c4b7c5ff4074b3479b\"'::regclass::oid AND\n", - " attname = 'content'\n", - " )\n", - " INSERT INTO pg_depend\n", - " SELECT\n", - " 'pg_class'::regclass::oid AS classid,\n", - " '\"cte_25a74c5b96e64b4183b4b55256867459\"'::regclass::oid AS objid,\n", - " 0::oid AS objsubid,\n", - " 'pg_class'::regclass::oid AS refclassid,\n", - " embedding_info.attrelid AS refobjid,\n", - " embedding_info.attnum AS refobjsubid,\n", - " 'n' AS deptype\n", - " FROM embedding_info;\n", - " END;\n", - " $$;\n", - " \n" - ] - }, { "data": { "text/html": [ @@ -192,67 +112,48 @@ "source": [ "import greenplumpython.experimental.embedding\n", "\n", - "gp.config.print_sql = True\n", - "\n", - "t = t.save_as(\n", - " column_names=[\"id\", \"content\"], distribution_key={\"id\"}, distribution_type=\"hash\"\n", - ").check_unique(columns={\"id\"})\n", "t = t.embedding().create_index(column=\"content\", model=\"all-MiniLM-L6-v2\")\n", "t" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Searching Embeddings" + ] + }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - " WITH embedding_oid AS (\n", - " SELECT attrelid, attnum\n", - " FROM pg_attribute\n", - " WHERE \n", - " attrelid = '\"cte_7721c0fd66c041c4b7c5ff4074b3479b\"'::regclass::oid AND\n", - " attname = 'content'\n", - " ), reloptions AS (\n", - " SELECT unnest(reloptions) AS option\n", - " FROM pg_class, embedding_oid\n", - " WHERE pg_class.oid = attrelid\n", - " ), embedding_info_json AS (\n", - " SELECT split_part(option, '=', 2)::json AS val\n", - " FROM reloptions, embedding_oid\n", - " WHERE option LIKE format('_pygp_emb_%s=%%', attnum)\n", - " ), embedding_info AS (\n", - " SELECT * \n", - " FROM embedding_info_json, json_to_record(val) AS (attnum int4, attrelid oid, model text)\n", - " )\n", - " SELECT nspname, relname, model\n", - " FROM embedding_info, pg_class, pg_namespace\n", - " WHERE \n", - " pg_class.oid = attrelid AND\n", - " relnamespace = pg_namespace.oid;\n", - " \n" - ] - }, - { - "ename": "AssertionError", - "evalue": "Database is required to create function.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m t\u001b[39m.\u001b[39;49membedding()\u001b[39m.\u001b[39;49msearch(column\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mcontent\u001b[39;49m\u001b[39m\"\u001b[39;49m, query\u001b[39m=\u001b[39;49m\u001b[39m\"\u001b[39;49m\u001b[39mapple\u001b[39;49m\u001b[39m\"\u001b[39;49m, top_k\u001b[39m=\u001b[39;49m\u001b[39m1\u001b[39;49m)\n", - "File \u001b[0;32m~/GreenplumPython/greenplumpython/experimental/embedding.py:154\u001b[0m, in \u001b[0;36mEmbedding.search\u001b[0;34m(self, column, query, top_k)\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataframe\u001b[39m.\u001b[39munique_key \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 152\u001b[0m distance \u001b[39m=\u001b[39m gp\u001b[39m.\u001b[39moperator(\u001b[39m\"\u001b[39m\u001b[39m<#>\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 153\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataframe\u001b[39m.\u001b[39mjoin(\n\u001b[0;32m--> 154\u001b[0m embedding_df\u001b[39m.\u001b[39;49massign(\n\u001b[1;32m 155\u001b[0m distance\u001b[39m=\u001b[39;49m\u001b[39mlambda\u001b[39;49;00m t: distance(t[\u001b[39m\"\u001b[39;49m\u001b[39m_emb_\u001b[39;49m\u001b[39m\"\u001b[39;49m], _generate_embedding(query, model))\n\u001b[1;32m 156\u001b[0m )\u001b[39m.\u001b[39morder_by(\u001b[39m\"\u001b[39m\u001b[39mdistance\u001b[39m\u001b[39m\"\u001b[39m)[:top_k],\n\u001b[1;32m 157\u001b[0m how\u001b[39m=\u001b[39m\u001b[39m\"\u001b[39m\u001b[39minner\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 158\u001b[0m on\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_dataframe\u001b[39m.\u001b[39munique_key,\n\u001b[1;32m 159\u001b[0m self_columns\u001b[39m=\u001b[39m{\u001b[39m\"\u001b[39m\u001b[39m*\u001b[39m\u001b[39m\"\u001b[39m},\n\u001b[1;32m 160\u001b[0m other_columns\u001b[39m=\u001b[39m{},\n\u001b[1;32m 161\u001b[0m )\n", - "File \u001b[0;32m~/GreenplumPython/greenplumpython/dataframe.py:492\u001b[0m, in \u001b[0;36mDataFrame.assign\u001b[0;34m(self, **new_columns)\u001b[0m\n\u001b[1;32m 490\u001b[0m other_parents[v\u001b[39m.\u001b[39m_other_dataframe\u001b[39m.\u001b[39m_name] \u001b[39m=\u001b[39m v\u001b[39m.\u001b[39m_other_dataframe\n\u001b[1;32m 491\u001b[0m v \u001b[39m=\u001b[39m v\u001b[39m.\u001b[39m_bind(db\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_db)\n\u001b[0;32m--> 492\u001b[0m targets\u001b[39m.\u001b[39mappend(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00m_serialize(v)\u001b[39m}\u001b[39;00m\u001b[39m AS \u001b[39m\u001b[39m{\u001b[39;00mk\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m)\n\u001b[1;32m 493\u001b[0m \u001b[39mreturn\u001b[39;00m DataFrame(\n\u001b[1;32m 494\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mSELECT *, \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39m'\u001b[39m\u001b[39m,\u001b[39m\u001b[39m'\u001b[39m\u001b[39m.\u001b[39mjoin(targets)\u001b[39m}\u001b[39;00m\u001b[39m FROM \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_name\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 495\u001b[0m parents\u001b[39m=\u001b[39m[\u001b[39mself\u001b[39m] \u001b[39m+\u001b[39m \u001b[39mlist\u001b[39m(other_parents\u001b[39m.\u001b[39mvalues()),\n\u001b[1;32m 496\u001b[0m )\n", - "File \u001b[0;32m~/GreenplumPython/greenplumpython/expr.py:561\u001b[0m, in \u001b[0;36m_serialize\u001b[0;34m(value)\u001b[0m\n\u001b[1;32m 551\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 552\u001b[0m \u001b[39m:meta private:\u001b[39;00m\n\u001b[1;32m 553\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[39m in Python 3 and Python 2 is EOL officially.\u001b[39;00m\n\u001b[1;32m 559\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 560\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(value, Expr):\n\u001b[0;32m--> 561\u001b[0m \u001b[39mreturn\u001b[39;00m value\u001b[39m.\u001b[39;49m_serialize()\n\u001b[1;32m 562\u001b[0m \u001b[39mreturn\u001b[39;00m adapt(value)\u001b[39m.\u001b[39mgetquoted()\u001b[39m.\u001b[39mdecode(\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m)\n", - "File \u001b[0;32m~/GreenplumPython/greenplumpython/expr.py:636\u001b[0m, in \u001b[0;36mBinaryExpr._serialize\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 633\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mgreenplumpython\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mexpr\u001b[39;00m \u001b[39mimport\u001b[39;00m _serialize\n\u001b[1;32m 635\u001b[0m left_str \u001b[39m=\u001b[39m _serialize(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_left)\n\u001b[0;32m--> 636\u001b[0m right_str \u001b[39m=\u001b[39m _serialize(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_right)\n\u001b[1;32m 637\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m(\u001b[39m\u001b[39m{\u001b[39;00mleft_str\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_operator\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m{\u001b[39;00mright_str\u001b[39m}\u001b[39;00m\u001b[39m)\u001b[39m\u001b[39m\"\u001b[39m\n", - "File \u001b[0;32m~/GreenplumPython/greenplumpython/expr.py:561\u001b[0m, in \u001b[0;36m_serialize\u001b[0;34m(value)\u001b[0m\n\u001b[1;32m 551\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 552\u001b[0m \u001b[39m:meta private:\u001b[39;00m\n\u001b[1;32m 553\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 558\u001b[0m \u001b[39m in Python 3 and Python 2 is EOL officially.\u001b[39;00m\n\u001b[1;32m 559\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 560\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(value, Expr):\n\u001b[0;32m--> 561\u001b[0m \u001b[39mreturn\u001b[39;00m value\u001b[39m.\u001b[39;49m_serialize()\n\u001b[1;32m 562\u001b[0m \u001b[39mreturn\u001b[39;00m adapt(value)\u001b[39m.\u001b[39mgetquoted()\u001b[39m.\u001b[39mdecode(\u001b[39m\"\u001b[39m\u001b[39mutf-8\u001b[39m\u001b[39m\"\u001b[39m)\n", - "File \u001b[0;32m~/GreenplumPython/greenplumpython/func.py:86\u001b[0m, in \u001b[0;36mFunctionExpr._serialize\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m_serialize\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n\u001b[1;32m 84\u001b[0m \u001b[39m# noqa D400\u001b[39;00m\n\u001b[1;32m 85\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\":meta private:\"\"\"\u001b[39;00m\n\u001b[0;32m---> 86\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_db \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m, \u001b[39m\"\u001b[39m\u001b[39mDatabase is required to create function.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 87\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_function\u001b[39m.\u001b[39m_create_in_db(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_db)\n\u001b[1;32m 88\u001b[0m distinct \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mDISTINCT\u001b[39m\u001b[39m\"\u001b[39m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_distinct \u001b[39melse\u001b[39;00m \u001b[39m\"\u001b[39m\u001b[39m\"\u001b[39m\n", - "\u001b[0;31mAssertionError\u001b[0m: Database is required to create function." - ] + "data": { + "text/html": [ + "\n", + "\t\n", + "\t\t\n", + "\t\t\n", + "\t\n", + "\t\n", + "\t\t\n", + "\t\t\n", + "\t\n", + "
idcontent
1I like eating apples.
" + ], + "text/plain": [ + "----------------------------\n", + " id | content \n", + "----+-----------------------\n", + " 1 | I like eating apples. \n", + "----------------------------\n", + "(1 row)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 5a7469d7..d0234ee8 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -1074,7 +1074,6 @@ def check_unique(self, columns: set[str]) -> "DataFrame": """ assert self.is_saved, "DataFrame must be saved before checking uniqueness." assert self._db is not None, "Database is required to check uniqueness." - print(self) self._db._execute( f"CREATE UNIQUE INDEX ON {self._qualified_table_name} ({','.join(columns)})", has_results=False, diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 79028ec6..84bfc449 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -74,7 +74,7 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: SET LOCAL allow_system_table_mods TO ON; WITH embedding_info AS ( - SELECT attrelid, attnum, '{model}' AS model + SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS base_relid, attnum, '{model}' AS model FROM pg_attribute WHERE attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND @@ -88,9 +88,9 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: FROM embedding_info; WITH embedding_info AS ( - SELECT attrelid, attnum, '{model}' AS model + SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS base_relid, attnum, '{model}' AS model FROM pg_attribute - WHERE + WHERE attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND attname = '{column}' ) @@ -100,14 +100,14 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: '{embedding_df._qualified_table_name}'::regclass::oid AS objid, 0::oid AS objsubid, 'pg_class'::regclass::oid AS refclassid, - embedding_info.attrelid AS refobjid, + embedding_info.base_relid AS refobjid, embedding_info.attnum AS refobjsubid, 'n' AS deptype FROM embedding_info; END; $$; """, - has_results=False + has_results=False, ) return self._dataframe @@ -115,7 +115,7 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: assert self._dataframe._db is not None embdedding_info = self._dataframe._db._execute( f""" - WITH embedding_oid AS ( + WITH indexed_col_info AS ( SELECT attrelid, attnum FROM pg_attribute WHERE @@ -123,36 +123,41 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: attname = '{column}' ), reloptions AS ( SELECT unnest(reloptions) AS option - FROM pg_class, embedding_oid + FROM pg_class, indexed_col_info WHERE pg_class.oid = attrelid ), embedding_info_json AS ( SELECT split_part(option, '=', 2)::json AS val - FROM reloptions, embedding_oid + FROM reloptions, indexed_col_info WHERE option LIKE format('_pygp_emb_%s=%%', attnum) ), embedding_info AS ( SELECT * - FROM embedding_info_json, json_to_record(val) AS (attnum int4, attrelid oid, model text) + FROM embedding_info_json, json_to_record(val) AS (attnum int4, base_relid oid, model text) ) - SELECT nspname, relname, model - FROM embedding_info, pg_class, pg_namespace + SELECT nspname, relname, attname, model + FROM embedding_info, pg_class, pg_namespace, pg_attribute WHERE - pg_class.oid = attrelid AND - relnamespace = pg_namespace.oid; + pg_class.oid = base_relid AND + relnamespace = pg_namespace.oid AND + base_relid = attrelid AND + pg_attribute.attnum = 2; """ ) # assert isinstance(embdedding_info, abc.Mapping) - embedding_table_name = None for row in embdedding_info: - embedding_table_name = f'"{row["nspname"]}"."{row["relname"]}"' + schema, embedding_table_name = row["nspname"], row["relname"] model = row["model"] + embedding_col_name = row["attname"] break - assert embedding_table_name is not None - embedding_df = self._dataframe._db.create_dataframe(embedding_table_name) + embedding_df = self._dataframe._db.create_dataframe( + table_name=embedding_table_name, schema=schema + ) assert self._dataframe.unique_key is not None - distance = gp.operator("<#>") + distance = gp.operator("<->") # L2 distance is the default operator class in pgvector return self._dataframe.join( embedding_df.assign( - distance=lambda t: distance(t["_emb_"], _generate_embedding(query, model)) + distance=lambda t: distance( + embedding_df[embedding_col_name], _generate_embedding(query, model) + ) ).order_by("distance")[:top_k], how="inner", on=self._dataframe.unique_key, diff --git a/greenplumpython/expr.py b/greenplumpython/expr.py index dc5222e6..9b0246cf 100644 --- a/greenplumpython/expr.py +++ b/greenplumpython/expr.py @@ -632,8 +632,9 @@ def __init__( def _serialize(self) -> str: from greenplumpython.expr import _serialize - left_str = _serialize(self._left) - right_str = _serialize(self._right) + # FIXME: Move _serialize() to be a method of Database. + left_str = _serialize(self._left._bind(db=self._db)) + right_str = _serialize(self._right._bind(db=self._db)) return f"({left_str} {self._operator} {right_str})" diff --git a/greenplumpython/type.py b/greenplumpython/type.py index 8956aa3b..2565739a 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -186,6 +186,7 @@ def _qualified_name(self) -> Tuple[Optional[str], str]: bytes: Type(name="bytea"), } + # FIXME: Change to data_type() to make it more clear. def type_(name: str, schema: Optional[str] = None, modifier: Optional[int] = None) -> Type: """ From 56d146f5dd8fbb4b9f234bbf13765840983da531 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Tue, 25 Jul 2023 23:59:53 -0400 Subject: [PATCH 04/19] Add title to notebook --- doc/source/notebooks/embedding.ipynb | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index feb4661d..c64b6b41 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -4,7 +4,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Installing the Package" + "# Generating, Indexing and Searching Embeddings\n", + "\n", + "## Installing the Package" ] }, { @@ -43,7 +45,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Preparing Data" + "## Preparing Data" ] }, { @@ -68,7 +70,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Generating and Indexing Embeddings" + "## Generating and Indexing Embeddings" ] }, { @@ -120,7 +122,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Searching Embeddings" + "## Searching Embeddings" ] }, { @@ -163,7 +165,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -178,8 +180,7 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" - }, - "orig_nbformat": 4 + } }, "nbformat": 4, "nbformat_minor": 2 From 4584d2570af4e034daffeabbf3a9d437c98c3dc1 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Wed, 26 Jul 2023 05:24:45 -0400 Subject: [PATCH 05/19] Fix bug on existing data types and reloptions --- doc/source/notebooks/embedding.ipynb | 12 ++++++------ greenplumpython/experimental/embedding.py | 15 ++++++++------- greenplumpython/type.py | 2 +- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index c64b6b41..86097623 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -28,8 +28,8 @@ "Requirement already satisfied, skipping upgrade: dill==0.3.6 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\n", "Building wheels for collected packages: greenplum-python\n", " Building wheel for greenplum-python (PEP 517) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=70614 sha256=4ee428916d3690ae05c591dff7e48e84e41f04d37fba06dfd6bf9543791e4d4f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-2iw1rkm7/wheels/bb/1f/99/ff8594e48ec11df99af6e0ee8611a5e560e9f44d1a3fefb351\n", + "\u001b[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=70623 sha256=6bbdfa6fb272db092d6e63def388beb58365fb6b726fc98bf40ff788ebe3143f\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-oqna_kdn/wheels/bb/1f/99/ff8594e48ec11df99af6e0ee8611a5e560e9f44d1a3fefb351\n", "Successfully built greenplum-python\n", "Installing collected packages: greenplum-python\n", "Successfully installed greenplum-python-1.0.1\n" @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -106,7 +106,7 @@ "(2 rows)" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -127,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -153,7 +153,7 @@ "(1 row)" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 84bfc449..f0f68e95 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -74,7 +74,7 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: SET LOCAL allow_system_table_mods TO ON; WITH embedding_info AS ( - SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS base_relid, attnum, '{model}' AS model + SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS embedding_relid, attnum, '{model}' AS model FROM pg_attribute WHERE attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND @@ -85,10 +85,11 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: reloptions, format('_pygp_emb_%s=%s', attnum::text, to_json(embedding_info)) ) - FROM embedding_info; + FROM embedding_info + WHERE oid = '{self._dataframe._qualified_table_name}'::regclass::oid; WITH embedding_info AS ( - SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS base_relid, attnum, '{model}' AS model + SELECT '{embedding_df._qualified_table_name}'::regclass::oid AS embedding_relid, attnum, '{model}' AS model FROM pg_attribute WHERE attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND @@ -100,7 +101,7 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: '{embedding_df._qualified_table_name}'::regclass::oid AS objid, 0::oid AS objsubid, 'pg_class'::regclass::oid AS refclassid, - embedding_info.base_relid AS refobjid, + embedding_info.embedding_relid AS refobjid, embedding_info.attnum AS refobjsubid, 'n' AS deptype FROM embedding_info; @@ -131,14 +132,14 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: WHERE option LIKE format('_pygp_emb_%s=%%', attnum) ), embedding_info AS ( SELECT * - FROM embedding_info_json, json_to_record(val) AS (attnum int4, base_relid oid, model text) + FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text) ) SELECT nspname, relname, attname, model FROM embedding_info, pg_class, pg_namespace, pg_attribute WHERE - pg_class.oid = base_relid AND + pg_class.oid = embedding_relid AND relnamespace = pg_namespace.oid AND - base_relid = attrelid AND + embedding_relid = attrelid AND pg_attribute.attnum = 2; """ ) diff --git a/greenplumpython/type.py b/greenplumpython/type.py index 2565739a..b8cbc411 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -230,7 +230,7 @@ def to_pg_type( args: Tuple[type, ...] = annotation.__args__ if for_return: return f"SETOF {to_pg_type(args[0], db)}" # type: ignore - if args[0] in _defined_types: + else: return f"{to_pg_type(args[0], db)}[]" # type: ignore raise NotImplementedError() else: From 76174a0a87ada2af0dadd4dabf25ba536de30a86 Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Wed, 26 Jul 2023 06:59:27 -0400 Subject: [PATCH 06/19] DROP CASCADE not work on segments --- doc/source/notebooks/embedding.ipynb | 55 +++++++++++++++++++---- greenplumpython/experimental/embedding.py | 2 +- 2 files changed, 47 insertions(+), 10 deletions(-) diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index 86097623..a9d4e4e2 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -24,12 +24,12 @@ " Installing build dependencies ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", "\u001b[?25h Preparing wheel metadata ... \u001b[?25ldone\n", - "\u001b[?25hRequirement already satisfied, skipping upgrade: psycopg2-binary==2.9.5 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (2.9.5)\n", - "Requirement already satisfied, skipping upgrade: dill==0.3.6 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\n", + "\u001b[?25hRequirement already satisfied, skipping upgrade: dill==0.3.6 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\n", + "Requirement already satisfied, skipping upgrade: psycopg2-binary==2.9.5 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (2.9.5)\n", "Building wheels for collected packages: greenplum-python\n", " Building wheel for greenplum-python (PEP 517) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=70623 sha256=6bbdfa6fb272db092d6e63def388beb58365fb6b726fc98bf40ff788ebe3143f\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-oqna_kdn/wheels/bb/1f/99/ff8594e48ec11df99af6e0ee8611a5e560e9f44d1a3fefb351\n", + "\u001b[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=70618 sha256=22e5718895157b66b4e6a4dd9bbb7164f4039f4d59dbf58fb519da409b246e38\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-q9fmewst/wheels/bb/1f/99/ff8594e48ec11df99af6e0ee8611a5e560e9f44d1a3fefb351\n", "Successfully built greenplum-python\n", "Installing collected packages: greenplum-python\n", "Successfully installed greenplum-python-1.0.1\n" @@ -61,7 +61,7 @@ "db = gp.database(\"postgresql://localhost:7000\")\n", "t = (\n", " db.create_dataframe(columns={\"id\": range(len(content)), \"content\": content})\n", - " .save_as(column_names=[\"id\", \"content\"], distribution_key={\"id\"}, distribution_type=\"hash\")\n", + " .save_as(table_name = 'text_sample', column_names=[\"id\", \"content\"], distribution_key={\"id\"}, distribution_type=\"hash\")\n", " .check_unique(columns={\"id\"})\n", ")" ] @@ -75,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -106,7 +106,7 @@ "(2 rows)" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -127,7 +127,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -153,7 +153,7 @@ "(1 row)" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -161,6 +161,43 @@ "source": [ "t.embedding().search(column=\"content\", query=\"apple\", top_k=1)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleaning All at Once" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " * postgresql://localhost:7000\n", + "Done.\n" + ] + }, + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%reload_ext sql\n", + "%sql postgresql://localhost:7000\n", + "%sql DROP TABLE text_sample CASCADE;" + ] } ], "metadata": { diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index f0f68e95..9438762c 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -101,7 +101,7 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: '{embedding_df._qualified_table_name}'::regclass::oid AS objid, 0::oid AS objsubid, 'pg_class'::regclass::oid AS refclassid, - embedding_info.embedding_relid AS refobjid, + '{self._dataframe._qualified_table_name}'::regclass::oid AS refobjid, embedding_info.attnum AS refobjsubid, 'n' AS deptype FROM embedding_info; From a7f07ea2932a2793a36b59f202a44836bd414d2f Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Thu, 3 Aug 2023 04:36:27 -0400 Subject: [PATCH 07/19] Add interface --- greenplumpython/dataframe.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index d0234ee8..326fc5ee 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -47,6 +47,7 @@ if TYPE_CHECKING: from greenplumpython.func import FunctionExpr + from greenplumpython.experimental.embedding import Embedding from uuid import uuid4 @@ -1193,3 +1194,7 @@ def from_columns(cls, columns: Dict[str, Iterable[Any]], db: Database) -> "DataF [f'unnest({_serialize(list(v))}) AS "{k}"' for k, v in columns.items()] ) return cls(f"SELECT {columns_string}", db=db) + + # Add interface here for language servers. + def embedding(self) -> "Embedding": + raise NotImplementedError From ecf616ac97e7dac60a9045574ea4c4142c426ad0 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Thu, 3 Aug 2023 18:46:35 +0200 Subject: [PATCH 08/19] Add docs for embedding --- doc/source/emb_example.rst | 5 ++++ doc/source/embedding.rst | 7 ++++++ doc/source/index.rst | 1 + doc/source/modules.rst | 1 + greenplumpython/dataframe.py | 8 +++---- greenplumpython/experimental/embedding.py | 29 +++++++++++++++++++++++ 6 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 doc/source/emb_example.rst create mode 100644 doc/source/embedding.rst diff --git a/doc/source/emb_example.rst b/doc/source/emb_example.rst new file mode 100644 index 00000000..ff41bef1 --- /dev/null +++ b/doc/source/emb_example.rst @@ -0,0 +1,5 @@ +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + notebooks/embedding diff --git a/doc/source/embedding.rst b/doc/source/embedding.rst new file mode 100644 index 00000000..0367e200 --- /dev/null +++ b/doc/source/embedding.rst @@ -0,0 +1,7 @@ +Embedding +========= + +.. automodule:: experimental.embedding + :members: + :show-inheritance: + :member-order: bysource \ No newline at end of file diff --git a/doc/source/index.rst b/doc/source/index.rst index 5fce21a3..4213dfab 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -24,6 +24,7 @@ There are explanations about the implementation and examples. modules abalone pandas + emb_example Indices and tables diff --git a/doc/source/modules.rst b/doc/source/modules.rst index 9a702759..18d35dde 100644 --- a/doc/source/modules.rst +++ b/doc/source/modules.rst @@ -17,5 +17,6 @@ The **GreenplumPython** library contains 5 main modules: group order op + embedding pd_df config \ No newline at end of file diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 326fc5ee..284c46d6 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -1067,12 +1067,11 @@ def distinct_on(self, *column_names: str) -> "DataFrame": @property def unique_key(self) -> List[str]: + """Return unique key.""" return self._unique_key def check_unique(self, columns: set[str]) -> "DataFrame": - """ - Check whether a given set of columns, i.e. key, is unique. - """ + """Check whether a given set of columns, i.e. key, is unique.""" assert self.is_saved, "DataFrame must be saved before checking uniqueness." assert self._db is not None, "Database is required to check uniqueness." self._db._execute( @@ -1194,7 +1193,8 @@ def from_columns(cls, columns: Dict[str, Iterable[Any]], db: Database) -> "DataF [f'unnest({_serialize(list(v))}) AS "{k}"' for k, v in columns.items()] ) return cls(f"SELECT {columns_string}", db=db) - + # Add interface here for language servers. def embedding(self) -> "Embedding": + """Allow user to process vector operators thanks to :class:`~experimental.embedding.Embedding`.""" raise NotImplementedError diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 9438762c..0268211e 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -26,13 +26,23 @@ def _generate_embedding(content: str, model_name: str) -> _vector_type: class Embedding: + """ + Embeddings provide a compact and meaningful representation of objects in a numerical vector space. + They capture the semantic relationships between objects. + + This class enables users to search unstructured data based on semantic similarity and to leverage the power of + the vector index scan. + """ + def __init__(self, dataframe: gp.DataFrame) -> None: self._dataframe = dataframe def create_index(self, column: str, model: str) -> gp.DataFrame: """ Generate embeddings and create index for a column of unstructured data. + This include + - texts, - images, or - videos, etc. @@ -44,6 +54,14 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: column-oriented approach, i.e., separated from the input DataFrame. The input DataFrame must have a **unique key** to identify the tuples in the search results. + + Args: + column: name of column to create index on. + model: name of model to generate embedding. + + Returns: + Dataframe with target column indexed based on embeddings. + """ assert self._dataframe.unique_key is not None, "Unique key is required to create index." @@ -113,6 +131,17 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: return self._dataframe def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: + """ + Searche unstructured data based on semantic similarity on embeddings. + + Args: + column: name of column to search + query: content to be searched + top_k: number of most similar results requested + + Returns: + Dataframe with the top k most similar results in the `column` of `query`. + """ assert self._dataframe._db is not None embdedding_info = self._dataframe._db._execute( f""" From d5db32a1797ffc768b162db276f239e1e9cc952e Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Fri, 4 Aug 2023 13:44:10 +0200 Subject: [PATCH 09/19] Remaining Type -> DataType --- greenplumpython/type.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/greenplumpython/type.py b/greenplumpython/type.py index 4a64497b..891ff6c1 100644 --- a/greenplumpython/type.py +++ b/greenplumpython/type.py @@ -209,7 +209,7 @@ def _serialize_to_type( return f"{_serialize_to_type(args[0], db)}[]" # type: ignore raise NotImplementedError() else: - if isinstance(annotation, Type): + if isinstance(annotation, DataType): return annotation._qualified_name_str assert db is not None, "Database is required to create type" if annotation not in _defined_types: From 18fd2cb6e79fa06c754208895a7a67742180a4bf Mon Sep 17 00:00:00 2001 From: Xuebin Su Date: Wed, 9 Aug 2023 23:12:07 -0400 Subject: [PATCH 10/19] Ignore type errors --- doc/source/notebooks/embedding.ipynb | 7 +++++- greenplumpython/experimental/embedding.py | 30 +++++++++++------------ 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index a9d4e4e2..4bd7a95e 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -61,7 +61,12 @@ "db = gp.database(\"postgresql://localhost:7000\")\n", "t = (\n", " db.create_dataframe(columns={\"id\": range(len(content)), \"content\": content})\n", - " .save_as(table_name = 'text_sample', column_names=[\"id\", \"content\"], distribution_key={\"id\"}, distribution_type=\"hash\")\n", + " .save_as(\n", + " table_name=\"text_sample\",\n", + " column_names=[\"id\", \"content\"],\n", + " distribution_key={\"id\"},\n", + " distribution_type=\"hash\",\n", + " )\n", " .check_unique(columns={\"id\"})\n", ")" ] diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 0268211e..a7c37e03 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -1,28 +1,26 @@ from collections import abc -from typing import Any +from typing import Any, cast, Callable from uuid import uuid4 import greenplumpython as gp - -_vector_type = gp.type_("vector", modifier=384) +from greenplumpython.func import FunctionExpr @gp.create_function -def _generate_embedding(content: str, model_name: str) -> _vector_type: +def _generate_embedding(content: str, model_name: str) -> gp.type_("vector", modifier=384): # type: ignore reportUnknownParameterType import sys + from sentence_transformers import SentenceTransformer # type: ignore reportMissingTypeStubs SD = globals().get("SD", sys.modules["plpy"]._SD) if "model" not in SD: - from sentence_transformers import SentenceTransformer - model = SentenceTransformer(model_name) SD["model"] = model else: model = SD["model"] # Sentences are encoded by calling model.encode() - emb = model.encode(content, normalize_embeddings=True) - return emb.tolist() + emb = model.encode(content, normalize_embeddings=True) # type: ignore reportUnknownVariableType + return emb.tolist() # type: ignore reportUnknownVariableType class Embedding: @@ -71,8 +69,9 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: embedding_df: gp.DataFrame = ( self._dataframe.assign( **{ - embedding_col_name: lambda t: _vector_type( - _generate_embedding(t[column], model) + embedding_col_name: cast( + Callable[[gp.DataFrame], FunctionExpr], + lambda t: _generate_embedding(t[column], model), # type: ignore reportUnknownLambdaType ) }, )[embedding_df_cols] @@ -172,12 +171,11 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: pg_attribute.attnum = 2; """ ) - # assert isinstance(embdedding_info, abc.Mapping) - for row in embdedding_info: - schema, embedding_table_name = row["nspname"], row["relname"] - model = row["model"] - embedding_col_name = row["attname"] - break + assert isinstance(embdedding_info, abc.Mapping[str, Any]) + row = embdedding_info[0] + schema, embedding_table_name = row["nspname"], row["relname"] + model = row["model"] + embedding_col_name = row["attname"] embedding_df = self._dataframe._db.create_dataframe( table_name=embedding_table_name, schema=schema ) From ac92f3378c03aaf55055b0812c78ebe5a9783f2d Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Mon, 21 Aug 2023 16:05:37 +0200 Subject: [PATCH 11/19] Fix lint error --- greenplumpython/experimental/embedding.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index a7c37e03..8d34cc95 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -1,22 +1,24 @@ from collections import abc -from typing import Any, cast, Callable +from typing import Any, Callable, cast from uuid import uuid4 import greenplumpython as gp from greenplumpython.func import FunctionExpr +from greenplumpython.row import Row @gp.create_function def _generate_embedding(content: str, model_name: str) -> gp.type_("vector", modifier=384): # type: ignore reportUnknownParameterType import sys - from sentence_transformers import SentenceTransformer # type: ignore reportMissingTypeStubs - SD = globals().get("SD", sys.modules["plpy"]._SD) + import sentence_transformers.SentenceTransformer as SentenceTransformer # type: ignore reportMissingImports + + SD = globals().get("SD") if globals().get("SD") is not None else sys.modules["plpy"]._SD if "model" not in SD: - model = SentenceTransformer(model_name) - SD["model"] = model + model = SentenceTransformer(model_name) # type: ignore reportUnknownVariableType + SD["model"] = model # type: ignore reportOptionalSubscript else: - model = SD["model"] + model = SD["model"] # type: ignore reportOptionalSubscript # Sentences are encoded by calling model.encode() emb = model.encode(content, normalize_embeddings=True) # type: ignore reportUnknownVariableType @@ -172,8 +174,9 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: """ ) assert isinstance(embdedding_info, abc.Mapping[str, Any]) - row = embdedding_info[0] - schema, embedding_table_name = row["nspname"], row["relname"] + row: Row = embdedding_info[0] + schema: str = row["nspname"] + embedding_table_name: str = row["relname"] model = row["model"] embedding_col_name = row["attname"] embedding_df = self._dataframe._db.create_dataframe( From 2ce05f081b59a32351da5bd9e0e2f3d8f3b742c7 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Mon, 21 Aug 2023 16:21:04 +0200 Subject: [PATCH 12/19] Add doc --- greenplumpython/dataframe.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index f2b1aab5..7563b6df 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -1075,7 +1075,15 @@ def unique_key(self) -> List[str]: return self._unique_key def check_unique(self, columns: set[str]) -> "DataFrame": - """Check whether a given set of columns, i.e. key, is unique.""" + """ + Check whether a given set of columns, i.e. key, is unique. + + Args: + columns: set of columns name to be checked + + Returns: + :class:`~dataframe.DataFrame`: self checked + """ assert self.is_saved, "DataFrame must be saved before checking uniqueness." assert self._db is not None, "Database is required to check uniqueness." self._db._execute( From 8ffc1f752d04654b6a1001f3d864b5e7ce6c2a7a Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Tue, 22 Aug 2023 14:37:40 +0200 Subject: [PATCH 13/19] Add embedding.ipynb in def search() in embedding.py --- doc/source/emb_example.rst | 4 +++- greenplumpython/experimental/embedding.py | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/doc/source/emb_example.rst b/doc/source/emb_example.rst index ff41bef1..9392ee27 100644 --- a/doc/source/emb_example.rst +++ b/doc/source/emb_example.rst @@ -1,5 +1,7 @@ +.. _embedding-example: + .. toctree:: :maxdepth: 2 - :caption: Contents: + :caption: Generating, Indexing and Searching Embeddings notebooks/embedding diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 8d34cc95..0398c590 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -142,6 +142,8 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: Returns: Dataframe with the top k most similar results in the `column` of `query`. + + See :ref:`embedding-example` for more details. """ assert self._dataframe._db is not None embdedding_info = self._dataframe._db._execute( From 6d125b0b076edca27d615f8b02f02d73f692fbc7 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Wed, 23 Aug 2023 14:34:43 +0200 Subject: [PATCH 14/19] Temporary fix vector dimension to 384 --- greenplumpython/experimental/embedding.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 0398c590..a3ae4d6f 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -3,12 +3,12 @@ from uuid import uuid4 import greenplumpython as gp -from greenplumpython.func import FunctionExpr from greenplumpython.row import Row +from greenplumpython.type import TypeCast @gp.create_function -def _generate_embedding(content: str, model_name: str) -> gp.type_("vector", modifier=384): # type: ignore reportUnknownParameterType +def _generate_embedding(content: str, model_name: str) -> gp.type_("vector"): # type: ignore reportUnknownParameterType import sys import sentence_transformers.SentenceTransformer as SentenceTransformer # type: ignore reportMissingImports @@ -72,8 +72,9 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: self._dataframe.assign( **{ embedding_col_name: cast( - Callable[[gp.DataFrame], FunctionExpr], - lambda t: _generate_embedding(t[column], model), # type: ignore reportUnknownLambdaType + Callable[[gp.DataFrame], TypeCast], + # FIXME: Modifier must be adapted to the model + lambda t: gp.type_("vector", modifier=384)(_generate_embedding(t[column], model)), # type: ignore reportUnknownLambdaType ) }, )[embedding_df_cols] From 8ea3b1cdbd047d56ac4b72691c9d9f3c4188dea2 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Wed, 23 Aug 2023 14:39:22 +0200 Subject: [PATCH 15/19] Add reminder to import embedding to load the implementation --- greenplumpython/dataframe.py | 8 ++++++-- greenplumpython/experimental/embedding.py | 2 -- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 7563b6df..8b050dad 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -1211,5 +1211,9 @@ def from_columns(cls, columns: Dict[str, Iterable[Any]], db: Database) -> "DataF # Add interface here for language servers. def embedding(self) -> "Embedding": - """Allow user to process vector operators thanks to :class:`~experimental.embedding.Embedding`.""" - raise NotImplementedError + """ + Enable embedding-based similarity search on columns of the current :class:`~DataFrame`. + + See :ref:`embedding-example` for more details. + """ + raise NotImplementedError("Please import greenplumpython.experimental.embedding to load the implementation.") diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index a3ae4d6f..a0be13cc 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -143,8 +143,6 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: Returns: Dataframe with the top k most similar results in the `column` of `query`. - - See :ref:`embedding-example` for more details. """ assert self._dataframe._db is not None embdedding_info = self._dataframe._db._execute( From 10ee29d2fe4c2f82023d05a7a628556c55502733 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Wed, 23 Aug 2023 14:43:29 +0200 Subject: [PATCH 16/19] Set number threads to 4 for embedding generation --- greenplumpython/experimental/embedding.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index a0be13cc..aab4ccc1 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -12,6 +12,11 @@ def _generate_embedding(content: str, model_name: str) -> gp.type_("vector"): # import sys import sentence_transformers.SentenceTransformer as SentenceTransformer # type: ignore reportMissingImports + import torch + + # Limit the degree of parallelism, otherwise the task may not complete. + # FIXME: This number should be set according to the resources available. + torch.set_num_threads(4) SD = globals().get("SD") if globals().get("SD") is not None else sys.modules["plpy"]._SD if "model" not in SD: From 6cbd2ceb7de1d164b6e3a0554726fb4ecb8b8fa1 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Thu, 24 Aug 2023 17:41:34 +0200 Subject: [PATCH 17/19] Support Batched k-NN search by allowing user to pass a DataFrame column as query --- doc/source/notebooks/embedding.ipynb | 198 +++++++++++++--------- greenplumpython/dataframe.py | 4 +- greenplumpython/experimental/embedding.py | 129 +++++++++----- 3 files changed, 208 insertions(+), 123 deletions(-) diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index 4bd7a95e..2dd22dc1 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -12,27 +12,38 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-24T14:57:03.091010Z", + "start_time": "2023-08-24T14:56:56.508814Z" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "/home/gpadmin/GreenplumPython\n", - "Defaulting to user installation because normal site-packages is not writeable\n", - "Processing /home/gpadmin/GreenplumPython\n", - " Installing build dependencies ... \u001b[?25ldone\n", - "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", - "\u001b[?25h Preparing wheel metadata ... \u001b[?25ldone\n", - "\u001b[?25hRequirement already satisfied, skipping upgrade: dill==0.3.6 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\n", - "Requirement already satisfied, skipping upgrade: psycopg2-binary==2.9.5 in /home/gpadmin/.local/lib/python3.9/site-packages (from greenplum-python==1.0.1) (2.9.5)\n", - "Building wheels for collected packages: greenplum-python\n", - " Building wheel for greenplum-python (PEP 517) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=70618 sha256=22e5718895157b66b4e6a4dd9bbb7164f4039f4d59dbf58fb519da409b246e38\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-q9fmewst/wheels/bb/1f/99/ff8594e48ec11df99af6e0ee8611a5e560e9f44d1a3fefb351\n", - "Successfully built greenplum-python\n", - "Installing collected packages: greenplum-python\n", - "Successfully installed greenplum-python-1.0.1\n" + "/Users/ruxuez/Desktop/dev/GreenplumPython\n", + "Processing /Users/ruxuez/Desktop/dev/GreenplumPython\r\n", + " Installing build dependencies ... \u001B[?25ldone\r\n", + "\u001B[?25h Getting requirements to build wheel ... \u001B[?25ldone\r\n", + "\u001B[?25h Preparing metadata (pyproject.toml) ... \u001B[?25ldone\r\n", + "\u001B[?25hRequirement already satisfied: psycopg2-binary==2.9.5 in ./venv/lib/python3.9/site-packages (from greenplum-python==1.0.1) (2.9.5)\r\n", + "Requirement already satisfied: dill==0.3.6 in ./venv/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\r\n", + "Building wheels for collected packages: greenplum-python\r\n", + " Building wheel for greenplum-python (pyproject.toml) ... \u001B[?25ldone\r\n", + "\u001B[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=84165 sha256=9746078492ea731ec973e8d87764de0b8fda43f8a9af1de2d4db7a13799ddb8c\r\n", + " Stored in directory: /private/var/folders/jf/ycmq4_px3nj7gcrs015qqhxm0000gq/T/pip-ephem-wheel-cache-d86u32fo/wheels/56/a3/62/fb507748981bea497278b550674de9ab4cfa5150c30722b3d5\r\n", + "Successfully built greenplum-python\r\n", + "Installing collected packages: greenplum-python\r\n", + " Attempting uninstall: greenplum-python\r\n", + " Found existing installation: greenplum-python 1.0.1\r\n", + " Uninstalling greenplum-python-1.0.1:\r\n", + " Successfully uninstalled greenplum-python-1.0.1\r\n", + "Successfully installed greenplum-python-1.0.1\r\n", + "\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m A new release of pip available: \u001B[0m\u001B[31;49m22.3.1\u001B[0m\u001B[39;49m -> \u001B[0m\u001B[32;49m23.2.1\u001B[0m\r\n", + "\u001B[1m[\u001B[0m\u001B[34;49mnotice\u001B[0m\u001B[1;39;49m]\u001B[0m\u001B[39;49m To update, run: \u001B[0m\u001B[32;49mpip install --upgrade pip\u001B[0m\r\n" ] } ], @@ -51,14 +62,19 @@ { "cell_type": "code", "execution_count": 2, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-24T14:57:03.994979Z", + "start_time": "2023-08-24T14:57:03.090193Z" + } + }, "outputs": [], "source": [ "content = [\"I have a dog.\", \"I like eating apples.\"]\n", "\n", "import greenplumpython as gp\n", "\n", - "db = gp.database(\"postgresql://localhost:7000\")\n", + "db = gp.database(\"postgres://localhost:7000\")\n", "t = (\n", " db.create_dataframe(columns={\"id\": range(len(content)), \"content\": content})\n", " .save_as(\n", @@ -81,35 +97,17 @@ { "cell_type": "code", "execution_count": 3, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-24T14:57:09.828879Z", + "start_time": "2023-08-24T14:57:03.997220Z" + } + }, "outputs": [ { "data": { - "text/html": [ - "\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "
idcontent
0I have a dog.
1I like eating apples.
" - ], - "text/plain": [ - "----------------------------\n", - " id | content \n", - "----+-----------------------\n", - " 0 | I have a dog. \n", - " 1 | I like eating apples. \n", - "----------------------------\n", - "(2 rows)" - ] + "text/plain": "----------------------------\n id | content \n----+-----------------------\n 0 | I have a dog. \n 1 | I like eating apples. \n----------------------------\n(2 rows)", + "text/html": "\n\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\n
idcontent
0I have a dog.
1I like eating apples.
" }, "execution_count": 3, "metadata": {}, @@ -133,30 +131,17 @@ { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2023-08-24T14:57:15.090346Z", + "start_time": "2023-08-24T14:57:09.830713Z" + } + }, "outputs": [ { "data": { - "text/html": [ - "\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "\t\n", - "\t\t\n", - "\t\t\n", - "\t\n", - "
idcontent
1I like eating apples.
" - ], - "text/plain": [ - "----------------------------\n", - " id | content \n", - "----+-----------------------\n", - " 1 | I like eating apples. \n", - "----------------------------\n", - "(1 row)" - ] + "text/plain": "----------------------------\n id | content \n----+-----------------------\n 1 | I like eating apples. \n----------------------------\n(1 row)", + "text/html": "\n\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\n
idcontent
1I like eating apples.
" }, "execution_count": 4, "metadata": {}, @@ -169,40 +154,91 @@ }, { "cell_type": "markdown", - "metadata": {}, "source": [ - "## Cleaning All at Once" - ] + "Batched k-NN search" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "query = (\n", + " db.create_dataframe(columns={\"idd\": range(3), \"query\": [\"apple\", \"dog\", \"banana\"]})\n", + " .save_as(\n", + " table_name=\"query_sample\",\n", + " column_names=[\"idd\", \"query\"],\n", + " distribution_key={\"idd\"},\n", + " distribution_type=\"hash\",\n", + " )\n", + " .check_unique(columns={\"idd\"})\n", + " .embedding()\n", + " .create_index(column=\"query\", model=\"all-MiniLM-L6-v2\")\n", + ")" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-24T14:57:19.072770Z", + "start_time": "2023-08-24T14:57:15.092543Z" + } + } }, { "cell_type": "code", "execution_count": 6, - "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - " * postgresql://localhost:7000\n", - "Done.\n" - ] - }, { "data": { - "text/plain": [ - "[]" - ] + "text/plain": "-------------------------------------------\n idd | id | query | content \n-----+----+--------+-----------------------\n 1 | 0 | dog | I have a dog. \n 2 | 0 | banana | I have a dog. \n 2 | 1 | banana | I like eating apples. \n 0 | 1 | apple | I like eating apples. \n-------------------------------------------\n(4 rows)", + "text/html": "\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n
iddidquerycontent
10dogI have a dog.
20bananaI have a dog.
21bananaI like eating apples.
01appleI like eating apples.
" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], + "source": [ + "t.embedding().search(column=\"content\", query=query[\"query\"], top_k=2)" + ], + "metadata": { + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-24T14:57:19.825771Z", + "start_time": "2023-08-24T14:57:19.074661Z" + } + } + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleaning All at Once" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "%reload_ext sql\n", "%sql postgresql://localhost:7000\n", - "%sql DROP TABLE text_sample CASCADE;" + "%sql DROP TABLE text_sample CASCADE;\n", + "%sql DROP TABLE query_sample CASCADE;" ] + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false + } } ], "metadata": { diff --git a/greenplumpython/dataframe.py b/greenplumpython/dataframe.py index 8b050dad..8bb78044 100644 --- a/greenplumpython/dataframe.py +++ b/greenplumpython/dataframe.py @@ -1216,4 +1216,6 @@ def embedding(self) -> "Embedding": See :ref:`embedding-example` for more details. """ - raise NotImplementedError("Please import greenplumpython.experimental.embedding to load the implementation.") + raise NotImplementedError( + "Please import greenplumpython.experimental.embedding to load the implementation." + ) diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index aab4ccc1..3f01af15 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -1,5 +1,5 @@ from collections import abc -from typing import Any, Callable, cast +from typing import Any, Callable, List, cast from uuid import uuid4 import greenplumpython as gp @@ -12,7 +12,7 @@ def _generate_embedding(content: str, model_name: str) -> gp.type_("vector"): # import sys import sentence_transformers.SentenceTransformer as SentenceTransformer # type: ignore reportMissingImports - import torch + import torch # type: ignore reportMissingImports; # type: ignore reportUnknownVariableType # Limit the degree of parallelism, otherwise the task may not complete. # FIXME: This number should be set according to the resources available. @@ -149,51 +149,98 @@ def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: Returns: Dataframe with the top k most similar results in the `column` of `query`. """ - assert self._dataframe._db is not None - embdedding_info = self._dataframe._db._execute( - f""" - WITH indexed_col_info AS ( - SELECT attrelid, attnum - FROM pg_attribute + + def find_embedding_df(df: gp.DataFrame, column_c: str): + assert df._db is not None + + embdedding_info = df._db._execute( + f""" + WITH indexed_col_info AS ( + SELECT attrelid, attnum + FROM pg_attribute + WHERE + attrelid = '{df._qualified_table_name}'::regclass::oid AND + attname = '{column_c}' + ), reloptions AS ( + SELECT unnest(reloptions) AS option + FROM pg_class, indexed_col_info + WHERE pg_class.oid = attrelid + ), embedding_info_json AS ( + SELECT split_part(option, '=', 2)::json AS val + FROM reloptions, indexed_col_info + WHERE option LIKE format('_pygp_emb_%s=%%', attnum) + ), embedding_info AS ( + SELECT * + FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text) + ) + SELECT nspname, relname, attname, model + FROM embedding_info, pg_class, pg_namespace, pg_attribute WHERE - attrelid = '{self._dataframe._qualified_table_name}'::regclass::oid AND - attname = '{column}' - ), reloptions AS ( - SELECT unnest(reloptions) AS option - FROM pg_class, indexed_col_info - WHERE pg_class.oid = attrelid - ), embedding_info_json AS ( - SELECT split_part(option, '=', 2)::json AS val - FROM reloptions, indexed_col_info - WHERE option LIKE format('_pygp_emb_%s=%%', attnum) - ), embedding_info AS ( - SELECT * - FROM embedding_info_json, json_to_record(val) AS (attnum int4, embedding_relid oid, model text) + pg_class.oid = embedding_relid AND + relnamespace = pg_namespace.oid AND + embedding_relid = attrelid AND + pg_attribute.attnum = 2; + """ ) - SELECT nspname, relname, attname, model - FROM embedding_info, pg_class, pg_namespace, pg_attribute - WHERE - pg_class.oid = embedding_relid AND - relnamespace = pg_namespace.oid AND - embedding_relid = attrelid AND - pg_attribute.attnum = 2; - """ - ) - assert isinstance(embdedding_info, abc.Mapping[str, Any]) - row: Row = embdedding_info[0] - schema: str = row["nspname"] - embedding_table_name: str = row["relname"] - model = row["model"] - embedding_col_name = row["attname"] - embedding_df = self._dataframe._db.create_dataframe( - table_name=embedding_table_name, schema=schema - ) + # assert isinstance(embdedding_info, abc.Mapping[str, Any]) + row: Row = embdedding_info[0] + schema: str = row["nspname"] + embedding_table_name: str = row["relname"] + model = row["model"] + embedding_col_name = row["attname"] + embedding_df = df._db.create_dataframe(table_name=embedding_table_name, schema=schema) + return embedding_df, embedding_table_name, embedding_col_name, model + + ( + self_embedding_df, + self_embedding_table_name, + self_embedding_col_name, + self_model, + ) = find_embedding_df(self._dataframe, column) assert self._dataframe.unique_key is not None distance = gp.operator("<->") # L2 distance is the default operator class in pgvector + if isinstance(query, gp.Expr): + assert query._dataframe is not None + (_, query_embedding_table_name, query_embedding_col_name, _,) = find_embedding_df( + query._dataframe.embedding()._dataframe, query._name # type: ignore reportUnknownArgumentType + ) + assert query._dataframe.unique_key is not None + joint_table_name = "cte_" + uuid4().hex + query_df_unique_keys: List[str] = list(query._dataframe.unique_key) + self_df_unique_keys: List[str] = list(self._dataframe.unique_key) + assert query_df_unique_keys is not None + assert self_df_unique_keys is not None + lateral_join_df = gp.DataFrame( + query=f""" + WITH {joint_table_name} as ( + SELECT + * + FROM {self_embedding_table_name} CROSS JOIN LATERAL ( + SELECT * FROM {query_embedding_table_name} + ORDER BY {self_embedding_table_name}.{self_embedding_col_name} <-> {query_embedding_table_name}.{query_embedding_col_name} + LIMIT {top_k} + ) AS {"cte_" + uuid4().hex} + ) + SELECT + {",".join([(query._dataframe._qualified_table_name+"."+key) for key in query_df_unique_keys])}, + {",".join([(self._dataframe._qualified_table_name+"."+key) for key in self_df_unique_keys])}, + {query._dataframe._qualified_table_name}.{query._name}, + {self._dataframe._qualified_table_name}.{column} + FROM {joint_table_name} + JOIN {query._dataframe._qualified_table_name} + ON {"AND".join([(query._dataframe._qualified_table_name+"."+key+" = "+joint_table_name+"." + key) for key in query_df_unique_keys])} + JOIN {self._dataframe._qualified_table_name} + ON {"AND".join([(self._dataframe._qualified_table_name+"."+key+" = "+joint_table_name+"." + key) for key in self_df_unique_keys])} + """, + db=self._dataframe._db, + ) + return lateral_join_df + return self._dataframe.join( - embedding_df.assign( + self_embedding_df.assign( distance=lambda t: distance( - embedding_df[embedding_col_name], _generate_embedding(query, model) + self_embedding_df[self_embedding_col_name], + _generate_embedding(query, self_model), ) ).order_by("distance")[:top_k], how="inner", From 5ac174366665aa31cb2198e97b4c458d17e3ced1 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Fri, 25 Aug 2023 14:58:39 +0200 Subject: [PATCH 18/19] Fix cross lateral join --- doc/source/notebooks/embedding.ipynb | 44 ++++++++++++----------- greenplumpython/experimental/embedding.py | 5 ++- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/doc/source/notebooks/embedding.ipynb b/doc/source/notebooks/embedding.ipynb index 2dd22dc1..69b139a2 100644 --- a/doc/source/notebooks/embedding.ipynb +++ b/doc/source/notebooks/embedding.ipynb @@ -14,8 +14,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2023-08-24T14:57:03.091010Z", - "start_time": "2023-08-24T14:56:56.508814Z" + "end_time": "2023-08-25T12:57:01.715707Z", + "start_time": "2023-08-25T12:56:54.919200Z" } }, "outputs": [ @@ -32,8 +32,8 @@ "Requirement already satisfied: dill==0.3.6 in ./venv/lib/python3.9/site-packages (from greenplum-python==1.0.1) (0.3.6)\r\n", "Building wheels for collected packages: greenplum-python\r\n", " Building wheel for greenplum-python (pyproject.toml) ... \u001B[?25ldone\r\n", - "\u001B[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=84165 sha256=9746078492ea731ec973e8d87764de0b8fda43f8a9af1de2d4db7a13799ddb8c\r\n", - " Stored in directory: /private/var/folders/jf/ycmq4_px3nj7gcrs015qqhxm0000gq/T/pip-ephem-wheel-cache-d86u32fo/wheels/56/a3/62/fb507748981bea497278b550674de9ab4cfa5150c30722b3d5\r\n", + "\u001B[?25h Created wheel for greenplum-python: filename=greenplum_python-1.0.1-py3-none-any.whl size=84199 sha256=78992ece399bf042be4143672a186d904cf1603148f91dcd7b0b54c610a3245d\r\n", + " Stored in directory: /private/var/folders/jf/ycmq4_px3nj7gcrs015qqhxm0000gq/T/pip-ephem-wheel-cache-3atrca12/wheels/56/a3/62/fb507748981bea497278b550674de9ab4cfa5150c30722b3d5\r\n", "Successfully built greenplum-python\r\n", "Installing collected packages: greenplum-python\r\n", " Attempting uninstall: greenplum-python\r\n", @@ -64,8 +64,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2023-08-24T14:57:03.994979Z", - "start_time": "2023-08-24T14:57:03.090193Z" + "end_time": "2023-08-25T12:57:02.644919Z", + "start_time": "2023-08-25T12:57:01.723149Z" } }, "outputs": [], @@ -99,8 +99,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2023-08-24T14:57:09.828879Z", - "start_time": "2023-08-24T14:57:03.997220Z" + "end_time": "2023-08-25T12:57:08.645604Z", + "start_time": "2023-08-25T12:57:02.646625Z" } }, "outputs": [ @@ -133,8 +133,8 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2023-08-24T14:57:15.090346Z", - "start_time": "2023-08-24T14:57:09.830713Z" + "end_time": "2023-08-25T12:57:14.069009Z", + "start_time": "2023-08-25T12:57:08.643273Z" } }, "outputs": [ @@ -155,7 +155,7 @@ { "cell_type": "markdown", "source": [ - "Batched k-NN search" + "## Batched k-NN search" ], "metadata": { "collapsed": false @@ -182,8 +182,8 @@ "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-08-24T14:57:19.072770Z", - "start_time": "2023-08-24T14:57:15.092543Z" + "end_time": "2023-08-25T12:57:17.400047Z", + "start_time": "2023-08-25T12:57:14.059315Z" } } }, @@ -193,8 +193,8 @@ "outputs": [ { "data": { - "text/plain": "-------------------------------------------\n idd | id | query | content \n-----+----+--------+-----------------------\n 1 | 0 | dog | I have a dog. \n 2 | 0 | banana | I have a dog. \n 2 | 1 | banana | I like eating apples. \n 0 | 1 | apple | I like eating apples. \n-------------------------------------------\n(4 rows)", - "text/html": "\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n
iddidquerycontent
10dogI have a dog.
20bananaI have a dog.
21bananaI like eating apples.
01appleI like eating apples.
" + "text/plain": "-------------------------------------------\n idd | id | query | content \n-----+----+--------+-----------------------\n 1 | 0 | dog | I have a dog. \n 2 | 1 | banana | I like eating apples. \n 0 | 1 | apple | I like eating apples. \n-------------------------------------------\n(3 rows)", + "text/html": "\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n\t\n\t\t\n\t\t\n\t\t\n\t\t\n\t\n
iddidquerycontent
10dogI have a dog.
21bananaI like eating apples.
01appleI like eating apples.
" }, "execution_count": 6, "metadata": {}, @@ -202,13 +202,13 @@ } ], "source": [ - "t.embedding().search(column=\"content\", query=query[\"query\"], top_k=2)" + "t.embedding().search(column=\"content\", query=query[\"query\"], top_k=1)" ], "metadata": { "collapsed": false, "ExecuteTime": { - "end_time": "2023-08-24T14:57:19.825771Z", - "start_time": "2023-08-24T14:57:19.074661Z" + "end_time": "2023-08-25T12:57:18.305871Z", + "start_time": "2023-08-25T12:57:17.402679Z" } } }, @@ -233,11 +233,15 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "outputs": [], "source": [], "metadata": { - "collapsed": false + "collapsed": false, + "ExecuteTime": { + "end_time": "2023-08-25T12:57:19.295165Z", + "start_time": "2023-08-25T12:57:19.292932Z" + } } } ], diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index 3f01af15..a2e1d2d7 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -182,7 +182,6 @@ def find_embedding_df(df: gp.DataFrame, column_c: str): pg_attribute.attnum = 2; """ ) - # assert isinstance(embdedding_info, abc.Mapping[str, Any]) row: Row = embdedding_info[0] schema: str = row["nspname"] embedding_table_name: str = row["relname"] @@ -215,8 +214,8 @@ def find_embedding_df(df: gp.DataFrame, column_c: str): WITH {joint_table_name} as ( SELECT * - FROM {self_embedding_table_name} CROSS JOIN LATERAL ( - SELECT * FROM {query_embedding_table_name} + FROM {query_embedding_table_name} CROSS JOIN LATERAL ( + SELECT * FROM {self_embedding_table_name} ORDER BY {self_embedding_table_name}.{self_embedding_col_name} <-> {query_embedding_table_name}.{query_embedding_col_name} LIMIT {top_k} ) AS {"cte_" + uuid4().hex} From d40cfa57580da9cf07c8e19d1a1d50ffc2453f29 Mon Sep 17 00:00:00 2001 From: Ruxue Zeng Date: Fri, 25 Aug 2023 17:53:56 +0200 Subject: [PATCH 19/19] Fix duplicate unique key in KNN batched search --- greenplumpython/experimental/embedding.py | 45 +++++++++++++++++++---- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/greenplumpython/experimental/embedding.py b/greenplumpython/experimental/embedding.py index a2e1d2d7..b753f1bf 100644 --- a/greenplumpython/experimental/embedding.py +++ b/greenplumpython/experimental/embedding.py @@ -1,8 +1,9 @@ from collections import abc -from typing import Any, Callable, List, cast +from typing import Any, Callable, Dict, List, Optional, Set, Union, cast from uuid import uuid4 import greenplumpython as gp +from greenplumpython.col import Column from greenplumpython.row import Row from greenplumpython.type import TypeCast @@ -137,7 +138,13 @@ def create_index(self, column: str, model: str) -> gp.DataFrame: ) return self._dataframe - def search(self, column: str, query: Any, top_k: int) -> gp.DataFrame: + def search( + self, + column: str, + query: Any, + top_k: int, + query_unique_key_columns: Optional[Union[Dict[str, Optional[str]], Set[str]]] = None, + ) -> gp.DataFrame: """ Searche unstructured data based on semantic similarity on embeddings. @@ -190,6 +197,14 @@ def find_embedding_df(df: gp.DataFrame, column_c: str): embedding_df = df._db.create_dataframe(table_name=embedding_table_name, schema=schema) return embedding_df, embedding_table_name, embedding_col_name, model + def _bind(t: str, columns: Union[Dict[str, Optional[str]], Set[str]]) -> List[str]: + target_list: List[str] = [] + for k in columns: + v = columns[k] if isinstance(columns, dict) else None + col_serialize = t + "." + k + (f' AS "{v}"' if v is not None else "") + target_list.append(col_serialize) + return target_list + ( self_embedding_df, self_embedding_table_name, @@ -198,13 +213,14 @@ def find_embedding_df(df: gp.DataFrame, column_c: str): ) = find_embedding_df(self._dataframe, column) assert self._dataframe.unique_key is not None distance = gp.operator("<->") # L2 distance is the default operator class in pgvector - if isinstance(query, gp.Expr): + if isinstance(query, Column): assert query._dataframe is not None (_, query_embedding_table_name, query_embedding_col_name, _,) = find_embedding_df( query._dataframe.embedding()._dataframe, query._name # type: ignore reportUnknownArgumentType ) assert query._dataframe.unique_key is not None joint_table_name = "cte_" + uuid4().hex + right_join_table_name = "cte_" + uuid4().hex query_df_unique_keys: List[str] = list(query._dataframe.unique_key) self_df_unique_keys: List[str] = list(self._dataframe.unique_key) assert query_df_unique_keys is not None @@ -213,21 +229,36 @@ def find_embedding_df(df: gp.DataFrame, column_c: str): query=f""" WITH {joint_table_name} as ( SELECT - * + {",".join(_bind(query_embedding_table_name, columns=query_unique_key_columns)) + if query_unique_key_columns is not None + else ",".join( + [(query_embedding_table_name+"."+key) for key in query_df_unique_keys] + )}, + {",".join([(right_join_table_name+"."+key) for key in self_df_unique_keys])}, + {query_embedding_table_name}.{query_embedding_col_name}, + {right_join_table_name}.{self_embedding_col_name} FROM {query_embedding_table_name} CROSS JOIN LATERAL ( SELECT * FROM {self_embedding_table_name} ORDER BY {self_embedding_table_name}.{self_embedding_col_name} <-> {query_embedding_table_name}.{query_embedding_col_name} LIMIT {top_k} - ) AS {"cte_" + uuid4().hex} + ) AS {right_join_table_name} ) + SELECT - {",".join([(query._dataframe._qualified_table_name+"."+key) for key in query_df_unique_keys])}, + {",".join(_bind(query._dataframe._qualified_table_name, columns=query_unique_key_columns)) + if query_unique_key_columns is not None + else ",".join( + [(query._dataframe._qualified_table_name+"."+key) for key in query_df_unique_keys] + )}, {",".join([(self._dataframe._qualified_table_name+"."+key) for key in self_df_unique_keys])}, {query._dataframe._qualified_table_name}.{query._name}, {self._dataframe._qualified_table_name}.{column} FROM {joint_table_name} JOIN {query._dataframe._qualified_table_name} - ON {"AND".join([(query._dataframe._qualified_table_name+"."+key+" = "+joint_table_name+"." + key) for key in query_df_unique_keys])} + ON {"AND".join([ + (f"{query._dataframe._qualified_table_name}.{key} = {joint_table_name}.{query_unique_key_columns[key] if query_unique_key_columns is not None and key in query_unique_key_columns else key}") + for key in query_df_unique_keys + ])} JOIN {self._dataframe._qualified_table_name} ON {"AND".join([(self._dataframe._qualified_table_name+"."+key+" = "+joint_table_name+"." + key) for key in self_df_unique_keys])} """,