@@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs):
829829
830830 if reconnecting :
831831 self .session = ReconnectingAsyncClientSession (client = self , ** kwargs )
832- await self .session .start_connecting_task ()
833832 else :
834- try :
835- await self .transport .connect ()
836- except Exception as e :
837- await self .transport .close ()
838- raise e
839833 self .session = AsyncClientSession (client = self )
840834
835+ await self .session .connect ()
836+
841837 # Get schema from transport if needed
842838 try :
843839 if self .fetch_schema_from_transport and not self .schema :
@@ -846,18 +842,15 @@ async def connect_async(self, reconnecting=False, **kwargs):
846842 # we don't know what type of exception is thrown here because it
847843 # depends on the underlying transport; we just make sure that the
848844 # transport is closed and re-raise the exception
849- await self .transport .close ()
845+ await self .session .close ()
850846 raise
851847
852848 return self .session
853849
854850 async def close_async (self ):
855851 """Close the async transport and stop the optional reconnecting task."""
856852
857- if isinstance (self .session , ReconnectingAsyncClientSession ):
858- await self .session .stop_connecting_task ()
859-
860- await self .transport .close ()
853+ await self .session .close ()
861854
862855 async def __aenter__ (self ):
863856 return await self .connect_async ()
@@ -1564,12 +1557,17 @@ async def _execute(
15641557 ):
15651558 request = request .serialize_variable_values (self .client .schema )
15661559
1567- # Execute the query with the transport with a timeout
1568- with fail_after (self .client .execute_timeout ):
1569- result = await self .transport .execute (
1570- request ,
1571- ** kwargs ,
1572- )
1560+ # Check if batching is enabled
1561+ if self .client .batching_enabled :
1562+ future_result = await self ._execute_future (request )
1563+ result = await future_result
1564+ else :
1565+ # Execute the query with the transport with a timeout
1566+ with fail_after (self .client .execute_timeout ):
1567+ result = await self .transport .execute (
1568+ request ,
1569+ ** kwargs ,
1570+ )
15731571
15741572 # Unserialize the result if requested
15751573 if self .client .schema :
@@ -1828,6 +1826,134 @@ async def execute_batch(
18281826
18291827 return cast (List [Dict [str , Any ]], [result .data for result in results ])
18301828
1829+ async def _batch_loop (self ) -> None :
1830+ """Main loop of the task used to wait for requests
1831+ to execute them in a batch"""
1832+
1833+ stop_loop = False
1834+
1835+ while not stop_loop :
1836+ # First wait for a first request in from the batch queue
1837+ requests_and_futures : List [Tuple [GraphQLRequest , asyncio .Future ]] = []
1838+
1839+ # Wait for the first request
1840+ request_and_future : Optional [Tuple [GraphQLRequest , asyncio .Future ]] = (
1841+ await self .batch_queue .get ()
1842+ )
1843+
1844+ if request_and_future is None :
1845+ # None is our sentinel value to stop the loop
1846+ break
1847+
1848+ requests_and_futures .append (request_and_future )
1849+
1850+ # Then wait the requested batch interval except if we already
1851+ # have the maximum number of requests in the queue
1852+ if self .batch_queue .qsize () < self .client .batch_max - 1 :
1853+ # Wait for the batch interval
1854+ await asyncio .sleep (self .client .batch_interval )
1855+
1856+ # Then get the requests which had been made during that wait interval
1857+ for _ in range (self .client .batch_max - 1 ):
1858+ try :
1859+ # Use get_nowait since we don't want to wait here
1860+ request_and_future = self .batch_queue .get_nowait ()
1861+
1862+ if request_and_future is None :
1863+ # Sentinel value - stop after processing current batch
1864+ stop_loop = True
1865+ break
1866+
1867+ requests_and_futures .append (request_and_future )
1868+
1869+ except asyncio .QueueEmpty :
1870+ # No more requests in queue, that's fine
1871+ break
1872+
1873+ # Extract requests and futures
1874+ requests = [request for request , _ in requests_and_futures ]
1875+ futures = [future for _ , future in requests_and_futures ]
1876+
1877+ # Execute the batch
1878+ try :
1879+ results : List [ExecutionResult ] = await self ._execute_batch (
1880+ requests ,
1881+ serialize_variables = False , # already done
1882+ parse_result = False , # will be done later
1883+ validate_document = False , # already validated
1884+ )
1885+
1886+ # Set the result for each future
1887+ for result , future in zip (results , futures ):
1888+ if not future .cancelled ():
1889+ future .set_result (result )
1890+
1891+ except Exception as exc :
1892+ # If batch execution fails, propagate the error to all futures
1893+ for future in futures :
1894+ if not future .cancelled ():
1895+ future .set_exception (exc )
1896+
1897+ # Signal that the task has stopped
1898+ self ._batch_task_stopped_event .set ()
1899+
1900+ async def _execute_future (
1901+ self ,
1902+ request : GraphQLRequest ,
1903+ ) -> asyncio .Future :
1904+ """If batching is enabled, this method will put a request in the batching queue
1905+ instead of executing it directly so that the requests could be put in a batch.
1906+ """
1907+
1908+ assert hasattr (self , "batch_queue" ), "Batching is not enabled"
1909+ assert not self ._batch_task_stop_requested , "Batching task has been stopped"
1910+
1911+ future : asyncio .Future = asyncio .Future ()
1912+ await self .batch_queue .put ((request , future ))
1913+
1914+ return future
1915+
1916+ async def _batch_init (self ):
1917+ """Initialize the batch task loop if batching is enabled."""
1918+ if self .client .batching_enabled :
1919+ self .batch_queue : asyncio .Queue = asyncio .Queue ()
1920+ self ._batch_task_stop_requested = False
1921+ self ._batch_task_stopped_event = asyncio .Event ()
1922+ self ._batch_task = asyncio .create_task (self ._batch_loop ())
1923+
1924+ async def _batch_cleanup (self ):
1925+ """Cleanup the batching task if batching is enabled."""
1926+ if hasattr (self , "_batch_task_stopped_event" ):
1927+ # Send a None in the queue to indicate that the batching task must stop
1928+ # after having processed the remaining requests in the queue
1929+ self ._batch_task_stop_requested = True
1930+ await self .batch_queue .put (None )
1931+
1932+ # Wait for the task to process remaining requests and stop
1933+ await self ._batch_task_stopped_event .wait ()
1934+
1935+ async def connect (self ):
1936+ """Connect the transport and initialize the batch task loop if batching
1937+ is enabled."""
1938+
1939+ await self ._batch_init ()
1940+
1941+ try :
1942+ await self .transport .connect ()
1943+ except Exception as e :
1944+ await self .transport .close ()
1945+ raise e
1946+
1947+ async def close (self ):
1948+ """Close the transport and cleanup the batching task if batching is enabled.
1949+
1950+ Will wait until all the remaining requests in the batch processing queue
1951+ have been executed.
1952+ """
1953+ await self ._batch_cleanup ()
1954+
1955+ await self .transport .close ()
1956+
18311957 async def fetch_schema (self ) -> None :
18321958 """Fetch the GraphQL schema explicitly using introspection.
18331959
@@ -1954,6 +2080,23 @@ async def stop_connecting_task(self):
19542080 self ._connect_task .cancel ()
19552081 self ._connect_task = None
19562082
2083+ async def connect (self ):
2084+ """Start the connect task and initialize the batch task loop if batching
2085+ is enabled."""
2086+
2087+ await self ._batch_init ()
2088+
2089+ await self .start_connecting_task ()
2090+
2091+ async def close (self ):
2092+ """Stop the connect task and cleanup the batching task
2093+ if batching is enabled."""
2094+ await self ._batch_cleanup ()
2095+
2096+ await self .stop_connecting_task ()
2097+
2098+ await self .transport .close ()
2099+
19572100 async def _execute_once (
19582101 self ,
19592102 request : GraphQLRequest ,
0 commit comments