@@ -89,32 +89,33 @@ def _make_session() -> requests.Session:
8989    return  s 
9090
9191
92- def  parse_stream_helper (line : bytes ):
92+ def  parse_stream_helper (line : bytes )  ->   Optional [ str ] :
9393    if  line :
9494        if  line .strip () ==  b"data: [DONE]" :
9595            # return here will cause GeneratorExit exception in urllib3 
9696            # and it will close http connection with TCP Reset 
9797            return  None 
98-         if  hasattr (line , "decode" ):
99-             line  =  line .decode ("utf-8" )
100-         if  line .startswith ("data: " ):
101-             line  =  line [len ("data: " ) :]
102-         return  line 
98+         if  line .startswith (b"data: " ):
99+             line  =  line [len (b"data: " ) :]
100+         return  line .decode ("utf-8" )
103101    return  None 
104102
105103
106- def  parse_stream (rbody ) :
104+ def  parse_stream (rbody :  Iterator [ bytes ])  ->   Iterator [ str ] :
107105    for  line  in  rbody :
108106        _line  =  parse_stream_helper (line )
109107        if  _line  is  not None :
110108            yield  _line 
111109
112110
113111async  def  parse_stream_async (rbody : aiohttp .StreamReader ):
114-     async  for  line , _  in  rbody .iter_chunks ():
115-         _line  =  parse_stream_helper (line )
116-         if  _line  is  not None :
117-             yield  _line 
112+     async  for  chunk , _  in  rbody .iter_chunks ():
113+         # While the `ChunkTupleAsyncStreamIterator` iterator is meant to iterate over chunks (and thus lines) it seems 
114+         # to still sometimes return multiple lines at a time, so let's split the chunk by lines again. 
115+         for  line  in  chunk .splitlines ():
116+             _line  =  parse_stream_helper (line )
117+             if  _line  is  not None :
118+                 yield  _line 
118119
119120
120121class  APIRequestor :
@@ -296,20 +297,25 @@ async def arequest(
296297    ) ->  Tuple [Union [OpenAIResponse , AsyncGenerator [OpenAIResponse , None ]], bool , str ]:
297298        ctx  =  aiohttp_session ()
298299        session  =  await  ctx .__aenter__ ()
299-         result  =  await  self .arequest_raw (
300-             method .lower (),
301-             url ,
302-             session ,
303-             params = params ,
304-             supplied_headers = headers ,
305-             files = files ,
306-             request_id = request_id ,
307-             request_timeout = request_timeout ,
308-         )
309-         resp , got_stream  =  await  self ._interpret_async_response (result , stream )
300+         try :
301+             result  =  await  self .arequest_raw (
302+                 method .lower (),
303+                 url ,
304+                 session ,
305+                 params = params ,
306+                 supplied_headers = headers ,
307+                 files = files ,
308+                 request_id = request_id ,
309+                 request_timeout = request_timeout ,
310+             )
311+             resp , got_stream  =  await  self ._interpret_async_response (result , stream )
312+         except  Exception :
313+             await  ctx .__aexit__ (None , None , None )
314+             raise 
310315        if  got_stream :
311316
312317            async  def  wrap_resp ():
318+                 assert  isinstance (resp , AsyncGenerator )
313319                try :
314320                    async  for  r  in  resp :
315321                        yield  r 
@@ -612,7 +618,10 @@ def _interpret_response(
612618        else :
613619            return  (
614620                self ._interpret_response_line (
615-                     result .content , result .status_code , result .headers , stream = False 
621+                     result .content .decode ("utf-8" ),
622+                     result .status_code ,
623+                     result .headers ,
624+                     stream = False ,
616625                ),
617626                False ,
618627            )
@@ -635,13 +644,16 @@ async def _interpret_async_response(
635644                util .log_warn (e , body = result .content )
636645            return  (
637646                self ._interpret_response_line (
638-                     await  result .read (), result .status , result .headers , stream = False 
647+                     (await  result .read ()).decode ("utf-8" ),
648+                     result .status ,
649+                     result .headers ,
650+                     stream = False ,
639651                ),
640652                False ,
641653            )
642654
643655    def  _interpret_response_line (
644-         self , rbody , rcode , rheaders , stream : bool 
656+         self , rbody :  str , rcode :  int , rheaders , stream : bool 
645657    ) ->  OpenAIResponse :
646658        # HTTP 204 response code does not have any content in the body. 
647659        if  rcode  ==  204 :
@@ -655,13 +667,11 @@ def _interpret_response_line(
655667                headers = rheaders ,
656668            )
657669        try :
658-             if  hasattr (rbody , "decode" ):
659-                 rbody  =  rbody .decode ("utf-8" )
660670            data  =  json .loads (rbody )
661-         except  (JSONDecodeError , UnicodeDecodeError ):
671+         except  (JSONDecodeError , UnicodeDecodeError )  as   e :
662672            raise  error .APIError (
663673                f"HTTP code { rcode } { rbody }  , rbody , rcode , headers = rheaders 
664-             )
674+             )  from   e 
665675        resp  =  OpenAIResponse (data , rheaders )
666676        # In the future, we might add a "status" parameter to errors 
667677        # to better handle the "error while streaming" case. 
0 commit comments