Skip to content

Commit

Permalink
fix(cache): handle get_from_cache=None and ensure directory exists (#544
Browse files Browse the repository at this point in the history
)

Signed-off-by: Dylan Pulver <[email protected]>
  • Loading branch information
dylanpulver authored Jul 4, 2024
1 parent f15d790 commit b2f0a16
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 97 deletions.
123 changes: 67 additions & 56 deletions safety/safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
from collections import defaultdict
from datetime import datetime
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Any

import click
import requests
Expand All @@ -21,6 +21,7 @@
from packaging.utils import canonicalize_name
from packaging.version import parse as parse_version, Version
from pydantic.json import pydantic_encoder
from filelock import FileLock

from safety_schemas.models import Ecosystem, FileType

Expand All @@ -41,34 +42,38 @@
LOG = logging.getLogger(__name__)


def get_from_cache(db_name, cache_valid_seconds=0, skip_time_verification=False):
if os.path.exists(DB_CACHE_FILE):
with open(DB_CACHE_FILE) as f:
try:
data = json.loads(f.read())
if db_name in data:
def get_from_cache(db_name: str, cache_valid_seconds: int = 0, skip_time_verification: bool = False) -> Optional[Dict[str, Any]]:
cache_file_lock = f"{DB_CACHE_FILE}.lock"
os.makedirs(os.path.dirname(cache_file_lock), exist_ok=True)
lock = FileLock(cache_file_lock, timeout=10)
with lock:
if os.path.exists(DB_CACHE_FILE):
with open(DB_CACHE_FILE) as f:
try:
data = json.loads(f.read())
if db_name in data:

if "cached_at" in data[db_name]:
if data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification:
LOG.debug('Getting the database from cache at %s, cache setting: %s',
data[db_name]["cached_at"], cache_valid_seconds)

try:
data[db_name]["db"]["meta"]["base_domain"] = "https://data.safetycli.com"
except KeyError as e:
pass
if "cached_at" in data[db_name]:
if data[db_name]["cached_at"] + cache_valid_seconds > time.time() or skip_time_verification:
LOG.debug('Getting the database from cache at %s, cache setting: %s',
data[db_name]["cached_at"], cache_valid_seconds)

return data[db_name]["db"]
try:
data[db_name]["db"]["meta"]["base_domain"] = "https://data.safetycli.com"
except KeyError as e:
pass

LOG.debug('Cached file is too old, it was cached at %s', data[db_name]["cached_at"])
else:
LOG.debug('There is not the cached_at key in %s database', data[db_name])
return data[db_name]["db"]

except json.JSONDecodeError:
LOG.debug('JSONDecodeError trying to get the cached database.')
else:
LOG.debug("Cache file doesn't exist...")
return False
LOG.debug('Cached file is too old, it was cached at %s', data[db_name]["cached_at"])
else:
LOG.debug('There is not the cached_at key in %s database', data[db_name])

except json.JSONDecodeError:
LOG.debug('JSONDecodeError trying to get the cached database.')
else:
LOG.debug("Cache file doesn't exist...")
return None


def write_to_cache(db_name, data):
Expand All @@ -95,25 +100,31 @@ def write_to_cache(db_name, data):
if exc.errno != errno.EEXIST:
raise

with open(DB_CACHE_FILE, "r") as f:
try:
cache = json.loads(f.read())
except json.JSONDecodeError:
LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.')
cache_file_lock = f"{DB_CACHE_FILE}.lock"
lock = FileLock(cache_file_lock, timeout=10)
with lock:
if os.path.exists(DB_CACHE_FILE):
with open(DB_CACHE_FILE, "r") as f:
try:
cache = json.loads(f.read())
except json.JSONDecodeError:
LOG.debug('JSONDecodeError in the local cache, dumping the full cache file.')
cache = {}
else:
cache = {}

with open(DB_CACHE_FILE, "w") as f:
cache[db_name] = {
"cached_at": time.time(),
"db": data
}
f.write(json.dumps(cache))
LOG.debug('Safety updated the cache file for %s database.', db_name)
with open(DB_CACHE_FILE, "w") as f:
cache[db_name] = {
"cached_at": time.time(),
"db": data
}
f.write(json.dumps(cache))
LOG.debug('Safety updated the cache file for %s database.', db_name)


def fetch_database_url(session, mirror, db_name, cached, telemetry=True,
ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True):
headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value}
headers = {'schema-version': JSON_SCHEMA_VERSION, 'ecosystem': ecosystem.value}

if cached and from_cache:
cached_data = get_from_cache(db_name=db_name, cache_valid_seconds=cached)
Expand All @@ -122,13 +133,13 @@ def fetch_database_url(session, mirror, db_name, cached, telemetry=True,
return cached_data
url = mirror + db_name


telemetry_data = {
'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry),
'telemetry': json.dumps(build_telemetry_data(telemetry=telemetry),
default=pydantic_encoder)}

try:
r = session.get(url=url, timeout=REQUEST_TIMEOUT,
r = session.get(url=url, timeout=REQUEST_TIMEOUT,
headers=headers, params=telemetry_data)
except requests.exceptions.ConnectionError:
raise NetworkConnectionError()
Expand Down Expand Up @@ -205,10 +216,10 @@ def fetch_database_file(path: str, db_name: str, cached = 0,

if not full_path.exists():
raise DatabaseFileNotFoundError(db=path)

with open(full_path) as f:
data = json.loads(f.read())

if cached:
LOG.info('Writing %s to cache because cached value was %s', db_name, cached)
write_to_cache(db_name, data)
Expand All @@ -226,7 +237,7 @@ def is_valid_database(db) -> bool:
return False


def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
ecosystem: Optional[Ecosystem] = None, from_cache=True):

if session.is_using_auth_credentials():
Expand All @@ -242,7 +253,7 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True,
if is_a_remote_mirror(mirror):
if ecosystem is None:
ecosystem = Ecosystem.PYTHON
data = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
data = fetch_database_url(session, mirror, db_name=db_name, cached=cached,
telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache)
else:
data = fetch_database_file(mirror, db_name=db_name, cached=cached,
Expand Down Expand Up @@ -562,16 +573,16 @@ def compute_sec_ver(remediations, packages: Dict[str, Package], secure_vulns_by_
secure_v = compute_sec_ver_for_user(package=pkg, secure_vulns_by_user=secure_vulns_by_user, db_full=db_full)

rem['closest_secure_version'] = get_closest_ver(secure_v, version, spec)

upgrade = rem['closest_secure_version'].get('upper', None)
downgrade = rem['closest_secure_version'].get('lower', None)
recommended_version = None

if upgrade:
recommended_version = upgrade
elif downgrade:
recommended_version = downgrade

rem['recommended_version'] = recommended_version
rem['other_recommended_versions'] = [other_v for other_v in secure_v if
other_v != str(recommended_version)]
Expand Down Expand Up @@ -645,12 +656,12 @@ def process_fixes(files, remediations, auto_remediation_limit, output, no_output

def process_fixes_scan(file_to_fix, to_fix_spec, auto_remediation_limit, output, no_output=True, prompt=False):
to_fix_remediations = []

def get_remmediation_from(spec):
upper = None
lower = None
recommended = None

try:
upper = Version(spec.remediation.closest_secure.upper) if spec.remediation.closest_secure.upper else None
except Exception as e:
Expand All @@ -664,15 +675,15 @@ def get_remmediation_from(spec):
try:
recommended = Version(spec.remediation.recommended)
except Exception as e:
LOG.error(f'Error getting recommended version for remediation, ignoring', exc_info=True)
LOG.error(f'Error getting recommended version for remediation, ignoring', exc_info=True)

return {
"vulnerabilities_found": spec.remediation.vulnerabilities_found,
"version": next(iter(spec.specifier)).version if spec.is_pinned() else None,
"requirement": spec,
"more_info_url": spec.remediation.more_info_url,
"closest_secure_version": {
'upper': upper,
'upper': upper,
'lower': lower
},
"recommended_version": recommended,
Expand All @@ -690,7 +701,7 @@ def get_remmediation_from(spec):
'files': {str(file_to_fix.location): {'content': None, 'fixes': {'TO_SKIP': [], 'TO_APPLY': [], 'TO_CONFIRM': []}, 'supported': False, 'filename': file_to_fix.location.name}},
'dependencies': defaultdict(dict),
}

fixes = apply_fixes(requirements, output, no_output, prompt, scan_flow=True, auto_remediation_limit=auto_remediation_limit)

return fixes
Expand Down Expand Up @@ -822,7 +833,7 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto
for name, data in requirements['files'].items():
output = [('', {}),
(f"Analyzing {name}... [{get_fix_opt_used_msg(auto_remediation_limit)} limit]", {'styling': {'bold': True}, 'start_line_decorator': '->', 'indent': ' '})]

r_skip = data['fixes']['TO_SKIP']
r_apply = data['fixes']['TO_APPLY']
r_confirm = data['fixes']['TO_CONFIRM']
Expand Down Expand Up @@ -901,7 +912,7 @@ def apply_fixes(requirements, out_type, no_output, prompt, scan_flow=False, auto
else:
not_supported_filename = data.get('filename', name)
output.append(
(f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.",
(f"{not_supported_filename} updates not supported: Please update these dependencies using your package manager.",
{'start_line_decorator': ' -', 'indent': ' '}))
output.append(('', {}))

Expand Down Expand Up @@ -999,7 +1010,7 @@ def review(*, report=None, params=None):

@sync_safety_context
def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True):

if db_mirror:
mirrors = [db_mirror]
else:
Expand Down
Loading

0 comments on commit b2f0a16

Please sign in to comment.