Skip to content

Commit

Permalink
Merge pull request #19 from swagnercarena/image_draw_notebook
Browse files Browse the repository at this point in the history
Image draw notebook
  • Loading branch information
swagnercarena authored Apr 19, 2024
2 parents 0a85258 + b29b19a commit ac49700
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
11 changes: 11 additions & 0 deletions notebooks/GenerateImages.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@
"Let's start by importing a paltax configuration file and disecting some of its values."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "035cb524-afa0-4904-9bb4-1ceb2d17e6d5",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.system('python')"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
9 changes: 4 additions & 5 deletions paltax/TrainConfigs/train_config_npe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ def get_config():
config.cache = False
config.half_precision = False

steps_per_epoch = FieldReference(3900)
config.steps_per_epoch = steps_per_epoch
config.num_train_steps = 500 * steps_per_epoch
config.keep_every_n_steps = steps_per_epoch
config.steps_per_epoch = FieldReference(3900)
config.num_train_steps = 500 * config.get_ref('steps_per_epoch')
config.keep_every_n_steps = config.get_ref('steps_per_epoch')

# Parameters of the learning rate schedule
config.learning_rate = 0.01
config.schedule_function_type = 'cosine'
config.warmup_steps = 10 * steps_per_epoch
config.warmup_steps = 10 * config.get_ref('steps_per_epoch')

return config
13 changes: 6 additions & 7 deletions paltax/TrainConfigs/train_config_snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,21 @@ def get_config():

# Need to set the boundaries of how long the model will train generically
# and when the sequential training will turn on.
steps_per_epoch = FieldReference(3900)
config.steps_per_epoch = steps_per_epoch # Assuming 4 GPUs
config.num_initial_train_steps = steps_per_epoch * 10
config.num_steps_per_refinement = steps_per_epoch * 10
config.num_train_steps = steps_per_epoch * 500
config.steps_per_epoch = FieldReference(3900) # Assuming 4 GPUs
config.num_initial_train_steps = config.get_ref('steps_per_epoch') * 10
config.num_steps_per_refinement = config.get_ref('steps_per_epoch') * 10
config.num_train_steps = config.get_ref('steps_per_epoch') * 500
config.num_refinements = ((
config.get_ref('num_train_steps') -
config.get_ref('num_initial_train_steps')) //
config.get_ref('num_steps_per_refinement'))

# Decide how often to save the model in checkpoints.
config.keep_every_n_steps = steps_per_epoch
config.keep_every_n_steps = config.get_ref('steps_per_epoch')

# Parameters of the learning rate schedule
config.learning_rate = 0.01
config.warmup_steps = 10 * steps_per_epoch
config.warmup_steps = 10 * config.get_ref('steps_per_epoch')
config.refinement_base_value_multiplier = 1e-1

# Sequential prior and initial proposal
Expand Down
13 changes: 6 additions & 7 deletions paltax/train_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,21 @@ def _create_test_config():
config.half_precision = False

# One step of training and one step of refinement.
steps_per_epoch = ml_collections.config_dict.FieldReference(1)
config.steps_per_epoch = steps_per_epoch
config.num_initial_train_steps = steps_per_epoch * 1
config.num_steps_per_refinement = steps_per_epoch * 1
config.num_train_steps = steps_per_epoch * 2
config.steps_per_epoch = ml_collections.config_dict.FieldReference(1)
config.num_initial_train_steps = config.get_ref('steps_per_epoch') * 1
config.num_steps_per_refinement = config.get_ref('steps_per_epoch') * 1
config.num_train_steps = config.get_ref('steps_per_epoch') * 2
config.num_refinements = ((
config.get_ref('num_train_steps') -
config.get_ref('num_initial_train_steps')) //
config.get_ref('num_steps_per_refinement'))

# Decide how often to save the model in checkpoints.
config.keep_every_n_steps = steps_per_epoch
config.keep_every_n_steps = config.get_ref('steps_per_epoch')

# Parameters of the learning rate schedule
config.learning_rate = 0.01
config.warmup_steps = 1 * steps_per_epoch
config.warmup_steps = 1 * config.get_ref('steps_per_epoch')
config.refinement_base_value_multiplier = 0.5

config.mu_prior = jnp.zeros(5)
Expand Down

0 comments on commit ac49700

Please sign in to comment.