diff --git a/.github/workflows/github-torch-hub.yml b/.github/workflows/github-torch-hub.yml index a0ee5e4655b..858f7ebb0a3 100644 --- a/.github/workflows/github-torch-hub.yml +++ b/.github/workflows/github-torch-hub.yml @@ -21,7 +21,7 @@ jobs: - name: Install dependencies run: | pip install torch - pip install numpy tokenizers boto3 filelock requests tqdm regex sentencepiece sacremoses + pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses - name: Torch hub list run: | diff --git a/hubconf.py b/hubconf.py index cadad632f17..98d816082b7 100644 --- a/hubconf.py +++ b/hubconf.py @@ -16,7 +16,7 @@ ) -dependencies = ["torch", "numpy", "tokenizers", "boto3", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"] +dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"] @add_start_docstrings(AutoConfig.__doc__) diff --git a/setup.py b/setup.py index 4efa3290039..dd033a9734f 100644 --- a/setup.py +++ b/setup.py @@ -99,8 +99,6 @@ "tokenizers == 0.7.0", # dataclasses for Python versions that don't have it "dataclasses;python_version<'3.7'", - # accessing files from S3 directly - "boto3", # filesystem locks e.g. to prevent parallel downloads "filelock", # for downloading models over HTTPS diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index cf70217a264..52fffc0300c 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -19,10 +19,7 @@ from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile -import boto3 import requests -from botocore.config import Config -from botocore.exceptions import ClientError from filelock import FileLock from tqdm.auto import tqdm @@ -144,7 +141,7 @@ def docstring_decorator(fn): def is_remote_url(url_or_filename): parsed = urlparse(url_or_filename) - return parsed.scheme in ("http", "https", "s3") + return parsed.scheme in ("http", "https") def hf_bucket_url(identifier, postfix=None, cdn=False) -> str: @@ -297,55 +294,6 @@ def cached_path( return output_path -def split_s3_path(url): - """Split a full s3 path into the bucket name and path.""" - parsed = urlparse(url) - if not parsed.netloc or not parsed.path: - raise ValueError("bad s3 path {}".format(url)) - bucket_name = parsed.netloc - s3_path = parsed.path - # Remove '/' at beginning of path. - if s3_path.startswith("/"): - s3_path = s3_path[1:] - return bucket_name, s3_path - - -def s3_request(func): - """ - Wrapper function for s3 requests in order to create more helpful error - messages. - """ - - @wraps(func) - def wrapper(url, *args, **kwargs): - try: - return func(url, *args, **kwargs) - except ClientError as exc: - if int(exc.response["Error"]["Code"]) == 404: - raise EnvironmentError("file {} not found".format(url)) - else: - raise - - return wrapper - - -@s3_request -def s3_etag(url, proxies=None): - """Check ETag on S3 object.""" - s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) - bucket_name, s3_path = split_s3_path(url) - s3_object = s3_resource.Object(bucket_name, s3_path) - return s3_object.e_tag - - -@s3_request -def s3_get(url, temp_file, proxies=None): - """Pull a file directly from S3.""" - s3_resource = boto3.resource("s3", config=Config(proxies=proxies)) - bucket_name, s3_path = split_s3_path(url) - s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) - - def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None): ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0]) if is_torch_available(): @@ -406,17 +354,13 @@ def get_from_cache( etag = None if not local_files_only: - # Get eTag to add to filename, if it exists. - if url.startswith("s3://"): - etag = s3_etag(url, proxies=proxies) - else: - try: - response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) - if response.status_code == 200: - etag = response.headers.get("ETag") - except (EnvironmentError, requests.exceptions.Timeout): - # etag is already None - pass + try: + response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout) + if response.status_code == 200: + etag = response.headers.get("ETag") + except (EnvironmentError, requests.exceptions.Timeout): + # etag is already None + pass filename = url_to_filename(url, etag) @@ -483,13 +427,7 @@ def _resumable_file_manager(): with temp_file_manager() as temp_file: logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name) - # GET file object - if url.startswith("s3://"): - if resume_download: - logger.warn('Warning: resumable downloads are not implemented for "s3://" urls') - s3_get(url, temp_file, proxies=proxies) - else: - http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) + http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent) logger.info("storing %s in cache at %s", url, cache_path) os.replace(temp_file.name, cache_path)