Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,51 @@ def can_write() -> bool:
def can_read() -> bool:
return True

@staticmethod
def _sanitize_for_bigquery(data: Any) -> Any:
"""BigQuery-specific float sanitization for PARSE_JSON compatibility.

BigQuery's PARSE_JSON() requires floats that can "round-trip" through
string representation. This method limits total significant figures to 15
(IEEE 754 double precision safe zone) to ensure clean binary representation.

Args:
data: The data structure to sanitize (dict, list, or primitive)

Returns:
Sanitized data compatible with BigQuery's PARSE_JSON

Example:
>>> BigQuery._sanitize_for_bigquery({"time": 1760509016.282637})
{'time': 1760509016.28264} # Limited to 15 significant figures

>>> BigQuery._sanitize_for_bigquery({"cost": 0.001228})
{'cost': 0.001228} # Unchanged (only 4 significant figures)
"""
import math

if isinstance(data, float):
# Handle special values that BigQuery can't store in JSON
if math.isnan(data) or math.isinf(data):
return None
if data == 0:
return 0.0

# Limit total significant figures to 15 for IEEE 754 compatibility
# BigQuery PARSE_JSON requires values that round-trip cleanly
# For large numbers (like Unix timestamps), this reduces decimal precision
# For small numbers (like costs), full precision is preserved
magnitude = math.floor(math.log10(abs(data))) + 1
safe_decimals = max(0, 15 - magnitude)
return float(f"{data:.{safe_decimals}f}")

elif isinstance(data, dict):
return {k: BigQuery._sanitize_for_bigquery(v) for k, v in data.items()}
elif isinstance(data, list):
return [BigQuery._sanitize_for_bigquery(item) for item in data]
else:
return data

def get_engine(self) -> Any:
return self.bigquery.Client.from_service_account_info( # type: ignore
info=self.json_credentials
Expand Down Expand Up @@ -202,7 +247,13 @@ def execute_query(

if isinstance(value, (dict, list)) and column_type == "JSON":
# For JSON objects in JSON columns, convert to string and use PARSE_JSON
json_str = json.dumps(value) if value else None
# Sanitize floats before serialization to ensure clean JSON for PARSE_JSON
sanitized_value = BigQuery._sanitize_for_bigquery(value)
json_str = (
json.dumps(sanitized_value)
if sanitized_value is not None
else None
)
if json_str:
# Replace @`key` with PARSE_JSON(@`key`) in the SQL query
modified_sql = modified_sql.replace(
Expand All @@ -213,7 +264,13 @@ def execute_query(
)
elif isinstance(value, (dict, list)):
# For dict/list values in STRING columns, serialize to JSON string
json_str = json.dumps(value) if value else None
# Sanitize floats before serialization to ensure clean JSON
sanitized_value = BigQuery._sanitize_for_bigquery(value)
json_str = (
json.dumps(sanitized_value)
if sanitized_value is not None
else None
)
query_parameters.append(
self.bigquery.ScalarQueryParameter(key, "STRING", json_str)
)
Expand Down Expand Up @@ -314,7 +371,10 @@ def get_sql_values_for_query(
# Try to parse JSON strings back to objects for BigQuery
try:
parsed_value = json.loads(value)
sql_values[column] = parsed_value
# Sanitize floats after parsing to prevent precision issues
# json.loads() creates new float objects that may have binary precision problems
sanitized_value = BigQuery._sanitize_for_bigquery(parsed_value)
sql_values[column] = sanitized_value
except (TypeError, ValueError, json.JSONDecodeError):
# Not a JSON string, keep as string
sql_values[column] = f"{value}"
Expand Down
51 changes: 51 additions & 0 deletions unstract/connectors/src/unstract/connectors/databases/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Database Utilities

Common utilities for database connectors to ensure consistent data handling
across all database types (BigQuery, PostgreSQL, MySQL, Snowflake, etc.).
"""

import math
from typing import Any


def sanitize_floats_for_database(data: Any) -> Any:
"""Sanitize special float values (NaN, Inf) for database compatibility.

This minimal sanitization applies to all databases. It only handles
special float values that no database can store in JSON:
- NaN (Not a Number) → None
- Infinity → None
- -Infinity → None

Database-specific precision handling (like BigQuery's round-trip requirements)
should be implemented in the respective database connector.

Args:
data: The data structure to sanitize (dict, list, or primitive)

Returns:
Sanitized data with NaN/Inf converted to None

Example:
>>> sanitize_floats_for_database({"value": float("nan")})
{'value': None}

>>> sanitize_floats_for_database({"value": float("inf")})
{'value': None}

>>> sanitize_floats_for_database({"price": 1760509016.282637})
{'price': 1760509016.282637} # Unchanged - precision preserved
"""
if isinstance(data, float):
# Only handle special values that no database supports
if math.isnan(data) or math.isinf(data):
return None
# Return unchanged - let database connector handle precision if needed
return data
elif isinstance(data, dict):
return {k: sanitize_floats_for_database(v) for k, v in data.items()}
elif isinstance(data, list):
return [sanitize_floats_for_database(item) for item in data]
else:
# Return other types unchanged (int, str, bool, None, etc.)
return data
52 changes: 5 additions & 47 deletions workers/shared/infrastructure/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import datetime
import json
import math
from typing import Any

from shared.enums.status_enums import FileProcessingStatus
Expand All @@ -15,6 +14,7 @@
from unstract.connectors.databases import connectors as db_connectors
from unstract.connectors.databases.exceptions import UnstractDBConnectorException
from unstract.connectors.databases.unstract_db import UnstractDB
from unstract.connectors.databases.utils import sanitize_floats_for_database
from unstract.connectors.exceptions import ConnectorError

from ..logging import WorkerLogger
Expand Down Expand Up @@ -78,46 +78,6 @@ def __init__(self, detail: str):
class WorkerDatabaseUtils:
"""Worker-compatible database utilities following production patterns."""

@staticmethod
def _sanitize_floats_for_database(data: Any, precision: int = 6) -> Any:
"""Recursively sanitize float values for database JSON compatibility.

BigQuery's PARSE_JSON() requires floats that can "round-trip" through
string representation. This function normalizes floats to ensure they
serialize cleanly for all database types (BigQuery, PostgreSQL, MySQL, etc.).

Args:
data: The data structure to sanitize (dict, list, or primitive)
precision: Number of decimal places to preserve (default: 6)

Returns:
Sanitized data with normalized float values

Example:
>>> _sanitize_floats_for_database({"time": 22.770092, "count": 5})
{'time': 22.770092, 'count': 5}
"""
if isinstance(data, float):
# Handle special float values that databases don't support in JSON
if math.isnan(data) or math.isinf(data):
return None
# Normalize float representation using string formatting
# This ensures clean binary representation that BigQuery accepts
return float(f"{data:.{precision}f}")
elif isinstance(data, dict):
return {
k: WorkerDatabaseUtils._sanitize_floats_for_database(v, precision)
for k, v in data.items()
}
elif isinstance(data, list):
return [
WorkerDatabaseUtils._sanitize_floats_for_database(item, precision)
for item in data
]
else:
# Return other types unchanged (int, str, bool, None, etc.)
return data

@staticmethod
def get_sql_values_for_query(
conn_cls: Any,
Expand Down Expand Up @@ -334,9 +294,7 @@ def _add_processing_columns(
if metadata and has_metadata_col:
try:
# Sanitize floats for database JSON compatibility (BigQuery, PostgreSQL, etc.)
sanitized_metadata = WorkerDatabaseUtils._sanitize_floats_for_database(
metadata
)
sanitized_metadata = sanitize_floats_for_database(metadata)
values[TableColumns.METADATA] = json.dumps(sanitized_metadata)
except (TypeError, ValueError) as e:
logger.error(f"Failed to serialize metadata to JSON: {e}")
Expand Down Expand Up @@ -404,7 +362,7 @@ def _process_single_column_mode(
values[v2_col_name] = wrapped_dict
else:
# Sanitize floats for database JSON compatibility
sanitized_data = WorkerDatabaseUtils._sanitize_floats_for_database(data)
sanitized_data = sanitize_floats_for_database(data)
values[single_column_name] = sanitized_data
if has_v2_col:
values[v2_col_name] = sanitized_data
Expand All @@ -416,14 +374,14 @@ def _process_split_column_mode(
"""Process data for split column mode."""
if isinstance(data, dict):
# Sanitize floats for database JSON compatibility
sanitized_data = WorkerDatabaseUtils._sanitize_floats_for_database(data)
sanitized_data = sanitize_floats_for_database(data)
values.update(sanitized_data)
elif isinstance(data, str):
values[single_column_name] = data
else:
try:
# Sanitize floats for database JSON compatibility before serialization
sanitized_data = WorkerDatabaseUtils._sanitize_floats_for_database(data)
sanitized_data = sanitize_floats_for_database(data)
values[single_column_name] = json.dumps(sanitized_data)
except (TypeError, ValueError) as e:
logger.error(
Expand Down