From 592feb494f8c5ef09fbb9273630261878ea4a825 Mon Sep 17 00:00:00 2001 From: tkrabel Date: Sun, 29 Oct 2023 14:31:46 +0100 Subject: [PATCH 1/5] Create one database connection per thread --- rope/contrib/autoimport/sqlite.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/rope/contrib/autoimport/sqlite.py b/rope/contrib/autoimport/sqlite.py index 507352782..05bd94ab5 100644 --- a/rope/contrib/autoimport/sqlite.py +++ b/rope/contrib/autoimport/sqlite.py @@ -11,6 +11,7 @@ from datetime import datetime from itertools import chain from pathlib import Path +from threading import local from typing import Generator, Iterable, Iterator, List, Optional, Set, Tuple from rope.base import exceptions, libutils, resourceobserver, taskhandle, versioning @@ -67,6 +68,7 @@ def filter_package(package: Package) -> bool: _deprecated_default: bool = object() # type: ignore +thread_local = local() class AutoImport: @@ -78,9 +80,10 @@ class AutoImport: """ connection: sqlite3.Connection - underlined: bool + memory: bool project: Project project_package: Package + underlined: bool def __init__( self, @@ -114,8 +117,9 @@ def __init__( assert project_package.path is not None self.project_package = project_package self.underlined = underlined + self.memory = memory if memory is _deprecated_default: - memory = True + self.memory = True warnings.warn( "The default value for `AutoImport(memory)` argument will " "change to use an on-disk database by default in the future. " @@ -158,6 +162,23 @@ def create_database_connection( db_path = str(Path(project.ropefolder.real_path) / "autoimport.db") return sqlite3.connect(db_path) + @property + def connection(self): + """Creates a new connection if called from a new thread. + + This makes sure AutoImport can be shared across threads. + """ + if not hasattr(thread_local, "connection"): + thread_local.connection = self.create_database_connection( + project=self.project, + memory=self.memory, + ) + return thread_local.connection + + @connection.setter + def connection(self, value: sqlite3.Connection): + thread_local.connection = value + def _setup_db(self): models.Metadata.create_table(self.connection) version_hash = list( From 97601dfd4fffdaa68f1332982a980a0490cdcf66 Mon Sep 17 00:00:00 2001 From: tkrabel Date: Sun, 29 Oct 2023 14:47:19 +0100 Subject: [PATCH 2/5] add unit test --- ropetest/contrib/autoimport/autoimporttest.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/ropetest/contrib/autoimport/autoimporttest.py b/ropetest/contrib/autoimport/autoimporttest.py index 4b0430119..65c1235ca 100644 --- a/ropetest/contrib/autoimport/autoimporttest.py +++ b/ropetest/contrib/autoimport/autoimporttest.py @@ -1,3 +1,4 @@ +from concurrent.futures import ThreadPoolExecutor from contextlib import closing, contextmanager from textwrap import dedent from unittest.mock import ANY, patch @@ -85,6 +86,28 @@ def foo(): assert [("from pkg1 import foo", "foo")] == results +def test_multithreading( + autoimport: AutoImport, + project: Project, + pkg1: Folder, + mod1: File, +): + mod1_init = pkg1.get_child("__init__.py") + mod1_init.write(dedent("""\ + def foo(): + pass + """)) + mod1.write(dedent("""\ + foo + """)) + autoimport = AutoImport(project, memory=False) + autoimport.generate_cache([mod1_init]) + + tp = ThreadPoolExecutor(1) + results = tp.submit(autoimport.search, "foo", True).result() + assert [("from pkg1 import foo", "foo")] == results + + @contextmanager def assert_database_is_reset(conn): conn.execute("ALTER TABLE names ADD COLUMN deprecated_column") From 681a9b895db16227ebadf49539556cf2203119f1 Mon Sep 17 00:00:00 2001 From: tkrabel Date: Sun, 29 Oct 2023 14:50:05 +0100 Subject: [PATCH 3/5] add CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b7c0104fa..014de0b65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ - #710, #561 Implement `except*` syntax (@lieryan) - #711 allow building documentation without having rope module installed (@kloczek) +- #720 create one sqlite3.Connection per thread using a thread local (@tkrabel) # Release 1.10.0 From bf8d3a336b14404d484bd707b81f8b5ffa667ad3 Mon Sep 17 00:00:00 2001 From: tkrabel Date: Sun, 29 Oct 2023 14:51:35 +0100 Subject: [PATCH 4/5] fix docstring --- rope/contrib/autoimport/sqlite.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rope/contrib/autoimport/sqlite.py b/rope/contrib/autoimport/sqlite.py index 05bd94ab5..14561610e 100644 --- a/rope/contrib/autoimport/sqlite.py +++ b/rope/contrib/autoimport/sqlite.py @@ -164,7 +164,8 @@ def create_database_connection( @property def connection(self): - """Creates a new connection if called from a new thread. + """ + Creates a new connection if called from a new thread. This makes sure AutoImport can be shared across threads. """ From db187bd6bd0b191db68857f46bbbe093f5170720 Mon Sep 17 00:00:00 2001 From: tkrabel Date: Wed, 1 Nov 2023 08:48:07 +0100 Subject: [PATCH 5/5] bugfix: every autoimport gets its own connection --- rope/contrib/autoimport/sqlite.py | 10 +++++----- ropetest/conftest.py | 7 +++++++ ropetest/contrib/autoimport/autoimporttest.py | 9 +++++++++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/rope/contrib/autoimport/sqlite.py b/rope/contrib/autoimport/sqlite.py index 14561610e..5d8554a90 100644 --- a/rope/contrib/autoimport/sqlite.py +++ b/rope/contrib/autoimport/sqlite.py @@ -68,7 +68,6 @@ def filter_package(package: Package) -> bool: _deprecated_default: bool = object() # type: ignore -thread_local = local() class AutoImport: @@ -127,6 +126,7 @@ def __init__( "`AutoImport(memory=True)` explicitly.", DeprecationWarning, ) + self.thread_local = local() self.connection = self.create_database_connection( project=project, memory=memory, @@ -169,16 +169,16 @@ def connection(self): This makes sure AutoImport can be shared across threads. """ - if not hasattr(thread_local, "connection"): - thread_local.connection = self.create_database_connection( + if not hasattr(self.thread_local, "connection"): + self.thread_local.connection = self.create_database_connection( project=self.project, memory=self.memory, ) - return thread_local.connection + return self.thread_local.connection @connection.setter def connection(self, value: sqlite3.Connection): - thread_local.connection = value + self.thread_local.connection = value def _setup_db(self): models.Metadata.create_table(self.connection) diff --git a/ropetest/conftest.py b/ropetest/conftest.py index d2efc68c6..32a35aaa6 100644 --- a/ropetest/conftest.py +++ b/ropetest/conftest.py @@ -18,6 +18,13 @@ def project_path(project): yield pathlib.Path(project.address) +@pytest.fixture +def project2(): + project = testutils.sample_project("sample_project2") + yield project + testutils.remove_project(project) + + """ Standard project structure for pytest fixtures /mod1.py -- mod1 diff --git a/ropetest/contrib/autoimport/autoimporttest.py b/ropetest/contrib/autoimport/autoimporttest.py index 65c1235ca..d65d8b2bf 100644 --- a/ropetest/contrib/autoimport/autoimporttest.py +++ b/ropetest/contrib/autoimport/autoimporttest.py @@ -108,6 +108,15 @@ def foo(): assert [("from pkg1 import foo", "foo")] == results +def test_connection(project: Project, project2: Project): + ai1 = AutoImport(project) + ai2 = AutoImport(project) + ai3 = AutoImport(project2) + + assert ai1.connection is not ai2.connection + assert ai1.connection is not ai3.connection + + @contextmanager def assert_database_is_reset(conn): conn.execute("ALTER TABLE names ADD COLUMN deprecated_column")