Skip to content

Commit

Permalink
Add ingest and upsert for weavaite
Browse files Browse the repository at this point in the history
  • Loading branch information
sunank200 committed Nov 20, 2023
1 parent 7325054 commit b76ab3f
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 85 deletions.
2 changes: 1 addition & 1 deletion airflow/dags/ingestion/ask-astro-load.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def ask_astro_load_bulk():
This DAG performs the initial load of data from sources.
If seed_baseline_url (set above) points to a parquet file with pre-embedded data it will be
ingested. Otherwise new data is extracted, split, embedded and ingested.
ingested. Otherwise, new data is extracted, split, embedded and ingested.
The first time this DAG runs (without seeded baseline) it will take at lease 90 minutes to
extract data from all sources. Extracted data is then serialized to disk in the project
Expand Down
248 changes: 164 additions & 84 deletions airflow/include/tasks/extract/utils/weaviate/ask_astro_weaviate_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,54 +136,30 @@ def create_schema(self, class_objects: list, existing: str = "ignore") -> None:
self.client.schema.create_class(class_object)
self.logger.info(f"Created/updated class {class_name}")

def ingest_data(
self,
dfs: list[pd.DataFrame],
class_name: str,
existing: str = "skip",
doc_key: str = None,
uuid_column: str = None,
vector_column: str = None,
batch_params: dict = None,
verbose: bool = True,
) -> list:
def prepare_data_for_ingestion(
self, dfs: list[pd.DataFrame], class_name: str, existing: str, uuid_column: str, vector_column: str
) -> tuple(pd.DataFrame, str):
"""
This task concatenates multiple dataframes from upstream dynamic tasks and vectorizes with import to weaviate.
The operator returns a list of any objects that failed to import.
A 'uuid' is generated based on the content and metadata (the git sha, document url, the document source and a
concatenation of the headers) and Weaviate will create the vectors.
Upsert and logic relies on a 'doc_key' which is a uniue representation of the document. Because documents can
be represented as multiple chunks (each with a UUID which is unique in the DB) the doc_key is a way to represent
all chunks associated with an ingested document.
:param dfs: A list of dataframes from downstream dynamic tasks
:param class_name: The name of the class to import data. Class should be created with weaviate schema.
:param existing: Whether to 'upsert', 'skip' or 'replace' any existing documents. Default is 'skip'.
:param doc_key: If using upsert you must specify a doc_key which uniquely identifies a document which may or may
not include multiple (unique) chunks.
:param vector_column: For pre-embedded data specify the name of the column containing the embedding vector
:param uuid_column: For data with pre-generated UUID specify the name of the column containing the UUID
:param batch_params: Additional parameters to pass to the weaviate batch configuration
:param verbose: Whether to print verbose output
Prepares data for ingestion into Weaviate.
:param dfs: A list of dataframes from downstream dynamic tasks.
:param class_name: The name of the class to import data.
:param existing: Strategy to handle existing data ('skip', 'replace', 'upsert').
:param uuid_column: Name of the column containing the UUID.
:param vector_column: Name of the column containing the vector data.
:return: A concatenated and processed DataFrame ready for ingestion.
"""

global objects_to_upsert
if existing not in ["skip", "replace", "upsert"]:
raise AirflowException("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert'")
raise AirflowException("Invalid parameter for 'existing'. Choices are 'skip', 'replace', 'upsert'")

df = pd.concat(dfs, ignore_index=True)

# Without a pre-generated UUID weaviate ingest just creates one with uuid.uuid4()
# This will lead to duplicates in vector db.
if uuid_column is None:
# reorder columns alphabetically for consistent uuid mapping
column_names = df.columns.to_list()
column_names.sort()
df = df[column_names]

self.logger.info("No uuid_column provided Generating UUIDs for ingest.")
self.logger.info("No uuid_column provided. Generating UUIDs for ingest.")
if "id" in column_names:
raise AirflowException("Property 'id' already in dataset. Consider renaming or specify 'uuid_column'.")
else:
Expand All @@ -196,82 +172,183 @@ def ingest_data(
df.drop_duplicates(inplace=True)

if df[uuid_column].duplicated().any():
raise AirflowException("Duplicate rows found. Remove duplicates before ingest.")
raise AirflowException("Duplicate rows found. Remove duplicates before ingest.")

if existing == "upsert":
if doc_key is None:
raise AirflowException("Must specify 'doc_key' if 'existing=upsert'.")
else:
if df[[doc_key, uuid_column]].duplicated().any():
raise AirflowException("Duplicate rows found. Remove duplicates before ingest.")

current_schema = self.client.schema.get(class_name=class_name)
doc_key_schema = [prop for prop in current_schema["properties"] if prop["name"] == doc_key]

if len(doc_key_schema) < 1:
raise AirflowException("doc_key does not exist in current schema.")
elif doc_key_schema[0]["tokenization"] != "field":
raise AirflowException(
"Tokenization for provided doc_key is not set to 'field'. Cannot upsert safely."
)

# get a list of any UUIDs which need to be removed later
objects_to_upsert = self._objects_to_upsert(
df=df, class_name=class_name, doc_key=doc_key, uuid_column=uuid_column
)
return df, uuid_column

df = df[df[uuid_column].isin(objects_to_upsert["objects_to_insert"])]
def handle_upsert(self, df: pd.DataFrame, class_name: str, doc_key: str, uuid_column: str) -> (pd.DataFrame, dict):
"""
Handles the 'upsert' operation for data ingestion.
self.logger.info(f"Passing {len(df)} objects for ingest.")
:param df: The DataFrame containing the data to be upserted.
:param class_name: The name of the class to import data.
:param doc_key: The document key used for upsert operation.
:param uuid_column: The column name containing the UUID.
:return: The DataFrame filtered for objects to insert.
"""
if doc_key is None:
raise AirflowException("Must specify 'doc_key' if 'existing=upsert'.")

if df[[doc_key, uuid_column]].duplicated().any():
raise AirflowException("Duplicate rows found. Remove duplicates before ingest.")

current_schema = self.client.schema.get(class_name=class_name)
doc_key_schema = [prop for prop in current_schema["properties"] if prop["name"] == doc_key]

if len(doc_key_schema) < 1:
raise AirflowException("doc_key does not exist in current schema.")
elif doc_key_schema[0]["tokenization"] != "field":
raise AirflowException("Tokenization for provided doc_key is not set to 'field'. Cannot upsert safely.")

objects_to_upsert = self._objects_to_upsert(df, class_name, doc_key, uuid_column)

return df[df[uuid_column].isin(objects_to_upsert["objects_to_insert"])], objects_to_upsert

def batch_process_data(
self,
df: pd.DataFrame,
class_name: str,
uuid_column: str,
vector_column: str,
batch_params: dict,
existing: str,
verbose: bool,
) -> (list, Any):
"""
Processes the DataFrame and batches the data for ingestion into Weaviate.
:param df: DataFrame containing the data to be ingested.
:param class_name: The name of the class in Weaviate to which data will be ingested.
:param uuid_column: Name of the column containing the UUID.
:param vector_column: Name of the column containing the vector data.
:param batch_params: Parameters for batch configuration.
:param existing: Strategy to handle existing data ('skip', 'replace', 'upsert').
:param verbose: Whether to print verbose output.
:return: List of any objects that failed to be added to the batch.
"""
batch = self.client.batch.configure(**batch_params)
batch_errors = []

for row_id, row in df.iterrows():
data_object = row.to_dict()
uuid = data_object[uuid_column]

# if the uuid exists we know that the properties are the same
if self.client.data_object.exists(uuid=uuid, class_name=class_name) is True:
# Check if the uuid exists and handle accordingly
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
if existing == "skip":
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Skipping.")
if verbose:
self.logger.warning(f"UUID {uuid} exists. Skipping.")
continue
elif existing == "replace":
# Default for weaviate is replacing existing
if verbose is True:
self.logger.warning(f"UUID {uuid} exists. Overwriting.")
if verbose:
self.logger.warning(f"UUID {uuid} exists. Overwriting.")

vector = data_object.pop(vector_column, None)
uuid = data_object.pop(uuid_column)

added_row = batch.add_data_object(class_name=class_name, uuid=uuid, data_object=data_object, vector=vector)
if verbose is True:
self.logger.info(f"Added row {row_id} with UUID {added_row} for batch import.")
try:
batch.add_data_object(class_name=class_name, uuid=uuid, data_object=data_object, vector=vector)
if verbose:
self.logger.info(f"Added row {row_id} with UUID {uuid} for batch import.")
except Exception as e:
if verbose:
self.logger.error(f"Failed to add row {row_id} with UUID {uuid}. Error: {e}")
batch_errors.append({"row_id": row_id, "uuid": uuid, "error": str(e)})

results = batch.create_objects()
return batch_errors + [item for result in results for item in result.get("errors", [])], results

def process_batch_errors(self, results: list, verbose: bool) -> list:
"""
Processes the results from batch operation and collects any errors.
:param results: Results from the batch operation.
:param verbose: Flag to enable verbose logging.
:return: List of error messages.
"""
batch_errors = []
for item in results:
if "errors" in item["result"]:
item_error = {"id": item["id"], "errors": item["result"]["errors"]}
if verbose:
self.logger.info(item_error)
batch_errors.append(item_error)
return batch_errors

def handle_upsert_rollback(self, objects_to_upsert: dict, class_name: str, verbose: bool):
"""
Handles rollback of inserts in case of errors during upsert operation.
:param objects_to_upsert: Dictionary of objects to upsert.
:param class_name: Name of the class in Weaviate.
:param verbose: Flag to enable verbose logging.
"""
for uuid in objects_to_upsert["objects_to_insert"]:
self.logger.info(f"Removing id {uuid} for rollback.")
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
self.client.data_object.delete(uuid=uuid, class_name=class_name, consistency_level="ALL")
elif verbose:
self.logger.info(f"UUID {uuid} does not exist. Skipping deletion.")

for uuid in objects_to_upsert["objects_to_delete"]:
if verbose:
self.logger.info(f"Deleting id {uuid} for successful upsert.")
if self.client.data_object.exists(uuid=uuid, class_name=class_name):
self.client.data_object.delete(uuid=uuid, class_name=class_name)
elif verbose:
self.logger.info(f"UUID {uuid} does not exist. Skipping deletion.")

def ingest_data(
self,
dfs: list[pd.DataFrame],
class_name: str,
existing: str = "skip",
doc_key: str = None,
uuid_column: str = None,
vector_column: str = None,
batch_params: dict = None,
verbose: bool = True,
) -> list:
"""
This task concatenates multiple dataframes from upstream dynamic tasks and vectorized with import to weaviate.
The operator returns a list of any objects that failed to import.
A 'uuid' is generated based on the content and metadata (the git sha, document url, the document source and a
concatenation of the headers) and Weaviate will create the vectors.
Upsert and logic relies on a 'doc_key' which is a uniue representation of the document. Because documents can
be represented as multiple chunks (each with a UUID which is unique in the DB) the doc_key is a way to represent
all chunks associated with an ingested document.
:param dfs: A list of dataframes from downstream dynamic tasks
:param class_name: The name of the class to import data. Class should be created with weaviate schema.
:param existing: Whether to 'upsert', 'skip' or 'replace' any existing documents. Default is 'skip'.
:param doc_key: If using upsert you must specify a doc_key which uniquely identifies a document which may or may
not include multiple (unique) chunks.
:param vector_column: For pre-embedded data specify the name of the column containing the embedding vector
:param uuid_column: For data with pre-generated UUID specify the name of the column containing the UUID
:param batch_params: Additional parameters to pass to the weaviate batch configuration
:param verbose: Whether to print verbose output
"""

global objects_to_upsert

df, uuid_column = self.prepare_data_for_ingestion(dfs, class_name, existing, uuid_column, vector_column)

# check errors from callback
if existing == "upsert":
if len(batch_errors) > 0:
self.logger.warning("Error during upsert. Rollling back all inserts.")
# rollback inserts
for uuid in objects_to_upsert["objects_to_insert"]:
self.logger.info(f"Removing id {uuid} for rollback.")
self.client.data_object.delete(uuid=uuid, class_name=class_name, consistency_level="ALL")

elif len(objects_to_upsert["objects_to_delete"]) > 0:
for uuid in objects_to_upsert["objects_to_delete"]:
if verbose:
self.logger.info(f"Deleting id {uuid} for successful upsert.")
self.client.data_object.delete(uuid=uuid, class_name=class_name)
df, objects_to_upsert = self.handle_upsert(df, class_name, doc_key, uuid_column)

self.logger.info(f"Passing {len(df)} objects for ingest.")

batch_errors, results = self.batch_process_data(
df, class_name, uuid_column, vector_column, batch_params, existing, verbose
)

batch_errors += self.process_batch_errors(results, verbose)

if existing == "upsert" and batch_errors:
self.logger.warning("Error during upsert. Rolling back all inserts.")
self.handle_upsert_rollback(objects_to_upsert, class_name, verbose)

return batch_errors

Expand Down Expand Up @@ -306,6 +383,9 @@ def _objects_to_upsert(self, df: pd.DataFrame, class_name: str, doc_key: str, uu
:param doc_key: The name of the property to query.
:param uuid_column: The name of the column containing the UUID.
"""
if doc_key is None:
# Return an empty dictionary or handle the situation as needed
return {}
ids_df = df.groupby(doc_key)[uuid_column].apply(set).reset_index(name="new_ids")
ids_df["existing_ids"] = ids_df[doc_key].apply(
lambda x: self._query_objects(value=x, doc_key=doc_key, uuid_column=uuid_column, class_name=class_name)
Expand Down

0 comments on commit b76ab3f

Please sign in to comment.