Skip to content

Commit

Permalink
Use native ntlk download (#3796)
Browse files Browse the repository at this point in the history
This PR changes how we download NLTK data to use the native nltk
downloader.

We had moved to our own hosted NLTK dataset because of this CVE:
https://nvd.nist.gov/vuln/detail/CVE-2024-39705

Ref: #3361

Latest versions of NLTK have fixed this issue:
https://github.com/nltk/nltk/blob/develop/ChangeLog
  • Loading branch information
vangheem authored Dec 2, 2024
1 parent 9445a2d commit 0fb814d
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 86 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
## 0.16.9-dev0
## 0.16.9

### Enhancements

### Features

### Fixes

- **Fix NLTK Download** to not download from unstructured S3 Bucket

## 0.16.8

### Enhancements
Expand Down
2 changes: 2 additions & 0 deletions test_unstructured/nlp/test_tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def test_nltk_packages_download_if_not_present():
tokenize._download_nltk_packages_if_not_present.cache_clear()
with patch.object(nltk, "find", side_effect=LookupError):
with patch.object(tokenize, "download_nltk_packages") as mock_download:
tokenize._download_nltk_packages_if_not_present()
Expand All @@ -16,6 +17,7 @@ def test_nltk_packages_download_if_not_present():


def test_nltk_packages_do_not_download_if():
tokenize._download_nltk_packages_if_not_present.cache_clear()
with patch.object(nltk, "find"), patch.object(nltk, "download") as mock_download:
tokenize._download_nltk_packages_if_not_present()

Expand Down
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.16.9-dev0" # pragma: no cover
__version__ = "0.16.9" # pragma: no cover
90 changes: 6 additions & 84 deletions unstructured/nlp/tokenize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
from __future__ import annotations

import hashlib
import os
import sys
import tarfile
import tempfile
import urllib.request
from functools import lru_cache
from typing import Final, List, Tuple

Expand All @@ -16,86 +11,10 @@

CACHE_MAX_SIZE: Final[int] = 128

NLTK_DATA_FILENAME = "nltk_data_3.8.2.tar.gz"
NLTK_DATA_URL = f"https://utic-public-cf.s3.amazonaws.com/{NLTK_DATA_FILENAME}"
NLTK_DATA_SHA256 = "ba2ca627c8fb1f1458c15d5a476377a5b664c19deeb99fd088ebf83e140c1663"


# NOTE(robinson) - mimic default dir logic from NLTK
# https://github.com/nltk/nltk/
# blob/8c233dc585b91c7a0c58f96a9d99244a379740d5/nltk/downloader.py#L1046
def get_nltk_data_dir() -> str | None:
"""Locates the directory the nltk data will be saved too. The directory
set by the NLTK environment variable takes highest precedence. Otherwise
the default is determined by the rules indicated below. Returns None when
the directory is not writable.
On Windows, the default download directory is
``PYTHONHOME/lib/nltk``, where *PYTHONHOME* is the
directory containing Python, e.g. ``C:\\Python311``.
On all other platforms, the default directory is the first of
the following which exists or which can be created with write
permission: ``/usr/share/nltk_data``, ``/usr/local/share/nltk_data``,
``/usr/lib/nltk_data``, ``/usr/local/lib/nltk_data``, ``~/nltk_data``.
"""
# Check if we are on GAE where we cannot write into filesystem.
if "APPENGINE_RUNTIME" in os.environ:
return

# Check if we have sufficient permissions to install in a
# variety of system-wide locations.
for nltkdir in nltk.data.path:
if os.path.exists(nltkdir) and nltk.internals.is_writable(nltkdir):
return nltkdir

# On Windows, use %APPDATA%
if sys.platform == "win32" and "APPDATA" in os.environ:
homedir = os.environ["APPDATA"]

# Otherwise, install in the user's home directory.
else:
homedir = os.path.expanduser("~/")
if homedir == "~/":
raise ValueError("Could not find a default download directory")

# NOTE(robinson) - NLTK appends nltk_data to the homedir. That's already
# present in the tar file so we don't have to do that here.
return homedir


def download_nltk_packages():
nltk_data_dir = get_nltk_data_dir()

if nltk_data_dir is None:
raise OSError("NLTK data directory does not exist or is not writable.")

# Check if the path ends with "nltk_data" and remove it if it does
if nltk_data_dir.endswith("nltk_data"):
nltk_data_dir = os.path.dirname(nltk_data_dir)

def sha256_checksum(filename: str, block_size: int = 65536):
sha256 = hashlib.sha256()
with open(filename, "rb") as f:
for block in iter(lambda: f.read(block_size), b""):
sha256.update(block)
return sha256.hexdigest()

with tempfile.TemporaryDirectory() as temp_dir_path:
tgz_file_path = os.path.join(temp_dir_path, NLTK_DATA_FILENAME)
urllib.request.urlretrieve(NLTK_DATA_URL, tgz_file_path)

file_hash = sha256_checksum(tgz_file_path)
if file_hash != NLTK_DATA_SHA256:
os.remove(tgz_file_path)
raise ValueError(f"SHA-256 mismatch: expected {NLTK_DATA_SHA256}, got {file_hash}")

# Extract the contents
if not os.path.exists(nltk_data_dir):
os.makedirs(nltk_data_dir)

with tarfile.open(tgz_file_path, "r:gz") as tar:
tar.extractall(path=nltk_data_dir)
nltk.download("averaged_perceptron_tagger_eng", quiet=True)
nltk.download("punkt_tab", quiet=True)


def check_for_nltk_package(package_name: str, package_category: str) -> bool:
Expand All @@ -109,10 +28,13 @@ def check_for_nltk_package(package_name: str, package_category: str) -> bool:
try:
nltk.find(f"{package_category}/{package_name}", paths=paths)
return True
except LookupError:
except (LookupError, OSError):
return False


# We cache this because we do not want to attempt
# downloading the packages multiple times
@lru_cache()
def _download_nltk_packages_if_not_present():
"""If required NLTK packages are not available, download them."""

Expand Down

0 comments on commit 0fb814d

Please sign in to comment.