@@ -81,6 +81,10 @@ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bo
81
81
self .established : bool = False
82
82
self .cache : Cache | None = None
83
83
self .cache_locks : defaultdict [str , asyncio .Lock ] | None = None
84
+ self .command_queue : asyncio .Queue [AntaCommand ] = asyncio .Queue ()
85
+ self .batch_task : asyncio .Task [None ] | None = None
86
+ # TODO: Check if we want to make the batch size configurable
87
+ self .batch_size : int = 100
84
88
85
89
# Initialize cache if not disabled
86
90
if not disable_cache :
@@ -104,6 +108,12 @@ def _init_cache(self) -> None:
104
108
self .cache = Cache (cache_class = Cache .MEMORY , ttl = 60 , namespace = self .name , plugins = [HitMissRatioPlugin ()])
105
109
self .cache_locks = defaultdict (asyncio .Lock )
106
110
111
+ def init_batch_task (self ) -> None :
112
+ """Initialize the batch task for the device."""
113
+ if self .batch_task is None :
114
+ logger .debug ("<%s>: Starting the batch task" , self .name )
115
+ self .batch_task = asyncio .create_task (self ._batch_task ())
116
+
107
117
@property
108
118
def cache_statistics (self ) -> dict [str , Any ] | None :
109
119
"""Return the device cache statistics for logging purposes."""
@@ -137,6 +147,72 @@ def __repr__(self) -> str:
137
147
f"disable_cache={ self .cache is None !r} )"
138
148
)
139
149
150
+ async def _batch_task (self ) -> None :
151
+ """Background task to retrieve commands put by tests from the command queue of this device.
152
+
153
+ Test coroutines put their AntaCommand instances in the queue, this task retrieves them. Once they stop coming,
154
+ the instances are grouped by UID, split into JSON and text batches, and collected in batches of `batch_size`.
155
+ """
156
+ collection_tasks : list [asyncio .Task [None ]] = []
157
+ all_commands : list [AntaCommand ] = []
158
+
159
+ while True :
160
+ try :
161
+ get_await = self .command_queue .get ()
162
+ command = await asyncio .wait_for (get_await , timeout = 0.5 )
163
+ logger .debug ("<%s>: Command retrieved from the queue: %s" , self .name , command )
164
+ all_commands .append (command )
165
+ except asyncio .TimeoutError : # noqa: PERF203
166
+ logger .debug ("<%s>: All test commands have been retrieved from the queue" , self .name )
167
+ break
168
+
169
+ # Group all command instances by UID
170
+ command_groups : defaultdict [str , list [AntaCommand ]] = defaultdict (list [AntaCommand ])
171
+ for command in all_commands :
172
+ command_groups [command .uid ].append (command )
173
+
174
+ # Split into JSON and text batches. We can safely take the first command instance from each UID as they are the same.
175
+ json_commands = {uid : commands for uid , commands in command_groups .items () if commands [0 ].ofmt == "json" }
176
+ text_commands = {uid : commands for uid , commands in command_groups .items () if commands [0 ].ofmt == "text" }
177
+
178
+ # Process JSON batches
179
+ for i in range (0 , len (json_commands ), self .batch_size ):
180
+ batch = dict (list (json_commands .items ())[i : i + self .batch_size ])
181
+ task = asyncio .create_task (self ._collect_batch (batch , ofmt = "json" ))
182
+ collection_tasks .append (task )
183
+
184
+ # Process text batches
185
+ for i in range (0 , len (text_commands ), self .batch_size ):
186
+ batch = dict (list (text_commands .items ())[i : i + self .batch_size ])
187
+ task = asyncio .create_task (self ._collect_batch (batch , ofmt = "text" ))
188
+ collection_tasks .append (task )
189
+
190
+ # Wait for all collection tasks to complete
191
+ if collection_tasks :
192
+ logger .debug ("<%s>: Waiting for %d collection tasks to complete" , self .name , len (collection_tasks ))
193
+ await asyncio .gather (* collection_tasks )
194
+
195
+ # TODO: Handle other exceptions
196
+
197
+ logger .debug ("<%s>: Stopping the batch task" , self .name )
198
+
199
+ async def _collect_batch (self , command_groups : dict [str , list [AntaCommand ]], ofmt : Literal ["json" , "text" ] = "json" ) -> None :
200
+ """Collect a batch of device commands.
201
+
202
+ This coroutine must be implemented by subclasses that want to support command queuing
203
+ in conjunction with the `_batch_task()` method.
204
+
205
+ Parameters
206
+ ----------
207
+ command_groups
208
+ Mapping of command instances grouped by UID to avoid duplicate commands.
209
+ ofmt
210
+ The output format of the batch.
211
+ """
212
+ _ = (command_groups , ofmt )
213
+ msg = f"_collect_batch method has not been implemented in { self .__class__ .__name__ } definition"
214
+ raise NotImplementedError (msg )
215
+
140
216
@abstractmethod
141
217
async def _collect (self , command : AntaCommand , * , collection_id : str | None = None ) -> None :
142
218
"""Collect device command output.
@@ -192,16 +268,38 @@ async def collect(self, command: AntaCommand, *, collection_id: str | None = Non
192
268
else :
193
269
await self ._collect (command = command , collection_id = collection_id )
194
270
195
- async def collect_commands (self , commands : list [AntaCommand ], * , collection_id : str | None = None ) -> None :
271
+ async def collect_commands (self , commands : list [AntaCommand ], * , command_queuing : bool = False , collection_id : str | None = None ) -> None :
196
272
"""Collect multiple commands.
197
273
198
274
Parameters
199
275
----------
200
276
commands
201
277
The commands to collect.
278
+ command_queuing
279
+ If True, the commands are put in a queue and collected in batches. Default is False.
202
280
collection_id
203
- An identifier used to build the eAPI request ID.
281
+ An identifier used to build the eAPI request ID. Not used when command queuing is enabled.
204
282
"""
283
+ # Collect the commands with queuing
284
+ if command_queuing :
285
+ # Disable cache for this device as it is not needed when using command queuing
286
+ self .cache = None
287
+ self .cache_locks = None
288
+
289
+ # Initialize the device batch task if not already running
290
+ self .init_batch_task ()
291
+
292
+ # Put the commands in the queue
293
+ for command in commands :
294
+ logger .debug ("<%s>: Putting command in the queue: %s" , self .name , command )
295
+ await self .command_queue .put (command )
296
+
297
+ # Wait for all commands to be collected.
298
+ logger .debug ("<%s>: Waiting for all commands to be collected" , self .name )
299
+ await asyncio .gather (* [command .event .wait () for command in commands ])
300
+ return
301
+
302
+ # Collect the commands without queuing. Default behavior.
205
303
await asyncio .gather (* (self .collect (command = command , collection_id = collection_id ) for command in commands ))
206
304
207
305
@abstractmethod
@@ -372,6 +470,78 @@ def _keys(self) -> tuple[Any, ...]:
372
470
"""
373
471
return (self ._session .host , self ._session .port )
374
472
473
+ async def _collect_batch (self , command_groups : dict [str , list [AntaCommand ]], ofmt : Literal ["json" , "text" ] = "json" ) -> None : # noqa: C901
474
+ """Collect a batch of device commands.
475
+
476
+ Parameters
477
+ ----------
478
+ command_groups
479
+ Mapping of command instances grouped by UID to avoid duplicate commands.
480
+ ofmt
481
+ The output format of the batch.
482
+ """
483
+ # Add 'enable' command if required
484
+ cmds = []
485
+ if self .enable and self ._enable_password is not None :
486
+ cmds .append ({"cmd" : "enable" , "input" : str (self ._enable_password )})
487
+ elif self .enable :
488
+ # No password
489
+ cmds .append ({"cmd" : "enable" })
490
+
491
+ # Take first instance from each group for the actual commands
492
+ cmds .extend (
493
+ [
494
+ {"cmd" : instances [0 ].command , "revision" : instances [0 ].revision } if instances [0 ].revision else {"cmd" : instances [0 ].command }
495
+ for instances in command_groups .values ()
496
+ ]
497
+ )
498
+
499
+ try :
500
+ response = await self ._session .cli (
501
+ commands = cmds ,
502
+ ofmt = ofmt ,
503
+ # TODO: See if we want to have different batches for different versions
504
+ version = 1 ,
505
+ # TODO: See if want to have a different req_id for each batch
506
+ req_id = f"ANTA-{ id (command_groups )} " ,
507
+ )
508
+
509
+ # Do not keep response of 'enable' command
510
+ if self .enable :
511
+ response = response [1 :]
512
+
513
+ # Update all AntaCommand instances with their output and signal their completion
514
+ logger .debug ("<%s>: Collected batch of commands, signaling their completion" , self .name )
515
+ for idx , instances in enumerate (command_groups .values ()):
516
+ output = response [idx ]
517
+ for cmd_instance in instances :
518
+ cmd_instance .output = output
519
+ cmd_instance .event .set ()
520
+
521
+ except asynceapi .EapiCommandError as e :
522
+ # TODO: Handle commands that passed
523
+ for instances in command_groups .values ():
524
+ for cmd_instance in instances :
525
+ cmd_instance .errors = e .errors
526
+ if cmd_instance .requires_privileges :
527
+ logger .error (
528
+ "Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required." ,
529
+ cmd_instance .command ,
530
+ self .name ,
531
+ )
532
+ if cmd_instance .supported :
533
+ logger .error ("Command '%s' failed on %s: %s" , cmd_instance .command , self .name , e .errors [0 ] if len (e .errors ) == 1 else e .errors )
534
+ else :
535
+ logger .debug ("Command '%s' is not supported on '%s' (%s)" , cmd_instance .command , self .name , self .hw_model )
536
+ cmd_instance .event .set ()
537
+
538
+ # TODO: Handle other exceptions
539
+ except Exception as e :
540
+ for instances in command_groups .values ():
541
+ for cmd_instance in instances :
542
+ cmd_instance .errors = [exc_to_str (e )]
543
+ cmd_instance .event .set ()
544
+
375
545
async def _collect (self , command : AntaCommand , * , collection_id : str | None = None ) -> None : # noqa: C901 function is too complex - because of many required except blocks
376
546
"""Collect device command output from EOS using aio-eapi.
377
547
0 commit comments