Skip to content

Commit

Permalink
TensorFlow and PyTorch generator (#105,#106)
Browse files Browse the repository at this point in the history
This PR adds data loaders which can be used with TensorFlow
and PyTorch to train (or fit) a model. The loaders
get the data from the Orchestrator and can be used
when training in a distributed fashion (e.g with Horovod).

[ committed by @al-rigazzi ]
[ reviewed by @Spartee and @MattToast ]
  • Loading branch information
al-rigazzi authored Feb 3, 2022
1 parent 649be17 commit e9210d4
Show file tree
Hide file tree
Showing 25 changed files with 2,941 additions and 12 deletions.
52 changes: 47 additions & 5 deletions doc/api/smartsim_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ in an interactive allocation.
Model
=====


.. currentmodule:: smartsim.entity.model

.. autosummary::
Expand Down Expand Up @@ -455,23 +454,66 @@ Ensemble
:inherited-members:


Machine Learning
================

.. _ml_api:

SmartSim includes built-in utilities for supporting TensorFlow, Keras, and Pytorch.

TensorFlow
==========
----------

.. _smartsim_tf_api:

SmartSim includes built-in utilities for supporting TensorFlow and Keras in SmartSim.
SmartSim includes built-in utilities for supporting TensorFlow and Keras in training and inference.

.. currentmodule:: smartsim.tf.utils
.. currentmodule:: smartsim.ml.tf.utils

.. autosummary::

freeze_model

.. automodule:: smartsim.tf.utils
.. automodule:: smartsim.ml.tf.utils
:members:


.. currentmodule:: smartsim.ml.tf.data

.. autoclass:: StaticDataGenerator
:show-inheritance:
:inherited-members:
:members:

.. autoclass:: DataGenerator
:members:
:show-inheritance:
:inherited-members:

PyTorch
----------

.. _smartsim_torch_api:

SmartSim includes built-in utilities for supporting PyTorch in training and inference.

.. currentmodule:: smartsim.ml.torch.data

.. autoclass:: StaticDataGenerator
:members:
:show-inheritance:
:inherited-members:

.. autoclass:: DataGenerator
:members:
:show-inheritance:
:inherited-members:

.. autoclass:: DataLoader
:members:
:show-inheritance:
:inherited-members:

Slurm
=====

Expand Down
3 changes: 2 additions & 1 deletion doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
tutorials/02_using_clients
tutorials/03_lattice_boltz_analysis
tutorials/04_inference
tutorials/05_starting_ray/05_starting_ray_builtin
tutorials/05_training
tutorials/06_starting_ray/06_starting_ray_builtin


.. toctree::
Expand Down
1 change: 1 addition & 0 deletions doc/tutorials/04_inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,4 @@ RandomForestRegressor. As with the other examples, the skl2onnx function
client.set_model("rf_regressor", model, "ONNX", device="CPU")
client.run_model("rf_regressor", inputs="input", outputs="output")
print(client.get_tensor("output"))
153 changes: 153 additions & 0 deletions doc/tutorials/05_training.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@

=============================
Online training with SmartSim
=============================

A SmartSim ``Orchestrator`` can be used to store and retrieve samples and targets used to
train a ML model. A typical example is one in which one simulation produces samples at
each time step and another application needs to download the samples as they are produced
to train a Deep Neural Network (e.g. a surrogate model).

In this section, we will use components implemented in ``smartsim.ml.tf.data``, to train a
Neural Network implemented in TensorFlow and Keras. In particular, we will be using
two classes:
- ``smartsim.ml.data.TrainingUploader`` which streamlines the uploading of samples and corresponding targets to the DB
- ``smartsim.ml.tf.data.DataGenerator`` which is a Keras ``Generator`` which can be used to train a DNN,
and will download the samples from the DB updating the training set at the end of each epoch.

The SmartSim ``Experiment`` will consist in one mock simulation (the ``producer``) uploading samples,
and one application (the ``training_service``) downloading the samples to train a DNN.

A richer example, entirely implemented in Python, is available as a Jupyter Notebook in the
``tutorials`` section of the SmartSim repository.
An equivalent example using PyTorch instead of TensorFlow is available in the same directory.


5.1.1 Producing and uploading the samples
-----------------------------------------

.. _ml_training_producer_code:

The first application in the workflow, the ``producer`` will upload batches of samples at regular intervals,
mimicking the behavior of an iterative simulation.

Since the ``training_service`` will use a ``smartsim.ml.tf.DataGenerator`` two download the samples, their
keys need to follow a pre-defined format. Assuming that only one process in the simulation
uploads the data, this format is ``<sample_prefix>_<iteration>``. And for targets
(which can also be integer labels), the key format is ``<target_prefix>_<iteration>``. Both ``<sample_prefix>``
and ``<target_prefix>`` are user-defined, and will need to be used to initialize the
``smartsim.ml.tf.DataGenerator`` object.

Assuming the simulation is written in Python, then the code would look like

.. code-block:: python
from SmartRedis import Client
# simulation initialization code
client = Client(cluster=False, address=None)
for iteration in range(num_iterations):
# simulation code producing two tensors, data_points
# and data_values
client.put_tensor(f"points_{iteration}", data_points)
client.put_tensor(f"values_{iteration}", data_values)
For simple simulations, this is sufficient. But if the simulation
uses MPI, then each rank could upload a portion of the data set. In that case,
the format for sample and target keys will be ``<sample_prefix>_<sub-index>_<iteration>``
and ``<target_prefix>_<sub-index>_<iteration>``, where ``<sub_index>`` can be, e.g.
the MPI rank id.


5.1.2 Downloading the samples and training the model
----------------------------------------------------

The second part of the workflow is the ``training_service``, an application that
downloads the data uploaded by the ``producer`` and uses them to train a ML model.
Most importantly, the ``training_service`` needs to keep looking for new samples,
and download them as they are available. The training data set size thus needs to grow at
each ``producer`` iteration.

In Keras, a ``Sequence`` represents a data set and can be passed to ``model.fit()``.
The class ``smartsim.ml.tf.DataGenerator`` is a Keras ``Sequence``, which updates
its data set at the end of each training epoch, looking for newly produced batches of samples.
A current limitation of the TensorFlow training algorithm is that it does not take
into account changes of size in the data sets once the training has started, i.e. it is always
assumed that the training (and validation) data does not change during the training. To
overcome this limitation, we need to train one epoch at the time. Thus,
following what we defined in the :ref:`producer section <ml_training_produced_code>`,
the ``training_service`` would look like

.. code-block:: python
from smartsim.ml.tf.data import DataGenerator
generator = DataGenerator(
sample_prefix="points",
target_prefix="value",
batch_size=32,
cluster=False)
model = # some ML model
# model initialization
for epoch in range(100):
model.fit(generator,
steps_per_epoch=None,
epochs=epoch+1,
initial_epoch=epoch,
batch_size=generator.batch_size,
verbose=2)
Again, this is enough for simple simulations. If the simulation uses MPI,
then the ``DataGenerator`` needs to know about the possible sub-indices. For example,
if the simulation runs 8 MPI ranks, the ``DataGenerator`` initialization will
need to be adapted as follows

.. code-block:: python
generator = DataGenerator(
sample_prefix="points",
target_prefix="value",
batch_size=32,
cluster=False,
uploader_ranks=8)
5.1.3 Launching the experiment
------------------------------

To launch the ``producer`` and the ``training_service`` as models
within a SmartSim ``Experiment``, we can use the following code:

.. code-block:: python
from smartsim import Experiment
from smartsim.database import Orchestrator
db = Orchestrator(port=6780)
exp = Experiment("online-training", launcher="local")
# producer
producer_script = "producer.py"
settings = exp.create_run_settings("python", exe_args=producer_script)
uploader_model = exp.create_model("producer", settings, enable_key_prefixing=True)
uploader_model.attach_generator_files(to_copy=producer_script)
# training_service
training_script = "training_service.py"
settings = exp.create_run_settings("python", exe_args=training_script)
trainer_model = exp.create_model("training_service", settings)
trainer_model.register_incoming_entity(uploader_model)
exp.start(db, uploader_model, block=False, summary=False)
exp.start(trainer_model, block=True, summary=False)
Two lines require attention, as they are needed by the ``DataGenerator`` to work:
- ``uploader_model.enable_key_prefixing()`` will ensure that the ``producer`` prefixes
all tensor keys with its name
- ``trainer_model.register_incoming_entity(uploader_model)`` enables the ``DataGenerator``
in the ``training_service`` to know that it needs to download samples produced by the ``producer``

6 changes: 6 additions & 0 deletions smartsim/ml/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .data import (
DynamicDataDownloader,
StaticDataDownloader,
TrainingDataUploader,
form_name,
)
Loading

0 comments on commit e9210d4

Please sign in to comment.