diff --git a/rfcs/20200117-tfx-generic-trainer.md b/rfcs/20200117-tfx-generic-trainer.md
new file mode 100644
index 000000000..f479a1e00
--- /dev/null
+++ b/rfcs/20200117-tfx-generic-trainer.md
@@ -0,0 +1,253 @@
+# TFX Generic Trainer
+
+| Status | Proposed |
+| :------------ | :-------------------------------------------------------- |
+| **Author(s)** | Jiayi Zhao (jyzhao@google.com) |
+| **Sponsor** | Konstantinos Katsiapis (katsiapis@google.com), Zhitao Li (zhitaoli@google.com), Karmel Allison (karmel@google.com) |
+| **Updated** | 2020-01-17 |
+
+## Objective
+
+### Goal
+
+* Support any TensorFlow Training loop in TFX Trainer in addition to
+ tf.estimator, primarily focused on native Keras model.
+
+### Non Goal
+
+* Natively support multi-worker distributed training by the system.
+* Non-TF training that generates savedmodel.
+
+## Background and Motivation
+
+In current TFX Trainer component, only tf.estimator is supported for training
+and generating models. User provides a module file which contains a
+`trainer_fn`, trainer will call the function to get the estimator model and
+related spec for training, and generate a saved model by
+`tf.estimator.train_and_evaluate`.
+
+[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level
+API for building and training models. It’s currently supported in TFX by using
+`tf.keras.estimator.model_to_estimator` in module file. User can create keras
+model in their `trainer_fn` but need to convert it to estimator for return (for
+example,
+[cifar10](https://github.com/tensorflow/tfx/blob/r0.15/tfx/examples/cifar10/cifar10_utils.py)).
+
+This doc will focus on native Keras support (without model_to_estimator) in TFX.
+We propose changing the user facing API to be more generic so that users can do
+(single node) native Keras model training within TFX.
+
+## User Benefit
+
+* Allows non estimator based training, especially Keras as TensorFlow is
+ establishing Keras as the
+ [Standardized high-level API](https://medium.com/tensorflow/standardizing-on-keras-guidance-on-high-level-apis-in-tensorflow-2-0-bad2b04c819a).
+* Allows
+ [custom training](https://www.tensorflow.org/tutorials/customization/custom_training)
+ for customization of training loop.
+
+## Detailed Design
+
+Below shows the pseudo code for current TFX Trainer’s executor:
+
+```python
+class Executor(base_executor.BaseExecutor):
+
+ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
+ output_dict: Dict[Text, List[types.Artifact]],
+ exec_properties: Dict[Text, Any]) -> None:
+ """Uses a user-supplied tf.estimator to train a tf model locally."""
+ trainer_fn = self._GetFn(exec_properties) # load from module file
+ trainer_fn_args = self._GetFnArgs(
+ input_dict, output_dict, exec_properties)
+
+ training_spec = trainer_fn(trainer_fn_args)
+ tf.estimator.train_and_evaluate(training_spec['estimator'], ...)
+ # For TFMA (downstream evaluator and model validator component).
+ tfma.export.export_eval_savedmodel(training_spec['estimator'], ...)
+```
+
+And the user supplied module file contains a function called `trainer_fn` which
+returns an estimator:
+
+```python
+def _build_keras_model() -> tf.keras.Model:
+ model = keras.XXX
+ model.compile(...)
+ return model
+
+def trainer_fn(
+ trainer_fn_args: trainer.executor.TrainerFnArgs) -> Dict[Text, Any]:
+ """Build the estimator using the high level API.
+
+ Args:
+ trainer_fn_args: Holds args used to train the model as name/value pairs.
+
+ Returns:
+ A dict of the following:
+ - estimator: The estimator that will be used for training and eval.
+ - train_spec: Spec for training.
+ - eval_spec: Spec for eval.
+ - eval_input_receiver_fn: Input function for eval.
+ """
+ ...
+
+ estimator = tf.keras.estimator.model_to_estimator(
+ keras_model=_build_keras_model(), ...)
+
+ return {
+ 'estimator': estimator,
+ 'train_spec': ...,
+ 'eval_spec': ...,
+ 'eval_input_receiver_fn': ...
+ }
+
+```
+
+We propose that in generic trainer's module file, user not only need to provide
+the model, but also control how the model is trained (`train_and_evaluate` for
+estimator and `model.fit` for keras will be in user module file instead of in
+executor), thus executor can be generic to model, and users can customize the
+[training loop](https://www.tensorflow.org/tutorials/customization/custom_training_walkthrough#training_loop).
+The executor pseudo code would look like below:
+
+```python
+class Executor(base_executor.BaseExecutor):
+
+ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
+ output_dict: Dict[Text, List[types.Artifact]],
+ exec_properties: Dict[Text, Any]) -> None:
+ """Train a user-supplied tf model."""
+ run_fn = self._GetRunFn(exec_properties) # load from module file
+
+ # run_fn_args contains
+ # 1. input train and eval data path.
+ # 2. desired output model path for the trained savedmodel.
+ # 3. training args, e.g., train/eval steps.
+ # 4. optional base model.
+ # 5. optional tuning result (kerastuner.HyperParameters config).
+ # 6. optional custom config for passing params from component.
+ run_fn_args = self._GetRunFnArgs(
+ input_dict, output_dict, exec_properties)
+
+ run_fn(run_fn_args)
+ # Validates the existence of run_fn's output savedmodel.
+ ...
+```
+
+In module file, user needs to provide `run_fn` instead of previous `trainer_fn`.
+The `trainer_fn` was responsible for creating the model, in addition to that,
+`run_fn` also needs to handle training part and output the trained model to a
+desired location given by run args:
+
+```python
+def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
+ """Build the TF model and train it."""
+ model = _build_keras_model()
+ model.fit(...)
+ # Save model to args.serving_model_dir.
+ model.save(...)
+```
+
+In generic trainer, executor is mainly for handling the
+[artifact](https://github.com/tensorflow/tfx/blob/r0.21/docs/guide/index.md#artifacts)
+(a unit of data that is passed between components), all model related logic is
+user supplied.
+
+A separate GenericExecutor will be created, and the existing trainer executor
+will be sunsetted. We plan to keep estimator based executor for one more version
+and then deprecate it.
+
+### How to convert current estimator based module file
+
+To convert the current estimator based module file (e.g.,
+[iris](https://github.com/tensorflow/tfx/blob/r0.15/tfx/examples/iris/iris_utils.py))
+for generic trainer, simply add a run_fn that calls the trainer_fn and train the
+returned model (code that used to be in the trainer.executor.Do).
+
+```python
+def run_fn(fn_args: executor.TrainerFnArgs):
+ """Train the model based on given args.
+
+ Args:
+ fn_args: Holds args used to train the model as name/value pairs.
+ """
+ schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema())
+
+ # Reuse the trainer_fn.
+ training_spec = trainer_fn(fn_args, schema)
+
+ # Train the model
+ absl.logging.info('Training model.')
+ tf.estimator.train_and_evaluate(training_spec['estimator'],
+ training_spec['train_spec'],
+ training_spec['eval_spec'])
+ absl.logging.info('Training complete. Model written to %s',
+ fn_args.serving_model_dir)
+
+ # Export an eval savedmodel for TFMA, note that for keras, eval savedmodel is
+ # not needed as TFMA2 can use serving model for evaluation.
+ absl.logging.info('Exporting eval_savedmodel for TFMA.')
+ tfma.export.export_eval_savedmodel(
+ estimator=training_spec['estimator'],
+ export_dir_base=fn_args.eval_model_dir,
+ eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])
+
+ absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir)
+```
+
+### tf.distribute.Strategy
+
+Distribution strategy will be user module's responsibilty with the new generic
+trainer interface. To use it, user needs to modify the `run_fn()` in the module
+file, below shows the pseudo code example for single worker and multi-worker
+distribute strategy.
+
+For single worker distribute strategy, you need to create an appropriate
+[tf.distribute.Strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy),
+and move the creation and compiling of Keras model inside `strategy.scope`:
+
+```python
+def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
+ """Build the TF model and train it."""
+ mirrored_strategy = tf.distribute.MirroredStrategy()
+ with mirrored_strategy.scope():
+ model = _build_keras_model()
+ model.fit(...)
+ model.save(...)
+```
+
+For multi-worker distribution strategy, the TFX Trainer does not have ability to
+spawn multi-worker cluster by
+[current executor](https://github.com/tensorflow/tfx/blob/r0.21/tfx/components/trainer/executor.py),
+hence not covered in the scope of this RFC. If the execution environment of an
+implementation of TFX Trainer has the ability to bring up the cluster of worker
+machines, and execute user funtion in the workers with correct
+[TF_CONFIG setup](https://www.tensorflow.org/guide/distributed_training#setting_up_tf_config_environment_variable),
+such as GCP AI Platform Training service via
+[extensions/google_cloud_ai_platform/trainer/executor.py](https://github.com/tensorflow/tfx/blob/r0.21/tfx/extensions/google_cloud_ai_platform/trainer/executor.py),
+the `run_fn()` would look like below:
+
+```python
+def _is_chief() -> bool:
+ """Decide whether the current worker's role is chief."""
+ # Check TF_CONFIG (set by TFX when bring up the worker) in execution env.
+ ...
+
+def run_fn(args: trainer.executor.TrainerFnArgs) -> None:
+ """Build the TF model and train it."""
+ ps_strategy = tf.distribute.experimental.ParameterServerStrategy()
+ with ps_strategy.scope():
+ model = _build_keras_model()
+ model.fit(...)
+ if _is_chief():
+ model.save(...)
+```
+
+For details about `tf.distribute.Strategy`, please refer to
+[here](https://www.tensorflow.org/guide/distributed_training).
+
+## Future work
+
+* Examples for custom training loop.
+* Native support for multi-worker distribution.
diff --git a/rfcs/20200420-tfx-tuner-component.md b/rfcs/20200420-tfx-tuner-component.md
new file mode 100644
index 000000000..f83ef405c
--- /dev/null
+++ b/rfcs/20200420-tfx-tuner-component.md
@@ -0,0 +1,386 @@
+# TFX Tuner Component
+
+| Status | Proposed |
+| :------------ | :-------------------------------------------------------- |
+| **Author(s)** | Jiayi Zhao (jyzhao@google.com), Amy Wu (wuamy@google.com) |
+| **Sponsor** | Zhitao Li (zhitaoli@google.com), Tom O'Malley |
+: : (omalleyt@google.com), Matthieu Monsch (mtth@google.com) :
+: : Makoto Uchida (muchida@google.com), Goutham Bhat :
+: : (goutham@google.com) :
+| **Updated** | 2020-04-20 |
+
+## Objective
+
+### Goal
+
+* A new Tuner component in TFX for automated hyper-parameter tuning, which is
+ based on abstractions from
+ [KerasTuner library](https://github.com/keras-team/keras-tuner), in order to
+ reuse abstractions and algorithms from latter.
+
+### Non Goal
+
+* Natively support multi-worker tuning by the system. As TFX doesn't have
+ ability to manage multi-worker clusters, running multiple trials in parallel
+ (parallel tuning) and running each trial in distributed env (distributed
+ training) are not supported natively. Parallel tuning may instead be
+ realized by a particular implementation of TFX Tuner (custom Executor),
+ e.g., in Google Cloud environment.
+* Implementation of custom tuner for
+ [KerasTuner library](https://github.com/keras-team/keras-tuner) is out of
+ scope of this design discussion, e.g., a built-in EstimatorTuner support.
+ However, user project can still implement a tuner that inherits from
+ [`kerastuner.BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py)
+ and provide it to the proposed TFX Tuner component.
+
+## Background and Motivation
+
+A hyperparameter is a parameter whose value is used to control the learning
+process of a model or the model itself (e.g., layers and number of nodes). By
+contrast, the values of other parameters (typically node weights) are learned.
+
+Hyperparameter optimization is a critical part of many machine learning
+pipelines. Thus we propose a new TFX component, with the given search space
+which specifies the hyperparameter configuration (name, type, range etc.). TFX
+will optimize the hyperparameters based on the tuning algorithm.
+
+## User Benefit
+
+This document proposes a built-in TFX Tuner component, which works seamlessly
+with Trainer and other TFX components. As the Tuner component will utilize the
+[KerasTuner library](https://github.com/keras-team/keras-tuner), all supported
+tuning methods will be available to TFX, including custom implementation of
+KerasTuner.
+
+## Design Proposal
+
+TFX Tuner component will be built with the
+[KerasTuner library](https://github.com/keras-team/keras-tuner). In the
+following sections, we will first briefly go over the KerasTuner library and
+several concepts in hyperparameter optimization. Then we will focus on our Tuner
+component interface and how we utilize the KerasTuner library. After that, we
+will discuss parallel tuning and our plan on Google Cloud integration.
+
+### KerasTuner Library
+
+The following graph shows a typical workflow of hyperparameter tuning under the
+KerasTuner framework:
+
+
![](20200420-tfx-tuner-component/workflow.png)
+
+Given the user provided model which accepts a hyperparameter container, tuner
+can search optimization through trials created by the tuning algortihm. For each
+trial, values within search spaces will be assigned to hyperparameter
+containers, and the user model will be trained with these hyperparameter values
+and evaluated based on the objective provided to the tuner. The evaluation
+results will be reported back to tuner and the tuning algorithm will decide the
+hyperparameter values for the next trial. After reaching certain conditions,
+e.g., max trials, the tuner will stop iteration and return the optimal
+hyperparameters.
+
+KerasTuner library provides above tuning functionality, here are some
+abstractions in KerasTuner:
+
+* [`HyperParameters`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/hyperparameters.py):
+ Hyperparameter container for both search space, and current values.
+* [`Oracle`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/oracle.py):
+ Implementation of a hyperparameter tuning algorithm, e.g., random search,
+ including state management of the algorithm’s progress.
+* [`Trial`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/trial.py):
+ Provided by the Oracle, contains information about Hyperparameter values for
+ the current iteration.
+* [`BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py):
+ a base tuner interface for above tuning workflow, responsible for the
+ iteration of trial execution:
+ * Generates Trial using Oracle.
+ * Trains user model with the HyperParameters in the current Trial.
+ * Evaluates metrics and reports back to Oracle for next Trial.
+* [`Tuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py):
+ An implementation of BaseTuner, for Keras model tuning.
+
+Note: Other than the Tuner, abstractions defined by `HyperParameters`, `Oracle`,
+`Trial` and `BaseTuner` are not restricted to Keras models, although the library
+is called KerasTuner.
+
+For more details and code examples, please refer to
+[here](https://github.com/keras-team/keras-tuner).
+
+### Component Interface
+
+Tuner component takes raw or transformed examples as input, along with schema or
+transform_graph for the feature specification, and outputs the hyperparameter
+tuning results, below shows the specification of Tuner component:
+
+```python
+class TunerSpec(ComponentSpec):
+ """ComponentSpec for TFX Tuner Component."""
+
+ PARAMETERS = {
+ # Specify a python module file which contains a UDF `tuner_fn`.
+ 'module_file': ExecutionParameter(type=(str, Text), optional=True),
+ # Specify the steps for the training stage of each trial’s execution.
+ 'train_args': ExecutionParameter(type=trainer_pb2.TrainArgs),
+ 'eval_args': ExecutionParameter(type=trainer_pb2.EvalArgs),
+ }
+
+ INPUTS = {
+ 'examples': ChannelParameter(type=standard_artifacts.Examples),
+ 'schema': ChannelParameter(
+ type=standard_artifacts.Schema, optional=True),
+ 'transform_graph':
+ ChannelParameter(
+ type=standard_artifacts.TransformGraph, optional=True),
+ }
+
+ OUTPUTS = {
+ 'best_hyperparameters':
+ ChannelParameter(type=standard_artifacts.HyperParameters),
+ }
+```
+
+Trainer has an optional hyperparameters input; tuning result can be fed into it
+so that Trainer can utilize best hyperparameters to construct the model. Below
+shows an example about how tuner and trainer are chained in the pipeline:
+
+```python
+# TrainerSpec:
+ INPUTS = {
+ ...
+ 'hyperparameters':
+ ChannelParameter(
+ type=standard_artifacts.HyperParameters, optional=True),
+ }
+
+# Pipeline DSL Example:
+ tuner = Tuner(
+ examples=example_gen.outputs['examples'],
+ schema=schema_gen.outputs['schema'],
+ module_file=module_file,
+ train_args=trainer_pb2.TrainArgs(num_steps=1000),
+ eval_args=trainer_pb2.EvalArgs(num_steps=500))
+
+ trainer = Trainer(
+ module_file=module_file,
+ examples=example_gen.outputs['examples'],
+ schema=schema_gen.outputs['schema'],
+ hyperparameters=tuner.outputs['best_hyperparameters'],
+ train_args=trainer_pb2.TrainArgs(num_steps=10000),
+ eval_args=trainer_pb2.EvalArgs(num_steps=5000))
+```
+
+For Trainer, users need to define model code and training logic
+([Generic Trainer](https://github.com/tensorflow/tfx/blob/r0.21.2/docs/guide/trainer.md#generic-trainer))
+in the module_file. For Tuner, in addition to model code, users also need to
+define hyperparameters, search space and a tuning algorithm in the module_file.
+A `tuner_fn` with the following signature is required for Tuner:
+
+```python
+from kerastuner.engine import base_tuner
+import tensorflow as tf
+from tfx.components.trainer.executor import TrainerFnArgs
+
+# Current TrainerFnArgs will be renamed to FnArgs as a util class.
+FnArgs = TrainerFnArgs
+TunerFnResult = NamedTuple('TunerFnResult',
+ [('tuner', base_tuner.BaseTuner),
+ ('fit_kwargs', Dict[Text, Any])])
+
+def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
+ """Build the tuner using the KerasTuner API.
+
+ Args:
+ fn_args: Holds args as name/value pairs.
+ working_dir: working dir for tuning. Automatically set by Executor.
+ train_files: List of file paths containing training tf.Example data.
+ eval_files: List of file paths containing eval tf.Example data.
+ train_steps: number of train steps.
+ eval_steps: number of eval steps.
+ schema: optional schema file of the input data.
+ transform_graph: optional transform graph produced by TFT.
+
+ Returns:
+ A namedtuple contains the following:
+ - tuner: A BaseTuner that will be used for tuning.
+ - fit_kwargs: Args to pass to tuner’s run_trial function for fitting the
+ model , e.g., the training and validation dataset. Required
+ args depend on the above tuner’s implementation.
+ """
+```
+
+The TunerFnResult returned by the above tuner_fn contains an instance that
+implements the
+[`BaseTuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/base_tuner.py)
+interface, that’s the contract required by Tuner for tuning. The model code,
+hyperparameters, search space and tuning algorithm are hidden under the
+BaseTuner abstraction so the Tuner itself is generic and agnostic to the model
+framework and tuning logic. Below shows an example module file with Keras model:
+
+```python
+import kerastuner
+import tensorflow as tf
+...
+
+def _input_fn(file_pattern: Text, ...) -> tf.data.Dataset:
+ ...
+
+# Model code for Trainer and Tuner.
+def _build_keras_model(hp: kerastuner.HyperParameters) -> tf.keras.Model:
+ ...
+ for _ in range(hp.get('num_layers')):
+ ...
+ ...
+ model = tf.keras.Model(...)
+ model.compile(
+ optimizer=tf.keras.optimizers.Adam(hp.get('learning_rate')),
+ loss='sparse_categorical_crossentropy',
+ metrics=[tf.keras.metrics.Accuracy()])
+ return model
+
+# This will be called by TFX Tuner.
+def tuner_fn(fn_args: FnArgs) -> TunerFnResult:
+ hp = kerastuner.HyperParameters()
+ # Defines search space.
+ hp.Choice('learning_rate', [1e-1, 1e-3])
+ hp.Int('num_layers', 1, 5)
+
+ # RandomSearch is a subclass of Keras model Tuner.
+ tuner = kerastuner.RandomSearch(
+ _build_keras_model,
+ max_trials=5,
+ hyperparameters=hp,
+ allow_new_entries=False,
+ objective='val_accuracy',
+ directory=fn_args.working_dir,
+ project_name='test')
+
+ train_dataset=_input_fn(fn_args.train_files, ...)
+ eval_dataset=_input_fn(fn_args.eval_files, ...)
+
+ return TunerFnResult(
+ tuner=tuner,
+ fit_kwargs={'x': train_dataset,
+ 'validation_data': eval_dataset,
+ 'steps_per_epoch': fn_args.train_steps,
+ 'validation_steps': fn_args.eval_steps})
+
+# This will be called by TFX Generic Trainer.
+def run_fn(fn_args: FnArgs) -> None:
+ hp = kerastuner.HyperParameters.from_config(fn_args.hyperparameters)
+ model = _build_keras_model(hp)
+ model.fit(...)
+ model.save(...)
+```
+
+In Tuner’s executor, `tuner_fn` will be called with information resolved from
+component inputs, then we call the `search` function of the returned tuner with
+`fit_kwargs` to launch trials for tuning, and finally emit the best trial’s
+hyperparameters:
+
+```python
+# Executor of Tuner Component:
+class Executor(base_executor.BaseExecutor):
+
+ def Do(self,
+ input_dict: Dict[Text, List[types.Artifact]],
+ output_dict: Dict[Text, List[types.Artifact]],
+ exec_properties: Dict[Text, Any]) -> None:
+ ...
+ tuner_spec = tuner_fn(self._create_fn_args(input_dict, exec_properties))
+ tuner_spec.tuner.search(**tuner_spec.fit_kwargs)
+ # Output file contains json format string of hyperparameters.get_config().
+ self._emit_best_hyperparameters(
+ output_dict, tuner_spec.tuner.get_best_hyperparameters()[0])
+```
+
+### Parallel Tuning
+
+In parallel tuning, multiple trials are executed in parallel. In this section,
+we will discuss how distribution works for KerasTuner library and the status of
+TFX.
+
+In the `search` function of tuner, trials will be run in sequence instead of in
+parallel. To support parallel tuning, we need to launch multiple tuners (the
+tuner here refers to the one in KerasTuner library, not TFX Tuner component),
+and have an optimization service for managing the state of the tuning algorithm,
+with which oracle of each tuner communicates, and retrieves the trials for each
+tuner.
+
+![](20200420-tfx-tuner-component/parallel_tuning.png)
+
+The above graph shows a parallel tuning of three tuners. Each tuner runs as a
+different worker, and it retrieves trials from its own oracle, which talks to
+optimization service. Trials of different tuners can run in parallel but trials
+within the same tuner will still execute in sequence. When launching tuners, the
+same identifier will be assigned to each oracle, thus the optimization service
+knows they are in the same tuning job group and will assign hyperparameter
+values for their trials based on the algorithm.
+
+The number of parallel tuners can be passed to component by the `TuneArgs` as
+shown below:
+
+```python
+# Args specific to tuning.
+message TuneArgs {
+ # Number of trials to run in parallel.
+ # Each trial will be trained and evaluated by separate worker jobs.
+ int32 num_parallel_trials = 1;
+}
+
+class TunerSpec(ComponentSpec):
+
+ PARAMETERS = {
+ ...
+ 'tune_args': ExecutionParameter(type=tuner_pb2.TuneArgs),
+ }
+```
+
+The KerasTuner library allows users to config
+[`tf.distribute.Strategy`](https://www.tensorflow.org/tutorials/distribute/kerass)
+if they are using
+[`kerastuner.Tuner`](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py)
+class (or subclasses of it). In above parallel tuning, each trial (each model
+training) is executed in a single worker, as such only single machine strategy
+is allowed. To support multi-worker distributed training, we need to be able to
+execute the trial (training) on different workers.
+
+At the time of writing, KerasTuner library can be used for parallel tuning with
+single machine `tf.distribute.Strategy`, e.g.,
+[`MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)
+, multi-worker strategy (distributed training for trial) support is on the
+roadmap (note that cluster managing is not part of the library).
+
+At the time of writing, TFX doesn’t have the ability to manage the multi-worker
+cluster and the centralized optimization service, so parallel tuning or
+distributed training is not supported natively in TFX (local or on-prem), but in
+the next section, we will discuss the integration for Google Cloud. Similar
+parallel tuning support can be built for other execution environments.
+
+### Google Cloud Integration
+
+In this section, we discuss the Tuner component with
+[Google Cloud AI Platform](https://cloud.google.com/ai-platform) (CAIP),
+specifically, an implementation of KerasTuner Oracle that talks to the
+[AI Platform Optimizer](https://cloud.google.com/ai-platform/optimizer/docs/overview)
+as the centralized optimization service, and a custom Tuner executor
+implementation that makes use of the Cloud Optimizer-based Oracle (symbol names
+are subject to change).
+
+As mentioned above in the parallel tuning section, KerasTuner uses a centralized
+optimization service that manages states of a tuning study and trials. In
+addition to that, we will create a `CloudOracle` as a client to the AI Platform
+Optimizer service, and a `CloudTuner` which inherits from Keras
+[Tuner](https://github.com/keras-team/keras-tuner/blob/1.0.0/kerastuner/engine/tuner.py).
+In the module file, users create the `tuner_fn` with `CloudTuner`, and then
+users configure the TFX Tuner component to use the a custom Tuner executor
+(`CloudExecutor`), which launches multiple `CloudTuner`s on a Google Cloud AI
+Platform Training job with possibly multiple worker machines running various
+trials concurrently. Below shows the workflow for in process tuning and Cloud
+tuning.
+
+![](20200420-tfx-tuner-component/cloud.png)
+
+## Future work
+
+* Native support for multi-worker parallel tuning.
+* Custom Tuner (inherits from BaseTuner) examples, e.g., for Estimator support
+ or Keras custom training loop support.
\ No newline at end of file
diff --git a/rfcs/cloud.png b/rfcs/cloud.png
new file mode 100644
index 000000000..09559da71
Binary files /dev/null and b/rfcs/cloud.png differ
diff --git a/rfcs/parallel_tuning.png b/rfcs/parallel_tuning.png
new file mode 100644
index 000000000..efd62b113
Binary files /dev/null and b/rfcs/parallel_tuning.png differ
diff --git a/rfcs/workflow.png b/rfcs/workflow.png
new file mode 100644
index 000000000..4f8bd89da
Binary files /dev/null and b/rfcs/workflow.png differ