Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 51 additions & 45 deletions python/python/lance/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import pickle
import sqlite3
from contextlib import closing
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional

import pyarrow as pa
Expand Down Expand Up @@ -105,64 +106,69 @@ class BatchInfo(NamedTuple):

def __init__(self, path):
self.path = path
# We don't re-use the connection because it's not thread safe
conn = sqlite3.connect(path)
# One table to store the results for each batch.
conn.execute(
"""
CREATE TABLE IF NOT EXISTS batches
(fragment_id INT, batch_index INT, result BLOB)
"""
)
# One table to store fully written (but not committed) fragments.
conn.execute(
"CREATE TABLE IF NOT EXISTS fragments (fragment_id INT, data BLOB)"
)
conn.commit()
# We don't re-use the connection because it's not thread safe.
# Each method creates and closes its own connection.
# Note: sqlite3's context manager only handles transactions, not connection
# closing. We use closing() to ensure connections are closed, which is
# required on Windows to avoid file locking issues.
with closing(sqlite3.connect(path)) as conn:
# One table to store the results for each batch.
conn.execute(
"""
CREATE TABLE IF NOT EXISTS batches
(fragment_id INT, batch_index INT, result BLOB)
"""
)
# One table to store fully written (but not committed) fragments.
conn.execute(
"CREATE TABLE IF NOT EXISTS fragments (fragment_id INT, data BLOB)"
)
conn.commit()

def cleanup(self):
os.remove(self.path)

def get_batch(self, info: BatchInfo) -> Optional[pa.RecordBatch]:
conn = sqlite3.connect(self.path)
cursor = conn.execute(
"SELECT result FROM batches WHERE fragment_id = ? AND batch_index = ?",
(info.fragment_id, info.batch_index),
)
row = cursor.fetchone()
if row is not None:
return pickle.loads(row[0])
return None
with closing(sqlite3.connect(self.path)) as conn:
cursor = conn.execute(
"SELECT result FROM batches WHERE fragment_id = ? AND batch_index = ?",
(info.fragment_id, info.batch_index),
)
row = cursor.fetchone()
if row is not None:
return pickle.loads(row[0])
return None

def insert_batch(self, info: BatchInfo, batch: pa.RecordBatch):
conn = sqlite3.connect(self.path)
conn.execute(
"INSERT INTO batches (fragment_id, batch_index, result) VALUES (?, ?, ?)",
(info.fragment_id, info.batch_index, pickle.dumps(batch)),
)
conn.commit()
with closing(sqlite3.connect(self.path)) as conn:
conn.execute(
"INSERT INTO batches (fragment_id, batch_index, result) "
"VALUES (?, ?, ?)",
(info.fragment_id, info.batch_index, pickle.dumps(batch)),
)
conn.commit()

def get_fragment(self, fragment_id: int) -> Optional[str]:
"""Retrieves a fragment as a JSON string."""
conn = sqlite3.connect(self.path)
cursor = conn.execute(
"SELECT data FROM fragments WHERE fragment_id = ?", (fragment_id,)
)
row = cursor.fetchone()
if row is not None:
return row[0]
return None
with closing(sqlite3.connect(self.path)) as conn:
cursor = conn.execute(
"SELECT data FROM fragments WHERE fragment_id = ?", (fragment_id,)
)
row = cursor.fetchone()
if row is not None:
return row[0]
return None

def insert_fragment(self, fragment_id: int, fragment: str):
"""Save a JSON string of a fragment to the cache."""
# Clear all batches for the fragment
conn = sqlite3.connect(self.path)
conn.execute(
"INSERT INTO fragments (fragment_id, data) VALUES (?, ?)",
(fragment_id, fragment),
)
conn.execute("DELETE FROM batches WHERE fragment_id = ?", (fragment_id,))
conn.commit()
with closing(sqlite3.connect(self.path)) as conn:
conn.execute(
"INSERT INTO fragments (fragment_id, data) VALUES (?, ?)",
(fragment_id, fragment),
)
# Clear all batches for the fragment
conn.execute("DELETE FROM batches WHERE fragment_id = ?", (fragment_id,))
conn.commit()


def normalize_transform(
Expand Down
Loading