Skip to content

Commit

Permalink
bpo-43553: Improve sqlite3 test coverage (GH-26886)
Browse files Browse the repository at this point in the history
  • Loading branch information
Erlend Egeberg Aasland authored Jun 24, 2021
1 parent 9049ea5 commit 2c1ae09
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 2 deletions.
26 changes: 24 additions & 2 deletions Lib/sqlite3/test/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
import threading
import unittest

from test.support import check_disallow_instantiation
from test.support import check_disallow_instantiation, threading_helper
from test.support.os_helper import TESTFN, unlink
from test.support import threading_helper


# Helper for tests using TESTFN
Expand Down Expand Up @@ -110,6 +109,10 @@ def test_disallow_instantiation(self):
cx = sqlite.connect(":memory:")
check_disallow_instantiation(self, type(cx("select 1")))

def test_complete_statement(self):
self.assertFalse(sqlite.complete_statement("select t"))
self.assertTrue(sqlite.complete_statement("create table t(t);"))


class ConnectionTests(unittest.TestCase):

Expand Down Expand Up @@ -225,6 +228,20 @@ def test_connection_exceptions(self):
self.assertTrue(hasattr(self.cx, exc))
self.assertIs(getattr(sqlite, exc), getattr(self.cx, exc))

def test_interrupt_on_closed_db(self):
cx = sqlite.connect(":memory:")
cx.close()
with self.assertRaises(sqlite.ProgrammingError):
cx.interrupt()

def test_interrupt(self):
self.assertIsNone(self.cx.interrupt())

def test_drop_unused_refs(self):
for n in range(500):
cu = self.cx.execute(f"select {n}")
self.assertEqual(cu.fetchone()[0], n)


class OpenTests(unittest.TestCase):
_sql = "create table test(id integer)"
Expand Down Expand Up @@ -594,6 +611,11 @@ def test_column_count(self):
new_count = len(res.description)
self.assertEqual(new_count - old_count, 1)

def test_same_query_in_multiple_cursors(self):
cursors = [self.cx.execute("select 1") for _ in range(3)]
for cu in cursors:
self.assertEqual(cu.fetchall(), [(1,)])


class ThreadTests(unittest.TestCase):
def setUp(self):
Expand Down
2 changes: 2 additions & 0 deletions Lib/sqlite3/test/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def test_sqlite_row_index(self):
row[-3]
with self.assertRaises(IndexError):
row[2**1000]
with self.assertRaises(IndexError):
row[complex()] # index must be int or string

def test_sqlite_row_index_unicode(self):
self.con.row_factory = sqlite.Row
Expand Down
37 changes: 37 additions & 0 deletions Lib/sqlite3/test/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,43 @@ def test_caster_is_used(self):
val = self.cur.fetchone()[0]
self.assertEqual(type(val), float)

def test_missing_adapter(self):
with self.assertRaises(sqlite.ProgrammingError):
sqlite.adapt(1.) # No float adapter registered

def test_missing_protocol(self):
with self.assertRaises(sqlite.ProgrammingError):
sqlite.adapt(1, None)

def test_defect_proto(self):
class DefectProto():
def __adapt__(self):
return None
with self.assertRaises(sqlite.ProgrammingError):
sqlite.adapt(1., DefectProto)

def test_defect_self_adapt(self):
class DefectSelfAdapt(float):
def __conform__(self, _):
return None
with self.assertRaises(sqlite.ProgrammingError):
sqlite.adapt(DefectSelfAdapt(1.))

def test_custom_proto(self):
class CustomProto():
def __adapt__(self):
return "adapted"
self.assertEqual(sqlite.adapt(1., CustomProto), "adapted")

def test_adapt(self):
val = 42
self.assertEqual(float(val), sqlite.adapt(val))

def test_adapt_alt(self):
alt = "other"
self.assertEqual(alt, sqlite.adapt(1., None, alt))


@unittest.skipUnless(zlib, "requires zlib")
class BinaryConverterTests(unittest.TestCase):
def convert(s):
Expand Down
37 changes: 37 additions & 0 deletions Lib/sqlite3/test/userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,36 @@
# misrepresented as being the original software.
# 3. This notice may not be removed or altered from any source distribution.

import contextlib
import functools
import io
import unittest
import unittest.mock
import gc
import sqlite3 as sqlite

def with_tracebacks(strings):
"""Convenience decorator for testing callback tracebacks."""
strings.append('Traceback')

def decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# First, run the test with traceback enabled.
sqlite.enable_callback_tracebacks(True)
buf = io.StringIO()
with contextlib.redirect_stderr(buf):
func(self, *args, **kwargs)
tb = buf.getvalue()
for s in strings:
self.assertIn(s, tb)

# Then run the test with traceback disabled.
sqlite.enable_callback_tracebacks(False)
func(self, *args, **kwargs)
return wrapper
return decorator

def func_returntext():
return "foo"
def func_returnunicode():
Expand Down Expand Up @@ -228,6 +253,7 @@ def test_func_return_long_long(self):
val = cur.fetchone()[0]
self.assertEqual(val, 1<<31)

@with_tracebacks(['func_raiseexception', '5/0', 'ZeroDivisionError'])
def test_func_exception(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
Expand Down Expand Up @@ -387,20 +413,23 @@ def test_aggr_no_finalize(self):
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")

@with_tracebacks(['__init__', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_init(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select excInit(t) from test")
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")

@with_tracebacks(['step', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_step(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
cur.execute("select excStep(t) from test")
val = cur.fetchone()[0]
self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")

@with_tracebacks(['finalize', '5/0', 'ZeroDivisionError'])
def test_aggr_exception_in_finalize(self):
cur = self.con.cursor()
with self.assertRaises(sqlite.OperationalError) as cm:
Expand Down Expand Up @@ -502,6 +531,14 @@ def authorizer_cb(action, arg1, arg2, dbname, source):
raise ValueError
return sqlite.SQLITE_OK

@with_tracebacks(['authorizer_cb', 'ValueError'])
def test_table_access(self):
super().test_table_access()

@with_tracebacks(['authorizer_cb', 'ValueError'])
def test_column_access(self):
super().test_table_access()

class AuthorizerIllegalTypeTests(AuthorizerTests):
@staticmethod
def authorizer_cb(action, arg1, arg2, dbname, source):
Expand Down

0 comments on commit 2c1ae09

Please sign in to comment.