Skip to content

Commit 756b942

Browse files
WIP asyncio.Queue
1 parent 5f65676 commit 756b942

File tree

3 files changed

+213
-15
lines changed

3 files changed

+213
-15
lines changed

anta/device.py

+172-2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def __init__(self, name: str, tags: set[str] | None = None, *, disable_cache: bo
8181
self.established: bool = False
8282
self.cache: Cache | None = None
8383
self.cache_locks: defaultdict[str, asyncio.Lock] | None = None
84+
self.command_queue: asyncio.Queue[AntaCommand] = asyncio.Queue()
85+
self.batch_task: asyncio.Task[None] | None = None
86+
# TODO: Check if we want to make the batch size configurable
87+
self.batch_size: int = 100
8488

8589
# Initialize cache if not disabled
8690
if not disable_cache:
@@ -104,6 +108,12 @@ def _init_cache(self) -> None:
104108
self.cache = Cache(cache_class=Cache.MEMORY, ttl=60, namespace=self.name, plugins=[HitMissRatioPlugin()])
105109
self.cache_locks = defaultdict(asyncio.Lock)
106110

111+
def init_batch_task(self) -> None:
112+
"""Initialize the batch task for the device."""
113+
if self.batch_task is None:
114+
logger.debug("<%s>: Starting the batch task", self.name)
115+
self.batch_task = asyncio.create_task(self._batch_task())
116+
107117
@property
108118
def cache_statistics(self) -> dict[str, Any] | None:
109119
"""Return the device cache statistics for logging purposes."""
@@ -137,6 +147,72 @@ def __repr__(self) -> str:
137147
f"disable_cache={self.cache is None!r})"
138148
)
139149

150+
async def _batch_task(self) -> None:
151+
"""Background task to retrieve commands put by tests from the command queue of this device.
152+
153+
Test coroutines put their AntaCommand instances in the queue, this task retrieves them. Once they stop coming,
154+
the instances are grouped by UID, split into JSON and text batches, and collected in batches of `batch_size`.
155+
"""
156+
collection_tasks: list[asyncio.Task[None]] = []
157+
all_commands: list[AntaCommand] = []
158+
159+
while True:
160+
try:
161+
get_await = self.command_queue.get()
162+
command = await asyncio.wait_for(get_await, timeout=0.5)
163+
logger.debug("<%s>: Command retrieved from the queue: %s", self.name, command)
164+
all_commands.append(command)
165+
except asyncio.TimeoutError: # noqa: PERF203
166+
logger.debug("<%s>: All test commands have been retrieved from the queue", self.name)
167+
break
168+
169+
# Group all command instances by UID
170+
command_groups: defaultdict[str, list[AntaCommand]] = defaultdict(list[AntaCommand])
171+
for command in all_commands:
172+
command_groups[command.uid].append(command)
173+
174+
# Split into JSON and text batches. We can safely take the first command instance from each UID as they are the same.
175+
json_commands = {uid: commands for uid, commands in command_groups.items() if commands[0].ofmt == "json"}
176+
text_commands = {uid: commands for uid, commands in command_groups.items() if commands[0].ofmt == "text"}
177+
178+
# Process JSON batches
179+
for i in range(0, len(json_commands), self.batch_size):
180+
batch = dict(list(json_commands.items())[i : i + self.batch_size])
181+
task = asyncio.create_task(self._collect_batch(batch, ofmt="json"))
182+
collection_tasks.append(task)
183+
184+
# Process text batches
185+
for i in range(0, len(text_commands), self.batch_size):
186+
batch = dict(list(text_commands.items())[i : i + self.batch_size])
187+
task = asyncio.create_task(self._collect_batch(batch, ofmt="text"))
188+
collection_tasks.append(task)
189+
190+
# Wait for all collection tasks to complete
191+
if collection_tasks:
192+
logger.debug("<%s>: Waiting for %d collection tasks to complete", self.name, len(collection_tasks))
193+
await asyncio.gather(*collection_tasks)
194+
195+
# TODO: Handle other exceptions
196+
197+
logger.debug("<%s>: Stopping the batch task", self.name)
198+
199+
async def _collect_batch(self, command_groups: dict[str, list[AntaCommand]], ofmt: Literal["json", "text"] = "json") -> None:
200+
"""Collect a batch of device commands.
201+
202+
This coroutine must be implemented by subclasses that want to support command queuing
203+
in conjunction with the `_batch_task()` method.
204+
205+
Parameters
206+
----------
207+
command_groups
208+
Mapping of command instances grouped by UID to avoid duplicate commands.
209+
ofmt
210+
The output format of the batch.
211+
"""
212+
_ = (command_groups, ofmt)
213+
msg = f"_collect_batch method has not been implemented in {self.__class__.__name__} definition"
214+
raise NotImplementedError(msg)
215+
140216
@abstractmethod
141217
async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None:
142218
"""Collect device command output.
@@ -192,16 +268,38 @@ async def collect(self, command: AntaCommand, *, collection_id: str | None = Non
192268
else:
193269
await self._collect(command=command, collection_id=collection_id)
194270

195-
async def collect_commands(self, commands: list[AntaCommand], *, collection_id: str | None = None) -> None:
271+
async def collect_commands(self, commands: list[AntaCommand], *, command_queuing: bool = False, collection_id: str | None = None) -> None:
196272
"""Collect multiple commands.
197273
198274
Parameters
199275
----------
200276
commands
201277
The commands to collect.
278+
command_queuing
279+
If True, the commands are put in a queue and collected in batches. Default is False.
202280
collection_id
203-
An identifier used to build the eAPI request ID.
281+
An identifier used to build the eAPI request ID. Not used when command queuing is enabled.
204282
"""
283+
# Collect the commands with queuing
284+
if command_queuing:
285+
# Disable cache for this device as it is not needed when using command queuing
286+
self.cache = None
287+
self.cache_locks = None
288+
289+
# Initialize the device batch task if not already running
290+
self.init_batch_task()
291+
292+
# Put the commands in the queue
293+
for command in commands:
294+
logger.debug("<%s>: Putting command in the queue: %s", self.name, command)
295+
await self.command_queue.put(command)
296+
297+
# Wait for all commands to be collected.
298+
logger.debug("<%s>: Waiting for all commands to be collected", self.name)
299+
await asyncio.gather(*[command.event.wait() for command in commands])
300+
return
301+
302+
# Collect the commands without queuing. Default behavior.
205303
await asyncio.gather(*(self.collect(command=command, collection_id=collection_id) for command in commands))
206304

207305
@abstractmethod
@@ -372,6 +470,78 @@ def _keys(self) -> tuple[Any, ...]:
372470
"""
373471
return (self._session.host, self._session.port)
374472

473+
async def _collect_batch(self, command_groups: dict[str, list[AntaCommand]], ofmt: Literal["json", "text"] = "json") -> None: # noqa: C901
474+
"""Collect a batch of device commands.
475+
476+
Parameters
477+
----------
478+
command_groups
479+
Mapping of command instances grouped by UID to avoid duplicate commands.
480+
ofmt
481+
The output format of the batch.
482+
"""
483+
# Add 'enable' command if required
484+
cmds = []
485+
if self.enable and self._enable_password is not None:
486+
cmds.append({"cmd": "enable", "input": str(self._enable_password)})
487+
elif self.enable:
488+
# No password
489+
cmds.append({"cmd": "enable"})
490+
491+
# Take first instance from each group for the actual commands
492+
cmds.extend(
493+
[
494+
{"cmd": instances[0].command, "revision": instances[0].revision} if instances[0].revision else {"cmd": instances[0].command}
495+
for instances in command_groups.values()
496+
]
497+
)
498+
499+
try:
500+
response = await self._session.cli(
501+
commands=cmds,
502+
ofmt=ofmt,
503+
# TODO: See if we want to have different batches for different versions
504+
version=1,
505+
# TODO: See if want to have a different req_id for each batch
506+
req_id=f"ANTA-{id(command_groups)}",
507+
)
508+
509+
# Do not keep response of 'enable' command
510+
if self.enable:
511+
response = response[1:]
512+
513+
# Update all AntaCommand instances with their output and signal their completion
514+
logger.debug("<%s>: Collected batch of commands, signaling their completion", self.name)
515+
for idx, instances in enumerate(command_groups.values()):
516+
output = response[idx]
517+
for cmd_instance in instances:
518+
cmd_instance.output = output
519+
cmd_instance.event.set()
520+
521+
except asynceapi.EapiCommandError as e:
522+
# TODO: Handle commands that passed
523+
for instances in command_groups.values():
524+
for cmd_instance in instances:
525+
cmd_instance.errors = e.errors
526+
if cmd_instance.requires_privileges:
527+
logger.error(
528+
"Command '%s' requires privileged mode on %s. Verify user permissions and if the `enable` option is required.",
529+
cmd_instance.command,
530+
self.name,
531+
)
532+
if cmd_instance.supported:
533+
logger.error("Command '%s' failed on %s: %s", cmd_instance.command, self.name, e.errors[0] if len(e.errors) == 1 else e.errors)
534+
else:
535+
logger.debug("Command '%s' is not supported on '%s' (%s)", cmd_instance.command, self.name, self.hw_model)
536+
cmd_instance.event.set()
537+
538+
# TODO: Handle other exceptions
539+
except Exception as e:
540+
for instances in command_groups.values():
541+
for cmd_instance in instances:
542+
cmd_instance.errors = [exc_to_str(e)]
543+
cmd_instance.event.set()
544+
375545
async def _collect(self, command: AntaCommand, *, collection_id: str | None = None) -> None: # noqa: C901 function is too complex - because of many required except blocks
376546
"""Collect device command output from EOS using aio-eapi.
377547

anta/models.py

+26-11
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
from __future__ import annotations
77

8+
import asyncio
89
import hashlib
910
import logging
1011
import re
1112
from abc import ABC, abstractmethod
12-
from functools import wraps
13+
from functools import cached_property, wraps
1314
from string import Formatter
1415
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, TypeVar
1516

@@ -165,7 +166,9 @@ class AntaCommand(BaseModel):
165166
Pydantic Model containing the variables values used to render the template.
166167
use_cache
167168
Enable or disable caching for this AntaCommand if the AntaDevice supports it.
168-
169+
event
170+
Event to signal that the command has been collected. Used by an AntaDevice to signal an AntaTest that the command has been collected.
171+
Only relevant when an AntaTest runs with `command_queuing=True`.
169172
"""
170173

171174
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -179,13 +182,13 @@ class AntaCommand(BaseModel):
179182
errors: list[str] = []
180183
params: AntaParamsBaseModel = AntaParamsBaseModel()
181184
use_cache: bool = True
185+
event: asyncio.Event | None = None
182186

183-
@property
187+
@cached_property
184188
def uid(self) -> str:
185189
"""Generate a unique identifier for this command."""
186190
uid_str = f"{self.command}_{self.version}_{self.revision or 'NA'}_{self.ofmt}"
187-
# Ignoring S324 probable use of insecure hash function - sha1 is enough for our needs.
188-
return hashlib.sha1(uid_str.encode()).hexdigest() # noqa: S324
191+
return hashlib.sha256(uid_str.encode()).hexdigest()
189192

190193
@property
191194
def json_output(self) -> dict[str, Any]:
@@ -409,6 +412,8 @@ def __init__(
409412
device: AntaDevice,
410413
inputs: dict[str, Any] | AntaTest.Input | None = None,
411414
eos_data: list[dict[Any, Any] | str] | None = None,
415+
*,
416+
command_queuing: bool = False,
412417
) -> None:
413418
"""AntaTest Constructor.
414419
@@ -421,10 +426,14 @@ def __init__(
421426
eos_data
422427
Populate outputs of the test commands instead of collecting from devices.
423428
This list must have the same length and order than the `instance_commands` instance attribute.
429+
command_queuing
430+
If True, the commands of this test will be queued in the device command queue and be sent in batches.
431+
Default is False, which means the commands will be sent one by one to the device.
424432
"""
425433
self.logger: logging.Logger = logging.getLogger(f"{self.module}.{self.__class__.__name__}")
426434
self.device: AntaDevice = device
427435
self.inputs: AntaTest.Input
436+
self.command_queuing = command_queuing
428437
self.instance_commands: list[AntaCommand] = []
429438
self.result: TestResult = TestResult(
430439
name=device.name,
@@ -474,10 +483,17 @@ def _init_commands(self, eos_data: list[dict[Any, Any] | str] | None) -> None:
474483
if self.__class__.commands:
475484
for cmd in self.__class__.commands:
476485
if isinstance(cmd, AntaCommand):
477-
self.instance_commands.append(cmd.model_copy())
486+
command = cmd.model_copy()
487+
if self.command_queuing:
488+
command.event = asyncio.Event()
489+
self.instance_commands.append(command)
478490
elif isinstance(cmd, AntaTemplate):
479491
try:
480-
self.instance_commands.extend(self.render(cmd))
492+
rendered_commands = self.render(cmd)
493+
if self.command_queuing:
494+
for command in rendered_commands:
495+
command.event = asyncio.Event()
496+
self.instance_commands.extend(rendered_commands)
481497
except AntaTemplateRenderError as e:
482498
self.result.is_error(message=f"Cannot render template {{{e.template}}}")
483499
return
@@ -568,7 +584,7 @@ async def collect(self) -> None:
568584
"""Collect outputs of all commands of this test class from the device of this test instance."""
569585
try:
570586
if self.blocked is False:
571-
await self.device.collect_commands(self.instance_commands, collection_id=self.name)
587+
await self.device.collect_commands(self.instance_commands, collection_id=self.name, command_queuing=self.command_queuing)
572588
except Exception as e: # noqa: BLE001
573589
# device._collect() is user-defined code.
574590
# We need to catch everything if we want the AntaTest object
@@ -593,7 +609,6 @@ def anta_test(function: F) -> Callable[..., Coroutine[Any, Any, TestResult]]:
593609
async def wrapper(
594610
self: AntaTest,
595611
eos_data: list[dict[Any, Any] | str] | None = None,
596-
**kwargs: dict[str, Any],
597612
) -> TestResult:
598613
"""Inner function for the anta_test decorator.
599614
@@ -640,7 +655,7 @@ async def wrapper(
640655
return self.result
641656

642657
try:
643-
function(self, **kwargs)
658+
function(self)
644659
except Exception as e: # noqa: BLE001
645660
# test() is user-defined code.
646661
# We need to catch everything if we want the AntaTest object
@@ -662,7 +677,7 @@ def update_progress(cls: type[AntaTest]) -> None:
662677
cls.progress.update(cls.nrfu_task, advance=1)
663678

664679
@abstractmethod
665-
def test(self) -> Coroutine[Any, Any, TestResult]:
680+
def test(self) -> None:
666681
"""Core of the test logic.
667682
668683
This is an abstractmethod that must be implemented by child classes.

anta/runner.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@
2929
logger = logging.getLogger(__name__)
3030

3131
DEFAULT_NOFILE = 16384
32+
COMMAND_QUEUING = False
33+
34+
35+
def get_command_queuing() -> bool:
36+
"""Return the command queuing flag from the environment variable if set."""
37+
try:
38+
command_queuing = bool(os.environ.get("ANTA_COMMAND_QUEUING", COMMAND_QUEUING))
39+
except ValueError as exception:
40+
logger.warning("The ANTA_COMMAND_QUEUING environment variable value is invalid: %s\nDefault to %s.", exc_to_str(exception), COMMAND_QUEUING)
41+
command_queuing = COMMAND_QUEUING
42+
return command_queuing
3243

3344

3445
def adjust_rlimit_nofile() -> tuple[int, int]:
@@ -190,11 +201,12 @@ def get_coroutines(selected_tests: defaultdict[AntaDevice, set[AntaTestDefinitio
190201
list[Coroutine[Any, Any, TestResult]]
191202
The list of coroutines to run.
192203
"""
204+
command_queuing = get_command_queuing()
193205
coros = []
194206
for device, test_definitions in selected_tests.items():
195207
for test in test_definitions:
196208
try:
197-
test_instance = test.test(device=device, inputs=test.inputs)
209+
test_instance = test.test(device=device, inputs=test.inputs, command_queuing=command_queuing)
198210
manager.add(test_instance.result)
199211
coros.append(test_instance.test())
200212
except Exception as e: # noqa: PERF203, BLE001
@@ -296,4 +308,5 @@ async def main( # noqa: PLR0913
296308
with Catchtime(logger=logger, message="Running ANTA tests"):
297309
await asyncio.gather(*coroutines)
298310

299-
log_cache_statistics(selected_inventory.devices)
311+
if not get_command_queuing():
312+
log_cache_statistics(selected_inventory.devices)

0 commit comments

Comments
 (0)