diff --git a/tests/test_copy.py b/tests/test_copy.py index 185283ac..be2aabaf 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -10,6 +10,7 @@ import io import os import tempfile +import unittest import asyncpg from asyncpg import _testbase as tb @@ -415,7 +416,6 @@ async def test_copy_to_table_basics(self): '*a5*|b5', '*!**|*n-u-l-l*', 'n-u-l-l|bb', - '_-_filtered_-_value_-_|never-here' ]).encode('utf-8') ) f.seek(0) @@ -432,7 +432,7 @@ async def test_copy_to_table_basics(self): schema_name='public', format='csv', delimiter='|', null='n-u-l-l', header=True, quote='*', escape='!', force_not_null=('a',), - force_null=force_null, where='a <> \'_-_filtered_-_value_-_\'') + force_null=force_null) self.assertEqual(res, 'COPY 7') @@ -636,16 +636,44 @@ async def test_copy_records_to_table_1(self): ] records.append(('a-100', None, None)) - records.append(('b-999', None, None)) res = await self.con.copy_records_to_table( - 'copytab', records=records, where='a <> \'b-999\'') + 'copytab', records=records) self.assertEqual(res, 'COPY 101') finally: await self.con.execute('DROP TABLE copytab') + async def test_copy_records_to_table_where(self): + if not self.con._server_caps.sql_copy_from_where: + raise unittest.SkipTest( + 'COPY WHERE not supported on server') + + await self.con.execute(''' + CREATE TABLE copytab_where(a text, b int, c timestamptz); + ''') + + try: + date = datetime.datetime.now(tz=datetime.timezone.utc) + delta = datetime.timedelta(days=1) + + records = [ + ('a-{}'.format(i), i, date + delta) + for i in range(100) + ] + + records.append(('a-100', None, None)) + records.append(('b-999', None, None)) + + res = await self.con.copy_records_to_table( + 'copytab_where', records=records, where='a <> \'b-999\'') + + self.assertEqual(res, 'COPY 101') + + finally: + await self.con.execute('DROP TABLE copytab_where') + async def test_copy_records_to_table_async(self): await self.con.execute(''' CREATE TABLE copytab_async(a text, b int, c timestamptz); @@ -660,11 +688,9 @@ async def record_generator(): yield ('a-{}'.format(i), i, date + delta) yield ('a-100', None, None) - yield ('b-999', None, None) res = await self.con.copy_records_to_table( 'copytab_async', records=record_generator(), - where='a <> \'b-999\'' ) self.assertEqual(res, 'COPY 101')