Skip to content

Commit

Permalink
Fix where() so that it works with importlib.resources when available
Browse files Browse the repository at this point in the history
  • Loading branch information
dstufft committed Apr 5, 2020
1 parent 5efdd48 commit ca60f0f
Showing 1 changed file with 35 additions and 5 deletions.
40 changes: 35 additions & 5 deletions certifi/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,36 @@
import os

try:
from importlib.resources import read_text
from importlib.resources import path as get_path, read_text

_CACERT_CTX = None
_CACERT_PATH = None

def where():
# This is slightly terrible, but we want to delay extracting the file
# in cases where we're inside of a zipimport situation until someone
# actually calls where(), but we don't want to re-extract the file
# on every call of where(), so we'll do it once then store it in a
# global variable.
global _CACERT_CTX
global _CACERT_PATH
if _CACERT_PATH is None:
# This is slightly janky, the importlib.resources API wants you to
# manage the cleanup of this file, so it doesn't actually return a
# path, it returns a context manager that will give you the path
# when you enter it and will do any cleanup when you leave it. In
# the common case of not needing a temporary file, it will just
# return the file system location and the __exit__() is a no-op.
#
# We also have to hold onto the actual context manager, because
# it will do the cleanup whenever it gets garbage collected, so
# we will also store that at the global level as well.
_CACERT_CTX = get_path("certifi", "cacert.pem")
_CACERT_PATH = str(_CACERT_CTX.__enter__())

return _CACERT_PATH


except ImportError:
# This fallback will work for Python versions prior to 3.7 that lack the
# importlib.resources module but relies on the existing `where` function
Expand All @@ -19,11 +48,12 @@ def read_text(_module, _path, encoding="ascii"):
with open(where(), "r", encoding=encoding) as data:
return data.read()

# If we don't have importlib.resources, then we will just do the old logic
# of assuming we're on the filesystem and munge the path directly.
def where():
f = os.path.dirname(__file__)

def where():
f = os.path.dirname(__file__)

return os.path.join(f, "cacert.pem")
return os.path.join(f, "cacert.pem")


def contents():
Expand Down

0 comments on commit ca60f0f

Please sign in to comment.