Skip to content

Commit

Permalink
updated linting
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-forte-elastic committed Oct 23, 2023
1 parent efa0174 commit f701315
Showing 1 changed file with 32 additions and 20 deletions.
52 changes: 32 additions & 20 deletions eql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,28 +193,29 @@ class CidrMatch(FunctionSignature):
additional_types = TypeHint.String.require_literal()
return_value = TypeHint.Boolean

octet_re = r'(?:25[0-5]|2[0-4][0-9]|[01]?[0-9]?[0-9])'
ipv4_re = r'\.'.join([octet_re, octet_re, octet_re, octet_re])
ipv4_compiled = re.compile(r'^{}$'.format(ipv4_re))
cidrv4_compiled = re.compile(r'^{}/(?:3[0-2]|2[0-9]|1[0-9]|[0-9])$'.format(ipv4_re))
h16_re = r'[a-fA-F0-9]{1,4}'
ipv6_re = r'^(?:[a-fA-F0-9]{1,4}:){7}[a-fA-F0-9]{1,4}$'
octet_re = r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9]?[0-9])"
ipv4_re = r"\.".join([octet_re, octet_re, octet_re, octet_re])
ipv4_compiled = re.compile(r"^{}$".format(ipv4_re))
cidrv4_compiled = re.compile(r"^{}/(?:3[0-2]|2[0-9]|1[0-9]|[0-9])$".format(ipv4_re))

h16_re = r"[a-fA-F0-9]{1,4}"
ipv6_re = r"^(?:[a-fA-F0-9]{1,4}:){7}[a-fA-F0-9]{1,4}$"
ipv6_compiled = re.compile(ipv6_re)
ipv6_shorthand_re = r'^([a-fA-F0-9]{1,4}:){0,6}(:[a-fA-F0-9]{1,4}){0,6}$'
ipv6_shorthand_re = r"^([a-fA-F0-9]{1,4}:){0,6}(:[a-fA-F0-9]{1,4}){0,6}$"
ipv6_shorthand_compiled = re.compile(ipv6_shorthand_re)
cidrv6_re = r'(?:[a-fA-F0-9]{1,4}:){0,7}[a-fA-F0-9]{1,4}/(?:12[0-8]|1[01][0-9]|[0-9]{1,2})'
cidrv6_compiled = re.compile('^{}$'.format(cidrv6_re))
cidrv6_shorthand_re = r'^([a-fA-F0-9]{1,4}:){0,7}(:[a-fA-F0-9]{1,4}){0,7}/\d{1,3}$'
cidrv6_re = r"(?:[a-fA-F0-9]{1,4}:){0,7}[a-fA-F0-9]{1,4}/(?:12[0-8]|1[01][0-9]|[0-9]{1,2})"
cidrv6_compiled = re.compile("^{}$".format(cidrv6_re))
cidrv6_shorthand_re = r"^([a-fA-F0-9]{1,4}:){0,7}(:[a-fA-F0-9]{1,4}){0,7}/\d{1,3}$"
cidrv6_shorthand_compiled = re.compile(cidrv6_shorthand_re)


# store it in native representation, then recover it in network order
masks4 = [struct.unpack(">L", struct.pack(">L", MAX_IP & ~(MAX_IP >> b)))[0] for b in range(33)]
mask_addresses4 = [socket.inet_ntoa(struct.pack(">L", m)) for m in masks4]

masks6 = [int('1' * i + '0' * (128 - i), 2) for i in range(129)]
mask_addresses6 = [socket.inet_ntop(socket.AF_INET6, struct.pack(">QQ", m >> 64, m & 0xFFFFFFFFFFFFFFFF)) for m in masks6]
masks6 = [int("1" * i + "0" * (128 - i), 2) for i in range(129)]
mask_addresses6 = [
socket.inet_ntop(socket.AF_INET6, struct.pack(">QQ", m >> 64, m & 0xFFFFFFFFFFFFFFFF)) for m in masks6
]

@classmethod
def expand_ipv6(cls, cidr):
Expand Down Expand Up @@ -257,15 +258,14 @@ def expand_ipv6_address(cls, ipv6_address):

return full_address


@classmethod
def to_mask(cls, cidr_string):
"""Split an IP address plus cidr block to the mask."""
ip_string, size = cidr_string.split("/")
size = int(size)
if cls.ipv4_compiled.match(ip_string):
ip_bytes = socket.inet_aton(ip_string)
subnet_int, = struct.unpack(">L", ip_bytes)
(subnet_int,) = struct.unpack(">L", ip_bytes)
mask = cls.masks4[size]
elif cls.ipv6_compiled.match(ip_string) or cls.ipv6_shorthand_compiled.match(ip_string):
ip_string = cls.expand_ipv6_address(ip_string)
Expand Down Expand Up @@ -342,7 +342,9 @@ def to_range(cls, cidr):
mask = cls.masks6[prefix_len]
max_ip_integer = ip_integer | (MAX_IP6 ^ mask)
min_h16s = struct.unpack(">8H", struct.pack(">QQ", ip_integer >> 64, ip_integer & 0xFFFFFFFFFFFFFFFF))
max_h16s = struct.unpack(">8H", struct.pack(">QQ", max_ip_integer >> 64, max_ip_integer & 0xFFFFFFFFFFFFFFFF))
max_h16s = struct.unpack(
">8H", struct.pack(">QQ", max_ip_integer >> 64, max_ip_integer & 0xFFFFFFFFFFFFFFFF)
)
min_octets = [h16 >> 8 for h16 in min_h16s] + [h16 & 0xFF for h16 in min_h16s[6:]]
max_octets = [h16 >> 8 for h16 in max_h16s] + [h16 & 0xFF for h16 in max_h16s[6:]]
else:
Expand All @@ -362,7 +364,11 @@ def get_callback(cls, _, *cidr_matches):
ipv6_masks.append(cls.to_mask(cidr.value))

def callback(source, *_):
if is_string(source) and (cls.ipv4_compiled.match(source) or cls.ipv6_compiled.match(source) or cls.ipv6_shorthand_compiled.match(source)):
if is_string(source) and (
cls.ipv4_compiled.match(source)
or cls.ipv6_compiled.match(source)
or cls.ipv6_shorthand_compiled.match(source)
):
if cls.ipv4_compiled.match(source):
ip_integer, _ = cls.to_mask(source + "/32")
for subnet, mask in ipv4_masks:
Expand All @@ -382,7 +388,11 @@ def callback(source, *_):
@classmethod
def run(cls, ip_address, *cidr_matches):
"""Compare an IP address against a list of cidr blocks."""
if is_string(ip_address) and (cls.ipv4_compiled.match(ip_address) or cls.ipv6_compiled.match(ip_address) or cls.ipv6_shorthand_compiled.match(ip_address)):
if is_string(ip_address) and (
cls.ipv4_compiled.match(ip_address)
or cls.ipv6_compiled.match(ip_address)
or cls.ipv6_shorthand_compiled.match(ip_address)
):
if cls.ipv4_compiled.match(ip_address):
ip_integer, _ = cls.to_mask(ip_address + "/32")
for cidr in cidr_matches:
Expand All @@ -394,7 +404,9 @@ def run(cls, ip_address, *cidr_matches):
ip_address = cls.expand_ipv6_address(ip_address)
ip_integer, _ = cls.to_mask(ip_address + "/128")
for cidr in cidr_matches:
if is_string(cidr) and (cls.cidrv6_compiled.match(cidr) or cls.cidrv6_shorthand_compiled.match(cidr)):
if is_string(cidr) and (
cls.cidrv6_compiled.match(cidr) or cls.cidrv6_shorthand_compiled.match(cidr)
):
cidr = cls.expand_ipv6(cidr)
subnet, mask = cls.to_mask(cidr)
if ip_integer & mask == subnet:
Expand Down

0 comments on commit f701315

Please sign in to comment.