From 99e91a65e13d2903d26b910a12915741ab9e51a6 Mon Sep 17 00:00:00 2001 From: vitthalmagadum Date: Mon, 3 Feb 2025 23:00:54 -0500 Subject: [PATCH] Added validator for source and destination should beling to same IP add family --- anta/input_models/connectivity.py | 6 +-- anta/tests/connectivity.py | 14 +++++- tests/units/input_models/test_connectivity.py | 43 +++++++++++++++++++ 3 files changed, 58 insertions(+), 5 deletions(-) create mode 100644 tests/units/input_models/test_connectivity.py diff --git a/anta/input_models/connectivity.py b/anta/input_models/connectivity.py index 50e97f0f2..1a904ac1d 100644 --- a/anta/input_models/connectivity.py +++ b/anta/input_models/connectivity.py @@ -11,15 +11,15 @@ from pydantic import BaseModel, ConfigDict -from anta.custom_types import Hostname, Interface +from anta.custom_types import Interface class Host(BaseModel): """Model for a remote host to ping.""" model_config = ConfigDict(extra="forbid") - destination: IPv4Address | IPv6Address | Hostname - """Destination address or hostname to ping.""" + destination: IPv4Address | IPv6Address + """Destination address to ping.""" source: IPv4Address | IPv6Address | Interface """Source address IP or egress interface to use.""" vrf: str = "default" diff --git a/anta/tests/connectivity.py b/anta/tests/connectivity.py index f3bc30e6a..6aad75ce9 100644 --- a/anta/tests/connectivity.py +++ b/anta/tests/connectivity.py @@ -9,6 +9,8 @@ from typing import ClassVar +from pydantic import field_validator + from anta.input_models.connectivity import Host, LLDPNeighbor, Neighbor from anta.models import AntaCommand, AntaTemplate, AntaTest @@ -56,8 +58,16 @@ class Input(AntaTest.Input): hosts: list[Host] """List of host to ping.""" - Host: ClassVar[type[Host]] = Host - """To maintain backward compatibility.""" + + @field_validator("hosts") + @classmethod + def validate_hosts(cls, hosts: list[Host]) -> list[Host]: + """Validate the 'destination' and 'source' IP address family in each host.""" + for host in hosts: + if not isinstance(host.source, str) and host.destination.version != host.source.version: + msg = f"{host} IP address family for destination does not match source" + raise ValueError(msg) + return hosts def render(self, template: AntaTemplate) -> list[AntaCommand]: """Render the template for each host in the input list.""" diff --git a/tests/units/input_models/test_connectivity.py b/tests/units/input_models/test_connectivity.py new file mode 100644 index 000000000..9e4288cf8 --- /dev/null +++ b/tests/units/input_models/test_connectivity.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023-2025 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Tests for anta.input_models.connectivity.py.""" + +# pylint: disable=C0302 +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +from pydantic import ValidationError + +from anta.tests.connectivity import VerifyReachability + +if TYPE_CHECKING: + from anta.input_models.connectivity import Host + + +class TestVerifyReachabilityInput: + """Test anta.tests.connectivity.VerifyReachability.Input.""" + + @pytest.mark.parametrize( + ("hosts"), + [ + pytest.param([{"destination": "fd12:3456:789a:1::2", "source": "fd12:3456:789a:1::1"}], id="valid"), + ], + ) + def test_valid(self, hosts: list[Host]) -> None: + """Test VerifyReachability.Input valid inputs.""" + VerifyReachability.Input(hosts=hosts) + + @pytest.mark.parametrize( + ("hosts"), + [ + pytest.param([{"destination": "fd12:3456:789a:1::2", "source": "192.168.0.10"}], id="invalid-source"), + pytest.param([{"destination": "192.168.0.10", "source": "fd12:3456:789a:1::2"}], id="invalid-destination"), + ], + ) + def test_invalid(self, hosts: list[Host]) -> None: + """Test VerifyReachability.Input invalid inputs.""" + with pytest.raises(ValidationError): + VerifyReachability.Input(hosts=hosts)