@@ -156,11 +156,16 @@ cdef class BaseProtocol(CoreProtocol):
156156 self ._check_state()
157157 timeout = self ._get_timeout_impl(timeout)
158158
159- self ._prepare(stmt_name, query)
160- self .last_query = query
161- self .statement = PreparedStatementState(stmt_name, query, self )
162-
163- return await self ._new_waiter(timeout)
159+ waiter = self ._new_waiter(timeout)
160+ try :
161+ self ._prepare(stmt_name, query) # network op
162+ self .last_query = query
163+ self .statement = PreparedStatementState(stmt_name, query, self )
164+ except Exception as ex:
165+ waiter.set_exception(ex)
166+ self ._coreproto_error()
167+ finally :
168+ return await waiter
164169
165170 async def bind_execute(self , PreparedStatementState state, args,
166171 str portal_name, int limit, return_extra,
@@ -174,19 +179,25 @@ cdef class BaseProtocol(CoreProtocol):
174179
175180 self ._check_state()
176181 timeout = self ._get_timeout_impl(timeout)
182+ args_buf = state._encode_bind_msg(args)
177183
178- self ._bind_execute(
179- portal_name,
180- state.name,
181- state._encode_bind_msg(args),
182- limit)
183-
184- self .last_query = state.query
185- self .statement = state
186- self .return_extra = return_extra
187- self .queries_count += 1
188-
189- return await self ._new_waiter(timeout)
184+ waiter = self ._new_waiter(timeout)
185+ try :
186+ self ._bind_execute(
187+ portal_name,
188+ state.name,
189+ args_buf,
190+ limit) # network op
191+
192+ self .last_query = state.query
193+ self .statement = state
194+ self .return_extra = return_extra
195+ self .queries_count += 1
196+ except Exception as ex:
197+ waiter.set_exception(ex)
198+ self ._coreproto_error()
199+ finally :
200+ return await waiter
190201
191202 async def bind_execute_many(self , PreparedStatementState state, args,
192203 str portal_name, timeout):
@@ -207,18 +218,21 @@ cdef class BaseProtocol(CoreProtocol):
207218 arg_bufs = iter (data_gen)
208219
209220 waiter = self ._new_waiter(timeout)
221+ try :
222+ self ._bind_execute_many(
223+ portal_name,
224+ state.name,
225+ arg_bufs) # network op
210226
211- self ._bind_execute_many(
212- portal_name,
213- state.name,
214- arg_bufs)
215-
216- self .last_query = state.query
217- self .statement = state
218- self .return_extra = False
219- self .queries_count += 1
220-
221- return await waiter
227+ self .last_query = state.query
228+ self .statement = state
229+ self .return_extra = False
230+ self .queries_count += 1
231+ except Exception as ex:
232+ waiter.set_exception(ex)
233+ self ._coreproto_error()
234+ finally :
235+ return await waiter
222236
223237 async def bind(self , PreparedStatementState state, args,
224238 str portal_name, timeout):
@@ -231,16 +245,22 @@ cdef class BaseProtocol(CoreProtocol):
231245
232246 self ._check_state()
233247 timeout = self ._get_timeout_impl(timeout)
248+ args_buf = state._encode_bind_msg(args)
234249
235- self ._bind(
236- portal_name,
237- state.name,
238- state._encode_bind_msg(args))
239-
240- self .last_query = state.query
241- self .statement = state
242-
243- return await self ._new_waiter(timeout)
250+ waiter = self ._new_waiter(timeout)
251+ try :
252+ self ._bind(
253+ portal_name,
254+ state.name,
255+ args_buf) # network op
256+
257+ self .last_query = state.query
258+ self .statement = state
259+ except Exception as ex:
260+ waiter.set_exception(ex)
261+ self ._coreproto_error()
262+ finally :
263+ return await waiter
244264
245265 async def execute(self , PreparedStatementState state,
246266 str portal_name, int limit, return_extra,
@@ -255,16 +275,21 @@ cdef class BaseProtocol(CoreProtocol):
255275 self ._check_state()
256276 timeout = self ._get_timeout_impl(timeout)
257277
258- self ._execute(
259- portal_name,
260- limit)
261-
262- self .last_query = state.query
263- self .statement = state
264- self .return_extra = return_extra
265- self .queries_count += 1
266-
267- return await self ._new_waiter(timeout)
278+ waiter = self ._new_waiter(timeout)
279+ try :
280+ self ._execute(
281+ portal_name,
282+ limit) # network op
283+
284+ self .last_query = state.query
285+ self .statement = state
286+ self .return_extra = return_extra
287+ self .queries_count += 1
288+ except Exception as ex:
289+ waiter.set_exception(ex)
290+ self ._coreproto_error()
291+ finally :
292+ return await waiter
268293
269294 async def query(self , query, timeout):
270295 if self .cancel_waiter is not None :
@@ -279,11 +304,16 @@ cdef class BaseProtocol(CoreProtocol):
279304 # prepare/bind/execute methods.
280305 timeout = self ._get_timeout(timeout)
281306
282- self ._simple_query(query)
283- self .last_query = query
284- self .queries_count += 1
285-
286- return await self ._new_waiter(timeout)
307+ waiter = self ._new_waiter(timeout)
308+ try :
309+ self ._simple_query(query) # network op
310+ self .last_query = query
311+ self .queries_count += 1
312+ except Exception as ex:
313+ waiter.set_exception(ex)
314+ self ._coreproto_error()
315+ finally :
316+ return await waiter
287317
288318 async def copy_out(self , copy_stmt, sink, timeout):
289319 if self .cancel_waiter is not None :
@@ -378,7 +408,7 @@ cdef class BaseProtocol(CoreProtocol):
378408 for codec in codecs:
379409 if (not codec.has_encoder() or
380410 codec.format != PG_FORMAT_BINARY):
381- raise RuntimeError (
411+ raise apg_exc.InternalClientError (
382412 ' no binary format encoder for '
383413 ' type {} (OID {})' .format(codec.name, codec.oid))
384414
@@ -439,7 +469,7 @@ cdef class BaseProtocol(CoreProtocol):
439469 except TimeoutError:
440470 raise
441471 else :
442- raise RuntimeError (' TimoutError was not raised' )
472+ raise apg_exc.InternalClientError (' TimoutError was not raised' )
443473
444474 except Exception as e:
445475 self ._write_copy_fail_msg(str (e))
@@ -460,16 +490,22 @@ cdef class BaseProtocol(CoreProtocol):
460490 self .cancel_sent_waiter = None
461491
462492 self ._check_state()
463- timeout = self ._get_timeout_impl(timeout)
464493
465494 if state.refs != 0 :
466- raise RuntimeError (
495+ raise apg_exc.InternalClientError (
467496 ' cannot close prepared statement; refs == {} != 0' .format(
468497 state.refs))
469498
470- self ._close(state.name, False )
471- state.closed = True
472- return await self ._new_waiter(timeout)
499+ timeout = self ._get_timeout_impl(timeout)
500+ waiter = self ._new_waiter(timeout)
501+ try :
502+ self ._close(state.name, False ) # network op
503+ state.closed = True
504+ except Exception as ex:
505+ waiter.set_exception(ex)
506+ self ._coreproto_error()
507+ finally :
508+ return await waiter
473509
474510 def is_closed (self ):
475511 return self .closing
@@ -579,6 +615,17 @@ cdef class BaseProtocol(CoreProtocol):
579615 raise apg_exc.InterfaceError(
580616 ' cannot perform operation: another operation is in progress' )
581617
618+ cdef _coreproto_error(self ):
619+ try :
620+ if self .waiter is not None :
621+ if not self .waiter.done():
622+ raise apg_exc.InternalClientError(
623+ ' waiter is not done while handling critical '
624+ ' protocol error' )
625+ self .waiter = None
626+ finally :
627+ self .abort()
628+
582629 cdef _new_waiter(self , timeout):
583630 if self .waiter is not None :
584631 raise apg_exc.InterfaceError(
@@ -596,7 +643,7 @@ cdef class BaseProtocol(CoreProtocol):
596643 cdef _on_result__prepare(self , object waiter):
597644 if ASYNCPG_DEBUG:
598645 if self .statement is None :
599- raise RuntimeError (
646+ raise apg_exc.InternalClientError (
600647 ' _on_result__prepare: statement is None' )
601648
602649 if self .result_param_desc is not None :
@@ -643,7 +690,7 @@ cdef class BaseProtocol(CoreProtocol):
643690 cdef _decode_row(self , const char * buf, ssize_t buf_len):
644691 if ASYNCPG_DEBUG:
645692 if self .statement is None :
646- raise RuntimeError (
693+ raise apg_exc.InternalClientError (
647694 ' _decode_row: statement is None' )
648695
649696 return self .statement._decode_row(buf, buf_len)
@@ -654,13 +701,13 @@ cdef class BaseProtocol(CoreProtocol):
654701
655702 if ASYNCPG_DEBUG:
656703 if waiter is None :
657- raise RuntimeError (' _on_result: waiter is None' )
704+ raise apg_exc.InternalClientError (' _on_result: waiter is None' )
658705
659706 if waiter.cancelled():
660707 return
661708
662709 if waiter.done():
663- raise RuntimeError (' _on_result: waiter is done' )
710+ raise apg_exc.InternalClientError (' _on_result: waiter is done' )
664711
665712 if self .result_type == RESULT_FAILED:
666713 if isinstance (self .result, dict ):
@@ -704,7 +751,7 @@ cdef class BaseProtocol(CoreProtocol):
704751 self ._on_result__copy_in(waiter)
705752
706753 else :
707- raise RuntimeError (
754+ raise apg_exc.InternalClientError (
708755 ' got result for unknown protocol state {}' .
709756 format(self .state))
710757
0 commit comments