Skip to content

Commit

Permalink
Added validator for source and destination should beling to same IP a…
Browse files Browse the repository at this point in the history
…dd family
  • Loading branch information
vitthalmagadum committed Feb 4, 2025
1 parent 4b89603 commit 99e91a6
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 5 deletions.
6 changes: 3 additions & 3 deletions anta/input_models/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@

from pydantic import BaseModel, ConfigDict

from anta.custom_types import Hostname, Interface
from anta.custom_types import Interface


class Host(BaseModel):
"""Model for a remote host to ping."""

model_config = ConfigDict(extra="forbid")
destination: IPv4Address | IPv6Address | Hostname
"""Destination address or hostname to ping."""
destination: IPv4Address | IPv6Address
"""Destination address to ping."""
source: IPv4Address | IPv6Address | Interface
"""Source address IP or egress interface to use."""
vrf: str = "default"
Expand Down
14 changes: 12 additions & 2 deletions anta/tests/connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from typing import ClassVar

from pydantic import field_validator

from anta.input_models.connectivity import Host, LLDPNeighbor, Neighbor
from anta.models import AntaCommand, AntaTemplate, AntaTest

Expand Down Expand Up @@ -56,8 +58,16 @@ class Input(AntaTest.Input):

hosts: list[Host]
"""List of host to ping."""
Host: ClassVar[type[Host]] = Host
"""To maintain backward compatibility."""

@field_validator("hosts")
@classmethod
def validate_hosts(cls, hosts: list[Host]) -> list[Host]:
"""Validate the 'destination' and 'source' IP address family in each host."""
for host in hosts:
if not isinstance(host.source, str) and host.destination.version != host.source.version:
msg = f"{host} IP address family for destination does not match source"
raise ValueError(msg)
return hosts

def render(self, template: AntaTemplate) -> list[AntaCommand]:
"""Render the template for each host in the input list."""
Expand Down
43 changes: 43 additions & 0 deletions tests/units/input_models/test_connectivity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2023-2025 Arista Networks, Inc.
# Use of this source code is governed by the Apache License 2.0
# that can be found in the LICENSE file.
"""Tests for anta.input_models.connectivity.py."""

# pylint: disable=C0302
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest
from pydantic import ValidationError

from anta.tests.connectivity import VerifyReachability

if TYPE_CHECKING:
from anta.input_models.connectivity import Host


class TestVerifyReachabilityInput:
"""Test anta.tests.connectivity.VerifyReachability.Input."""

@pytest.mark.parametrize(
("hosts"),
[
pytest.param([{"destination": "fd12:3456:789a:1::2", "source": "fd12:3456:789a:1::1"}], id="valid"),
],
)
def test_valid(self, hosts: list[Host]) -> None:
"""Test VerifyReachability.Input valid inputs."""
VerifyReachability.Input(hosts=hosts)

@pytest.mark.parametrize(
("hosts"),
[
pytest.param([{"destination": "fd12:3456:789a:1::2", "source": "192.168.0.10"}], id="invalid-source"),
pytest.param([{"destination": "192.168.0.10", "source": "fd12:3456:789a:1::2"}], id="invalid-destination"),
],
)
def test_invalid(self, hosts: list[Host]) -> None:
"""Test VerifyReachability.Input invalid inputs."""
with pytest.raises(ValidationError):
VerifyReachability.Input(hosts=hosts)

0 comments on commit 99e91a6

Please sign in to comment.