@@ -89,32 +89,33 @@ def _make_session() -> requests.Session:
8989 return s
9090
9191
92- def parse_stream_helper (line : bytes ):
92+ def parse_stream_helper (line : bytes ) -> Optional [ str ] :
9393 if line :
9494 if line .strip () == b"data: [DONE]" :
9595 # return here will cause GeneratorExit exception in urllib3
9696 # and it will close http connection with TCP Reset
9797 return None
98- if hasattr (line , "decode" ):
99- line = line .decode ("utf-8" )
100- if line .startswith ("data: " ):
101- line = line [len ("data: " ) :]
102- return line
98+ if line .startswith (b"data: " ):
99+ line = line [len (b"data: " ) :]
100+ return line .decode ("utf-8" )
103101 return None
104102
105103
106- def parse_stream (rbody ) :
104+ def parse_stream (rbody : Iterator [ bytes ]) -> Iterator [ str ] :
107105 for line in rbody :
108106 _line = parse_stream_helper (line )
109107 if _line is not None :
110108 yield _line
111109
112110
113111async def parse_stream_async (rbody : aiohttp .StreamReader ):
114- async for line , _ in rbody .iter_chunks ():
115- _line = parse_stream_helper (line )
116- if _line is not None :
117- yield _line
112+ async for chunk , _ in rbody .iter_chunks ():
113+ # While the `ChunkTupleAsyncStreamIterator` iterator is meant to iterate over chunks (and thus lines) it seems
114+ # to still sometimes return multiple lines at a time, so let's split the chunk by lines again.
115+ for line in chunk .splitlines ():
116+ _line = parse_stream_helper (line )
117+ if _line is not None :
118+ yield _line
118119
119120
120121class APIRequestor :
@@ -296,20 +297,25 @@ async def arequest(
296297 ) -> Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
297298 ctx = aiohttp_session ()
298299 session = await ctx .__aenter__ ()
299- result = await self .arequest_raw (
300- method .lower (),
301- url ,
302- session ,
303- params = params ,
304- supplied_headers = headers ,
305- files = files ,
306- request_id = request_id ,
307- request_timeout = request_timeout ,
308- )
309- resp , got_stream = await self ._interpret_async_response (result , stream )
300+ try :
301+ result = await self .arequest_raw (
302+ method .lower (),
303+ url ,
304+ session ,
305+ params = params ,
306+ supplied_headers = headers ,
307+ files = files ,
308+ request_id = request_id ,
309+ request_timeout = request_timeout ,
310+ )
311+ resp , got_stream = await self ._interpret_async_response (result , stream )
312+ except Exception :
313+ await ctx .__aexit__ (None , None , None )
314+ raise
310315 if got_stream :
311316
312317 async def wrap_resp ():
318+ assert isinstance (resp , AsyncGenerator )
313319 try :
314320 async for r in resp :
315321 yield r
@@ -612,7 +618,10 @@ def _interpret_response(
612618 else :
613619 return (
614620 self ._interpret_response_line (
615- result .content , result .status_code , result .headers , stream = False
621+ result .content .decode ("utf-8" ),
622+ result .status_code ,
623+ result .headers ,
624+ stream = False ,
616625 ),
617626 False ,
618627 )
@@ -635,13 +644,16 @@ async def _interpret_async_response(
635644 util .log_warn (e , body = result .content )
636645 return (
637646 self ._interpret_response_line (
638- await result .read (), result .status , result .headers , stream = False
647+ (await result .read ()).decode ("utf-8" ),
648+ result .status ,
649+ result .headers ,
650+ stream = False ,
639651 ),
640652 False ,
641653 )
642654
643655 def _interpret_response_line (
644- self , rbody , rcode , rheaders , stream : bool
656+ self , rbody : str , rcode : int , rheaders , stream : bool
645657 ) -> OpenAIResponse :
646658 # HTTP 204 response code does not have any content in the body.
647659 if rcode == 204 :
@@ -655,13 +667,11 @@ def _interpret_response_line(
655667 headers = rheaders ,
656668 )
657669 try :
658- if hasattr (rbody , "decode" ):
659- rbody = rbody .decode ("utf-8" )
660670 data = json .loads (rbody )
661- except (JSONDecodeError , UnicodeDecodeError ):
671+ except (JSONDecodeError , UnicodeDecodeError ) as e :
662672 raise error .APIError (
663673 f"HTTP code { rcode } from API ({ rbody } )" , rbody , rcode , headers = rheaders
664- )
674+ ) from e
665675 resp = OpenAIResponse (data , rheaders )
666676 # In the future, we might add a "status" parameter to errors
667677 # to better handle the "error while streaming" case.
0 commit comments