diff --git a/selene_sdk/utils/config_utils.py b/selene_sdk/utils/config_utils.py index b7546583..1c88e538 100644 --- a/selene_sdk/utils/config_utils.py +++ b/selene_sdk/utils/config_utils.py @@ -9,13 +9,16 @@ from time import strftime import types import random +import shutil, yaml import numpy as np import torch from . import _is_lua_trained_model from . import instantiate +from . import load_path +from selene_sdk import version def class_instantiate(classobj): """Not used currently, but might be useful later for recursive @@ -111,6 +114,7 @@ def initialize_model(model_configs, train=True, lr=None): module = None if os.path.isdir(import_model_from): + import_model_from = import_model_from.rstrip(os.sep) module = module_from_dir(import_model_from) else: module = module_from_file(import_model_from) @@ -260,9 +264,11 @@ def parse_configs_and_run(configs, Parameters ---------- - configs : dict - The dictionary of nested configuration parameters. Will look - for the following top-level parameters: + configs : str or dict + If it is a str, then configs is the name of the configuration YAML file, from which we will read + nested configuration parameters. + If it is a dict, then configs is a dict storing nested configuration parameters. + Will look for the following top-level parameters: * `ops`: A list of 1 or more of the values \ {"train", "evaluate", "analyze"}. The operations specified\ @@ -305,8 +311,19 @@ def parse_configs_and_run(configs, to the dirs specified in each operation's configuration. """ + if isinstance(configs, str): + configs_file = configs + if not os.path.isfile(configs_file): + print("The configuration YAML file {} does not exist!".format(configs_file)) + return + configs = load_path(configs_file, instantiate=False) operations = configs["ops"] + #print selene_sdk version + if "selene_sdk_version" not in configs: + configs["selene_sdk_version"] = version.__version__ + print("Running with selene_sdk version {0}".format(version.__version__)) + if "train" in operations and "lr" not in configs and lr != None: configs["lr"] = float(lr) elif "train" in operations and "lr" in configs and lr != None: @@ -331,8 +348,9 @@ def parse_configs_and_run(configs, if "create_subdirectory" in configs: create_subdirectory = configs["create_subdirectory"] if create_subdirectory: + rand_str = str(random.random())[2:] current_run_output_dir = os.path.join( - current_run_output_dir, strftime("%Y-%m-%d-%H-%M-%S")) + current_run_output_dir, '{}-{}'.format(strftime("%Y-%m-%d-%H-%M-%S"), rand_str)) os.makedirs(current_run_output_dir) print("Outputs and logs saved to {0}".format( current_run_output_dir)) @@ -343,9 +361,25 @@ def parse_configs_and_run(configs, np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + #torch.backends.cudnn.deterministic = True + #torch.backends.cudnn.benchmark = False print("Setting random seed = {0}".format(seed)) else: print("Warning: no random seed specified in config file. " "Using a random seed ensures results are reproducible.") + if current_run_output_dir: + # write configs to output directory + with open('{}/{}'.format(current_run_output_dir,'configs.yaml'), 'w') as f: + yaml.dump(configs, f, default_flow_style=None) + # copy model file or directory to output + model_input = configs['model']['path'] + if os.path.isdir(model_input): # copy the directory + shutil.copytree (model_input, + os.path.join(current_run_output_dir, os.path.basename(import_model_from)), + dirs_exist_ok=True) + else: + shutil.copy (model_input, current_run_output_dir) + + execute(operations, configs, current_run_output_dir) diff --git a/setup.py b/setup.py index 96546bd3..e2d8bd4f 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ "scipy", "seaborn", "statsmodels", - "torch>=0.4.1, <=1.9", + "torch>=0.4.1, <=1.11", ], entry_points={ 'console_scripts': [ diff --git a/tutorials/getting_started_with_selene/simple_train.yml b/tutorials/getting_started_with_selene/simple_train.yml index 4140c61d..651e0d40 100644 --- a/tutorials/getting_started_with_selene/simple_train.yml +++ b/tutorials/getting_started_with_selene/simple_train.yml @@ -47,6 +47,6 @@ train_model: !obj:selene_sdk.TrainModel { } random_seed: 1447 output_dir: ./training_outputs -create_subdirectory: False +create_subdirectory: True load_test_set: False ...