Skip to content

Commit

Permalink
Ensure no execution of slow notebooks.
Browse files Browse the repository at this point in the history
Fixed dpsgd notebook.

PiperOrigin-RevId: 605543405
  • Loading branch information
vroulet authored and OptaxDev committed Feb 9, 2024
1 parent a6f30f2 commit 7b2b1fd
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
6 changes: 4 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,10 @@ def new_process_docstring(app, what, name, obj, options, lines):
nb_execution_allow_errors = False
nb_execution_excludepatterns = [
# slow examples
'_collections/examples/cifar10_resnet.ipynb'
'_collections/examples/adversarial_training.ipynb'
'cifar10_resnet.ipynb',
'adversarial_training.ipynb',
'reduce_on_plateau.ipynb',
'differentially_private_sgd.ipynb'
]

# -- Options for katex ------------------------------------------------------
Expand Down
22 changes: 14 additions & 8 deletions examples/contrib/differentially_private_sgd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@
"outputs": [],
"source": [
"# @markdown Whether to use DP-SGD or vanilla SGD:\n",
"DPSGD = False # @param{type:\"boolean\"}\n",
"DPSGD = True # @param{type:\"boolean\"}\n",
"# @markdown Learning rate for the optimizer:\n",
"LEARNING_RATE = 0.1 # @param{type:\"number\"}\n",
"LEARNING_RATE = 0.25 # @param{type:\"number\"}\n",
"# @markdown Noise multiplier for DP-SGD optimizer:\n",
"NOISE_MULTIPLIER = 1.1 # @param{type:\"number\"}\n",
"NOISE_MULTIPLIER = 1.3 # @param{type:\"number\"}\n",
"# @markdown L2 norm clip:\n",
"L2_NORM_CLIP = 1.0 # @param{type:\"number\"}\n",
"L2_NORM_CLIP = 1.5 # @param{type:\"number\"}\n",
"# @markdown Number of samples in each batch:\n",
"BATCH_SIZE = 256 # @param{type:\"integer\"}\n",
"# @markdown Number of epochs:\n",
"NUM_EPOCHS = 20 # @param{type:\"integer\"}\n",
"NUM_EPOCHS = 15 # @param{type:\"integer\"}\n",
"# @markdown Probability of information leakage:\n",
"DELTA = 1e-5 # @param{type:\"number\"}"
]
Expand Down Expand Up @@ -351,6 +351,7 @@
" test_loss, test_acc = test_step(params, test_batch)\n",
" accuracy.append(test_acc)\n",
" loss.append(test_loss)\n",
" print(f\"Epoch {epoch + 1}/{NUM_EPOCHS}, test accuracy: {test_acc}\")\n",
"\n",
" #\n",
" if DPSGD:\n",
Expand Down Expand Up @@ -397,7 +398,10 @@
}
],
"source": [
"_, axs = plt.subplots(ncols=3 if DPSGD else 2)\n",
"if DPSGD:\n",
" _, axs = plt.subplots(ncols=3, figsize=(9, 3))\n",
"else:\n",
" _, axs = plt.subplots(ncols=2, figsize=(6, 3))\n",
"\n",
"axs[0].plot(accuracy)\n",
"axs[0].set_title(\"Test accuracy\")\n",
Expand All @@ -406,7 +410,9 @@
"\n",
"if DPSGD:\n",
" axs[2].plot(epsilon)\n",
" axs[2].set_title(\"Epsilon\")"
" axs[2].set_title(\"Epsilon\")\n",
"\n",
"plt.tight_layout()"
]
},
{
Expand Down Expand Up @@ -439,7 +445,7 @@
}
],
"source": [
"accuracy[-1]"
"print(f'Final accuracy: {accuracy[-1]}')"
]
}
],
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ test = [

examples = [
"tensorflow-datasets>=4.2.0",
"tensorflow>=2.4.0"
"tensorflow>=2.4.0",
"dp_accounting>=0.4"
]

docs = [
Expand Down

0 comments on commit 7b2b1fd

Please sign in to comment.