2323from graphql import DocumentNode , ExecutionResult , print_ast
2424from multidict import CIMultiDictProxy
2525
26- from ..utils import extract_files
2726from .appsync_auth import AppSyncAuthentication
2827from .async_transport import AsyncTransport
2928from .common .aiohttp_closed_event import create_aiohttp_closed_event
3332 TransportProtocolError ,
3433 TransportServerError ,
3534)
35+ from .file_upload import FileVar , close_files , extract_files , open_files
3636
3737log = logging .getLogger (__name__ )
3838
@@ -207,6 +207,10 @@ async def execute(
207207 file_classes = self .file_classes ,
208208 )
209209
210+ # Opening the files using the FileVar parameters
211+ open_files (list (files .values ()), transport_supports_streaming = True )
212+ self .files = files
213+
210214 # Save the nulled variable values in the payload
211215 payload ["variables" ] = nulled_variable_values
212216
@@ -220,8 +224,8 @@ async def execute(
220224 file_map = {str (i ): [path ] for i , path in enumerate (files )}
221225
222226 # Enumerate the file streams
223- # Will generate something like {'0': <_io.BufferedReader ...> }
224- file_streams = {str (i ): files [path ] for i , path in enumerate (files )}
227+ # Will generate something like {'0': FileVar object }
228+ file_vars = {str (i ): files [path ] for i , path in enumerate (files )}
225229
226230 # Add the payload to the operations field
227231 operations_str = self .json_serialize (payload )
@@ -235,12 +239,15 @@ async def execute(
235239 log .debug ("file_map %s" , file_map_str )
236240 data .add_field ("map" , file_map_str , content_type = "application/json" )
237241
238- # Add the extracted files as remaining fields
239- for k , f in file_streams .items ():
240- name = getattr (f , "name" , k )
241- content_type = getattr (f , "content_type" , None )
242+ for k , file_var in file_vars .items ():
243+ assert isinstance (file_var , FileVar )
242244
243- data .add_field (k , f , filename = name , content_type = content_type )
245+ data .add_field (
246+ k ,
247+ file_var .f ,
248+ filename = file_var .filename ,
249+ content_type = file_var .content_type ,
250+ )
244251
245252 post_args : Dict [str , Any ] = {"data" : data }
246253
@@ -267,51 +274,59 @@ async def execute(
267274 if self .session is None :
268275 raise TransportClosed ("Transport is not connected" )
269276
270- async with self .session .post (self .url , ssl = self .ssl , ** post_args ) as resp :
271-
272- # Saving latest response headers in the transport
273- self .response_headers = resp .headers
277+ try :
278+ async with self .session .post (self .url , ssl = self .ssl , ** post_args ) as resp :
274279
275- async def raise_response_error (
276- resp : aiohttp .ClientResponse , reason : str
277- ) -> NoReturn :
278- # We raise a TransportServerError if the status code is 400 or higher
279- # We raise a TransportProtocolError in the other cases
280+ # Saving latest response headers in the transport
281+ self .response_headers = resp .headers
280282
281- try :
282- # Raise a ClientResponseError if response status is 400 or higher
283- resp .raise_for_status ()
284- except ClientResponseError as e :
285- raise TransportServerError (str (e ), e .status ) from e
286-
287- result_text = await resp .text ()
288- raise TransportProtocolError (
289- f"Server did not return a GraphQL result: "
290- f"{ reason } : "
291- f"{ result_text } "
292- )
283+ async def raise_response_error (
284+ resp : aiohttp .ClientResponse , reason : str
285+ ) -> NoReturn :
286+ # We raise a TransportServerError if status code is 400 or higher
287+ # We raise a TransportProtocolError in the other cases
293288
294- try :
295- result = await resp .json (loads = self .json_deserialize , content_type = None )
289+ try :
290+ # Raise ClientResponseError if response status is 400 or higher
291+ resp .raise_for_status ()
292+ except ClientResponseError as e :
293+ raise TransportServerError (str (e ), e .status ) from e
296294
297- if log .isEnabledFor (logging .INFO ):
298295 result_text = await resp .text ()
299- log .info ("<<< %s" , result_text )
296+ raise TransportProtocolError (
297+ f"Server did not return a GraphQL result: "
298+ f"{ reason } : "
299+ f"{ result_text } "
300+ )
300301
301- except Exception :
302- await raise_response_error (resp , "Not a JSON answer" )
302+ try :
303+ result = await resp .json (
304+ loads = self .json_deserialize , content_type = None
305+ )
303306
304- if result is None :
305- await raise_response_error (resp , "Not a JSON answer" )
307+ if log .isEnabledFor (logging .INFO ):
308+ result_text = await resp .text ()
309+ log .info ("<<< %s" , result_text )
306310
307- if "errors" not in result and "data" not in result :
308- await raise_response_error (resp , 'No "data" or "errors" keys in answer' )
311+ except Exception :
312+ await raise_response_error (resp , "Not a JSON answer" )
309313
310- return ExecutionResult (
311- errors = result .get ("errors" ),
312- data = result .get ("data" ),
313- extensions = result .get ("extensions" ),
314- )
314+ if result is None :
315+ await raise_response_error (resp , "Not a JSON answer" )
316+
317+ if "errors" not in result and "data" not in result :
318+ await raise_response_error (
319+ resp , 'No "data" or "errors" keys in answer'
320+ )
321+
322+ return ExecutionResult (
323+ errors = result .get ("errors" ),
324+ data = result .get ("data" ),
325+ extensions = result .get ("extensions" ),
326+ )
327+ finally :
328+ if upload_files :
329+ close_files (list (self .files .values ()))
315330
316331 def subscribe (
317332 self ,
0 commit comments