diff --git a/python/python/lance/udf.py b/python/python/lance/udf.py index 525c3346967..de6c7c4ff59 100644 --- a/python/python/lance/udf.py +++ b/python/python/lance/udf.py @@ -6,6 +6,7 @@ import os import pickle import sqlite3 +from contextlib import closing from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional import pyarrow as pa @@ -105,64 +106,69 @@ class BatchInfo(NamedTuple): def __init__(self, path): self.path = path - # We don't re-use the connection because it's not thread safe - conn = sqlite3.connect(path) - # One table to store the results for each batch. - conn.execute( - """ - CREATE TABLE IF NOT EXISTS batches - (fragment_id INT, batch_index INT, result BLOB) - """ - ) - # One table to store fully written (but not committed) fragments. - conn.execute( - "CREATE TABLE IF NOT EXISTS fragments (fragment_id INT, data BLOB)" - ) - conn.commit() + # We don't re-use the connection because it's not thread safe. + # Each method creates and closes its own connection. + # Note: sqlite3's context manager only handles transactions, not connection + # closing. We use closing() to ensure connections are closed, which is + # required on Windows to avoid file locking issues. + with closing(sqlite3.connect(path)) as conn: + # One table to store the results for each batch. + conn.execute( + """ + CREATE TABLE IF NOT EXISTS batches + (fragment_id INT, batch_index INT, result BLOB) + """ + ) + # One table to store fully written (but not committed) fragments. + conn.execute( + "CREATE TABLE IF NOT EXISTS fragments (fragment_id INT, data BLOB)" + ) + conn.commit() def cleanup(self): os.remove(self.path) def get_batch(self, info: BatchInfo) -> Optional[pa.RecordBatch]: - conn = sqlite3.connect(self.path) - cursor = conn.execute( - "SELECT result FROM batches WHERE fragment_id = ? AND batch_index = ?", - (info.fragment_id, info.batch_index), - ) - row = cursor.fetchone() - if row is not None: - return pickle.loads(row[0]) - return None + with closing(sqlite3.connect(self.path)) as conn: + cursor = conn.execute( + "SELECT result FROM batches WHERE fragment_id = ? AND batch_index = ?", + (info.fragment_id, info.batch_index), + ) + row = cursor.fetchone() + if row is not None: + return pickle.loads(row[0]) + return None def insert_batch(self, info: BatchInfo, batch: pa.RecordBatch): - conn = sqlite3.connect(self.path) - conn.execute( - "INSERT INTO batches (fragment_id, batch_index, result) VALUES (?, ?, ?)", - (info.fragment_id, info.batch_index, pickle.dumps(batch)), - ) - conn.commit() + with closing(sqlite3.connect(self.path)) as conn: + conn.execute( + "INSERT INTO batches (fragment_id, batch_index, result) " + "VALUES (?, ?, ?)", + (info.fragment_id, info.batch_index, pickle.dumps(batch)), + ) + conn.commit() def get_fragment(self, fragment_id: int) -> Optional[str]: """Retrieves a fragment as a JSON string.""" - conn = sqlite3.connect(self.path) - cursor = conn.execute( - "SELECT data FROM fragments WHERE fragment_id = ?", (fragment_id,) - ) - row = cursor.fetchone() - if row is not None: - return row[0] - return None + with closing(sqlite3.connect(self.path)) as conn: + cursor = conn.execute( + "SELECT data FROM fragments WHERE fragment_id = ?", (fragment_id,) + ) + row = cursor.fetchone() + if row is not None: + return row[0] + return None def insert_fragment(self, fragment_id: int, fragment: str): """Save a JSON string of a fragment to the cache.""" - # Clear all batches for the fragment - conn = sqlite3.connect(self.path) - conn.execute( - "INSERT INTO fragments (fragment_id, data) VALUES (?, ?)", - (fragment_id, fragment), - ) - conn.execute("DELETE FROM batches WHERE fragment_id = ?", (fragment_id,)) - conn.commit() + with closing(sqlite3.connect(self.path)) as conn: + conn.execute( + "INSERT INTO fragments (fragment_id, data) VALUES (?, ?)", + (fragment_id, fragment), + ) + # Clear all batches for the fragment + conn.execute("DELETE FROM batches WHERE fragment_id = ?", (fragment_id,)) + conn.commit() def normalize_transform(