From 4c43dc8c0247b1aa78869a7ee5b6bb369050b4cc Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Fri, 26 Jul 2019 13:50:26 -0700 Subject: [PATCH 01/45] Add implementation of directory saving --- sacred/experiment.py | 3 +- sacred/observers/__init__.py | 3 +- sacred/observers/s3_observer.py | 272 ++++++++++++++++++++++++++++++++ sacred/run.py | 20 ++- 4 files changed, 289 insertions(+), 9 deletions(-) create mode 100644 sacred/observers/s3_observer.py diff --git a/sacred/experiment.py b/sacred/experiment.py index aa1014ef..4e1a6289 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -333,6 +333,7 @@ def add_artifact( filename, name=None, metadata=None, + recursive=False, content_type=None, ): """Add a file as an artifact. @@ -359,7 +360,7 @@ def add_artifact( This only has an effect when using the MongoObserver. """ assert self.current_run is not None, "Can only be called during a run." - self.current_run.add_artifact(filename, name, metadata, content_type) + self.current_run.add_artifact(filename, name, recursive, metadata, content_type) @property def info(self): diff --git a/sacred/observers/__init__.py b/sacred/observers/__init__.py index d0b81b52..a577605f 100644 --- a/sacred/observers/__init__.py +++ b/sacred/observers/__init__.py @@ -8,8 +8,9 @@ from sacred.observers.tinydb_hashfs import TinyDbObserver, TinyDbReader from sacred.observers.slack import SlackObserver from sacred.observers.telegram_obs import TelegramObserver +from sacred.observers.s3_observer import S3FileObserver __all__ = ('FileStorageObserver', 'RunObserver', 'MongoObserver', 'SqlObserver', 'TinyDbObserver', 'TinyDbReader', - 'SlackObserver', 'TelegramObserver') + 'SlackObserver', 'TelegramObserver', 'S3FileObserver') diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py new file mode 100644 index 00000000..a24f7f77 --- /dev/null +++ b/sacred/observers/s3_observer.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python +# coding=utf-8 + +import json +import os +import os.path + +import boto3 +from botocore.errorfactory import ClientError + +from shutil import copyfile + +from sacred.commandline_options import CommandLineOption +from sacred.dependencies import get_digest +from sacred.observers.base import RunObserver +from sacred import optional as opt +from sacred.serializer import flatten +import re + +DEFAULT_S3_PRIORITY = 20 + + +class S3FileObserver(RunObserver): + VERSION = 'S3FileObserver-0.1.0' + + @classmethod + def create(cls, bucket, basedir, resource_dir=None, source_dir=None, priority=DEFAULT_S3_PRIORITY): + resource_dir = resource_dir or os.path.join(basedir, '_resources') + source_dir = source_dir or os.path.join(basedir, '_sources') + + return cls(bucket, basedir, resource_dir, source_dir, priority) + + + def __init__(self, bucket, basedir, resource_dir, source_dir, + priority=DEFAULT_S3_PRIORITY): + self.basedir = basedir + self.bucket = bucket + self.resource_dir = resource_dir + self.source_dir = source_dir + self.priority = priority + self.dir = None + self.run_entry = None + self.config = None + self.info = None + self.cout = "" + self.cout_write_cursor = 0 + self.s3 = boto3.resource('s3') + self.saved_metrics = {} + + def _list_s3_subdirs(self, prefix=None): + if prefix is None: + prefix = self.basedir + try: + bucket = self.s3.Bucket(self.bucket) + all_keys = [el.key for el in bucket.objects.filter(Prefix=prefix)] + except ClientError as er: + if er.response['Error']['Code'] == 'NoSuchBucket': + return None + else: + raise ClientError(er.response['Error']['Code']) + + subdir_match = f'{prefix}\/(.*)\/' + distinct_subdirs = set([re.match(subdir_match, key).groups()[0] for key in all_keys]) + return list(distinct_subdirs) + + def _create_bucket(self): + session = boto3.session.Session() + current_region = session.region_name + bucket_response = self.s3.create_bucket( + Bucket=self.bucket, + CreateBucketConfiguration={ + 'LocationConstraint': current_region}) + return bucket_response + + def _determine_run_dir(self, _id): + bucket_path_subdirs = self._list_s3_subdirs() + if bucket_path_subdirs is None: + self._create_bucket() + max_run_id = 0 + else: + max_run_id = max([int(d) for d in bucket_path_subdirs if d.isdigit()]) + + self.dir = None + if _id is None: + _id = max_run_id + 1 + + self.run_dir = os.path.join(self.basedir, str(_id)) + return _id + + def queued_event(self, ex_info, command, host_info, queue_time, config, + meta_info, _id): + _id = self._determine_run_dir(_id) + + self.run_entry = { + 'experiment': dict(ex_info), + 'command': command, + 'host': dict(host_info), + 'meta': meta_info, + 'status': 'QUEUED', + } + self.config = config + self.info = {} + + self.save_json(self.run_entry, 'run.json') + self.save_json(self.config, 'config.json') + + for s, m in ex_info['sources']: + self.save_file(s) + + return _id + + def save_sources(self, ex_info): + base_dir = ex_info['base_dir'] + source_info = [] + for s, m in ex_info['sources']: + abspath = os.path.join(base_dir, s) + store_path, md5sum = self.find_or_save(abspath, self.source_dir) + source_info.append([s, os.path.relpath(store_path, self.basedir)]) + return source_info + + def started_event(self, ex_info, command, host_info, start_time, config, + meta_info, _id): + + _id = self._determine_run_dir(_id) + + ex_info['sources'] = self.save_sources(ex_info) + + self.run_entry = { + 'experiment': dict(ex_info), + 'command': command, + 'host': dict(host_info), + 'start_time': start_time.isoformat(), + 'meta': meta_info, + 'status': 'RUNNING', + 'resources': [], + 'artifacts': [], + 'heartbeat': None + } + self.config = config + self.info = {} + self.cout = "" + self.cout_write_cursor = 0 + + self.save_json(self.run_entry, 'run.json') + self.save_json(self.config, 'config.json') + self.save_cout() + + return _id + + def find_or_save(self, filename, store_dir): + source_name, ext = os.path.splitext(os.path.basename(filename)) + md5sum = get_digest(filename) + store_name = source_name + '_' + md5sum + ext + store_path = os.path.join(store_dir, store_name) + if len(self._list_s3_subdirs(prefix=store_path)) == 0: + self.save_file(filename, store_path) + return store_path, md5sum + + def put_data(self, key, binary_data): + self.s3.Object(self.bucket, key).put(binary_data) + + def save_json(self, obj, filename): + key = os.path.join(self.run_dir, filename) + self.put_data(key, json.dumps(flatten(obj), + sort_keys=True, indent=2)) + + def save_file(self, filename, target_name=None): + target_name = target_name or os.path.basename(filename) + key = os.path.join(self.run_dir, target_name) + self.put_data(key, open(filename, 'rb')) + + def save_directory(self, source_dir, target_name): + # Stolen from: https://github.com/boto/boto3/issues/358#issuecomment-346093506 + target_name = target_name or os.path.basename(source_dir) + all_files = [] + for root, dirs, files in os.walk(source_dir): + all_files += [os.path.join(root, f) for f in files] + s3_resource = boto3.resource('s3') + + for filename in all_files: + s3_resource.Object(self.bucket, + os.path.join(self.run_dir, target_name, os.path.relpath(filename, source_dir))) \ + .put(Body=open(filename, 'rb')) + + def save_cout(self): + binary_data = self.cout[self.cout_write_cursor:].encode("utf-8") + key = os.path.join(self.run_dir, 'cout.txt') + self.put_data(key, binary_data) + self.cout_write_cursor = len(self.cout) + + + def heartbeat_event(self, info, captured_out, beat_time, result): + self.info = info + self.run_entry['heartbeat'] = beat_time.isoformat() + self.run_entry['result'] = result + self.cout = captured_out + self.save_cout() + self.save_json(self.run_entry, 'run.json') + if self.info: + self.save_json(self.info, 'info.json') + + def completed_event(self, stop_time, result): + self.run_entry['stop_time'] = stop_time.isoformat() + self.run_entry['result'] = result + self.run_entry['status'] = 'COMPLETED' + + self.save_json(self.run_entry, 'run.json') + + def interrupted_event(self, interrupt_time, status): + self.run_entry['stop_time'] = interrupt_time.isoformat() + self.run_entry['status'] = status + self.save_json(self.run_entry, 'run.json') + + def failed_event(self, fail_time, fail_trace): + self.run_entry['stop_time'] = fail_time.isoformat() + self.run_entry['status'] = 'FAILED' + self.run_entry['fail_trace'] = fail_trace + self.save_json(self.run_entry, 'run.json') + + def resource_event(self, filename): + store_path, md5sum = self.find_or_save(filename, self.resource_dir) + self.run_entry['resources'].append([filename, store_path]) + self.save_json(self.run_entry, 'run.json') + + def artifact_event(self, name, filename, metadata=None, content_type=None): + self.save_file(filename, name) + self.run_entry['artifacts'].append(name) + self.save_json(self.run_entry, 'run.json') + + def artifact_directory_event(self, name, filename): + self.save_directory(filename, name) + self.run_entry['artifacts'].append(name + "/") + self.save_json(self.run_entry, 'run.json') + + def log_metrics(self, metrics_by_name, info): + """Store new measurements into metrics.json. + """ + + for metric_name, metric_ptr in metrics_by_name.items(): + + if metric_name not in self.saved_metrics: + self.saved_metrics[metric_name] = {"values": [], + "steps": [], + "timestamps": []} + + self.saved_metrics[metric_name]["values"] += metric_ptr["values"] + self.saved_metrics[metric_name]["steps"] += metric_ptr["steps"] + + # Manually convert them to avoid passing a datetime dtype handler + # when we're trying to convert into json. + timestamps_norm = [ts.isoformat() + for ts in metric_ptr["timestamps"]] + self.saved_metrics[metric_name]["timestamps"] += timestamps_norm + + self.save_json(self.saved_metrics, 'metrics.json') + + def __eq__(self, other): + if isinstance(other, S3FileObserver): + return self.basedir == other.basedir + return False + + +class S3StorageOption(CommandLineOption): + """Add a S3 File observer to the experiment.""" + + short_flag = 'S3' + arg = 'BUCKET_PATH' + arg_description = "s3:///path/to/exp" + + @classmethod + def apply(cls, args, run): + run.observers.append(S3FileObserver.create(args)) diff --git a/sacred/run.py b/sacred/run.py index 10cc6d25..6c9ad0a9 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -160,6 +160,7 @@ def add_artifact( self, filename, name=None, + recursive=False, metadata=None, content_type=None, ): @@ -187,7 +188,7 @@ def add_artifact( """ filename = os.path.abspath(filename) name = os.path.basename(filename) if name is None else name - self._emit_artifact_added(name, filename, metadata, content_type) + self._emit_artifact_added(name, filename, recursive, metadata, content_type) def __call__(self, *args): r"""Start this run. @@ -385,13 +386,18 @@ def _emit_resource_added(self, filename): for observer in self.observers: self._safe_call(observer, 'resource_event', filename=filename) - def _emit_artifact_added(self, name, filename, metadata, content_type): + def _emit_artifact_added(self, name, filename, recursive, metadata, content_type): for observer in self.observers: - self._safe_call(observer, 'artifact_event', - name=name, - filename=filename, - metadata=metadata, - content_type=content_type) + if recursive: + self._safe_call(observer, 'artifact_directory_event', + name=name, + filename=filename) + else: + self._safe_call(observer, 'artifact_event', + name=name, + filename=filename, + metadata=metadata, + content_type=content_type) def _safe_call(self, obs, method, **kwargs): if obs not in self._failed_observers and hasattr(obs, method): From 4d18ad36aa9affe72d53cba5586a63f0c261f346 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Fri, 26 Jul 2019 15:38:55 -0700 Subject: [PATCH 02/45] basic tests for s3 observer --- sacred/observers/s3_observer.py | 26 ++++- tests/test_observers/test_s3_observer.py | 116 +++++++++++++++++++++++ 2 files changed, 140 insertions(+), 2 deletions(-) create mode 100644 tests/test_observers/test_s3_observer.py diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index a24f7f77..9a4ea43d 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -16,9 +16,28 @@ from sacred import optional as opt from sacred.serializer import flatten import re +import socket DEFAULT_S3_PRIORITY = 20 +def _is_valid_bucket(bucket_name): + if len(bucket_name) < 3 or len(bucket_name) > 63: + return False + if '..' in bucket_name or '.-' in bucket_name or '-.' in bucket_name: + return False + for char in bucket_name: + if char.isdigit(): + continue + if char.islower(): + continue + if char == '-': + continue + return False + try: + socket.inet_aton(bucket_name) + except: + # congrats, you're a valid bucket name + return True class S3FileObserver(RunObserver): VERSION = 'S3FileObserver-0.1.0' @@ -33,6 +52,8 @@ def create(cls, bucket, basedir, resource_dir=None, source_dir=None, priority=DE def __init__(self, bucket, basedir, resource_dir, source_dir, priority=DEFAULT_S3_PRIORITY): + if not _is_valid_bucket(bucket): + raise ValueError("Your chosen bucket name does not follow AWS rules. Consult here to see the requirements: https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html") self.basedir = basedir self.bucket = bucket self.resource_dir = resource_dir @@ -74,7 +95,7 @@ def _create_bucket(self): def _determine_run_dir(self, _id): bucket_path_subdirs = self._list_s3_subdirs() - if bucket_path_subdirs is None: + if bucket_path_subdirs is None or len(bucket_path_subdirs) == 0: self._create_bucket() max_run_id = 0 else: @@ -157,7 +178,7 @@ def find_or_save(self, filename, store_dir): return store_path, md5sum def put_data(self, key, binary_data): - self.s3.Object(self.bucket, key).put(binary_data) + self.s3.Object(self.bucket, key).put(Body=binary_data) def save_json(self, obj, filename): key = os.path.join(self.run_dir, filename) @@ -167,6 +188,7 @@ def save_json(self, obj, filename): def save_file(self, filename, target_name=None): target_name = target_name or os.path.basename(filename) key = os.path.join(self.run_dir, target_name) + ##import pdb; pdb.set_trace() self.put_data(key, open(filename, 'rb')) def save_directory(self, source_dir, target_name): diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py new file mode 100644 index 00000000..2927de29 --- /dev/null +++ b/tests/test_observers/test_s3_observer.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python +# coding=utf-8 + +import datetime +import hashlib +import os +import tempfile +from copy import copy +import pytest +import json + +from sacred.observers import S3FileObserver +from sacred.metrics_logger import ScalarMetricLogEntry, linearize_metrics + +import boto3 +from botocore.exceptions import ClientError + +T1 = datetime.datetime(1999, 5, 4, 3, 2, 1, 0) +T2 = datetime.datetime(1999, 5, 5, 5, 5, 5, 5) + +BUCKET = 'pytest-s3-observer-bucket' +BASEDIR = 'some-tests' + +@pytest.fixture() +def sample_run(): + exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} + host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'} + config = {'config': 'True', 'foo': 'bar', 'answer': 42} + command = 'run' + meta_info = {'comment': 'test run'} + return { + '_id': 'FEDCBA9876543210', + 'ex_info': exp, + 'command': command, + 'host_info': host, + 'start_time': T1, + 'config': config, + 'meta_info': meta_info, + } + + +@pytest.fixture() +def dir_obs(): + return S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) + + +""" +Test that reusing the same bucket name doesn't recreate the bucket, + but instead reuses it (check if both _ids went to the same bucket) +Test failing gracefully if you pass in a disallowed S3 bucket name + + + +Is it possible to set up a test with and without a valid credentials file? + I guess you can save ~/.aws/config and ~/.aws/credentials +""" +def _delete_bucket(bucket_name): + s3 = boto3.resource('s3') + bucket = s3.Bucket(bucket_name) + for key in bucket.objects.all(): + key.delete() + bucket.delete() + +def _bucket_exists(bucket_name): + s3 = boto3.resource('s3') + try: + s3.meta.client.head_bucket(Bucket=bucket_name) + except ClientError as e: + if e.response['Error']['Code'] == '404': + return False + return True + +def _key_exists(bucket_name, key): + s3 = boto3.resource('s3') + try: + s3.Object(bucket_name, key).load() + except ClientError as e: + if e.response['Error']['Code'] == '404': + return False + return True + +def _get_file_data(bucket_name, key): + s3 = boto3.resource('s3') + return s3.Object(bucket_name, key).get()['Body'].read() + +def test_fs_observer_started_event_creates_bucket(dir_obs, sample_run): + observer = dir_obs + sample_run['_id'] = None + _id = observer.started_event(**sample_run) + run_dir = os.path.join(BASEDIR, str(_id)) + + assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, 'cout.txt')) + config = _get_file_data(bucket_name=BUCKET, key=os.path.join(run_dir, 'config.json')) + + assert json.loads(config) == sample_run['config'] + run = _get_file_data(bucket_name=BUCKET, key=os.path.join(run_dir, 'run.json')) + assert json.loads(run) == { + 'experiment': sample_run['ex_info'], + 'command': sample_run['command'], + 'host': sample_run['host_info'], + 'start_time': T1.isoformat(), + 'heartbeat': None, + 'meta': sample_run['meta_info'], + "resources": [], + "artifacts": [], + "status": "RUNNING" + } + _delete_bucket(BUCKET) + +def test_fs_observer_started_event_increments_run_id(dir_obs, sample_run): + observer = dir_obs + sample_run['_id'] = None + _id = observer.started_event(**sample_run) + _id2 = observer.started_event(**sample_run) + assert _id + 1 == _id2 + _delete_bucket(BUCKET) From 45243645e29828fbed20e7bcfe8a04ea5f150c3d Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Fri, 26 Jul 2019 15:40:01 -0700 Subject: [PATCH 03/45] remove unused imports --- tests/test_observers/test_s3_observer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 2927de29..ccb2ff46 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -2,15 +2,11 @@ # coding=utf-8 import datetime -import hashlib import os -import tempfile -from copy import copy import pytest import json from sacred.observers import S3FileObserver -from sacred.metrics_logger import ScalarMetricLogEntry, linearize_metrics import boto3 from botocore.exceptions import ClientError @@ -79,10 +75,12 @@ def _key_exists(bucket_name, key): return False return True + def _get_file_data(bucket_name, key): s3 = boto3.resource('s3') return s3.Object(bucket_name, key).get()['Body'].read() + def test_fs_observer_started_event_creates_bucket(dir_obs, sample_run): observer = dir_obs sample_run['_id'] = None @@ -107,6 +105,7 @@ def test_fs_observer_started_event_creates_bucket(dir_obs, sample_run): } _delete_bucket(BUCKET) + def test_fs_observer_started_event_increments_run_id(dir_obs, sample_run): observer = dir_obs sample_run['_id'] = None From 12bb9cca374182563f0dd49c3c57dae6e3d88f73 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Fri, 26 Jul 2019 17:31:30 -0700 Subject: [PATCH 04/45] fix format string --- sacred/observers/s3_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 9a4ea43d..3d86bbbf 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -80,7 +80,7 @@ def _list_s3_subdirs(self, prefix=None): else: raise ClientError(er.response['Error']['Code']) - subdir_match = f'{prefix}\/(.*)\/' + subdir_match = '{prefix}\/(.*)\/'.format(prefix=prefix) distinct_subdirs = set([re.match(subdir_match, key).groups()[0] for key in all_keys]) return list(distinct_subdirs) From f687d9d13ae51513fec39a6d65c13e57c3f5c630 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Mon, 29 Jul 2019 10:18:46 -0700 Subject: [PATCH 05/45] Add boto3 to requirements for testing purposes --- requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 1ef969e5..a380191f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ mock>=2.0.0 munch>=2.0.4 pbr>=1.10.0 wrapt>=1.10.8 -packaging>=18.0 \ No newline at end of file +packaging>=18.0 +boto3>=1.9.0 \ No newline at end of file From bb3ecf54b178487f3b84dc1fa4427296f6bb3132 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Mon, 29 Jul 2019 14:50:57 -0700 Subject: [PATCH 06/45] fix flake8 issues and add moto for test mocking --- dev-requirements.txt | 3 +- sacred/experiment.py | 3 +- sacred/observers/s3_observer.py | 35 +++++++++++++----------- sacred/run.py | 6 ++-- tests/test_observers/test_s3_observer.py | 23 ++++++++++++---- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 003067be..9b7d8baa 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -28,4 +28,5 @@ wrapt==1.10.8 scikit-learn==0.20.3 pymongo==3.8.0 py-cpuinfo==4.0 - +boto3>=1.9.0 +moto>=1.3.13 \ No newline at end of file diff --git a/sacred/experiment.py b/sacred/experiment.py index 4e1a6289..3a3e6cd9 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -360,7 +360,8 @@ def add_artifact( This only has an effect when using the MongoObserver. """ assert self.current_run is not None, "Can only be called during a run." - self.current_run.add_artifact(filename, name, recursive, metadata, content_type) + self.current_run.add_artifact(filename, name, recursive, metadata, + content_type) @property def info(self): diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 3d86bbbf..d99c6752 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -8,18 +8,16 @@ import boto3 from botocore.errorfactory import ClientError -from shutil import copyfile - from sacred.commandline_options import CommandLineOption from sacred.dependencies import get_digest from sacred.observers.base import RunObserver -from sacred import optional as opt from sacred.serializer import flatten import re import socket DEFAULT_S3_PRIORITY = 20 + def _is_valid_bucket(bucket_name): if len(bucket_name) < 3 or len(bucket_name) > 63: return False @@ -39,21 +37,24 @@ def _is_valid_bucket(bucket_name): # congrats, you're a valid bucket name return True + class S3FileObserver(RunObserver): VERSION = 'S3FileObserver-0.1.0' @classmethod - def create(cls, bucket, basedir, resource_dir=None, source_dir=None, priority=DEFAULT_S3_PRIORITY): + def create(cls, bucket, basedir, resource_dir=None, source_dir=None, + priority=DEFAULT_S3_PRIORITY): resource_dir = resource_dir or os.path.join(basedir, '_resources') source_dir = source_dir or os.path.join(basedir, '_sources') return cls(bucket, basedir, resource_dir, source_dir, priority) - def __init__(self, bucket, basedir, resource_dir, source_dir, priority=DEFAULT_S3_PRIORITY): if not _is_valid_bucket(bucket): - raise ValueError("Your chosen bucket name does not follow AWS rules. Consult here to see the requirements: https://docs.aws.amazon.com/AmazonS3/latest/dev/BucketRestrictions.html") + raise ValueError("Your chosen bucket name does not follow AWS " + "bucket naming rules") + self.basedir = basedir self.bucket = bucket self.resource_dir = resource_dir @@ -80,8 +81,9 @@ def _list_s3_subdirs(self, prefix=None): else: raise ClientError(er.response['Error']['Code']) - subdir_match = '{prefix}\/(.*)\/'.format(prefix=prefix) - distinct_subdirs = set([re.match(subdir_match, key).groups()[0] for key in all_keys]) + subdir_match = r'{prefix}\/(.*)\/'.format(prefix=prefix) + distinct_subdirs = set([re.match(subdir_match, key).groups()[0] for + key in all_keys]) return list(distinct_subdirs) def _create_bucket(self): @@ -99,7 +101,8 @@ def _determine_run_dir(self, _id): self._create_bucket() max_run_id = 0 else: - max_run_id = max([int(d) for d in bucket_path_subdirs if d.isdigit()]) + max_run_id = max([int(d) for d in bucket_path_subdirs + if d.isdigit()]) self.dir = None if _id is None: @@ -188,11 +191,11 @@ def save_json(self, obj, filename): def save_file(self, filename, target_name=None): target_name = target_name or os.path.basename(filename) key = os.path.join(self.run_dir, target_name) - ##import pdb; pdb.set_trace() self.put_data(key, open(filename, 'rb')) def save_directory(self, source_dir, target_name): - # Stolen from: https://github.com/boto/boto3/issues/358#issuecomment-346093506 + # Stolen from: + # https://github.com/boto/boto3/issues/358#issuecomment-346093506 target_name = target_name or os.path.basename(source_dir) all_files = [] for root, dirs, files in os.walk(source_dir): @@ -200,9 +203,10 @@ def save_directory(self, source_dir, target_name): s3_resource = boto3.resource('s3') for filename in all_files: + file_location = os.path.join(self.run_dir, target_name, + os.path.relpath(filename, source_dir)) s3_resource.Object(self.bucket, - os.path.join(self.run_dir, target_name, os.path.relpath(filename, source_dir))) \ - .put(Body=open(filename, 'rb')) + file_location).put(Body=open(filename, 'rb')) def save_cout(self): binary_data = self.cout[self.cout_write_cursor:].encode("utf-8") @@ -210,7 +214,6 @@ def save_cout(self): self.put_data(key, binary_data) self.cout_write_cursor = len(self.cout) - def heartbeat_event(self, info, captured_out, beat_time, result): self.info = info self.run_entry['heartbeat'] = beat_time.isoformat() @@ -262,8 +265,8 @@ def log_metrics(self, metrics_by_name, info): if metric_name not in self.saved_metrics: self.saved_metrics[metric_name] = {"values": [], - "steps": [], - "timestamps": []} + "steps": [], + "timestamps": []} self.saved_metrics[metric_name]["values"] += metric_ptr["values"] self.saved_metrics[metric_name]["steps"] += metric_ptr["steps"] diff --git a/sacred/run.py b/sacred/run.py index 6c9ad0a9..99275a04 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -188,7 +188,8 @@ def add_artifact( """ filename = os.path.abspath(filename) name = os.path.basename(filename) if name is None else name - self._emit_artifact_added(name, filename, recursive, metadata, content_type) + self._emit_artifact_added(name, filename, recursive, + metadata, content_type) def __call__(self, *args): r"""Start this run. @@ -386,7 +387,8 @@ def _emit_resource_added(self, filename): for observer in self.observers: self._safe_call(observer, 'resource_event', filename=filename) - def _emit_artifact_added(self, name, filename, recursive, metadata, content_type): + def _emit_artifact_added(self, name, filename, recursive, metadata, + content_type): for observer in self.observers: if recursive: self._safe_call(observer, 'artifact_directory_event', diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index ccb2ff46..73ce5960 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # coding=utf-8 +from moto import mock_s3 + import datetime import os import pytest @@ -17,6 +19,9 @@ BUCKET = 'pytest-s3-observer-bucket' BASEDIR = 'some-tests' +# how long does mock_s3 have memory for? If it's only a single test, that's bad + + @pytest.fixture() def sample_run(): exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} @@ -35,7 +40,7 @@ def sample_run(): } -@pytest.fixture() +@pytest.fixture def dir_obs(): return S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) @@ -50,6 +55,7 @@ def dir_obs(): Is it possible to set up a test with and without a valid credentials file? I guess you can save ~/.aws/config and ~/.aws/credentials """ + def _delete_bucket(bucket_name): s3 = boto3.resource('s3') bucket = s3.Bucket(bucket_name) @@ -57,6 +63,7 @@ def _delete_bucket(bucket_name): key.delete() bucket.delete() + def _bucket_exists(bucket_name): s3 = boto3.resource('s3') try: @@ -66,6 +73,7 @@ def _bucket_exists(bucket_name): return False return True + def _key_exists(bucket_name, key): s3 = boto3.resource('s3') try: @@ -81,13 +89,19 @@ def _get_file_data(bucket_name, key): return s3.Object(bucket_name, key).get()['Body'].read() +@mock_s3 def test_fs_observer_started_event_creates_bucket(dir_obs, sample_run): observer = dir_obs sample_run['_id'] = None _id = observer.started_event(**sample_run) run_dir = os.path.join(BASEDIR, str(_id)) - - assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, 'cout.txt')) + assert _bucket_exists(bucket_name=BUCKET) + assert _key_exists(bucket_name=BUCKET, + key=os.path.join(run_dir, 'cout.txt')) + assert _key_exists(bucket_name=BUCKET, + key=os.path.join(run_dir, 'config.json')) + assert _key_exists(bucket_name=BUCKET, + key=os.path.join(run_dir, 'run.json')) config = _get_file_data(bucket_name=BUCKET, key=os.path.join(run_dir, 'config.json')) assert json.loads(config) == sample_run['config'] @@ -103,13 +117,12 @@ def test_fs_observer_started_event_creates_bucket(dir_obs, sample_run): "artifacts": [], "status": "RUNNING" } - _delete_bucket(BUCKET) +@mock_s3 def test_fs_observer_started_event_increments_run_id(dir_obs, sample_run): observer = dir_obs sample_run['_id'] = None _id = observer.started_event(**sample_run) _id2 = observer.started_event(**sample_run) assert _id + 1 == _id2 - _delete_bucket(BUCKET) From b688cb6a3d5fb845cd7acda1250e955b91b53527 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Mon, 29 Jul 2019 15:10:14 -0700 Subject: [PATCH 07/45] hopefully fix issue with google compute engine --- .travis.yml | 2 ++ tests/test_observers/test_s3_observer.py | 3 --- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 98e0e8e7..2edc7cc2 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,8 @@ language: python sudo: false dist: xenial +before_install: + - sudo rm -f /etc/boto.cfg install: - pip install tox matrix: diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 73ce5960..316c87ef 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -19,9 +19,6 @@ BUCKET = 'pytest-s3-observer-bucket' BASEDIR = 'some-tests' -# how long does mock_s3 have memory for? If it's only a single test, that's bad - - @pytest.fixture() def sample_run(): exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} From 5de6e42cc2257d3fb21b46413b7ff089512ee5dd Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Mon, 29 Jul 2019 15:18:31 -0700 Subject: [PATCH 08/45] give up and install google-compute-engine --- .travis.yml | 2 -- dev-requirements.txt | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 2edc7cc2..98e0e8e7 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,6 @@ language: python sudo: false dist: xenial -before_install: - - sudo rm -f /etc/boto.cfg install: - pip install tox matrix: diff --git a/dev-requirements.txt b/dev-requirements.txt index 9b7d8baa..96ba873a 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -29,4 +29,5 @@ scikit-learn==0.20.3 pymongo==3.8.0 py-cpuinfo==4.0 boto3>=1.9.0 -moto>=1.3.13 \ No newline at end of file +moto>=1.3.13 +google-compute-engine>=2.8.0 \ No newline at end of file From 00276624f2cceb95c9c2c71624fb79f7b56b6ed5 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Mon, 29 Jul 2019 15:23:16 -0700 Subject: [PATCH 09/45] add default location --- sacred/observers/s3_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index d99c6752..8bcd6d20 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -88,7 +88,7 @@ def _list_s3_subdirs(self, prefix=None): def _create_bucket(self): session = boto3.session.Session() - current_region = session.region_name + current_region = session.region_name or 'us-west-2' bucket_response = self.s3.create_bucket( Bucket=self.bucket, CreateBucketConfiguration={ From d1d7bae2292016f827d9127f3df848c571c697b3 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Mon, 29 Jul 2019 15:30:07 -0700 Subject: [PATCH 10/45] add decode utf to enforce python 3.5 compatibility --- setup.py | 1 + tests/test_observers/test_s3_observer.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 67dca023..087141bd 100755 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ 'py-cpuinfo>=4.0', 'colorama>=0.4', 'packaging>=18.0', + 'boto3>=1.9.0' ], tests_require=[ 'mock>=0.8, <3.0', diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 316c87ef..e86ab529 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -99,11 +99,13 @@ def test_fs_observer_started_event_creates_bucket(dir_obs, sample_run): key=os.path.join(run_dir, 'config.json')) assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, 'run.json')) - config = _get_file_data(bucket_name=BUCKET, key=os.path.join(run_dir, 'config.json')) + config = _get_file_data(bucket_name=BUCKET, + key=os.path.join(run_dir, 'config.json')) - assert json.loads(config) == sample_run['config'] - run = _get_file_data(bucket_name=BUCKET, key=os.path.join(run_dir, 'run.json')) - assert json.loads(run) == { + assert json.loads(config.decode('utf-8')) == sample_run['config'] + run = _get_file_data(bucket_name=BUCKET, + key=os.path.join(run_dir, 'run.json')) + assert json.loads(run.decode('utf-8')) == { 'experiment': sample_run['ex_info'], 'command': sample_run['command'], 'host': sample_run['host_info'], From 5b993c28ef9c33818093f2e5cae1167cf8373d88 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 30 Jul 2019 11:58:38 -0700 Subject: [PATCH 11/45] Clean up and add more tests --- sacred/observers/s3_observer.py | 77 +++++++++++++++++------- tests/test_observers/test_s3_observer.py | 54 ++++++++--------- 2 files changed, 83 insertions(+), 48 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 8bcd6d20..1f3139c2 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -33,12 +33,14 @@ def _is_valid_bucket(bucket_name): return False try: socket.inet_aton(bucket_name) - except: + except socket.error: # congrats, you're a valid bucket name return True class S3FileObserver(RunObserver): + ## TODO (possibly): make S3FileObserver inherit from FSO to avoid + ## duplicating code. But this might be even messier? VERSION = 'S3FileObserver-0.1.0' @classmethod @@ -69,6 +71,17 @@ def __init__(self, bucket, basedir, resource_dir, source_dir, self.s3 = boto3.resource('s3') self.saved_metrics = {} + def _objects_exist_in_dir(self, prefix): + try: + bucket = self.s3.Bucket(self.bucket) + all_keys = [el.key for el in bucket.objects.filter(Prefix=prefix)] + except ClientError as er: + if er.response['Error']['Code'] == 'NoSuchBucket': + return None + else: + raise ClientError(er.response['Error']['Code']) + return len(all_keys) > 0 + def _list_s3_subdirs(self, prefix=None): if prefix is None: prefix = self.basedir @@ -82,8 +95,15 @@ def _list_s3_subdirs(self, prefix=None): raise ClientError(er.response['Error']['Code']) subdir_match = r'{prefix}\/(.*)\/'.format(prefix=prefix) - distinct_subdirs = set([re.match(subdir_match, key).groups()[0] for - key in all_keys]) + subdirs = [] + for key in all_keys: + match_obj = re.match(subdir_match, key) + if match_obj is None: + import pdb; pdb.set_trace() + continue + else: + subdirs.append(match_obj.groups()[0]) + distinct_subdirs = set(subdirs) return list(distinct_subdirs) def _create_bucket(self): @@ -96,19 +116,26 @@ def _create_bucket(self): return bucket_response def _determine_run_dir(self, _id): - bucket_path_subdirs = self._list_s3_subdirs() - if bucket_path_subdirs is None or len(bucket_path_subdirs) == 0: - self._create_bucket() - max_run_id = 0 - else: - max_run_id = max([int(d) for d in bucket_path_subdirs - if d.isdigit()]) - - self.dir = None if _id is None: + bucket_path_subdirs = self._list_s3_subdirs() + if bucket_path_subdirs is None: + self._create_bucket() + + if bucket_path_subdirs is None or len(bucket_path_subdirs) == 0: + max_run_id = 0 + else: + integer_directories = [int(d) for d in bucket_path_subdirs + if d.isdigit()] + if len(integer_directories) == 0: + max_run_id = 0 + else: + max_run_id = max(integer_directories) + _id = max_run_id + 1 - self.run_dir = os.path.join(self.basedir, str(_id)) + self.dir = os.path.join(self.basedir, str(_id)) + if self._objects_exist_in_dir(self.dir): + raise FileExistsError(f"S3 dir at {self.dir} already exists") return _id def queued_event(self, ex_info, command, host_info, queue_time, config, @@ -184,13 +211,13 @@ def put_data(self, key, binary_data): self.s3.Object(self.bucket, key).put(Body=binary_data) def save_json(self, obj, filename): - key = os.path.join(self.run_dir, filename) + key = os.path.join(self.dir, filename) self.put_data(key, json.dumps(flatten(obj), sort_keys=True, indent=2)) def save_file(self, filename, target_name=None): target_name = target_name or os.path.basename(filename) - key = os.path.join(self.run_dir, target_name) + key = os.path.join(self.dir, target_name) self.put_data(key, open(filename, 'rb')) def save_directory(self, source_dir, target_name): @@ -203,17 +230,19 @@ def save_directory(self, source_dir, target_name): s3_resource = boto3.resource('s3') for filename in all_files: - file_location = os.path.join(self.run_dir, target_name, + file_location = os.path.join(self.dir, target_name, os.path.relpath(filename, source_dir)) s3_resource.Object(self.bucket, file_location).put(Body=open(filename, 'rb')) def save_cout(self): binary_data = self.cout[self.cout_write_cursor:].encode("utf-8") - key = os.path.join(self.run_dir, 'cout.txt') + key = os.path.join(self.dir, 'cout.txt') self.put_data(key, binary_data) self.cout_write_cursor = len(self.cout) + + ## same as FSO def heartbeat_event(self, info, captured_out, beat_time, result): self.info = info self.run_entry['heartbeat'] = beat_time.isoformat() @@ -271,8 +300,6 @@ def log_metrics(self, metrics_by_name, info): self.saved_metrics[metric_name]["values"] += metric_ptr["values"] self.saved_metrics[metric_name]["steps"] += metric_ptr["steps"] - # Manually convert them to avoid passing a datetime dtype handler - # when we're trying to convert into json. timestamps_norm = [ts.isoformat() for ts in metric_ptr["timestamps"]] self.saved_metrics[metric_name]["timestamps"] += timestamps_norm @@ -281,7 +308,8 @@ def log_metrics(self, metrics_by_name, info): def __eq__(self, other): if isinstance(other, S3FileObserver): - return self.basedir == other.basedir + return (self.bucket == other.bucket + and self.basedir == other.basedir) return False @@ -294,4 +322,11 @@ class S3StorageOption(CommandLineOption): @classmethod def apply(cls, args, run): - run.observers.append(S3FileObserver.create(args)) + match_obj = re.match(r's3:\/\/([^\/]*)\/(.*)', args) + if match_obj is None or len(match_obj.groups()) != 2: + raise ValueError("Valid bucket specification not found. " + "Enter bucket and directory path like: " + "s3:///path/to/exp") + bucket, basedir = match_obj.groups() + run.observers.append(S3FileObserver.create(bucket=bucket, + basedir=basedir)) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index e86ab529..ea13e7be 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -27,7 +27,7 @@ def sample_run(): command = 'run' meta_info = {'comment': 'test run'} return { - '_id': 'FEDCBA9876543210', + '_id': None, 'ex_info': exp, 'command': command, 'host_info': host, @@ -38,29 +38,10 @@ def sample_run(): @pytest.fixture -def dir_obs(): +def observer(): return S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) -""" -Test that reusing the same bucket name doesn't recreate the bucket, - but instead reuses it (check if both _ids went to the same bucket) -Test failing gracefully if you pass in a disallowed S3 bucket name - - - -Is it possible to set up a test with and without a valid credentials file? - I guess you can save ~/.aws/config and ~/.aws/credentials -""" - -def _delete_bucket(bucket_name): - s3 = boto3.resource('s3') - bucket = s3.Bucket(bucket_name) - for key in bucket.objects.all(): - key.delete() - bucket.delete() - - def _bucket_exists(bucket_name): s3 = boto3.resource('s3') try: @@ -87,9 +68,7 @@ def _get_file_data(bucket_name, key): @mock_s3 -def test_fs_observer_started_event_creates_bucket(dir_obs, sample_run): - observer = dir_obs - sample_run['_id'] = None +def test_fs_observer_started_event_creates_bucket(observer, sample_run): _id = observer.started_event(**sample_run) run_dir = os.path.join(BASEDIR, str(_id)) assert _bucket_exists(bucket_name=BUCKET) @@ -119,9 +98,30 @@ def test_fs_observer_started_event_creates_bucket(dir_obs, sample_run): @mock_s3 -def test_fs_observer_started_event_increments_run_id(dir_obs, sample_run): - observer = dir_obs - sample_run['_id'] = None +def test_fs_observer_started_event_increments_run_id(observer, sample_run): _id = observer.started_event(**sample_run) _id2 = observer.started_event(**sample_run) assert _id + 1 == _id2 + + +def test_s3_observer_equality(): + obs_one = S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) + obs_two = S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) + different_basedir = S3FileObserver.create(bucket=BUCKET, + basedir="another/dir") + assert obs_one == obs_two + assert obs_one != different_basedir + + +@mock_s3 +def test_raises_error_on_duplicate_id_directory(observer, sample_run): + observer.started_event(**sample_run) + sample_run['_id'] = 1 + with pytest.raises(FileExistsError): + observer.started_event(**sample_run) + + +def test_raises_error_on_invalid_bucket_name(): + with pytest.raises(ValueError): + _ = S3FileObserver.create(bucket="this_bucket_is_invalid", + basedir=BASEDIR) From 5b44c48d5ae6435c3d8a0d8302e451fa8053aed3 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 30 Jul 2019 13:17:46 -0700 Subject: [PATCH 12/45] fix format string --- sacred/observers/s3_observer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 1f3139c2..9e97bd0b 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -135,7 +135,8 @@ def _determine_run_dir(self, _id): self.dir = os.path.join(self.basedir, str(_id)) if self._objects_exist_in_dir(self.dir): - raise FileExistsError(f"S3 dir at {self.dir} already exists") + raise FileExistsError( + "S3 dir at {} already exists".format(self.dir)) return _id def queued_event(self, ex_info, command, host_info, queue_time, config, From 25d26c93855bfae2c410b7d293ba294c704052b3 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 30 Jul 2019 13:19:55 -0700 Subject: [PATCH 13/45] flake8 fixes --- sacred/observers/s3_observer.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 9e97bd0b..5dc5e526 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -39,8 +39,6 @@ def _is_valid_bucket(bucket_name): class S3FileObserver(RunObserver): - ## TODO (possibly): make S3FileObserver inherit from FSO to avoid - ## duplicating code. But this might be even messier? VERSION = 'S3FileObserver-0.1.0' @classmethod @@ -99,7 +97,6 @@ def _list_s3_subdirs(self, prefix=None): for key in all_keys: match_obj = re.match(subdir_match, key) if match_obj is None: - import pdb; pdb.set_trace() continue else: subdirs.append(match_obj.groups()[0]) @@ -125,7 +122,7 @@ def _determine_run_dir(self, _id): max_run_id = 0 else: integer_directories = [int(d) for d in bucket_path_subdirs - if d.isdigit()] + if d.isdigit()] if len(integer_directories) == 0: max_run_id = 0 else: @@ -242,8 +239,6 @@ def save_cout(self): self.put_data(key, binary_data) self.cout_write_cursor = len(self.cout) - - ## same as FSO def heartbeat_event(self, info, captured_out, beat_time, result): self.info = info self.run_entry['heartbeat'] = beat_time.isoformat() @@ -309,8 +304,8 @@ def log_metrics(self, metrics_by_name, info): def __eq__(self, other): if isinstance(other, S3FileObserver): - return (self.bucket == other.bucket - and self.basedir == other.basedir) + return (self.bucket == other.bucket and + self.basedir == other.basedir) return False From 0c20f665b867902b81f421a22c2f1dbfe009d071 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 30 Jul 2019 14:32:38 -0700 Subject: [PATCH 14/45] remove recursive artifacts --- sacred/experiment.py | 4 +--- sacred/observers/s3_observer.py | 28 ++++++++++++++-------------- sacred/run.py | 21 +++++++-------------- 3 files changed, 22 insertions(+), 31 deletions(-) diff --git a/sacred/experiment.py b/sacred/experiment.py index 3a3e6cd9..aa1014ef 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -333,7 +333,6 @@ def add_artifact( filename, name=None, metadata=None, - recursive=False, content_type=None, ): """Add a file as an artifact. @@ -360,8 +359,7 @@ def add_artifact( This only has an effect when using the MongoObserver. """ assert self.current_run is not None, "Can only be called during a run." - self.current_run.add_artifact(filename, name, recursive, metadata, - content_type) + self.current_run.add_artifact(filename, name, metadata, content_type) @property def info(self): diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 5dc5e526..3f512b17 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -19,6 +19,8 @@ def _is_valid_bucket(bucket_name): + # See https://docs.aws.amazon.com/awscloudtrail/latest/userguide/ + # cloudtrail-s3-bucket-naming-requirements.html if len(bucket_name) < 3 or len(bucket_name) > 63: return False if '..' in bucket_name or '.-' in bucket_name or '-.' in bucket_name: @@ -32,9 +34,9 @@ def _is_valid_bucket(bucket_name): continue return False try: + # If a name is a valid IP address, it cannot be a bucket name socket.inet_aton(bucket_name) except socket.error: - # congrats, you're a valid bucket name return True @@ -57,6 +59,8 @@ def __init__(self, bucket, basedir, resource_dir, source_dir, self.basedir = basedir self.bucket = bucket + # Keeping the convention of referring to locations in S3 as `dir` + # because that is a useful mental model and there isn't a better word self.resource_dir = resource_dir self.source_dir = source_dir self.priority = priority @@ -70,14 +74,11 @@ def __init__(self, bucket, basedir, resource_dir, source_dir, self.saved_metrics = {} def _objects_exist_in_dir(self, prefix): - try: - bucket = self.s3.Bucket(self.bucket) - all_keys = [el.key for el in bucket.objects.filter(Prefix=prefix)] - except ClientError as er: - if er.response['Error']['Code'] == 'NoSuchBucket': - return None - else: - raise ClientError(er.response['Error']['Code']) + # This should be run after you've confirmed the bucket + # exists, and will error out if it does not exist + + bucket = self.s3.Bucket(self.bucket) + all_keys = [el.key for el in bucket.objects.filter(Prefix=prefix)] return len(all_keys) > 0 def _list_s3_subdirs(self, prefix=None): @@ -114,7 +115,9 @@ def _create_bucket(self): def _determine_run_dir(self, _id): if _id is None: + # Get all existing subdirectories under s3://bucket/basedir/ bucket_path_subdirs = self._list_s3_subdirs() + # _list_s3_subdirs returns None when the bucket doesn't exist if bucket_path_subdirs is None: self._create_bucket() @@ -126,6 +129,8 @@ def _determine_run_dir(self, _id): if len(integer_directories) == 0: max_run_id = 0 else: + # If there are directories under basedir that aren't + # run directories, ignore those max_run_id = max(integer_directories) _id = max_run_id + 1 @@ -277,11 +282,6 @@ def artifact_event(self, name, filename, metadata=None, content_type=None): self.run_entry['artifacts'].append(name) self.save_json(self.run_entry, 'run.json') - def artifact_directory_event(self, name, filename): - self.save_directory(filename, name) - self.run_entry['artifacts'].append(name + "/") - self.save_json(self.run_entry, 'run.json') - def log_metrics(self, metrics_by_name, info): """Store new measurements into metrics.json. """ diff --git a/sacred/run.py b/sacred/run.py index 99275a04..9b73518f 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -160,7 +160,6 @@ def add_artifact( self, filename, name=None, - recursive=False, metadata=None, content_type=None, ): @@ -188,8 +187,7 @@ def add_artifact( """ filename = os.path.abspath(filename) name = os.path.basename(filename) if name is None else name - self._emit_artifact_added(name, filename, recursive, - metadata, content_type) + self._emit_artifact_added(name, filename, metadata, content_type) def __call__(self, *args): r"""Start this run. @@ -387,19 +385,14 @@ def _emit_resource_added(self, filename): for observer in self.observers: self._safe_call(observer, 'resource_event', filename=filename) - def _emit_artifact_added(self, name, filename, recursive, metadata, + def _emit_artifact_added(self, name, filename, metadata, content_type): for observer in self.observers: - if recursive: - self._safe_call(observer, 'artifact_directory_event', - name=name, - filename=filename) - else: - self._safe_call(observer, 'artifact_event', - name=name, - filename=filename, - metadata=metadata, - content_type=content_type) + self._safe_call(observer, 'artifact_event', + name=name, + filename=filename, + metadata=metadata, + content_type=content_type) def _safe_call(self, obs, method, **kwargs): if obs not in self._failed_observers and hasattr(obs, method): From 883ced1305b39dac1b54e468bcefe62d16f4abf5 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 30 Jul 2019 14:46:23 -0700 Subject: [PATCH 15/45] remove newline --- sacred/run.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sacred/run.py b/sacred/run.py index 9b73518f..10cc6d25 100755 --- a/sacred/run.py +++ b/sacred/run.py @@ -385,8 +385,7 @@ def _emit_resource_added(self, filename): for observer in self.observers: self._safe_call(observer, 'resource_event', filename=filename) - def _emit_artifact_added(self, name, filename, metadata, - content_type): + def _emit_artifact_added(self, name, filename, metadata, content_type): for observer in self.observers: self._safe_call(observer, 'artifact_event', name=name, From a8c3f59cf25fa300f97ff46d36c5b1cc56181f52 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 30 Jul 2019 15:48:16 -0700 Subject: [PATCH 16/45] add more test coverage --- tests/test_observers/test_s3_observer.py | 102 +++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index ea13e7be..73e32060 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -9,6 +9,8 @@ import json from sacred.observers import S3FileObserver +import tempfile +import hashlib import boto3 from botocore.exceptions import ClientError @@ -19,6 +21,7 @@ BUCKET = 'pytest-s3-observer-bucket' BASEDIR = 'some-tests' + @pytest.fixture() def sample_run(): exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} @@ -42,6 +45,27 @@ def observer(): return S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) +@pytest.fixture +def tmpfile(): + # NOTE: instead of using a with block and delete=True we are creating and + # manually deleting the file, such that we can close it before running the + # tests. This is necessary since on Windows we can not open the same file + # twice, so for the FileStorageObserver to read it, we need to close it. + f = tempfile.NamedTemporaryFile(suffix='.py', delete=False) + + f.content = 'import sacred\n' + f.write(f.content.encode()) + f.flush() + f.seek(0) + f.md5sum = hashlib.md5(f.read()).hexdigest() + + f.close() + + yield f + + os.remove(f.name) + + def _bucket_exists(bucket_name): s3 = boto3.resource('s3') try: @@ -121,7 +145,85 @@ def test_raises_error_on_duplicate_id_directory(observer, sample_run): observer.started_event(**sample_run) +@mock_s3 +def test_completed_event_updates_run_json(observer, sample_run): + observer.started_event(**sample_run) + run = json.loads(_get_file_data(bucket_name=BUCKET, + key=os.path.join(observer.dir, + 'run.json')) + .decode('utf-8')) + assert run['status'] == 'RUNNING' + observer.completed_event(T2, 'success!') + run = json.loads(_get_file_data(bucket_name=BUCKET, + key=os.path.join(observer.dir, + 'run.json')) + .decode('utf-8')) + assert run['status'] == 'COMPLETED' + + +@mock_s3 +def test_interrupted_event_updates_run_json(observer, sample_run): + observer.started_event(**sample_run) + run = json.loads(_get_file_data(bucket_name=BUCKET, + key=os.path.join(observer.dir, + 'run.json')) + .decode('utf-8')) + assert run['status'] == 'RUNNING' + observer.interrupted_event(T2, 'SERVER_EXPLODED') + run = json.loads(_get_file_data(bucket_name=BUCKET, + key=os.path.join(observer.dir, + 'run.json')) + .decode('utf-8')) + assert run['status'] == 'SERVER_EXPLODED' + + +@mock_s3 +def test_failed_event_updates_run_json(observer, sample_run): + observer.started_event(**sample_run) + run = json.loads(_get_file_data(bucket_name=BUCKET, + key=os.path.join(observer.dir, + 'run.json')) + .decode('utf-8')) + assert run['status'] == 'RUNNING' + observer.failed_event(T2, 'Everything imaginable went wrong') + run = json.loads(_get_file_data(bucket_name=BUCKET, + key=os.path.join(observer.dir, + 'run.json')) + .decode('utf-8')) + assert run['status'] == 'FAILED' + + +@mock_s3 +def test_queued_event_updates_run_json(observer, sample_run): + del sample_run['start_time'] + sample_run['queue_time'] = T2 + observer.queued_event(**sample_run) + run = json.loads(_get_file_data(bucket_name=BUCKET, + key=os.path.join(observer.dir, + 'run.json')) + .decode('utf-8')) + assert run['status'] == 'QUEUED' + + +@mock_s3 +def test_artifact_event_works(observer, sample_run, tmpfile): + observer.started_event(**sample_run) + observer.artifact_event('test_artifact.py', tmpfile.name) + + assert _key_exists(bucket_name=BUCKET, + key=os.path.join(observer.dir, 'test_artifact.py')) + artifact_data = (_get_file_data(bucket_name=BUCKET, + key=os.path.join(observer.dir, + 'test_artifact.py')) + .decode('utf-8')) + assert artifact_data == tmpfile.content + + def test_raises_error_on_invalid_bucket_name(): with pytest.raises(ValueError): _ = S3FileObserver.create(bucket="this_bucket_is_invalid", basedir=BASEDIR) + + with pytest.raises(ValueError): + _ = S3FileObserver.create(bucket="hi", + basedir=BASEDIR) From 777a7f5d45986442a25452bd374f214f36a8f1ff Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 14:56:00 -0700 Subject: [PATCH 17/45] change observer name, refactor list_s3_subdirs to have an explicit bucket exists method --- sacred/observers/__init__.py | 4 +- sacred/observers/s3_observer.py | 83 ++++++++++++++---------- tests/test_observers/test_s3_observer.py | 35 ++++++---- 3 files changed, 74 insertions(+), 48 deletions(-) diff --git a/sacred/observers/__init__.py b/sacred/observers/__init__.py index a577605f..3aab9f5e 100644 --- a/sacred/observers/__init__.py +++ b/sacred/observers/__init__.py @@ -8,9 +8,9 @@ from sacred.observers.tinydb_hashfs import TinyDbObserver, TinyDbReader from sacred.observers.slack import SlackObserver from sacred.observers.telegram_obs import TelegramObserver -from sacred.observers.s3_observer import S3FileObserver +from sacred.observers.s3_observer import S3Observer __all__ = ('FileStorageObserver', 'RunObserver', 'MongoObserver', 'SqlObserver', 'TinyDbObserver', 'TinyDbReader', - 'SlackObserver', 'TelegramObserver', 'S3FileObserver') + 'SlackObserver', 'TelegramObserver', 'S3Observer') diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 3f512b17..197a9dbe 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -1,6 +1,3 @@ -#!/usr/bin/env python -# coding=utf-8 - import json import os import os.path @@ -23,16 +20,18 @@ def _is_valid_bucket(bucket_name): # cloudtrail-s3-bucket-naming-requirements.html if len(bucket_name) < 3 or len(bucket_name) > 63: return False - if '..' in bucket_name or '.-' in bucket_name or '-.' in bucket_name: - return False - for char in bucket_name: - if char.isdigit(): - continue - if char.islower(): - continue - if char == '-': - continue - return False + + labels = bucket_name.split('.') + # A bucket name consists of "labels" separated by periods + for label in labels: + if len(label) == 0 or label[0] == '-' or label[-1] == '-': + # Labels must be of nonzero length, and cannot begin or end with a hyphen + return False + for char in label: + # Labels can only contain digits, lowercase letters, or hyphens. + # Anything else will fail here + if not (char.isdigit() or char.islower() or char == '-'): + return False try: # If a name is a valid IP address, it cannot be a bucket name socket.inet_aton(bucket_name) @@ -40,12 +39,24 @@ def _is_valid_bucket(bucket_name): return True -class S3FileObserver(RunObserver): - VERSION = 'S3FileObserver-0.1.0' +class S3Observer(RunObserver): + VERSION = 'S3Observer-0.1.0' @classmethod def create(cls, bucket, basedir, resource_dir=None, source_dir=None, priority=DEFAULT_S3_PRIORITY): + """ + A factory method to create a S3Observer object + + :param bucket: The name of the bucket you want to store results in. Doesn't need to contain + `s3://`, but does need to be a valid AWS bucket name + :param basedir: The relative path inside your bucket where you want this experiment + to store results + :param resource_dir: TODO what is this anyway? + :param source_dir: + :param priority: + :return: + """ resource_dir = resource_dir or os.path.join(basedir, '_resources') source_dir = source_dir or os.path.join(basedir, '_sources') @@ -81,18 +92,19 @@ def _objects_exist_in_dir(self, prefix): all_keys = [el.key for el in bucket.objects.filter(Prefix=prefix)] return len(all_keys) > 0 - def _list_s3_subdirs(self, prefix=None): - if prefix is None: - prefix = self.basedir + def _bucket_exists(self): try: - bucket = self.s3.Bucket(self.bucket) - all_keys = [el.key for el in bucket.objects.filter(Prefix=prefix)] + self.s3.meta.client.head_bucket(Bucket=self.bucket) except ClientError as er: if er.response['Error']['Code'] == 'NoSuchBucket': - return None - else: - raise ClientError(er.response['Error']['Code']) + return False + return True + def _list_s3_subdirs(self, prefix=None): + if prefix is None: + prefix = self.basedir + bucket = self.s3.Bucket(self.bucket) + all_keys = [obj.key for obj in bucket.objects.filter(Prefix=prefix)] subdir_match = r'{prefix}\/(.*)\/'.format(prefix=prefix) subdirs = [] for key in all_keys: @@ -115,18 +127,20 @@ def _create_bucket(self): def _determine_run_dir(self, _id): if _id is None: - # Get all existing subdirectories under s3://bucket/basedir/ - bucket_path_subdirs = self._list_s3_subdirs() - # _list_s3_subdirs returns None when the bucket doesn't exist - if bucket_path_subdirs is None: + bucket_exists = self._bucket_exists() + + if not bucket_exists: self._create_bucket() + bucket_path_subdirs = [] + else: + bucket_path_subdirs = self._list_s3_subdirs() - if bucket_path_subdirs is None or len(bucket_path_subdirs) == 0: + if not bucket_path_subdirs: max_run_id = 0 else: integer_directories = [int(d) for d in bucket_path_subdirs if d.isdigit()] - if len(integer_directories) == 0: + if not integer_directories: max_run_id = 0 else: # If there are directories under basedir that aren't @@ -303,13 +317,14 @@ def log_metrics(self, metrics_by_name, info): self.save_json(self.saved_metrics, 'metrics.json') def __eq__(self, other): - if isinstance(other, S3FileObserver): + if isinstance(other, S3Observer): return (self.bucket == other.bucket and self.basedir == other.basedir) - return False + else: + return False -class S3StorageOption(CommandLineOption): +class S3Option(CommandLineOption): """Add a S3 File observer to the experiment.""" short_flag = 'S3' @@ -324,5 +339,5 @@ def apply(cls, args, run): "Enter bucket and directory path like: " "s3:///path/to/exp") bucket, basedir = match_obj.groups() - run.observers.append(S3FileObserver.create(bucket=bucket, - basedir=basedir)) + run.observers.append(S3Observer.create(bucket=bucket, + basedir=basedir)) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 73e32060..e5275af8 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -8,7 +8,7 @@ import pytest import json -from sacred.observers import S3FileObserver +from sacred.observers import S3Observer import tempfile import hashlib @@ -42,7 +42,7 @@ def sample_run(): @pytest.fixture def observer(): - return S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) + return S3Observer.create(bucket=BUCKET, basedir=BASEDIR) @pytest.fixture @@ -129,12 +129,15 @@ def test_fs_observer_started_event_increments_run_id(observer, sample_run): def test_s3_observer_equality(): - obs_one = S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) - obs_two = S3FileObserver.create(bucket=BUCKET, basedir=BASEDIR) - different_basedir = S3FileObserver.create(bucket=BUCKET, - basedir="another/dir") + obs_one = S3Observer.create(bucket=BUCKET, basedir=BASEDIR) + obs_two = S3Observer.create(bucket=BUCKET, basedir=BASEDIR) + different_basedir = S3Observer.create(bucket=BUCKET, + basedir="another/dir") + different_bucket = S3Observer.create(bucket="some-other-bucket", + basedir=BASEDIR) assert obs_one == obs_two assert obs_one != different_basedir + assert obs_one != different_bucket @mock_s3 @@ -219,11 +222,19 @@ def test_artifact_event_works(observer, sample_run, tmpfile): assert artifact_data == tmpfile.content -def test_raises_error_on_invalid_bucket_name(): - with pytest.raises(ValueError): - _ = S3FileObserver.create(bucket="this_bucket_is_invalid", - basedir=BASEDIR) +test_buckets = [("hi", True), + ("this_bucket_is_invalid", True), + ("this-bucket-is-valid", False), + ("this-bucket.is-valid", False), + ("this-bucket..is-invalid", True)] + - with pytest.raises(ValueError): - _ = S3FileObserver.create(bucket="hi", +@pytest.mark.parametrize("bucket_name, should_raise", test_buckets) +def test_raises_error_on_invalid_bucket_name(bucket_name, should_raise): + if should_raise: + with pytest.raises(ValueError): + _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR) + else: + _ = S3Observer.create(bucket=bucket_name, + basedir=BASEDIR) From a929db846d65fed3abfeb14486dce9eae8bb3f4c Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 15:04:04 -0700 Subject: [PATCH 18/45] add ability to pass in region, and error on Observer creation if region is not either passed in or set in config file --- sacred/observers/s3_observer.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 197a9dbe..73876028 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -44,7 +44,7 @@ class S3Observer(RunObserver): @classmethod def create(cls, bucket, basedir, resource_dir=None, source_dir=None, - priority=DEFAULT_S3_PRIORITY): + priority=DEFAULT_S3_PRIORITY, region=None): """ A factory method to create a S3Observer object @@ -60,10 +60,10 @@ def create(cls, bucket, basedir, resource_dir=None, source_dir=None, resource_dir = resource_dir or os.path.join(basedir, '_resources') source_dir = source_dir or os.path.join(basedir, '_sources') - return cls(bucket, basedir, resource_dir, source_dir, priority) + return cls(bucket, basedir, resource_dir, source_dir, priority, region) def __init__(self, bucket, basedir, resource_dir, source_dir, - priority=DEFAULT_S3_PRIORITY): + priority=DEFAULT_S3_PRIORITY, region=None): if not _is_valid_bucket(bucket): raise ValueError("Your chosen bucket name does not follow AWS " "bucket naming rules") @@ -81,8 +81,18 @@ def __init__(self, bucket, basedir, resource_dir, source_dir, self.info = None self.cout = "" self.cout_write_cursor = 0 - self.s3 = boto3.resource('s3') self.saved_metrics = {} + if region is not None: + self.region = region + self.s3 = boto3.resource('s3', region_name=region) + else: + session = boto3.session.Session() + if session.region_name is not None: + self.region = session.region_name + self.s3 = boto3.resource('s3') + else: + raise ValueError("You must either pass in an AWS region name, or have a region" + " name specified in your AWS config file") def _objects_exist_in_dir(self, prefix): # This should be run after you've confirmed the bucket @@ -117,12 +127,10 @@ def _list_s3_subdirs(self, prefix=None): return list(distinct_subdirs) def _create_bucket(self): - session = boto3.session.Session() - current_region = session.region_name or 'us-west-2' bucket_response = self.s3.create_bucket( Bucket=self.bucket, CreateBucketConfiguration={ - 'LocationConstraint': current_region}) + 'LocationConstraint': self.region}) return bucket_response def _determine_run_dir(self, _id): From 9788e860f8273c81fd8cd4daad359a2d604bb736 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 15:08:25 -0700 Subject: [PATCH 19/45] fix error handling --- sacred/observers/s3_observer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 73876028..a8b9c4d5 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -106,7 +106,7 @@ def _bucket_exists(self): try: self.s3.meta.client.head_bucket(Bucket=self.bucket) except ClientError as er: - if er.response['Error']['Code'] == 'NoSuchBucket': + if er.response['Error']['Code'] == '404': return False return True @@ -136,7 +136,6 @@ def _create_bucket(self): def _determine_run_dir(self, _id): if _id is None: bucket_exists = self._bucket_exists() - if not bucket_exists: self._create_bucket() bucket_path_subdirs = [] From f4a0299c0aca1bd34fe8f6a1865194c59edfe2f7 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 15:09:06 -0700 Subject: [PATCH 20/45] fix comment --- sacred/observers/s3_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index a8b9c4d5..a5c776ae 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -151,7 +151,7 @@ def _determine_run_dir(self, _id): max_run_id = 0 else: # If there are directories under basedir that aren't - # run directories, ignore those + # numeric run directories, ignore those max_run_id = max(integer_directories) _id = max_run_id + 1 From 3dd47bf92368c0867685546e88184d7be06062ae Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 15:46:21 -0700 Subject: [PATCH 21/45] fix flake8 issues --- sacred/observers/s3_observer.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index a5c776ae..9144c0a6 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -25,7 +25,8 @@ def _is_valid_bucket(bucket_name): # A bucket name consists of "labels" separated by periods for label in labels: if len(label) == 0 or label[0] == '-' or label[-1] == '-': - # Labels must be of nonzero length, and cannot begin or end with a hyphen + # Labels must be of nonzero length, + # and cannot begin or end with a hyphen return False for char in label: # Labels can only contain digits, lowercase letters, or hyphens. @@ -48,13 +49,19 @@ def create(cls, bucket, basedir, resource_dir=None, source_dir=None, """ A factory method to create a S3Observer object - :param bucket: The name of the bucket you want to store results in. Doesn't need to contain - `s3://`, but does need to be a valid AWS bucket name - :param basedir: The relative path inside your bucket where you want this experiment - to store results - :param resource_dir: TODO what is this anyway? - :param source_dir: - :param priority: + :param bucket: The name of the bucket you want to store results in. + Doesn't need to contain `s3://`, but needs to be a valid bucket name + :param basedir: The relative path inside your bucket where you want + this experiment to store results + :param resource_dir: Where to store resources for this experiment. By + default, will be /_resources + :param source_dir: Where to store code sources for this experiment. By + default, will be /sources + :param priority: The priority to assign to this observer if + multiple observers are present + :param region: The AWS region in which you want to create and access + buckets. Needs to be either set here or configured in your AWS + config file. :return: """ resource_dir = resource_dir or os.path.join(basedir, '_resources') @@ -91,8 +98,9 @@ def __init__(self, bucket, basedir, resource_dir, source_dir, self.region = session.region_name self.s3 = boto3.resource('s3') else: - raise ValueError("You must either pass in an AWS region name, or have a region" - " name specified in your AWS config file") + raise ValueError("You must either pass in an AWS region name," + " or have a region name specified in your" + " AWS config file") def _objects_exist_in_dir(self, prefix): # This should be run after you've confirmed the bucket From bc5d296f1e04480b9f7f9702bf3a09bdd7bf46f5 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 15:46:36 -0700 Subject: [PATCH 22/45] explicitly set region in tests --- tests/test_observers/test_s3_observer.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index e5275af8..6b5b868a 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -20,6 +20,7 @@ BUCKET = 'pytest-s3-observer-bucket' BASEDIR = 'some-tests' +REGION = 'us-west-2' @pytest.fixture() @@ -42,7 +43,7 @@ def sample_run(): @pytest.fixture def observer(): - return S3Observer.create(bucket=BUCKET, basedir=BASEDIR) + return S3Observer.create(bucket=BUCKET, basedir=BASEDIR, region=REGION) @pytest.fixture @@ -129,19 +130,17 @@ def test_fs_observer_started_event_increments_run_id(observer, sample_run): def test_s3_observer_equality(): - obs_one = S3Observer.create(bucket=BUCKET, basedir=BASEDIR) - obs_two = S3Observer.create(bucket=BUCKET, basedir=BASEDIR) - different_basedir = S3Observer.create(bucket=BUCKET, - basedir="another/dir") - different_bucket = S3Observer.create(bucket="some-other-bucket", - basedir=BASEDIR) + obs_one = S3Observer.create(bucket=BUCKET, basedir=BASEDIR, region=REGION) + obs_two = S3Observer.create(bucket=BUCKET, basedir=BASEDIR, region=REGION) + different_basedir = S3Observer.create(bucket=BUCKET, basedir="another/dir", region=REGION) + different_bucket = S3Observer.create(bucket="other-bucket", basedir=BASEDIR, region=REGION) assert obs_one == obs_two assert obs_one != different_basedir assert obs_one != different_bucket @mock_s3 -def test_raises_error_on_duplicate_id_directory(observer, sample_run): +def test_z_raises_error_on_duplicate_id_directory(observer, sample_run): observer.started_event(**sample_run) sample_run['_id'] = 1 with pytest.raises(FileExistsError): From ab5fafb487a428a1a8fe12ede201dc876091c182 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 15:46:57 -0700 Subject: [PATCH 23/45] start to write s3observer docs --- docs/observers.rst | 39 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/docs/observers.rst b/docs/observers.rst index a0d2fc57..757ecd06 100644 --- a/docs/observers.rst +++ b/docs/observers.rst @@ -9,7 +9,7 @@ Observers have a ``priority`` attribute, and are run in order of descending priority. The first observer determines the ``_id`` of the run. -At the moment there are four observers that are shipped with Sacred: +At the moment there are five observers that are shipped with Sacred: * The main one is the :ref:`mongo_observer` which stores all information in a `MongoDB `_. @@ -20,6 +20,7 @@ At the moment there are four observers that are shipped with Sacred: to store run information in a JSON file. * The :ref:`sql_observer` connects to any SQL database and will store the relevant information there. + * The :ref:`s3_observer` stores run information within an AWS S3 bucket But if you want the run information stored some other way, it is easy to write your own :ref:`custom_observer`. @@ -591,6 +592,42 @@ Schema .. image:: images/sql_schema.png +.. _s3_observer: + +S3 Observer +============ +The S3Observer stores run information in a designated prefix location within a S3 bucket, either by +using an existing bucket, or creating a new one. Using the S3Observer requires that boto3 be +installed, and also that an AWS config file is created with a user's Access Key and Secret Key. +An easy way to do this is by installing AWS command line tools (``pip install awscli``) and +running ``aws configure``. + +Adding a S3Observer +-------------------- + +To create an S3Observer in Python: + +.. code-block:: python + + from sacred.observers import S3Observer + ex.observers.append(S3Observer.create(bucket='my-awesome-bucket', + basedir='/my-project/my-cool-experiment/')) + +By default, an S3Observer will use the region that is set in your AWS config file, but if you'd +prefer to pass in a specific region, you can use the ``region`` parameter of create to do so. +If you try to create an S3Observer without this parameter, and with region not set in your config +file, it will error out at the point of the observer object being created. + + + +Directory Structure +-------------------- + +S3Observers follow the same conventions as FileStorageObservers when it comes to directory +structure within a S3 bucket: within ``s3:///basedir/`` numeric run directories will be +created in ascending order, and each run directory will contain the files specified within the +FileStorageObserver Directory Structure documentation above. + Slack Observer ============== From fedf694460a7b2f646dc1f1db3f99bcba4aa28c6 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:01:13 -0700 Subject: [PATCH 24/45] add requirement in tox file --- docs/observers.rst | 2 -- tox.ini | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/observers.rst b/docs/observers.rst index 49faec77..22556cc0 100644 --- a/docs/observers.rst +++ b/docs/observers.rst @@ -618,8 +618,6 @@ prefer to pass in a specific region, you can use the ``region`` parameter of cre If you try to create an S3Observer without this parameter, and with region not set in your config file, it will error out at the point of the observer object being created. - - Directory Structure -------------------- diff --git a/tox.ini b/tox.ini index e665ab52..75ef4fbb 100644 --- a/tox.ini +++ b/tox.ini @@ -57,6 +57,7 @@ basepython = python deps = pytest==4.3.0 mock==2.0.0 + moto==1.3.13 commands = pytest {posargs} From 3889e33cff70523c846ccb1d38c80538ca8df938 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:07:54 -0700 Subject: [PATCH 25/45] black reformatting --- sacred/experiment.py | 5 +- sacred/observers/mongo.py | 4 +- sacred/observers/s3_observer.py | 215 +++++++++++---------- sacred/observers/telegram_obs.py | 2 +- tests/test_observers/failing_mongo_mock.py | 2 +- tests/test_observers/test_s3_observer.py | 214 ++++++++++---------- 6 files changed, 236 insertions(+), 206 deletions(-) diff --git a/sacred/experiment.py b/sacred/experiment.py index ef68ec81..47fe8807 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -524,8 +524,9 @@ def _create_run( def _check_command(self, cmd_name): commands = dict(self.gather_commands()) if cmd_name is not None and cmd_name not in commands: - return 'Error: Command "{}" not found. Available commands are: ' "{}".format( - cmd_name, ", ".join(commands.keys()) + return ( + 'Error: Command "{}" not found. Available commands are: ' + "{}".format(cmd_name, ", ".join(commands.keys())) ) if cmd_name is None: diff --git a/sacred/observers/mongo.py b/sacred/observers/mongo.py index 3b95f70b..ab488eda 100644 --- a/sacred/observers/mongo.py +++ b/sacred/observers/mongo.py @@ -75,7 +75,7 @@ def create( priority=DEFAULT_MONGO_PRIORITY, client=None, failure_dir=None, - **kwargs + **kwargs, ): """Factory method for MongoObserver. @@ -558,7 +558,7 @@ def create( overwrite=None, priority=DEFAULT_MONGO_PRIORITY, client=None, - **kwargs + **kwargs, ): return cls( QueueCompatibleMongoObserver.create( diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 9144c0a6..005dd341 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -21,17 +21,17 @@ def _is_valid_bucket(bucket_name): if len(bucket_name) < 3 or len(bucket_name) > 63: return False - labels = bucket_name.split('.') + labels = bucket_name.split(".") # A bucket name consists of "labels" separated by periods for label in labels: - if len(label) == 0 or label[0] == '-' or label[-1] == '-': + if len(label) == 0 or label[0] == "-" or label[-1] == "-": # Labels must be of nonzero length, # and cannot begin or end with a hyphen return False for char in label: # Labels can only contain digits, lowercase letters, or hyphens. # Anything else will fail here - if not (char.isdigit() or char.islower() or char == '-'): + if not (char.isdigit() or char.islower() or char == "-"): return False try: # If a name is a valid IP address, it cannot be a bucket name @@ -41,11 +41,18 @@ def _is_valid_bucket(bucket_name): class S3Observer(RunObserver): - VERSION = 'S3Observer-0.1.0' + VERSION = "S3Observer-0.1.0" @classmethod - def create(cls, bucket, basedir, resource_dir=None, source_dir=None, - priority=DEFAULT_S3_PRIORITY, region=None): + def create( + cls, + bucket, + basedir, + resource_dir=None, + source_dir=None, + priority=DEFAULT_S3_PRIORITY, + region=None, + ): """ A factory method to create a S3Observer object @@ -64,16 +71,24 @@ def create(cls, bucket, basedir, resource_dir=None, source_dir=None, config file. :return: """ - resource_dir = resource_dir or os.path.join(basedir, '_resources') - source_dir = source_dir or os.path.join(basedir, '_sources') + resource_dir = resource_dir or os.path.join(basedir, "_resources") + source_dir = source_dir or os.path.join(basedir, "_sources") return cls(bucket, basedir, resource_dir, source_dir, priority, region) - def __init__(self, bucket, basedir, resource_dir, source_dir, - priority=DEFAULT_S3_PRIORITY, region=None): + def __init__( + self, + bucket, + basedir, + resource_dir, + source_dir, + priority=DEFAULT_S3_PRIORITY, + region=None, + ): if not _is_valid_bucket(bucket): - raise ValueError("Your chosen bucket name does not follow AWS " - "bucket naming rules") + raise ValueError( + "Your chosen bucket name does not follow AWS " "bucket naming rules" + ) self.basedir = basedir self.bucket = bucket @@ -91,16 +106,18 @@ def __init__(self, bucket, basedir, resource_dir, source_dir, self.saved_metrics = {} if region is not None: self.region = region - self.s3 = boto3.resource('s3', region_name=region) + self.s3 = boto3.resource("s3", region_name=region) else: session = boto3.session.Session() if session.region_name is not None: self.region = session.region_name - self.s3 = boto3.resource('s3') + self.s3 = boto3.resource("s3") else: - raise ValueError("You must either pass in an AWS region name," - " or have a region name specified in your" - " AWS config file") + raise ValueError( + "You must either pass in an AWS region name," + " or have a region name specified in your" + " AWS config file" + ) def _objects_exist_in_dir(self, prefix): # This should be run after you've confirmed the bucket @@ -114,7 +131,7 @@ def _bucket_exists(self): try: self.s3.meta.client.head_bucket(Bucket=self.bucket) except ClientError as er: - if er.response['Error']['Code'] == '404': + if er.response["Error"]["Code"] == "404": return False return True @@ -123,7 +140,7 @@ def _list_s3_subdirs(self, prefix=None): prefix = self.basedir bucket = self.s3.Bucket(self.bucket) all_keys = [obj.key for obj in bucket.objects.filter(Prefix=prefix)] - subdir_match = r'{prefix}\/(.*)\/'.format(prefix=prefix) + subdir_match = r"{prefix}\/(.*)\/".format(prefix=prefix) subdirs = [] for key in all_keys: match_obj = re.match(subdir_match, key) @@ -137,8 +154,8 @@ def _list_s3_subdirs(self, prefix=None): def _create_bucket(self): bucket_response = self.s3.create_bucket( Bucket=self.bucket, - CreateBucketConfiguration={ - 'LocationConstraint': self.region}) + CreateBucketConfiguration={"LocationConstraint": self.region}, + ) return bucket_response def _determine_run_dir(self, _id): @@ -153,8 +170,9 @@ def _determine_run_dir(self, _id): if not bucket_path_subdirs: max_run_id = 0 else: - integer_directories = [int(d) for d in bucket_path_subdirs - if d.isdigit()] + integer_directories = [ + int(d) for d in bucket_path_subdirs if d.isdigit() + ] if not integer_directories: max_run_id = 0 else: @@ -166,66 +184,67 @@ def _determine_run_dir(self, _id): self.dir = os.path.join(self.basedir, str(_id)) if self._objects_exist_in_dir(self.dir): - raise FileExistsError( - "S3 dir at {} already exists".format(self.dir)) + raise FileExistsError("S3 dir at {} already exists".format(self.dir)) return _id - def queued_event(self, ex_info, command, host_info, queue_time, config, - meta_info, _id): + def queued_event( + self, ex_info, command, host_info, queue_time, config, meta_info, _id + ): _id = self._determine_run_dir(_id) self.run_entry = { - 'experiment': dict(ex_info), - 'command': command, - 'host': dict(host_info), - 'meta': meta_info, - 'status': 'QUEUED', + "experiment": dict(ex_info), + "command": command, + "host": dict(host_info), + "meta": meta_info, + "status": "QUEUED", } self.config = config self.info = {} - self.save_json(self.run_entry, 'run.json') - self.save_json(self.config, 'config.json') + self.save_json(self.run_entry, "run.json") + self.save_json(self.config, "config.json") - for s, m in ex_info['sources']: + for s, m in ex_info["sources"]: self.save_file(s) return _id def save_sources(self, ex_info): - base_dir = ex_info['base_dir'] + base_dir = ex_info["base_dir"] source_info = [] - for s, m in ex_info['sources']: + for s, m in ex_info["sources"]: abspath = os.path.join(base_dir, s) store_path, md5sum = self.find_or_save(abspath, self.source_dir) source_info.append([s, os.path.relpath(store_path, self.basedir)]) return source_info - def started_event(self, ex_info, command, host_info, start_time, config, - meta_info, _id): + def started_event( + self, ex_info, command, host_info, start_time, config, meta_info, _id + ): _id = self._determine_run_dir(_id) - ex_info['sources'] = self.save_sources(ex_info) + ex_info["sources"] = self.save_sources(ex_info) self.run_entry = { - 'experiment': dict(ex_info), - 'command': command, - 'host': dict(host_info), - 'start_time': start_time.isoformat(), - 'meta': meta_info, - 'status': 'RUNNING', - 'resources': [], - 'artifacts': [], - 'heartbeat': None + "experiment": dict(ex_info), + "command": command, + "host": dict(host_info), + "start_time": start_time.isoformat(), + "meta": meta_info, + "status": "RUNNING", + "resources": [], + "artifacts": [], + "heartbeat": None, } self.config = config self.info = {} self.cout = "" self.cout_write_cursor = 0 - self.save_json(self.run_entry, 'run.json') - self.save_json(self.config, 'config.json') + self.save_json(self.run_entry, "run.json") + self.save_json(self.config, "config.json") self.save_cout() return _id @@ -233,7 +252,7 @@ def started_event(self, ex_info, command, host_info, start_time, config, def find_or_save(self, filename, store_dir): source_name, ext = os.path.splitext(os.path.basename(filename)) md5sum = get_digest(filename) - store_name = source_name + '_' + md5sum + ext + store_name = source_name + "_" + md5sum + ext store_path = os.path.join(store_dir, store_name) if len(self._list_s3_subdirs(prefix=store_path)) == 0: self.save_file(filename, store_path) @@ -244,13 +263,12 @@ def put_data(self, key, binary_data): def save_json(self, obj, filename): key = os.path.join(self.dir, filename) - self.put_data(key, json.dumps(flatten(obj), - sort_keys=True, indent=2)) + self.put_data(key, json.dumps(flatten(obj), sort_keys=True, indent=2)) def save_file(self, filename, target_name=None): target_name = target_name or os.path.basename(filename) key = os.path.join(self.dir, target_name) - self.put_data(key, open(filename, 'rb')) + self.put_data(key, open(filename, "rb")) def save_directory(self, source_dir, target_name): # Stolen from: @@ -259,57 +277,59 @@ def save_directory(self, source_dir, target_name): all_files = [] for root, dirs, files in os.walk(source_dir): all_files += [os.path.join(root, f) for f in files] - s3_resource = boto3.resource('s3') + s3_resource = boto3.resource("s3") for filename in all_files: - file_location = os.path.join(self.dir, target_name, - os.path.relpath(filename, source_dir)) - s3_resource.Object(self.bucket, - file_location).put(Body=open(filename, 'rb')) + file_location = os.path.join( + self.dir, target_name, os.path.relpath(filename, source_dir) + ) + s3_resource.Object(self.bucket, file_location).put( + Body=open(filename, "rb") + ) def save_cout(self): - binary_data = self.cout[self.cout_write_cursor:].encode("utf-8") - key = os.path.join(self.dir, 'cout.txt') + binary_data = self.cout[self.cout_write_cursor :].encode("utf-8") + key = os.path.join(self.dir, "cout.txt") self.put_data(key, binary_data) self.cout_write_cursor = len(self.cout) def heartbeat_event(self, info, captured_out, beat_time, result): self.info = info - self.run_entry['heartbeat'] = beat_time.isoformat() - self.run_entry['result'] = result + self.run_entry["heartbeat"] = beat_time.isoformat() + self.run_entry["result"] = result self.cout = captured_out self.save_cout() - self.save_json(self.run_entry, 'run.json') + self.save_json(self.run_entry, "run.json") if self.info: - self.save_json(self.info, 'info.json') + self.save_json(self.info, "info.json") def completed_event(self, stop_time, result): - self.run_entry['stop_time'] = stop_time.isoformat() - self.run_entry['result'] = result - self.run_entry['status'] = 'COMPLETED' + self.run_entry["stop_time"] = stop_time.isoformat() + self.run_entry["result"] = result + self.run_entry["status"] = "COMPLETED" - self.save_json(self.run_entry, 'run.json') + self.save_json(self.run_entry, "run.json") def interrupted_event(self, interrupt_time, status): - self.run_entry['stop_time'] = interrupt_time.isoformat() - self.run_entry['status'] = status - self.save_json(self.run_entry, 'run.json') + self.run_entry["stop_time"] = interrupt_time.isoformat() + self.run_entry["status"] = status + self.save_json(self.run_entry, "run.json") def failed_event(self, fail_time, fail_trace): - self.run_entry['stop_time'] = fail_time.isoformat() - self.run_entry['status'] = 'FAILED' - self.run_entry['fail_trace'] = fail_trace - self.save_json(self.run_entry, 'run.json') + self.run_entry["stop_time"] = fail_time.isoformat() + self.run_entry["status"] = "FAILED" + self.run_entry["fail_trace"] = fail_trace + self.save_json(self.run_entry, "run.json") def resource_event(self, filename): store_path, md5sum = self.find_or_save(filename, self.resource_dir) - self.run_entry['resources'].append([filename, store_path]) - self.save_json(self.run_entry, 'run.json') + self.run_entry["resources"].append([filename, store_path]) + self.save_json(self.run_entry, "run.json") def artifact_event(self, name, filename, metadata=None, content_type=None): self.save_file(filename, name) - self.run_entry['artifacts'].append(name) - self.save_json(self.run_entry, 'run.json') + self.run_entry["artifacts"].append(name) + self.save_json(self.run_entry, "run.json") def log_metrics(self, metrics_by_name, info): """Store new measurements into metrics.json. @@ -318,23 +338,23 @@ def log_metrics(self, metrics_by_name, info): for metric_name, metric_ptr in metrics_by_name.items(): if metric_name not in self.saved_metrics: - self.saved_metrics[metric_name] = {"values": [], - "steps": [], - "timestamps": []} + self.saved_metrics[metric_name] = { + "values": [], + "steps": [], + "timestamps": [], + } self.saved_metrics[metric_name]["values"] += metric_ptr["values"] self.saved_metrics[metric_name]["steps"] += metric_ptr["steps"] - timestamps_norm = [ts.isoformat() - for ts in metric_ptr["timestamps"]] + timestamps_norm = [ts.isoformat() for ts in metric_ptr["timestamps"]] self.saved_metrics[metric_name]["timestamps"] += timestamps_norm - self.save_json(self.saved_metrics, 'metrics.json') + self.save_json(self.saved_metrics, "metrics.json") def __eq__(self, other): if isinstance(other, S3Observer): - return (self.bucket == other.bucket and - self.basedir == other.basedir) + return self.bucket == other.bucket and self.basedir == other.basedir else: return False @@ -342,17 +362,18 @@ def __eq__(self, other): class S3Option(CommandLineOption): """Add a S3 File observer to the experiment.""" - short_flag = 'S3' - arg = 'BUCKET_PATH' + short_flag = "S3" + arg = "BUCKET_PATH" arg_description = "s3:///path/to/exp" @classmethod def apply(cls, args, run): - match_obj = re.match(r's3:\/\/([^\/]*)\/(.*)', args) + match_obj = re.match(r"s3:\/\/([^\/]*)\/(.*)", args) if match_obj is None or len(match_obj.groups()) != 2: - raise ValueError("Valid bucket specification not found. " - "Enter bucket and directory path like: " - "s3:///path/to/exp") + raise ValueError( + "Valid bucket specification not found. " + "Enter bucket and directory path like: " + "s3:///path/to/exp" + ) bucket, basedir = match_obj.groups() - run.observers.append(S3Observer.create(bucket=bucket, - basedir=basedir)) + run.observers.append(S3Observer.create(bucket=bucket, basedir=basedir)) diff --git a/sacred/observers/telegram_obs.py b/sacred/observers/telegram_obs.py index c216e1fc..e48b9420 100644 --- a/sacred/observers/telegram_obs.py +++ b/sacred/observers/telegram_obs.py @@ -80,7 +80,7 @@ def __init__( chat_id, silent_completion=False, priority=DEFAULT_TELEGRAM_PRIORITY, - **kwargs + **kwargs, ): self.silent_completion = silent_completion self.chat_id = chat_id diff --git a/tests/test_observers/failing_mongo_mock.py b/tests/test_observers/failing_mongo_mock.py index c0bdb541..6695e2e4 100644 --- a/tests/test_observers/failing_mongo_mock.py +++ b/tests/test_observers/failing_mongo_mock.py @@ -8,7 +8,7 @@ def __init__( self, max_calls_before_failure=2, exception_to_raise=pymongo.errors.AutoReconnect, - **kwargs + **kwargs, ): super().__init__(**kwargs) self._max_calls_before_failure = max_calls_before_failure diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 6b5b868a..84455100 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -18,26 +18,26 @@ T1 = datetime.datetime(1999, 5, 4, 3, 2, 1, 0) T2 = datetime.datetime(1999, 5, 5, 5, 5, 5, 5) -BUCKET = 'pytest-s3-observer-bucket' -BASEDIR = 'some-tests' -REGION = 'us-west-2' +BUCKET = "pytest-s3-observer-bucket" +BASEDIR = "some-tests" +REGION = "us-west-2" @pytest.fixture() def sample_run(): - exp = {'name': 'test_exp', 'sources': [], 'doc': '', 'base_dir': '/tmp'} - host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'} - config = {'config': 'True', 'foo': 'bar', 'answer': 42} - command = 'run' - meta_info = {'comment': 'test run'} + exp = {"name": "test_exp", "sources": [], "doc": "", "base_dir": "/tmp"} + host = {"hostname": "test_host", "cpu_count": 1, "python_version": "3.4"} + config = {"config": "True", "foo": "bar", "answer": 42} + command = "run" + meta_info = {"comment": "test run"} return { - '_id': None, - 'ex_info': exp, - 'command': command, - 'host_info': host, - 'start_time': T1, - 'config': config, - 'meta_info': meta_info, + "_id": None, + "ex_info": exp, + "command": command, + "host_info": host, + "start_time": T1, + "config": config, + "meta_info": meta_info, } @@ -52,9 +52,9 @@ def tmpfile(): # manually deleting the file, such that we can close it before running the # tests. This is necessary since on Windows we can not open the same file # twice, so for the FileStorageObserver to read it, we need to close it. - f = tempfile.NamedTemporaryFile(suffix='.py', delete=False) + f = tempfile.NamedTemporaryFile(suffix=".py", delete=False) - f.content = 'import sacred\n' + f.content = "import sacred\n" f.write(f.content.encode()) f.flush() f.seek(0) @@ -68,28 +68,28 @@ def tmpfile(): def _bucket_exists(bucket_name): - s3 = boto3.resource('s3') + s3 = boto3.resource("s3") try: s3.meta.client.head_bucket(Bucket=bucket_name) except ClientError as e: - if e.response['Error']['Code'] == '404': + if e.response["Error"]["Code"] == "404": return False return True def _key_exists(bucket_name, key): - s3 = boto3.resource('s3') + s3 = boto3.resource("s3") try: s3.Object(bucket_name, key).load() except ClientError as e: - if e.response['Error']['Code'] == '404': + if e.response["Error"]["Code"] == "404": return False return True def _get_file_data(bucket_name, key): - s3 = boto3.resource('s3') - return s3.Object(bucket_name, key).get()['Body'].read() + s3 = boto3.resource("s3") + return s3.Object(bucket_name, key).get()["Body"].read() @mock_s3 @@ -97,28 +97,25 @@ def test_fs_observer_started_event_creates_bucket(observer, sample_run): _id = observer.started_event(**sample_run) run_dir = os.path.join(BASEDIR, str(_id)) assert _bucket_exists(bucket_name=BUCKET) - assert _key_exists(bucket_name=BUCKET, - key=os.path.join(run_dir, 'cout.txt')) - assert _key_exists(bucket_name=BUCKET, - key=os.path.join(run_dir, 'config.json')) - assert _key_exists(bucket_name=BUCKET, - key=os.path.join(run_dir, 'run.json')) - config = _get_file_data(bucket_name=BUCKET, - key=os.path.join(run_dir, 'config.json')) - - assert json.loads(config.decode('utf-8')) == sample_run['config'] - run = _get_file_data(bucket_name=BUCKET, - key=os.path.join(run_dir, 'run.json')) - assert json.loads(run.decode('utf-8')) == { - 'experiment': sample_run['ex_info'], - 'command': sample_run['command'], - 'host': sample_run['host_info'], - 'start_time': T1.isoformat(), - 'heartbeat': None, - 'meta': sample_run['meta_info'], + assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, "cout.txt")) + assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, "config.json")) + assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, "run.json")) + config = _get_file_data( + bucket_name=BUCKET, key=os.path.join(run_dir, "config.json") + ) + + assert json.loads(config.decode("utf-8")) == sample_run["config"] + run = _get_file_data(bucket_name=BUCKET, key=os.path.join(run_dir, "run.json")) + assert json.loads(run.decode("utf-8")) == { + "experiment": sample_run["ex_info"], + "command": sample_run["command"], + "host": sample_run["host_info"], + "start_time": T1.isoformat(), + "heartbeat": None, + "meta": sample_run["meta_info"], "resources": [], "artifacts": [], - "status": "RUNNING" + "status": "RUNNING", } @@ -132,8 +129,12 @@ def test_fs_observer_started_event_increments_run_id(observer, sample_run): def test_s3_observer_equality(): obs_one = S3Observer.create(bucket=BUCKET, basedir=BASEDIR, region=REGION) obs_two = S3Observer.create(bucket=BUCKET, basedir=BASEDIR, region=REGION) - different_basedir = S3Observer.create(bucket=BUCKET, basedir="another/dir", region=REGION) - different_bucket = S3Observer.create(bucket="other-bucket", basedir=BASEDIR, region=REGION) + different_basedir = S3Observer.create( + bucket=BUCKET, basedir="another/dir", region=REGION + ) + different_bucket = S3Observer.create( + bucket="other-bucket", basedir=BASEDIR, region=REGION + ) assert obs_one == obs_two assert obs_one != different_basedir assert obs_one != different_bucket @@ -142,7 +143,7 @@ def test_s3_observer_equality(): @mock_s3 def test_z_raises_error_on_duplicate_id_directory(observer, sample_run): observer.started_event(**sample_run) - sample_run['_id'] = 1 + sample_run["_id"] = 1 with pytest.raises(FileExistsError): observer.started_event(**sample_run) @@ -150,90 +151,97 @@ def test_z_raises_error_on_duplicate_id_directory(observer, sample_run): @mock_s3 def test_completed_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) - run = json.loads(_get_file_data(bucket_name=BUCKET, - key=os.path.join(observer.dir, - 'run.json')) - .decode('utf-8')) - assert run['status'] == 'RUNNING' - observer.completed_event(T2, 'success!') - run = json.loads(_get_file_data(bucket_name=BUCKET, - key=os.path.join(observer.dir, - 'run.json')) - .decode('utf-8')) - assert run['status'] == 'COMPLETED' + run = json.loads( + _get_file_data( + bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + ).decode("utf-8") + ) + assert run["status"] == "RUNNING" + observer.completed_event(T2, "success!") + run = json.loads( + _get_file_data( + bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + ).decode("utf-8") + ) + assert run["status"] == "COMPLETED" @mock_s3 def test_interrupted_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) - run = json.loads(_get_file_data(bucket_name=BUCKET, - key=os.path.join(observer.dir, - 'run.json')) - .decode('utf-8')) - assert run['status'] == 'RUNNING' - observer.interrupted_event(T2, 'SERVER_EXPLODED') - run = json.loads(_get_file_data(bucket_name=BUCKET, - key=os.path.join(observer.dir, - 'run.json')) - .decode('utf-8')) - assert run['status'] == 'SERVER_EXPLODED' + run = json.loads( + _get_file_data( + bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + ).decode("utf-8") + ) + assert run["status"] == "RUNNING" + observer.interrupted_event(T2, "SERVER_EXPLODED") + run = json.loads( + _get_file_data( + bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + ).decode("utf-8") + ) + assert run["status"] == "SERVER_EXPLODED" @mock_s3 def test_failed_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) - run = json.loads(_get_file_data(bucket_name=BUCKET, - key=os.path.join(observer.dir, - 'run.json')) - .decode('utf-8')) - assert run['status'] == 'RUNNING' - observer.failed_event(T2, 'Everything imaginable went wrong') - run = json.loads(_get_file_data(bucket_name=BUCKET, - key=os.path.join(observer.dir, - 'run.json')) - .decode('utf-8')) - assert run['status'] == 'FAILED' + run = json.loads( + _get_file_data( + bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + ).decode("utf-8") + ) + assert run["status"] == "RUNNING" + observer.failed_event(T2, "Everything imaginable went wrong") + run = json.loads( + _get_file_data( + bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + ).decode("utf-8") + ) + assert run["status"] == "FAILED" @mock_s3 def test_queued_event_updates_run_json(observer, sample_run): - del sample_run['start_time'] - sample_run['queue_time'] = T2 + del sample_run["start_time"] + sample_run["queue_time"] = T2 observer.queued_event(**sample_run) - run = json.loads(_get_file_data(bucket_name=BUCKET, - key=os.path.join(observer.dir, - 'run.json')) - .decode('utf-8')) - assert run['status'] == 'QUEUED' + run = json.loads( + _get_file_data( + bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + ).decode("utf-8") + ) + assert run["status"] == "QUEUED" @mock_s3 def test_artifact_event_works(observer, sample_run, tmpfile): observer.started_event(**sample_run) - observer.artifact_event('test_artifact.py', tmpfile.name) - - assert _key_exists(bucket_name=BUCKET, - key=os.path.join(observer.dir, 'test_artifact.py')) - artifact_data = (_get_file_data(bucket_name=BUCKET, - key=os.path.join(observer.dir, - 'test_artifact.py')) - .decode('utf-8')) + observer.artifact_event("test_artifact.py", tmpfile.name) + + assert _key_exists( + bucket_name=BUCKET, key=os.path.join(observer.dir, "test_artifact.py") + ) + artifact_data = _get_file_data( + bucket_name=BUCKET, key=os.path.join(observer.dir, "test_artifact.py") + ).decode("utf-8") assert artifact_data == tmpfile.content -test_buckets = [("hi", True), - ("this_bucket_is_invalid", True), - ("this-bucket-is-valid", False), - ("this-bucket.is-valid", False), - ("this-bucket..is-invalid", True)] +test_buckets = [ + ("hi", True), + ("this_bucket_is_invalid", True), + ("this-bucket-is-valid", False), + ("this-bucket.is-valid", False), + ("this-bucket..is-invalid", True), +] @pytest.mark.parametrize("bucket_name, should_raise", test_buckets) def test_raises_error_on_invalid_bucket_name(bucket_name, should_raise): if should_raise: with pytest.raises(ValueError): - _ = S3Observer.create(bucket=bucket_name, - basedir=BASEDIR) + _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR) else: - _ = S3Observer.create(bucket=bucket_name, - basedir=BASEDIR) + _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR) From ed58c960b2524b815fa18e03c7e048e82dfe4de3 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:09:24 -0700 Subject: [PATCH 26/45] pass in region for tests --- tests/test_observers/test_s3_observer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 84455100..50987d53 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -242,6 +242,10 @@ def test_artifact_event_works(observer, sample_run, tmpfile): def test_raises_error_on_invalid_bucket_name(bucket_name, should_raise): if should_raise: with pytest.raises(ValueError): - _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR) + _ = S3Observer.create(bucket=bucket_name, + basedir=BASEDIR, + region=REGION) else: - _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR) + _ = S3Observer.create(bucket=bucket_name, + basedir=BASEDIR, + region=REGION) From d481d667450cb603acd10f228c1a4abab8e90793 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:18:29 -0700 Subject: [PATCH 27/45] remove use of os.path.join to get around issues with windows paths --- sacred/observers/s3_observer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 005dd341..5a4a6fa7 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -40,6 +40,9 @@ def _is_valid_bucket(bucket_name): return True +def s3_join(iterable): + return "/".join(iterable) + class S3Observer(RunObserver): VERSION = "S3Observer-0.1.0" @@ -71,8 +74,8 @@ def create( config file. :return: """ - resource_dir = resource_dir or os.path.join(basedir, "_resources") - source_dir = source_dir or os.path.join(basedir, "_sources") + resource_dir = resource_dir or "/".join([basedir, "_resources"]) + source_dir = source_dir or "/".join([basedir, "_sources"]) return cls(bucket, basedir, resource_dir, source_dir, priority, region) @@ -182,7 +185,7 @@ def _determine_run_dir(self, _id): _id = max_run_id + 1 - self.dir = os.path.join(self.basedir, str(_id)) + self.dir = s3_join([self.basedir, str(_id)]) if self._objects_exist_in_dir(self.dir): raise FileExistsError("S3 dir at {} already exists".format(self.dir)) return _id @@ -253,7 +256,7 @@ def find_or_save(self, filename, store_dir): source_name, ext = os.path.splitext(os.path.basename(filename)) md5sum = get_digest(filename) store_name = source_name + "_" + md5sum + ext - store_path = os.path.join(store_dir, store_name) + store_path = s3_join([store_dir, store_name]) if len(self._list_s3_subdirs(prefix=store_path)) == 0: self.save_file(filename, store_path) return store_path, md5sum @@ -262,12 +265,12 @@ def put_data(self, key, binary_data): self.s3.Object(self.bucket, key).put(Body=binary_data) def save_json(self, obj, filename): - key = os.path.join(self.dir, filename) + key = s3_join([self.dir, filename]) self.put_data(key, json.dumps(flatten(obj), sort_keys=True, indent=2)) def save_file(self, filename, target_name=None): target_name = target_name or os.path.basename(filename) - key = os.path.join(self.dir, target_name) + key = s3_join([self.dir, target_name]) self.put_data(key, open(filename, "rb")) def save_directory(self, source_dir, target_name): @@ -280,8 +283,8 @@ def save_directory(self, source_dir, target_name): s3_resource = boto3.resource("s3") for filename in all_files: - file_location = os.path.join( - self.dir, target_name, os.path.relpath(filename, source_dir) + file_location = s3_join( + [self.dir, target_name, os.path.relpath(filename, source_dir)] ) s3_resource.Object(self.bucket, file_location).put( Body=open(filename, "rb") @@ -289,7 +292,7 @@ def save_directory(self, source_dir, target_name): def save_cout(self): binary_data = self.cout[self.cout_write_cursor :].encode("utf-8") - key = os.path.join(self.dir, "cout.txt") + key = s3_join([self.dir, "cout.txt"]) self.put_data(key, binary_data) self.cout_write_cursor = len(self.cout) From 3c2315c9e621f1642c6ba1d6b4401dbadf91a7bc Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:22:34 -0700 Subject: [PATCH 28/45] more black reformatting --- sacred/observers/s3_observer.py | 1 + tests/test_observers/test_s3_observer.py | 8 ++------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 5a4a6fa7..d6774cf8 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -43,6 +43,7 @@ def _is_valid_bucket(bucket_name): def s3_join(iterable): return "/".join(iterable) + class S3Observer(RunObserver): VERSION = "S3Observer-0.1.0" diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 50987d53..3a1e7207 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -242,10 +242,6 @@ def test_artifact_event_works(observer, sample_run, tmpfile): def test_raises_error_on_invalid_bucket_name(bucket_name, should_raise): if should_raise: with pytest.raises(ValueError): - _ = S3Observer.create(bucket=bucket_name, - basedir=BASEDIR, - region=REGION) + _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR, region=REGION) else: - _ = S3Observer.create(bucket=bucket_name, - basedir=BASEDIR, - region=REGION) + _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR, region=REGION) From f418e346396761e6ca1542ee94d413e0dc096aeb Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:27:49 -0700 Subject: [PATCH 29/45] remove comma because of python 3.5 syntax error --- sacred/observers/mongo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sacred/observers/mongo.py b/sacred/observers/mongo.py index ab488eda..1c5ace3a 100644 --- a/sacred/observers/mongo.py +++ b/sacred/observers/mongo.py @@ -75,7 +75,7 @@ def create( priority=DEFAULT_MONGO_PRIORITY, client=None, failure_dir=None, - **kwargs, + **kwargs ): """Factory method for MongoObserver. From 553b195c6ab95d423c9a222f7f0e837cbc6323e0 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:36:52 -0700 Subject: [PATCH 30/45] fix black project toml, update s3_join method, fix tests to use s3_join --- pyproject.toml | 1 - sacred/observers/s3_observer.py | 4 +-- tests/test_observers/test_s3_observer.py | 33 +++++++++++++----------- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b54bec17..4a4b9385 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,4 +16,3 @@ exclude = ''' | dist )/ ) -''' \ No newline at end of file diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index d6774cf8..0fc1b1d3 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -40,8 +40,8 @@ def _is_valid_bucket(bucket_name): return True -def s3_join(iterable): - return "/".join(iterable) +def s3_join(**args): + return "/".join(args) class S3Observer(RunObserver): diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 3a1e7207..874b15a5 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -23,6 +23,9 @@ REGION = "us-west-2" +def s3_join(*args): + return "/".join(args) + @pytest.fixture() def sample_run(): exp = {"name": "test_exp", "sources": [], "doc": "", "base_dir": "/tmp"} @@ -95,17 +98,17 @@ def _get_file_data(bucket_name, key): @mock_s3 def test_fs_observer_started_event_creates_bucket(observer, sample_run): _id = observer.started_event(**sample_run) - run_dir = os.path.join(BASEDIR, str(_id)) + run_dir = s3_join(BASEDIR, str(_id)) assert _bucket_exists(bucket_name=BUCKET) - assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, "cout.txt")) - assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, "config.json")) - assert _key_exists(bucket_name=BUCKET, key=os.path.join(run_dir, "run.json")) + assert _key_exists(bucket_name=BUCKET, key=s3_join(run_dir, "cout.txt")) + assert _key_exists(bucket_name=BUCKET, key=s3_join(run_dir, "config.json")) + assert _key_exists(bucket_name=BUCKET, key=s3_join(run_dir, "run.json")) config = _get_file_data( - bucket_name=BUCKET, key=os.path.join(run_dir, "config.json") + bucket_name=BUCKET, key=s3_join(run_dir, "config.json") ) assert json.loads(config.decode("utf-8")) == sample_run["config"] - run = _get_file_data(bucket_name=BUCKET, key=os.path.join(run_dir, "run.json")) + run = _get_file_data(bucket_name=BUCKET, key=s3_join(run_dir, "run.json")) assert json.loads(run.decode("utf-8")) == { "experiment": sample_run["ex_info"], "command": sample_run["command"], @@ -153,14 +156,14 @@ def test_completed_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) run = json.loads( _get_file_data( - bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + bucket_name=BUCKET, key=s3_join(observer.dir, "run.json") ).decode("utf-8") ) assert run["status"] == "RUNNING" observer.completed_event(T2, "success!") run = json.loads( _get_file_data( - bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + bucket_name=BUCKET, key=s3_join(observer.dir, "run.json") ).decode("utf-8") ) assert run["status"] == "COMPLETED" @@ -171,14 +174,14 @@ def test_interrupted_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) run = json.loads( _get_file_data( - bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + bucket_name=BUCKET, key=s3_join(observer.dir, "run.json") ).decode("utf-8") ) assert run["status"] == "RUNNING" observer.interrupted_event(T2, "SERVER_EXPLODED") run = json.loads( _get_file_data( - bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + bucket_name=BUCKET, key=s3_join(observer.dir, "run.json") ).decode("utf-8") ) assert run["status"] == "SERVER_EXPLODED" @@ -189,14 +192,14 @@ def test_failed_event_updates_run_json(observer, sample_run): observer.started_event(**sample_run) run = json.loads( _get_file_data( - bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + bucket_name=BUCKET, key=s3_join(observer.dir, "run.json") ).decode("utf-8") ) assert run["status"] == "RUNNING" observer.failed_event(T2, "Everything imaginable went wrong") run = json.loads( _get_file_data( - bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + bucket_name=BUCKET, key=s3_join(observer.dir, "run.json") ).decode("utf-8") ) assert run["status"] == "FAILED" @@ -209,7 +212,7 @@ def test_queued_event_updates_run_json(observer, sample_run): observer.queued_event(**sample_run) run = json.loads( _get_file_data( - bucket_name=BUCKET, key=os.path.join(observer.dir, "run.json") + bucket_name=BUCKET, key=s3_join(observer.dir, "run.json") ).decode("utf-8") ) assert run["status"] == "QUEUED" @@ -221,10 +224,10 @@ def test_artifact_event_works(observer, sample_run, tmpfile): observer.artifact_event("test_artifact.py", tmpfile.name) assert _key_exists( - bucket_name=BUCKET, key=os.path.join(observer.dir, "test_artifact.py") + bucket_name=BUCKET, key=s3_join(observer.dir, "test_artifact.py") ) artifact_data = _get_file_data( - bucket_name=BUCKET, key=os.path.join(observer.dir, "test_artifact.py") + bucket_name=BUCKET, key=s3_join(observer.dir, "test_artifact.py") ).decode("utf-8") assert artifact_data == tmpfile.content From a6b3415c4f8e652fe015bd492b48a47b1b5eef35 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:38:24 -0700 Subject: [PATCH 31/45] reformat for black modulo mongo comma issue --- tests/test_observers/test_s3_observer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 874b15a5..81945e3b 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -26,6 +26,7 @@ def s3_join(*args): return "/".join(args) + @pytest.fixture() def sample_run(): exp = {"name": "test_exp", "sources": [], "doc": "", "base_dir": "/tmp"} @@ -103,9 +104,7 @@ def test_fs_observer_started_event_creates_bucket(observer, sample_run): assert _key_exists(bucket_name=BUCKET, key=s3_join(run_dir, "cout.txt")) assert _key_exists(bucket_name=BUCKET, key=s3_join(run_dir, "config.json")) assert _key_exists(bucket_name=BUCKET, key=s3_join(run_dir, "run.json")) - config = _get_file_data( - bucket_name=BUCKET, key=s3_join(run_dir, "config.json") - ) + config = _get_file_data(bucket_name=BUCKET, key=s3_join(run_dir, "config.json")) assert json.loads(config.decode("utf-8")) == sample_run["config"] run = _get_file_data(bucket_name=BUCKET, key=s3_join(run_dir, "run.json")) From fb07ac229590f6b93aefcf676753105370dc79dc Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 16:46:17 -0700 Subject: [PATCH 32/45] fix s3_join calls to not pass in list --- sacred/observers/s3_observer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index 0fc1b1d3..d6e5e86c 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -40,7 +40,7 @@ def _is_valid_bucket(bucket_name): return True -def s3_join(**args): +def s3_join(*args): return "/".join(args) @@ -186,7 +186,7 @@ def _determine_run_dir(self, _id): _id = max_run_id + 1 - self.dir = s3_join([self.basedir, str(_id)]) + self.dir = s3_join(self.basedir, str(_id)) if self._objects_exist_in_dir(self.dir): raise FileExistsError("S3 dir at {} already exists".format(self.dir)) return _id @@ -257,7 +257,7 @@ def find_or_save(self, filename, store_dir): source_name, ext = os.path.splitext(os.path.basename(filename)) md5sum = get_digest(filename) store_name = source_name + "_" + md5sum + ext - store_path = s3_join([store_dir, store_name]) + store_path = s3_join(store_dir, store_name) if len(self._list_s3_subdirs(prefix=store_path)) == 0: self.save_file(filename, store_path) return store_path, md5sum @@ -266,12 +266,12 @@ def put_data(self, key, binary_data): self.s3.Object(self.bucket, key).put(Body=binary_data) def save_json(self, obj, filename): - key = s3_join([self.dir, filename]) + key = s3_join(self.dir, filename) self.put_data(key, json.dumps(flatten(obj), sort_keys=True, indent=2)) def save_file(self, filename, target_name=None): target_name = target_name or os.path.basename(filename) - key = s3_join([self.dir, target_name]) + key = s3_join(self.dir, target_name) self.put_data(key, open(filename, "rb")) def save_directory(self, source_dir, target_name): @@ -285,7 +285,7 @@ def save_directory(self, source_dir, target_name): for filename in all_files: file_location = s3_join( - [self.dir, target_name, os.path.relpath(filename, source_dir)] + self.dir, target_name, os.path.relpath(filename, source_dir) ) s3_resource.Object(self.bucket, file_location).put( Body=open(filename, "rb") @@ -293,7 +293,7 @@ def save_directory(self, source_dir, target_name): def save_cout(self): binary_data = self.cout[self.cout_write_cursor :].encode("utf-8") - key = s3_join([self.dir, "cout.txt"]) + key = s3_join(self.dir, "cout.txt") self.put_data(key, binary_data) self.cout_write_cursor = len(self.cout) From df571a48a6540a6efbc2e22695cdd2bef27a310e Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 17:11:10 -0700 Subject: [PATCH 33/45] changes to mongo --- sacred/observers/mongo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sacred/observers/mongo.py b/sacred/observers/mongo.py index 1c5ace3a..940d35c6 100644 --- a/sacred/observers/mongo.py +++ b/sacred/observers/mongo.py @@ -558,7 +558,7 @@ def create( overwrite=None, priority=DEFAULT_MONGO_PRIORITY, client=None, - **kwargs, + **kwargs ): return cls( QueueCompatibleMongoObserver.create( @@ -568,7 +568,7 @@ def create( overwrite=overwrite, priority=priority, client=client, - **kwargs, + **kwargs ), interval=interval, retry_interval=retry_interval, From 3d1c48f4ce1c8dc859f8e83687a5c3f5dd683849 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 17:40:50 -0700 Subject: [PATCH 34/45] fix py35 issues --- sacred/observers/telegram_obs.py | 2 +- tests/test_observers/failing_mongo_mock.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sacred/observers/telegram_obs.py b/sacred/observers/telegram_obs.py index e48b9420..c216e1fc 100644 --- a/sacred/observers/telegram_obs.py +++ b/sacred/observers/telegram_obs.py @@ -80,7 +80,7 @@ def __init__( chat_id, silent_completion=False, priority=DEFAULT_TELEGRAM_PRIORITY, - **kwargs, + **kwargs ): self.silent_completion = silent_completion self.chat_id = chat_id diff --git a/tests/test_observers/failing_mongo_mock.py b/tests/test_observers/failing_mongo_mock.py index 6695e2e4..c0bdb541 100644 --- a/tests/test_observers/failing_mongo_mock.py +++ b/tests/test_observers/failing_mongo_mock.py @@ -8,7 +8,7 @@ def __init__( self, max_calls_before_failure=2, exception_to_raise=pymongo.errors.AutoReconnect, - **kwargs, + **kwargs ): super().__init__(**kwargs) self._max_calls_before_failure = max_calls_before_failure From f5b10a47fee0e3ca239df201606f257cc6dcf555 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 17:48:05 -0700 Subject: [PATCH 35/45] sync py35 error cases with master --- sacred/observers/mongo.py | 4 ++-- sacred/observers/sql.py | 2 +- sacred/observers/telegram_obs.py | 2 +- tests/test_observers/failing_mongo_mock.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sacred/observers/mongo.py b/sacred/observers/mongo.py index 940d35c6..70f047db 100644 --- a/sacred/observers/mongo.py +++ b/sacred/observers/mongo.py @@ -568,8 +568,8 @@ def create( overwrite=overwrite, priority=priority, client=client, - **kwargs + **kwargs, ), interval=interval, retry_interval=retry_interval, - ) + ) \ No newline at end of file diff --git a/sacred/observers/sql.py b/sacred/observers/sql.py index b97810d6..2438455e 100644 --- a/sacred/observers/sql.py +++ b/sacred/observers/sql.py @@ -149,4 +149,4 @@ class SqlOption(CommandLineOption): @classmethod def apply(cls, args, run): - run.observers.append(SqlObserver.create(args)) + run.observers.append(SqlObserver.create(args)) \ No newline at end of file diff --git a/sacred/observers/telegram_obs.py b/sacred/observers/telegram_obs.py index c216e1fc..3ada664d 100644 --- a/sacred/observers/telegram_obs.py +++ b/sacred/observers/telegram_obs.py @@ -213,4 +213,4 @@ def failed_event(self, fail_time, fail_trace): ) except Exception as e: log = logging.getLogger("telegram-observer") - log.warning("failed to send failed_event message via telegram.", exc_info=e) + log.warning("failed to send failed_event message via telegram.", exc_info=e) \ No newline at end of file diff --git a/tests/test_observers/failing_mongo_mock.py b/tests/test_observers/failing_mongo_mock.py index c0bdb541..9e92c54e 100644 --- a/tests/test_observers/failing_mongo_mock.py +++ b/tests/test_observers/failing_mongo_mock.py @@ -141,4 +141,4 @@ def _is_in_failure_range(self): self._max_calls_before_failure < self._calls <= self._max_calls_before_reconnect - ) + ) \ No newline at end of file From bcd1904b75df6e1f866072f1872ebba3fd7bab44 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 17:52:40 -0700 Subject: [PATCH 36/45] add new lines at end of files --- sacred/observers/mongo.py | 2 +- sacred/observers/sql.py | 2 +- sacred/observers/telegram_obs.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sacred/observers/mongo.py b/sacred/observers/mongo.py index 70f047db..3b95f70b 100644 --- a/sacred/observers/mongo.py +++ b/sacred/observers/mongo.py @@ -572,4 +572,4 @@ def create( ), interval=interval, retry_interval=retry_interval, - ) \ No newline at end of file + ) diff --git a/sacred/observers/sql.py b/sacred/observers/sql.py index 2438455e..b97810d6 100644 --- a/sacred/observers/sql.py +++ b/sacred/observers/sql.py @@ -149,4 +149,4 @@ class SqlOption(CommandLineOption): @classmethod def apply(cls, args, run): - run.observers.append(SqlObserver.create(args)) \ No newline at end of file + run.observers.append(SqlObserver.create(args)) diff --git a/sacred/observers/telegram_obs.py b/sacred/observers/telegram_obs.py index 3ada664d..c216e1fc 100644 --- a/sacred/observers/telegram_obs.py +++ b/sacred/observers/telegram_obs.py @@ -213,4 +213,4 @@ def failed_event(self, fail_time, fail_trace): ) except Exception as e: log = logging.getLogger("telegram-observer") - log.warning("failed to send failed_event message via telegram.", exc_info=e) \ No newline at end of file + log.warning("failed to send failed_event message via telegram.", exc_info=e) From e63720d6a9554828f876f2bcec59aedc6b2a5f4f Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 18:06:19 -0700 Subject: [PATCH 37/45] sync offending black files with master --- sacred/experiment.py | 5 ++--- tests/test_observers/failing_mongo_mock.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sacred/experiment.py b/sacred/experiment.py index 47fe8807..ef68ec81 100755 --- a/sacred/experiment.py +++ b/sacred/experiment.py @@ -524,9 +524,8 @@ def _create_run( def _check_command(self, cmd_name): commands = dict(self.gather_commands()) if cmd_name is not None and cmd_name not in commands: - return ( - 'Error: Command "{}" not found. Available commands are: ' - "{}".format(cmd_name, ", ".join(commands.keys())) + return 'Error: Command "{}" not found. Available commands are: ' "{}".format( + cmd_name, ", ".join(commands.keys()) ) if cmd_name is None: diff --git a/tests/test_observers/failing_mongo_mock.py b/tests/test_observers/failing_mongo_mock.py index 9e92c54e..c0bdb541 100644 --- a/tests/test_observers/failing_mongo_mock.py +++ b/tests/test_observers/failing_mongo_mock.py @@ -141,4 +141,4 @@ def _is_in_failure_range(self): self._max_calls_before_failure < self._calls <= self._max_calls_before_reconnect - ) \ No newline at end of file + ) From d83a2903601be1d178dbbb2b036fae62edf1b9a1 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 18:09:31 -0700 Subject: [PATCH 38/45] remove final newline --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4a4b9385..226e1102 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,4 +15,4 @@ exclude = ''' | build | dist )/ -) +) \ No newline at end of file From 558a79c31805ff2d7734a729ddfe8feb44f04c0d Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Tue, 20 Aug 2019 18:10:26 -0700 Subject: [PATCH 39/45] add end quotations to pyproject.toml --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 226e1102..b54bec17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,4 +15,5 @@ exclude = ''' | build | dist )/ -) \ No newline at end of file +) +''' \ No newline at end of file From 6f99985dc7ae0729bc1ecaa4494e205993495777 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Wed, 21 Aug 2019 12:17:40 -0700 Subject: [PATCH 40/45] remove Z from test name --- tests/test_observers/test_s3_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 81945e3b..7dae744a 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -143,7 +143,7 @@ def test_s3_observer_equality(): @mock_s3 -def test_z_raises_error_on_duplicate_id_directory(observer, sample_run): +def test_raises_error_on_duplicate_id_directory(observer, sample_run): observer.started_event(**sample_run) sample_run["_id"] = 1 with pytest.raises(FileExistsError): From 365d2d19f5df2320ee23ee40e088fe00169ad538 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Wed, 21 Aug 2019 12:24:36 -0700 Subject: [PATCH 41/45] fix string line breaks --- sacred/observers/s3_observer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index d6e5e86c..c18bfc97 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -91,7 +91,7 @@ def __init__( ): if not _is_valid_bucket(bucket): raise ValueError( - "Your chosen bucket name does not follow AWS " "bucket naming rules" + "Your chosen bucket name doesn't follow AWS bucket naming rules" ) self.basedir = basedir @@ -118,9 +118,8 @@ def __init__( self.s3 = boto3.resource("s3") else: raise ValueError( - "You must either pass in an AWS region name," - " or have a region name specified in your" - " AWS config file" + "You must either pass in an AWS region name, or have a " + "region name specified in your AWS config file" ) def _objects_exist_in_dir(self, prefix): From 548d7c16e771a052f8bb130085aedf9d42ab19c0 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Wed, 21 Aug 2019 12:32:05 -0700 Subject: [PATCH 42/45] remove create method and update docs accordingly --- docs/observers.rst | 8 +++++--- sacred/observers/s3_observer.py | 23 ++++++----------------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/docs/observers.rst b/docs/observers.rst index 22556cc0..b4cb6dfa 100644 --- a/docs/observers.rst +++ b/docs/observers.rst @@ -20,7 +20,9 @@ At the moment there are five observers that are shipped with Sacred: to store run information in a JSON file. * The :ref:`sql_observer` connects to any SQL database and will store the relevant information there. - * The :ref:`s3_observer` stores run information within an AWS S3 bucket + * The :ref:`s3_observer` stores run information in an AWS S3 bucket, within + some specified prefix/directory + But if you want the run information stored some other way, it is easy to write your own :ref:`custom_observer`. @@ -610,8 +612,8 @@ To create an S3Observer in Python: .. code-block:: python from sacred.observers import S3Observer - ex.observers.append(S3Observer.create(bucket='my-awesome-bucket', - basedir='/my-project/my-cool-experiment/')) + ex.observers.append(S3Observer(bucket='my-awesome-bucket', + basedir='/my-project/my-cool-experiment/')) By default, an S3Observer will use the region that is set in your AWS config file, but if you'd prefer to pass in a specific region, you can use the ``region`` parameter of create to do so. diff --git a/sacred/observers/s3_observer.py b/sacred/observers/s3_observer.py index c18bfc97..0df89f7e 100644 --- a/sacred/observers/s3_observer.py +++ b/sacred/observers/s3_observer.py @@ -47,9 +47,8 @@ def s3_join(*args): class S3Observer(RunObserver): VERSION = "S3Observer-0.1.0" - @classmethod - def create( - cls, + def __init__( + self, bucket, basedir, resource_dir=None, @@ -58,7 +57,8 @@ def create( region=None, ): """ - A factory method to create a S3Observer object + Constructor for a S3Observer object. This is run when you + first create the object, before it's used within an experiment. :param bucket: The name of the bucket you want to store results in. Doesn't need to contain `s3://`, but needs to be a valid bucket name @@ -75,24 +75,13 @@ def create( config file. :return: """ - resource_dir = resource_dir or "/".join([basedir, "_resources"]) - source_dir = source_dir or "/".join([basedir, "_sources"]) - - return cls(bucket, basedir, resource_dir, source_dir, priority, region) - def __init__( - self, - bucket, - basedir, - resource_dir, - source_dir, - priority=DEFAULT_S3_PRIORITY, - region=None, - ): if not _is_valid_bucket(bucket): raise ValueError( "Your chosen bucket name doesn't follow AWS bucket naming rules" ) + resource_dir = resource_dir or "/".join([basedir, "_resources"]) + source_dir = source_dir or "/".join([basedir, "_sources"]) self.basedir = basedir self.bucket = bucket From 00d1f34bd02bf8c8cdab325b4e5e2abb7bbdd876 Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Wed, 21 Aug 2019 12:32:46 -0700 Subject: [PATCH 43/45] remove create method from s3 tests --- tests/test_observers/test_s3_observer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 7dae744a..5b451ca5 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -47,7 +47,7 @@ def sample_run(): @pytest.fixture def observer(): - return S3Observer.create(bucket=BUCKET, basedir=BASEDIR, region=REGION) + return S3Observer(bucket=BUCKET, basedir=BASEDIR, region=REGION) @pytest.fixture @@ -129,12 +129,12 @@ def test_fs_observer_started_event_increments_run_id(observer, sample_run): def test_s3_observer_equality(): - obs_one = S3Observer.create(bucket=BUCKET, basedir=BASEDIR, region=REGION) - obs_two = S3Observer.create(bucket=BUCKET, basedir=BASEDIR, region=REGION) - different_basedir = S3Observer.create( + obs_one = S3Observer(bucket=BUCKET, basedir=BASEDIR, region=REGION) + obs_two = S3Observer(bucket=BUCKET, basedir=BASEDIR, region=REGION) + different_basedir = S3Observer( bucket=BUCKET, basedir="another/dir", region=REGION ) - different_bucket = S3Observer.create( + different_bucket = S3Observer( bucket="other-bucket", basedir=BASEDIR, region=REGION ) assert obs_one == obs_two @@ -244,6 +244,6 @@ def test_artifact_event_works(observer, sample_run, tmpfile): def test_raises_error_on_invalid_bucket_name(bucket_name, should_raise): if should_raise: with pytest.raises(ValueError): - _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR, region=REGION) + _ = S3Observer(bucket=bucket_name, basedir=BASEDIR, region=REGION) else: - _ = S3Observer.create(bucket=bucket_name, basedir=BASEDIR, region=REGION) + _ = S3Observer(bucket=bucket_name, basedir=BASEDIR, region=REGION) From 22ce54fb685c03e80f232aea1090cbf0283f880b Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Wed, 21 Aug 2019 12:39:39 -0700 Subject: [PATCH 44/45] clean up docs --- docs/observers.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/observers.rst b/docs/observers.rst index b4cb6dfa..cfe135a1 100644 --- a/docs/observers.rst +++ b/docs/observers.rst @@ -21,7 +21,7 @@ At the moment there are five observers that are shipped with Sacred: * The :ref:`sql_observer` connects to any SQL database and will store the relevant information there. * The :ref:`s3_observer` stores run information in an AWS S3 bucket, within - some specified prefix/directory + a given prefix/directory But if you want the run information stored some other way, it is easy to write From 3d52c2fc634b303f70323e0f01ff65f38e8e691a Mon Sep 17 00:00:00 2001 From: Cody Wild Date: Wed, 21 Aug 2019 12:56:45 -0700 Subject: [PATCH 45/45] fix black issue in test_s3_observer, worried this will annoy flake8 --- tests/test_observers/test_s3_observer.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/test_observers/test_s3_observer.py b/tests/test_observers/test_s3_observer.py index 5b451ca5..19cfe284 100644 --- a/tests/test_observers/test_s3_observer.py +++ b/tests/test_observers/test_s3_observer.py @@ -131,12 +131,8 @@ def test_fs_observer_started_event_increments_run_id(observer, sample_run): def test_s3_observer_equality(): obs_one = S3Observer(bucket=BUCKET, basedir=BASEDIR, region=REGION) obs_two = S3Observer(bucket=BUCKET, basedir=BASEDIR, region=REGION) - different_basedir = S3Observer( - bucket=BUCKET, basedir="another/dir", region=REGION - ) - different_bucket = S3Observer( - bucket="other-bucket", basedir=BASEDIR, region=REGION - ) + different_basedir = S3Observer(bucket=BUCKET, basedir="another/dir", region=REGION) + different_bucket = S3Observer(bucket="other-bucket", basedir=BASEDIR, region=REGION) assert obs_one == obs_two assert obs_one != different_basedir assert obs_one != different_bucket