From edcbcc329a3ce52b62a4533e16ebf97170344785 Mon Sep 17 00:00:00 2001 From: Sebastian Wagner-Carena Date: Fri, 12 Apr 2024 12:02:05 -0400 Subject: [PATCH 1/3] small changes to generation notebook. --- notebooks/GenerateImages.ipynb | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/notebooks/GenerateImages.ipynb b/notebooks/GenerateImages.ipynb index e831e82..fd7f327 100644 --- a/notebooks/GenerateImages.ipynb +++ b/notebooks/GenerateImages.ipynb @@ -49,6 +49,35 @@ "Let's start by importing a paltax configuration file and disecting some of its values." ] }, + { + "cell_type": "code", + "execution_count": 3, + "id": "035cb524-afa0-4904-9bb4-1ceb2d17e6d5", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "sh: python: command not found\n" + ] + }, + { + "data": { + "text/plain": [ + "32512" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import os\n", + "os.system('python')" + ] + }, { "cell_type": "code", "execution_count": null, From 36a23fefc972c75cd47b52554dcc83b5a9261e0d Mon Sep 17 00:00:00 2001 From: Sebastian Wagner-Carena Date: Fri, 12 Apr 2024 12:04:49 -0400 Subject: [PATCH 2/3] fixed use of FieldReference in config files. --- paltax/TrainConfigs/train_config_npe_base.py | 9 ++++----- paltax/TrainConfigs/train_config_snpe_base.py | 13 ++++++------- paltax/train_snpe_test.py | 13 ++++++------- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/paltax/TrainConfigs/train_config_npe_base.py b/paltax/TrainConfigs/train_config_npe_base.py index 3f7ee85..8479b6e 100644 --- a/paltax/TrainConfigs/train_config_npe_base.py +++ b/paltax/TrainConfigs/train_config_npe_base.py @@ -27,14 +27,13 @@ def get_config(): config.cache = False config.half_precision = False - steps_per_epoch = FieldReference(3900) - config.steps_per_epoch = steps_per_epoch - config.num_train_steps = 500 * steps_per_epoch - config.keep_every_n_steps = steps_per_epoch + config.steps_per_epoch = FieldReference(3900) + config.num_train_steps = 500 * config.get_ref('steps_per_epoch') + config.keep_every_n_steps = config.get_ref('steps_per_epoch') # Parameters of the learning rate schedule config.learning_rate = 0.01 config.schedule_function_type = 'cosine' - config.warmup_steps = 10 * steps_per_epoch + config.warmup_steps = 10 * config.get_ref('steps_per_epoch') return config diff --git a/paltax/TrainConfigs/train_config_snpe_base.py b/paltax/TrainConfigs/train_config_snpe_base.py index 36a3699..4179462 100644 --- a/paltax/TrainConfigs/train_config_snpe_base.py +++ b/paltax/TrainConfigs/train_config_snpe_base.py @@ -30,22 +30,21 @@ def get_config(): # Need to set the boundaries of how long the model will train generically # and when the sequential training will turn on. - steps_per_epoch = FieldReference(3900) - config.steps_per_epoch = steps_per_epoch # Assuming 4 GPUs - config.num_initial_train_steps = steps_per_epoch * 10 - config.num_steps_per_refinement = steps_per_epoch * 10 - config.num_train_steps = steps_per_epoch * 500 + config.steps_per_epoch = FieldReference(3900) # Assuming 4 GPUs + config.num_initial_train_steps = config.get_ref('steps_per_epoch') * 10 + config.num_steps_per_refinement = config.get_ref('steps_per_epoch') * 10 + config.num_train_steps = config.get_ref('steps_per_epoch') * 500 config.num_refinements = (( config.get_ref('num_train_steps') - config.get_ref('num_initial_train_steps')) // config.get_ref('num_steps_per_refinement')) # Decide how often to save the model in checkpoints. - config.keep_every_n_steps = steps_per_epoch + config.keep_every_n_steps = config.get_ref('steps_per_epoch') # Parameters of the learning rate schedule config.learning_rate = 0.01 - config.warmup_steps = 10 * steps_per_epoch + config.warmup_steps = 10 * config.get_ref('steps_per_epoch') config.refinement_base_value_multiplier = 1e-1 # Sequential prior and initial proposal diff --git a/paltax/train_snpe_test.py b/paltax/train_snpe_test.py index 05267f7..c8766b2 100644 --- a/paltax/train_snpe_test.py +++ b/paltax/train_snpe_test.py @@ -52,22 +52,21 @@ def _create_test_config(): config.half_precision = False # One step of training and one step of refinement. - steps_per_epoch = ml_collections.config_dict.FieldReference(1) - config.steps_per_epoch = steps_per_epoch - config.num_initial_train_steps = steps_per_epoch * 1 - config.num_steps_per_refinement = steps_per_epoch * 1 - config.num_train_steps = steps_per_epoch * 2 + config.steps_per_epoch = ml_collections.config_dict.FieldReference(1) + config.num_initial_train_steps = config.get_ref('steps_per_epoch') * 1 + config.num_steps_per_refinement = config.get_ref('steps_per_epoch') * 1 + config.num_train_steps = config.get_ref('steps_per_epoch') * 2 config.num_refinements = (( config.get_ref('num_train_steps') - config.get_ref('num_initial_train_steps')) // config.get_ref('num_steps_per_refinement')) # Decide how often to save the model in checkpoints. - config.keep_every_n_steps = steps_per_epoch + config.keep_every_n_steps = config.get_ref('steps_per_epoch') # Parameters of the learning rate schedule config.learning_rate = 0.01 - config.warmup_steps = 1 * steps_per_epoch + config.warmup_steps = 1 * config.get_ref('steps_per_epoch') config.refinement_base_value_multiplier = 0.5 config.mu_prior = jnp.zeros(5) From b29b19a5d2d1a28b4259f1b3af621525cc98a34c Mon Sep 17 00:00:00 2001 From: Sebastian Wagner-Carena Date: Fri, 12 Apr 2024 12:08:19 -0400 Subject: [PATCH 3/3] undid changes to notebook. --- notebooks/GenerateImages.ipynb | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/notebooks/GenerateImages.ipynb b/notebooks/GenerateImages.ipynb index fd7f327..f610ca9 100644 --- a/notebooks/GenerateImages.ipynb +++ b/notebooks/GenerateImages.ipynb @@ -51,28 +51,10 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "id": "035cb524-afa0-4904-9bb4-1ceb2d17e6d5", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "sh: python: command not found\n" - ] - }, - { - "data": { - "text/plain": [ - "32512" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "import os\n", "os.system('python')"