Skip to content

Study Keras: weight initialization

Yang Yang(Tony) edited this page Apr 23, 2019 · 1 revision
  • TensorFlow Version: 2.0
  • Keras Model API: subclassing Keras.Model

Sample Code

class DNNClassifier(tf.keras.Model):
    def __init__(self, feature_columns, hidden_units, n_classes):
        super(DNNClassifier, self).__init__()
        self.feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
        self.hidden_layers = []
        for hidden_unit in hidden_units:
            self.hidden_layers.append(tf.keras.layers.Dense(hidden_unit))
        self.prediction_layer = tf.keras.layers.Dense(n_classes, activation='softmax')

    def call(self, inputs):
        x = self.feature_layer(inputs)
        for hidden_layer in self.hidden_layers:
            x = hidden_layer(x)
        return self.prediction_layer(x)

Stacktrace on model.fit

  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(86)<module>()
-> model.fit(train_ds, validation_data=val_ds, epochs=model.default_training_epochs(), verbose=0)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(794)fit()
-> initial_epoch=initial_epoch)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1519)fit_generator()
-> steps_name='steps_per_epoch')
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(257)model_iteration()
-> batch_outs = batch_function(*batch_data)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1242)train_on_batch()
-> extract_tensors_from_dataset=True)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2507)_standardize_user_data()
-> self._set_inputs(cast_inputs)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py(456)_method_wrapper()
-> result = method(self, *args, **kwargs)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2779)_set_inputs()
-> outputs = self.call(inputs)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(62)call()
-> x = hidden_layer(x)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(594)__call__()
-> self._maybe_build(inputs)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(1713)_maybe_build()
-> self.build(input_shapes)
> /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py(961)build()
-> dtype = dtypes.as_dtype(self.dtype or K.floatx())

Stacktrace on model.predict

  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(90)<module>()
-> model.predict(test_ds)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1147)predict()
-> callbacks=callbacks)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(257)model_iteration()
-> batch_outs = batch_function(*batch_data)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(531)predict_on_batch()
-> return model.predict_on_batch(x)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1365)predict_on_batch()
-> x, extract_tensors_from_dataset=True)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2507)_standardize_user_data()
-> self._set_inputs(cast_inputs)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py(456)_method_wrapper()
-> result = method(self, *args, **kwargs)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2779)_set_inputs()
-> outputs = self.call(inputs)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(62)call()
-> x = hidden_layer(x)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(594)__call__()
-> self._maybe_build(inputs)
  /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(1713)_maybe_build()
-> self.build(input_shapes)
> /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py(961)build()
-> dtype = dtypes.as_dtype(self.dtype or K.floatx())

Conclusion

Main logics at Model._set_inputs. It creates symbolic tensors and passes them to the layer.__call__. Inside layer.__call__, if the input is symbolic tensors, layer.__call__ will invoke layer.build.