Skip to content

Commit fd0c654

Browse files
committed
Fix prepared statement already exists on batch prepare failure
When a batch successfully prepared some statements, but then failed to prepare others, the prepared statements that were successfully prepared were not properly cleaned up. This could lead to a "prepared statement already exists" error on subsequent attempts to prepare the same statement. #1847 (comment)
1 parent 672c4a3 commit fd0c654

File tree

2 files changed

+76
-29
lines changed

2 files changed

+76
-29
lines changed

batch_test.go

+30
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,36 @@ func TestSendBatchSimpleProtocol(t *testing.T) {
10081008
assert.False(t, rows.Next())
10091009
}
10101010

1011+
// https://github.com/jackc/pgx/issues/1847#issuecomment-2347858887
1012+
func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
1013+
t.Parallel()
1014+
1015+
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
1016+
defer cancel()
1017+
1018+
pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
1019+
pgxtest.SkipCockroachDB(t, conn, "Server serial type is incompatible with test")
1020+
1021+
mustExec(t, conn, `create temporary table foo(col1 text primary key);`)
1022+
1023+
batch := &pgx.Batch{}
1024+
batch.Queue("select col1 from foo")
1025+
batch.Queue("select col1 from baz")
1026+
err := conn.SendBatch(ctx, batch).Close()
1027+
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)
1028+
1029+
mustExec(t, conn, `create temporary table baz(col1 text primary key);`)
1030+
1031+
// Since table baz now exists, the batch should succeed.
1032+
1033+
batch = &pgx.Batch{}
1034+
batch.Queue("select col1 from foo")
1035+
batch.Queue("select col1 from baz")
1036+
err = conn.SendBatch(ctx, batch).Close()
1037+
require.NoError(t, err)
1038+
})
1039+
}
1040+
10111041
func ExampleConn_SendBatch() {
10121042
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
10131043
defer cancel()

conn.go

+46-29
Original file line numberDiff line numberDiff line change
@@ -1126,47 +1126,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
11261126

11271127
// Prepare any needed queries
11281128
if len(distinctNewQueries) > 0 {
1129-
for _, sd := range distinctNewQueries {
1130-
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
1131-
}
1129+
err := func() (err error) {
1130+
for _, sd := range distinctNewQueries {
1131+
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
1132+
}
11321133

1133-
err := pipeline.Sync()
1134-
if err != nil {
1135-
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
1136-
}
1134+
// Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will
1135+
// clean them up later.
1136+
if sdCache != nil {
1137+
for _, sd := range distinctNewQueries {
1138+
sdCache.Put(sd)
1139+
}
1140+
}
1141+
1142+
// If something goes wrong preparing the statements, we need to invalidate the cache entries we just added.
1143+
defer func() {
1144+
if err != nil && sdCache != nil {
1145+
for _, sd := range distinctNewQueries {
1146+
sdCache.Invalidate(sd.SQL)
1147+
}
1148+
}
1149+
}()
1150+
1151+
err = pipeline.Sync()
1152+
if err != nil {
1153+
return err
1154+
}
1155+
1156+
for _, sd := range distinctNewQueries {
1157+
results, err := pipeline.GetResults()
1158+
if err != nil {
1159+
return err
1160+
}
1161+
1162+
resultSD, ok := results.(*pgconn.StatementDescription)
1163+
if !ok {
1164+
return fmt.Errorf("expected statement description, got %T", results)
1165+
}
1166+
1167+
// Fill in the previously empty / pending statement descriptions.
1168+
sd.ParamOIDs = resultSD.ParamOIDs
1169+
sd.Fields = resultSD.Fields
1170+
}
11371171

1138-
for _, sd := range distinctNewQueries {
11391172
results, err := pipeline.GetResults()
11401173
if err != nil {
1141-
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
1174+
return err
11421175
}
11431176

1144-
resultSD, ok := results.(*pgconn.StatementDescription)
1177+
_, ok := results.(*pgconn.PipelineSync)
11451178
if !ok {
1146-
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results), closed: true}
1179+
return fmt.Errorf("expected sync, got %T", results)
11471180
}
11481181

1149-
// Fill in the previously empty / pending statement descriptions.
1150-
sd.ParamOIDs = resultSD.ParamOIDs
1151-
sd.Fields = resultSD.Fields
1152-
}
1153-
1154-
results, err := pipeline.GetResults()
1182+
return nil
1183+
}()
11551184
if err != nil {
11561185
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
11571186
}
1158-
1159-
_, ok := results.(*pgconn.PipelineSync)
1160-
if !ok {
1161-
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results), closed: true}
1162-
}
1163-
}
1164-
1165-
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
1166-
if sdCache != nil {
1167-
for _, sd := range distinctNewQueries {
1168-
sdCache.Put(sd)
1169-
}
11701187
}
11711188

11721189
// Queue the queries.

0 commit comments

Comments
 (0)