Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions testing/PostgresDockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ RUN apt update -y && \
apt install -y \
python3.8 \
python3-pip \
python3-psycopg2 \
curl \
wget \
pkg-config \
Expand Down
6 changes: 5 additions & 1 deletion testing/postgres-client-tests/postgres-client-tests.bats
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ teardown() {
node $BATS_TEST_DIRNAME/node/workbench.js $USER $PORT $DOLTGRES_VERSION $BATS_TEST_DIRNAME/node/testdata
}


@test "perl DBI:Pg client" {
perl $BATS_TEST_DIRNAME/perl/postgres-test.pl $USER $PORT
}
Expand All @@ -68,3 +67,8 @@ teardown() {
(cd $BATS_TEST_DIRNAME/c; make clean; make)
$BATS_TEST_DIRNAME/c/postgres-c-connector-test $USER $PORT
}

@test "python postgres: psycopg2 client" {
cd $BATS_TEST_DIRNAME/python
python3 psycopg2_test.py $USER $PORT
}
115 changes: 115 additions & 0 deletions testing/postgres-client-tests/python/psycopg2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#!/usr/bin/env python3
import os
import sys
import traceback
import psycopg2

# ---------------------------------------------------------------------------
# Query list (kept at top for consistency with other tests)
# ---------------------------------------------------------------------------

TEST_QUERIES = [
"DROP TABLE IF EXISTS test",
"create table test (pk int, value int, d1 decimal(9, 3), f1 float, c1 char(10), t1 text, primary key(pk))",
"select * from test",
"insert into test (pk, value, d1, f1, c1, t1) values (0,0,0.0,0.0,'abc','a1')",
"select * from test",
"select dolt_add('-A');",
"select dolt_commit('-m', 'my commit')",
"select COUNT(*) FROM dolt.log",
"select dolt_checkout('-b', 'mybranch')",
"insert into test (pk, value, d1, f1, c1, t1) values (10,10, 123456.789, 420.42,'example','some text')",
"select dolt_commit('-a', '-m', 'my commit2')",
"select dolt_checkout('main')",
"select dolt_merge('mybranch')",
"select COUNT(*) FROM dolt.log",
]

# ---------------------------------------------------------------------------

def env(name, default=None):
return os.getenv(name, default)


def connect(user: str, port: int):
conn = psycopg2.connect(
host=env("PGHOST", "localhost"),
port=port,
dbname="postgres",
user=user,
password=env("PGPASSWORD", "password"),
connect_timeout=int(env("PGCONNECT_TIMEOUT", "10")),
sslmode=env("PGSSLMODE"),
)
conn.autocommit = True
return conn


def run(cur, q):
print(f"SQL> {q}", flush=True)
cur.execute(q)
if cur.description is not None:
cur.fetchall() # drain result set

# load_test creates a table with |n_rows| and asserts that all rows are correctly returned.
def load_test(cur, n_rows=1000):
print("\n=== Part 1: Load test ===", flush=True)

rows = max(1000, int(n_rows))

run(cur, "DROP TABLE IF EXISTS load_test")
run(cur, "CREATE TABLE load_test (id INT PRIMARY KEY, val INT NOT NULL)")

data = [(i, i * 10) for i in range(rows)]
cur.executemany(
"INSERT INTO load_test (id, val) VALUES (%s, %s)",
data,
)

cur.execute("SELECT COUNT(*) FROM load_test")
cnt = cur.fetchone()[0]
if cnt != rows:
raise AssertionError(f"COUNT(*) mismatch: expected {rows}, got {cnt}")

cur.execute("SELECT id FROM load_test ORDER BY id")
got = cur.fetchall()
if len(got) != rows:
raise AssertionError(f"fetchall mismatch: expected {rows}, got {len(got)}")

print(f"Inserted and selected {rows} rows OK.", flush=True)


def compliance_test(cur):
print("\n=== Part 2: Test Queries ===", flush=True)
for q in TEST_QUERIES:
run(cur, q)
print("Compliance queries executed OK.", flush=True)


def main():
if len(sys.argv) != 3:
print("Usage: python3 psycopg2_test.py <user> <port>")
return 2

user = sys.argv[1]
port = int(sys.argv[2])
load_rows = int(env("LOAD_ROWS", "1000"))

try:
with connect(user, port) as conn:
with conn.cursor() as cur:
load_test(cur, load_rows)
compliance_test(cur)

print("\n✅ All tests passed.", flush=True)
return 0

except Exception as e:
print("\n❌ Test failed.", flush=True)
print(f"Error: {e}", flush=True)
traceback.print_exc()
return 1


if __name__ == "__main__":
sys.exit(main())