Skip to content

Commit

Permalink
Add: Allow to request a specific number of CVEs and CPEs
Browse files Browse the repository at this point in the history
Extend the NVD CVE and CPE API to allow requesting a specific number of
results. This has become necessary to just test the CPE API because
otherwise more then 1 million CPEs would be downloaded.
  • Loading branch information
bjoernricks committed Nov 5, 2023
1 parent 4b72fd8 commit 86e5dd7
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 101 deletions.
1 change: 1 addition & 0 deletions pontos/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

Headers = Dict[str, str]
Params = Dict[str, Union[str, int]]
JSON = dict[str, Union[int, str, dict[str, Any]]]

__all__ = (
"convert_camel_case",
Expand Down
47 changes: 35 additions & 12 deletions pontos/nvd/cpe/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
from datetime import datetime
from types import TracebackType
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Optional,
Expand All @@ -34,14 +32,17 @@
from pontos.errors import PontosError
from pontos.nvd.api import (
DEFAULT_TIMEOUT_CONFIG,
JSON,
NVDApi,
Params,
convert_camel_case,
format_date,
now,
)
from pontos.nvd.models.cpe import CPE

DEFAULT_NIST_NVD_CPES_URL = "https://services.nvd.nist.gov/rest/json/cpes/2.0"
MAX_CPES_PER_PAGE = 10000


class CPEApi(NVDApi):
Expand Down Expand Up @@ -131,6 +132,7 @@ async def cpes(
cpe_match_string: Optional[str] = None,
keywords: Optional[Union[List[str], str]] = None,
match_criteria_id: Optional[str] = None,
request_results: Optional[int] = None,
) -> AsyncIterator[CPE]:
"""
Get all CPEs for the provided arguments
Expand All @@ -148,6 +150,8 @@ async def cpes(
the metadata title or reference links.
match_criteria_id: Returns all CPE records associated with a match
string identified by its UUID.
request_results: Number of CPEs to download. Set to None (default)
to download all available CPEs.
Returns:
An async iterator of CPE model instances.
Expand All @@ -161,9 +165,7 @@ async def cpes(
async for cpe in api.cpes(keywords=["Mac OS X"]):
print(cpe.cpe_name, cpe.cpe_name_id)
"""
total_results = None

params: Dict[str, Union[str, int]] = {}
params: Params = {}
if last_modified_start_date:
params["lastModStartDate"] = format_date(last_modified_start_date)
if not last_modified_end_date:
Expand All @@ -186,9 +188,18 @@ async def cpes(
params["matchCriteriaId"] = match_criteria_id

start_index = 0
results_per_page = None
downloaded_results = 0
results_per_page = (
request_results
if request_results and request_results < MAX_CPES_PER_PAGE
else MAX_CPES_PER_PAGE
)
total_results = None
requested_results = request_results

while total_results is None or start_index < total_results:
while (
requested_results is None or downloaded_results < requested_results
):
params["startIndex"] = start_index

if results_per_page is not None:
Expand All @@ -197,19 +208,31 @@ async def cpes(
response = await self._get(params=params)
response.raise_for_status()

data: Dict[str, Union[int, str, Dict[str, Any]]] = response.json(
object_hook=convert_camel_case
)
data: JSON = response.json(object_hook=convert_camel_case)

results_per_page: int = data["results_per_page"] # type: ignore
total_results: int = data["total_results"] # type: ignore
products: Iterable = data.get("products", []) # type: ignore

if not requested_results:
requested_results = total_results

for product in products:
yield CPE.from_dict(product["cpe"])

if results_per_page is not None:
start_index += results_per_page
if results_per_page is None:
# just be safe here. should never occur
results_per_page = len(products)

start_index += results_per_page
downloaded_results += results_per_page

if (
request_results
and downloaded_results + results_per_page > request_results
):
# avoid downloading more results then requested
results_per_page = request_results - downloaded_results

async def __aenter__(self) -> "CPEApi":
await super().__aenter__()
Expand Down
46 changes: 34 additions & 12 deletions pontos/nvd/cve/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
from datetime import datetime
from types import TracebackType
from typing import (
Any,
AsyncIterator,
Dict,
Iterable,
List,
Optional,
Expand All @@ -33,6 +31,7 @@
from pontos.errors import PontosError
from pontos.nvd.api import (
DEFAULT_TIMEOUT_CONFIG,
JSON,
NVDApi,
Params,
convert_camel_case,
Expand All @@ -46,6 +45,7 @@
__all__ = ("CVEApi",)

DEFAULT_NIST_NVD_CVES_URL = "https://services.nvd.nist.gov/rest/json/cves/2.0"
MAX_CVES_PER_PAGE = 10000


class CVEApi(NVDApi):
Expand Down Expand Up @@ -112,6 +112,7 @@ async def cves(
has_cert_notes: Optional[bool] = None,
has_kev: Optional[bool] = None,
has_oval: Optional[bool] = None,
request_results: Optional[int] = None,
) -> AsyncIterator[CVE]:
"""
Get all CVEs for the provided arguments
Expand Down Expand Up @@ -160,6 +161,8 @@ async def cves(
has_oval: Returns the CVEs that contain information from MITRE's
Open Vulnerability and Assessment Language (OVAL) before this
transitioned to the Center for Internet Security (CIS).
request_results: Number of CVEs to download. Set to None (default)
to download all available CVEs.
Returns:
An async iterator to iterate over CVE model instances
Expand All @@ -173,8 +176,6 @@ async def cves(
async for cve in api.cves(keywords=["Mac OS X", "kernel"]):
print(cve.id)
"""
total_results: Optional[int] = None

params: Params = {}
if last_modified_start_date:
params["lastModStartDate"] = format_date(last_modified_start_date)
Expand Down Expand Up @@ -231,9 +232,18 @@ async def cves(
params["hasOval"] = ""

start_index: int = 0
results_per_page = None
downloaded_results = 0
results_per_page = (
request_results
if request_results and request_results < MAX_CVES_PER_PAGE
else MAX_CVES_PER_PAGE
)
total_results = None
requested_results = request_results

while total_results is None or start_index < total_results:
while (
requested_results is None or downloaded_results < requested_results
):
params["startIndex"] = start_index

if results_per_page is not None:
Expand All @@ -242,21 +252,33 @@ async def cves(
response = await self._get(params=params)
response.raise_for_status()

data: Dict[str, Union[int, str, Dict[str, Any]]] = response.json(
object_hook=convert_camel_case
)
data: JSON = response.json(object_hook=convert_camel_case)

total_results = data["total_results"] # type: ignore
results_per_page: int = data["results_per_page"] # type: ignore
total_results: int = data["total_results"] # type: ignore
vulnerabilities: Iterable = data.get( # type: ignore
"vulnerabilities", []
)

if not requested_results:
requested_results = total_results

for vulnerability in vulnerabilities:
yield CVE.from_dict(vulnerability["cve"])

if results_per_page is not None:
start_index += results_per_page
if results_per_page is None:
# just be safe here. should never occur
results_per_page = len(vulnerabilities)

start_index += results_per_page
downloaded_results += results_per_page

if (
request_results
and downloaded_results + results_per_page > request_results
):
# avoid downloading more results then requested
results_per_page = request_results - downloaded_results

async def cve(self, cve_id: str) -> CVE:
"""
Expand Down
Loading

0 comments on commit 86e5dd7

Please sign in to comment.