Skip to content

Commit

Permalink
add spiral experiments poisson
Browse files Browse the repository at this point in the history
  • Loading branch information
bartvanerp committed Aug 24, 2024
1 parent 44b6e37 commit 491c96d
Showing 1 changed file with 70 additions and 63 deletions.
133 changes: 70 additions & 63 deletions experiments/spiral_experiments_poisson.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -214,77 +214,83 @@
}
],
"source": [
"function run_experiment(folder, ω, dimmid; epochs=20000, batch_size=256, lr=0.005f0, max_layers=30, N=1024, rng=Random.default_rng())\n",
" \n",
" # loop over dimensions\n",
" for dim in dimmid\n",
"function run_experiment(folder, ω, dimmid, runs; epochs=20000, batch_size=256, lr=0.005f0, max_layers=30, N=1024, rng=Random.default_rng())\n",
"\n",
" mkpath(folder * \"/dim_$(dim)\")\n",
" for run in 1:runs\n",
"\n",
" p = ProgressMeter.Progress(length(ω))\n",
" # set seed\n",
" Random.seed!(run)\n",
" \n",
" # loop over dimensions\n",
" for dim in dimmid\n",
"\n",
" # loop over difficulties\n",
" Threads.@threads for ωi in ω\n",
" mkpath(folder * \"/run_$(run)/dim_$(dim)\")\n",
"\n",
" # generate data\n",
" Random.seed!(ωi)\n",
" x_train, y_train = generate_spiral(N, ωi)\n",
" x_val, y_val = generate_spiral(N, ωi)\n",
" x_test, y_test = generate_spiral(N, ωi)\n",
" y_train_onehot = to_one_hot(y_train)\n",
" y_val_onehot = to_one_hot(y_val)\n",
" y_test_onehot = to_one_hot(y_test)\n",
" p = ProgressMeter.Progress(length(ω))\n",
"\n",
" # create model and optimiser\n",
" model = create_model(2 => 2, dim, max_layers=max_layers)\n",
" opt = create_optimiser(model, lr=lr)\n",
" \n",
" loss_train = zeros(epochs)\n",
" loss_val = zeros(epochs)\n",
" loss_test = zeros(epochs)\n",
" best_val = Inf\n",
" best_model = nothing\n",
" # loop over difficulties\n",
" Threads.@threads for ωi in ω\n",
"\n",
" # generate data\n",
" x_train, y_train = generate_spiral(N, ωi)\n",
" x_val, y_val = generate_spiral(N, ωi)\n",
" x_test, y_test = generate_spiral(N, ωi)\n",
" y_train_onehot = to_one_hot(y_train)\n",
" y_val_onehot = to_one_hot(y_val)\n",
" y_test_onehot = to_one_hot(y_test)\n",
"\n",
" # create model and optimiser\n",
" model = create_model(2 => 2, dim, max_layers=max_layers)\n",
" opt = create_optimiser(model, lr=lr)\n",
" \n",
" loss_train = zeros(epochs)\n",
" loss_val = zeros(epochs)\n",
" loss_test = zeros(epochs)\n",
" best_val = Inf\n",
" best_model = nothing\n",
"\n",
" for e in 1:epochs\n",
" for n in Iterators.partition(randperm(rng, N), batch_size)\n",
" _, gs = Zygote.withgradient(m -> loss(y_train_onehot[:,n], x_train[:,n], m; batch_prop = length(n)/N), model)\n",
" opt, model = Optimisers.update!(opt, model, gs[1])\n",
" end\n",
" loss_train[e] = loss(y_train_onehot, x_train, model)\n",
" loss_val[e] = loss(y_val_onehot, x_val, model)\n",
" loss_test[e] = loss(y_test_onehot, x_test, model)\n",
"\n",
" if loss_val[e] < best_val\n",
" best_val = loss_val[e]\n",
" best_model = model\n",
" end\n",
"\n",
" for e in 1:epochs\n",
" for n in Iterators.partition(randperm(rng, N), batch_size)\n",
" _, gs = Zygote.withgradient(m -> loss(y_train_onehot[:,n], x_train[:,n], m; batch_prop = length(n)/N), model)\n",
" opt, model = Optimisers.update!(opt, model, gs[1])\n",
" end\n",
" loss_train[e] = loss(y_train_onehot, x_train, model)\n",
" loss_val[e] = loss(y_val_onehot, x_val, model)\n",
" loss_test[e] = loss(y_test_onehot, x_test, model)\n",
"\n",
" if loss_val[e] < best_val\n",
" best_val = loss_val[e]\n",
" best_model = model\n",
" #save results\n",
" jldopen(folder * \"/run_$(run)/dim_$(dim)/dim_$(dim)_omega_$(ωi).jld2\", \"w\") do file\n",
" file[\"data/x_train\"] = x_train\n",
" file[\"data/y_train\"] = y_train\n",
" file[\"data/x_val\"] = x_val\n",
" file[\"data/y_val\"] = y_val\n",
" file[\"data/x_test\"] = x_test\n",
" file[\"data/y_test\"] = y_test\n",
" file[\"model\"] = best_model\n",
" file[\"results/loss_train\"] = loss_train\n",
" file[\"results/loss_val\"] = loss_val\n",
" file[\"results/loss_test\"] = loss_test\n",
" file[\"results/predictions_train\"] = predict(best_model, x_train)\n",
" file[\"results/predictions_val\"] = predict(best_model, x_val)\n",
" file[\"results/predictions_test\"] = predict(best_model, x_test)\n",
" file[\"results/accuracy_train\"] = accuracy(y_train, predict(best_model, x_train))\n",
" file[\"results/accuracy_val\"] = accuracy(y_val, predict(best_model, x_val))\n",
" file[\"results/accuracy_test\"] = accuracy(y_test, predict(best_model, x_test))\n",
" file[\"results/posterior\"] = UnboundedBNN.transform(best_model.posterior)\n",
" file[\"results/prior\"] = best_model.prior\n",
" end\n",
"\n",
" end\n",
" ProgressMeter.next!(p)\n",
"\n",
" #save results\n",
" jldopen(\"$folder/dim_$(dim)//dim_$(dim)_omega_$(ωi).jld2\", \"w\") do file\n",
" file[\"data/x_train\"] = x_train\n",
" file[\"data/y_train\"] = y_train\n",
" file[\"data/x_val\"] = x_val\n",
" file[\"data/y_val\"] = y_val\n",
" file[\"data/x_test\"] = x_test\n",
" file[\"data/y_test\"] = y_test\n",
" file[\"model\"] = best_model\n",
" file[\"results/loss_train\"] = loss_train\n",
" file[\"results/loss_val\"] = loss_val\n",
" file[\"results/loss_test\"] = loss_test\n",
" file[\"results/predictions_train\"] = predict(best_model, x_train)\n",
" file[\"results/predictions_val\"] = predict(best_model, x_val)\n",
" file[\"results/predictions_test\"] = predict(best_model, x_test)\n",
" file[\"results/accuracy_train\"] = accuracy(y_train, predict(best_model, x_train))\n",
" file[\"results/accuracy_val\"] = accuracy(y_val, predict(best_model, x_val))\n",
" file[\"results/accuracy_test\"] = accuracy(y_test, predict(best_model, x_test))\n",
" file[\"results/posterior\"] = UnboundedBNN.transform(best_model.posterior)\n",
" file[\"results/prior\"] = best_model.prior\n",
" end\n",
"\n",
" ProgressMeter.next!(p)\n",
"\n",
" end\n",
"\n",
" end\n",
Expand All @@ -301,15 +307,16 @@
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 2:12:23\u001b[39m\u001b[K\n",
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 1:50:50\u001b[39m\u001b[K\u001b[K\n",
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:59:58\u001b[39m\u001b[K\u001b[K\u001b[K\u001b[K\u001b[K\n",
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 1:04:46\u001b[39m\u001b[K\n"
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 1:14:22\u001b[39m\u001b[K\n",
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 1:09:22\u001b[39m\u001b[K\u001b[K\n",
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 1:24:58\u001b[39m\u001b[K\n",
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 1:11:08\u001b[39m\u001b[K\n",
"\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 1:10:55\u001b[39m\u001b[K\n"
]
}
],
"source": [
"run_experiment(\"data/spiral/poisson_without_kl\", 0:30, (4, 8, 16, 32))"
"run_experiment(\"data/spiral/poisson\", 0:30, 32, 5)"
]
}
],
Expand Down

0 comments on commit 491c96d

Please sign in to comment.