Skip to content

Commit 2cead94

Browse files
committed
📝 Add Examples
1 parent 71206da commit 2cead94

File tree

66 files changed

+58360
-20
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+58360
-20
lines changed

blades/aggregators/__init__.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
__all__ = [
88
"Mean",
9-
"Median",
10-
"Trimmedmean",
11-
"GeoMed",
12-
"DnC",
13-
"Clippedclustering",
14-
"Signguard",
15-
"Multikrum",
16-
"Centeredclipping",
9+
# "Median",
10+
# "Trimmedmean",
11+
# "GeoMed",
12+
# "DnC",
13+
# "Clippedclustering",
14+
# "Signguard",
15+
# "Multikrum",
16+
# "Centeredclipping",
1717
]

blades/aggregators/aggregators.py

+13
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,29 @@ def _median(inputs: List[torch.Tensor]):
1717

1818

1919
class Mean(object):
20+
"""Computes the ``sample mean`` over the updates from all give clients."""
21+
2022
def __call__(self, inputs: List[torch.Tensor]):
2123
return _mean(inputs)
2224

2325

2426
class Median(object):
27+
"""Partitioner that uses Dirichlet distribution to allocate samples to
28+
clients."""
29+
2530
def __call__(self, inputs: List[torch.Tensor]):
31+
"""A robust aggregator from paper `Byzantine-robust distributed
32+
learning:
33+
34+
Towards optimal statistical rates.<https://proceedings.mlr.press/v80/yin18a>`_.
35+
It computes the coordinate-wise median of the given set of clients
36+
"""
2637
return _median(inputs)
2738

2839

2940
class Trimmedmean(object):
41+
"""A robust aggregator."""
42+
3043
def __init__(self, num_byzantine: int, *, filter_frac=1.0):
3144
if filter_frac > 1.0 or filter_frac < 0.0:
3245
raise ValueError(f"filter_frac should be in [0.0, 1.0], got {filter_frac}.")

blades/aggregators/clippedclustering.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,14 @@
1010

1111

1212
class Clippedclustering(object):
13-
def __init__(self, agg="mean", max_tau=1e5, linkage="average") -> None:
14-
super(Clippedclustering, self).__init__()
15-
13+
"""Clipped clustering aggregator."""
14+
15+
# def __init__(self, agg="mean", max_tau=1e5, linkage="average") -> None:
16+
def __init__(self):
17+
agg = "mean"
18+
max_tau = 1e5
19+
linkage = "average"
20+
# def __init__(self, agg="mean", max_tau=1e5, linkage="average") -> None:
1621
assert linkage in ["average", "single"]
1722
self.tau = max_tau
1823
self.linkage = linkage

blades/aggregators/signguard.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,10 @@
99

1010

1111
class Signguard(object):
12-
r"""A robust aggregator from paper `Xu et al.
12+
"""A robust aggregator from paper `Xu et al. SignGuard: Byzantine-robust
13+
Federated Learning through Collaborative Malicious Gradient Filtering.
1314
14-
SignGuard: Byzantine-robust Federated
15-
Learning through Collaborative Malicious Gradient
16-
Filtering <https://arxiv.org/abs/2109.05872>`_.
15+
<https://arxiv.org/abs/2109.05872>`_.
1716
"""
1817

1918
def __init__(self, agg="mean", max_tau=1e5, linkage="average") -> None:

docs/requirements.txt

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
11
git+https://github.com/fedlib/blades_sphinx_theme.git
2+
lxml_html_clean
3+
m2r2
24
nbsphinx
5+
nbsphinx-link
36
numpy>=1.19.5
7+
sphinx==5.1.1
8+
sphinx_autodoc_typehints
9+
sphinx_gallery
8.13 KB
Binary file not shown.
5.57 KB
Binary file not shown.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Customize Attack\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"import ray\nfrom ray import tune\nfrom ray.tune.stopper import MaximumIterationStopper\n\nfrom blades.algorithms.fedavg import FedavgConfig, Fedavg\nfrom fedlib.trainers import TrainerConfig\n\n\nfrom fedlib.trainers import Trainer\nfrom fedlib.clients import ClientCallback\nfrom blades.adversaries import Adversary\n\n\nclass LabelFlipAdversary(Adversary):\n def on_trainer_init(self, trainer: Trainer):\n class LabelFlipCallback(ClientCallback):\n def on_train_batch_begin(self, data, target):\n return data, 10 - 1 - target\n\n for client in self.clients:\n client.to_malicious(callbacks_cls=LabelFlipCallback, local_training=True)\n\n\nclass ExampleFedavgConfig(FedavgConfig):\n def __init__(self, algo_class=None):\n \"\"\"Initializes a FedavgConfig instance.\"\"\"\n super().__init__(algo_class=algo_class or ExampleFedavg)\n\n self.dataset_config = {\n \"type\": \"FashionMNIST\",\n \"num_clients\": 10,\n \"train_batch_size\": 32,\n }\n self.global_model = \"cnn\"\n self.num_malicious_clients = 1\n self.adversary_config = {\"type\": LabelFlipAdversary}\n\n\nclass ExampleFedavg(Fedavg):\n @classmethod\n def get_default_config(cls) -> TrainerConfig:\n return ExampleFedavgConfig()\n\n\nif __name__ == \"__main__\":\n ray.init()\n\n config_dict = (\n ExampleFedavgConfig()\n .resources(\n num_gpus_for_driver=0.0,\n num_cpus_for_driver=1,\n num_remote_workers=0,\n num_gpus_per_worker=0.0,\n )\n .to_dict()\n )\n print(config_dict)\n tune.run(\n ExampleFedavg,\n config=config_dict,\n stop=MaximumIterationStopper(100),\n )"
19+
]
20+
}
21+
],
22+
"metadata": {
23+
"kernelspec": {
24+
"display_name": "Python 3",
25+
"language": "python",
26+
"name": "python3"
27+
},
28+
"language_info": {
29+
"codemirror_mode": {
30+
"name": "ipython",
31+
"version": 3
32+
},
33+
"file_extension": ".py",
34+
"mimetype": "text/x-python",
35+
"name": "python",
36+
"nbconvert_exporter": "python",
37+
"pygments_lexer": "ipython3",
38+
"version": "3.10.14"
39+
}
40+
},
41+
"nbformat": 4,
42+
"nbformat_minor": 0
43+
}
+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Customize Attack
3+
==================
4+
5+
"""
6+
7+
8+
import ray
9+
from ray import tune
10+
from ray.tune.stopper import MaximumIterationStopper
11+
12+
from blades.algorithms.fedavg import FedavgConfig, Fedavg
13+
from fedlib.trainers import TrainerConfig
14+
15+
16+
from fedlib.trainers import Trainer
17+
from fedlib.clients import ClientCallback
18+
from blades.adversaries import Adversary
19+
20+
21+
class LabelFlipAdversary(Adversary):
22+
def on_trainer_init(self, trainer: Trainer):
23+
class LabelFlipCallback(ClientCallback):
24+
def on_train_batch_begin(self, data, target):
25+
return data, 10 - 1 - target
26+
27+
for client in self.clients:
28+
client.to_malicious(callbacks_cls=LabelFlipCallback, local_training=True)
29+
30+
31+
class ExampleFedavgConfig(FedavgConfig):
32+
def __init__(self, algo_class=None):
33+
"""Initializes a FedavgConfig instance."""
34+
super().__init__(algo_class=algo_class or ExampleFedavg)
35+
36+
self.dataset_config = {
37+
"type": "FashionMNIST",
38+
"num_clients": 10,
39+
"train_batch_size": 32,
40+
}
41+
self.global_model = "cnn"
42+
self.num_malicious_clients = 1
43+
self.adversary_config = {"type": LabelFlipAdversary}
44+
45+
46+
class ExampleFedavg(Fedavg):
47+
@classmethod
48+
def get_default_config(cls) -> TrainerConfig:
49+
return ExampleFedavgConfig()
50+
51+
52+
if __name__ == "__main__":
53+
ray.init()
54+
55+
config_dict = (
56+
ExampleFedavgConfig()
57+
.resources(
58+
num_gpus_for_driver=0.0,
59+
num_cpus_for_driver=1,
60+
num_remote_workers=0,
61+
num_gpus_per_worker=0.0,
62+
)
63+
.to_dict()
64+
)
65+
print(config_dict)
66+
tune.run(
67+
ExampleFedavg,
68+
config=config_dict,
69+
stop=MaximumIterationStopper(100),
70+
)
+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
2+
.. DO NOT EDIT.
3+
.. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY.
4+
.. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE:
5+
.. "_examples/customize_attack.py"
6+
.. LINE NUMBERS ARE GIVEN BELOW.
7+
8+
.. only:: html
9+
10+
.. note::
11+
:class: sphx-glr-download-link-note
12+
13+
:ref:`Go to the end <sphx_glr_download__examples_customize_attack.py>`
14+
to download the full example code.
15+
16+
.. rst-class:: sphx-glr-example-title
17+
18+
.. _sphx_glr__examples_customize_attack.py:
19+
20+
21+
Customize Attack
22+
==================
23+
24+
.. GENERATED FROM PYTHON SOURCE LINES 6-71
25+
26+
.. code-block:: Python
27+
28+
29+
30+
import ray
31+
from ray import tune
32+
from ray.tune.stopper import MaximumIterationStopper
33+
34+
from blades.algorithms.fedavg import FedavgConfig, Fedavg
35+
from fedlib.trainers import TrainerConfig
36+
37+
38+
from fedlib.trainers import Trainer
39+
from fedlib.clients import ClientCallback
40+
from blades.adversaries import Adversary
41+
42+
43+
class LabelFlipAdversary(Adversary):
44+
def on_trainer_init(self, trainer: Trainer):
45+
class LabelFlipCallback(ClientCallback):
46+
def on_train_batch_begin(self, data, target):
47+
return data, 10 - 1 - target
48+
49+
for client in self.clients:
50+
client.to_malicious(callbacks_cls=LabelFlipCallback, local_training=True)
51+
52+
53+
class ExampleFedavgConfig(FedavgConfig):
54+
def __init__(self, algo_class=None):
55+
"""Initializes a FedavgConfig instance."""
56+
super().__init__(algo_class=algo_class or ExampleFedavg)
57+
58+
self.dataset_config = {
59+
"type": "FashionMNIST",
60+
"num_clients": 10,
61+
"train_batch_size": 32,
62+
}
63+
self.global_model = "cnn"
64+
self.num_malicious_clients = 1
65+
self.adversary_config = {"type": LabelFlipAdversary}
66+
67+
68+
class ExampleFedavg(Fedavg):
69+
@classmethod
70+
def get_default_config(cls) -> TrainerConfig:
71+
return ExampleFedavgConfig()
72+
73+
74+
if __name__ == "__main__":
75+
ray.init()
76+
77+
config_dict = (
78+
ExampleFedavgConfig()
79+
.resources(
80+
num_gpus_for_driver=0.0,
81+
num_cpus_for_driver=1,
82+
num_remote_workers=0,
83+
num_gpus_per_worker=0.0,
84+
)
85+
.to_dict()
86+
)
87+
print(config_dict)
88+
tune.run(
89+
ExampleFedavg,
90+
config=config_dict,
91+
stop=MaximumIterationStopper(100),
92+
)
93+
94+
95+
.. _sphx_glr_download__examples_customize_attack.py:
96+
97+
.. only:: html
98+
99+
.. container:: sphx-glr-footer sphx-glr-footer-example
100+
101+
.. container:: sphx-glr-download sphx-glr-download-jupyter
102+
103+
:download:`Download Jupyter notebook: customize_attack.ipynb <customize_attack.ipynb>`
104+
105+
.. container:: sphx-glr-download sphx-glr-download-python
106+
107+
:download:`Download Python source code: customize_attack.py <customize_attack.py>`
108+
109+
110+
.. only:: html
111+
112+
.. rst-class:: sphx-glr-signature
113+
114+
`Gallery generated by Sphinx-Gallery <https://sphinx-gallery.github.io>`_
Binary file not shown.
Loading
Loading
Loading
Loading

0 commit comments

Comments
 (0)