@@ -1126,47 +1126,64 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
1126
1126
1127
1127
// Prepare any needed queries
1128
1128
if len (distinctNewQueries ) > 0 {
1129
- for _ , sd := range distinctNewQueries {
1130
- pipeline .SendPrepare (sd .Name , sd .SQL , nil )
1131
- }
1129
+ err := func () (err error ) {
1130
+ for _ , sd := range distinctNewQueries {
1131
+ pipeline .SendPrepare (sd .Name , sd .SQL , nil )
1132
+ }
1132
1133
1133
- err := pipeline .Sync ()
1134
- if err != nil {
1135
- return & pipelineBatchResults {ctx : ctx , conn : c , err : err , closed : true }
1136
- }
1134
+ // Store all statements we are preparing into the cache. It's fine if it overflows because HandleInvalidated will
1135
+ // clean them up later.
1136
+ if sdCache != nil {
1137
+ for _ , sd := range distinctNewQueries {
1138
+ sdCache .Put (sd )
1139
+ }
1140
+ }
1141
+
1142
+ // If something goes wrong preparing the statements, we need to invalidate the cache entries we just added.
1143
+ defer func () {
1144
+ if err != nil && sdCache != nil {
1145
+ for _ , sd := range distinctNewQueries {
1146
+ sdCache .Invalidate (sd .SQL )
1147
+ }
1148
+ }
1149
+ }()
1150
+
1151
+ err = pipeline .Sync ()
1152
+ if err != nil {
1153
+ return err
1154
+ }
1155
+
1156
+ for _ , sd := range distinctNewQueries {
1157
+ results , err := pipeline .GetResults ()
1158
+ if err != nil {
1159
+ return err
1160
+ }
1161
+
1162
+ resultSD , ok := results .(* pgconn.StatementDescription )
1163
+ if ! ok {
1164
+ return fmt .Errorf ("expected statement description, got %T" , results )
1165
+ }
1166
+
1167
+ // Fill in the previously empty / pending statement descriptions.
1168
+ sd .ParamOIDs = resultSD .ParamOIDs
1169
+ sd .Fields = resultSD .Fields
1170
+ }
1137
1171
1138
- for _ , sd := range distinctNewQueries {
1139
1172
results , err := pipeline .GetResults ()
1140
1173
if err != nil {
1141
- return & pipelineBatchResults { ctx : ctx , conn : c , err : err , closed : true }
1174
+ return err
1142
1175
}
1143
1176
1144
- resultSD , ok := results .(* pgconn.StatementDescription )
1177
+ _ , ok := results .(* pgconn.PipelineSync )
1145
1178
if ! ok {
1146
- return & pipelineBatchResults { ctx : ctx , conn : c , err : fmt .Errorf ("expected statement description , got %T" , results ), closed : true }
1179
+ return fmt .Errorf ("expected sync , got %T" , results )
1147
1180
}
1148
1181
1149
- // Fill in the previously empty / pending statement descriptions.
1150
- sd .ParamOIDs = resultSD .ParamOIDs
1151
- sd .Fields = resultSD .Fields
1152
- }
1153
-
1154
- results , err := pipeline .GetResults ()
1182
+ return nil
1183
+ }()
1155
1184
if err != nil {
1156
1185
return & pipelineBatchResults {ctx : ctx , conn : c , err : err , closed : true }
1157
1186
}
1158
-
1159
- _ , ok := results .(* pgconn.PipelineSync )
1160
- if ! ok {
1161
- return & pipelineBatchResults {ctx : ctx , conn : c , err : fmt .Errorf ("expected sync, got %T" , results ), closed : true }
1162
- }
1163
- }
1164
-
1165
- // Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
1166
- if sdCache != nil {
1167
- for _ , sd := range distinctNewQueries {
1168
- sdCache .Put (sd )
1169
- }
1170
1187
}
1171
1188
1172
1189
// Queue the queries.
0 commit comments