-
Notifications
You must be signed in to change notification settings - Fork 704
Study Keras: weight initialization
Yang Yang(Tony) edited this page Apr 23, 2019
·
1 revision
- TensorFlow Version: 2.0
- Keras Model API: subclassing
Keras.Model
class DNNClassifier(tf.keras.Model):
def __init__(self, feature_columns, hidden_units, n_classes):
super(DNNClassifier, self).__init__()
self.feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
self.hidden_layers = []
for hidden_unit in hidden_units:
self.hidden_layers.append(tf.keras.layers.Dense(hidden_unit))
self.prediction_layer = tf.keras.layers.Dense(n_classes, activation='softmax')
def call(self, inputs):
x = self.feature_layer(inputs)
for hidden_layer in self.hidden_layers:
x = hidden_layer(x)
return self.prediction_layer(x)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(86)<module>()
-> model.fit(train_ds, validation_data=val_ds, epochs=model.default_training_epochs(), verbose=0)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(794)fit()
-> initial_epoch=initial_epoch)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1519)fit_generator()
-> steps_name='steps_per_epoch')
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(257)model_iteration()
-> batch_outs = batch_function(*batch_data)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1242)train_on_batch()
-> extract_tensors_from_dataset=True)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2507)_standardize_user_data()
-> self._set_inputs(cast_inputs)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py(456)_method_wrapper()
-> result = method(self, *args, **kwargs)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2779)_set_inputs()
-> outputs = self.call(inputs)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(62)call()
-> x = hidden_layer(x)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(594)__call__()
-> self._maybe_build(inputs)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(1713)_maybe_build()
-> self.build(input_shapes)
> /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py(961)build()
-> dtype = dtypes.as_dtype(self.dtype or K.floatx())
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(90)<module>()
-> model.predict(test_ds)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1147)predict()
-> callbacks=callbacks)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(257)model_iteration()
-> batch_outs = batch_function(*batch_data)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training_generator.py(531)predict_on_batch()
-> return model.predict_on_batch(x)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(1365)predict_on_batch()
-> x, extract_tensors_from_dataset=True)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2507)_standardize_user_data()
-> self._set_inputs(cast_inputs)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/training/tracking/base.py(456)_method_wrapper()
-> result = method(self, *args, **kwargs)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py(2779)_set_inputs()
-> outputs = self.call(inputs)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/example.py(62)call()
-> x = hidden_layer(x)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(594)__call__()
-> self._maybe_build(inputs)
/Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py(1713)_maybe_build()
-> self.build(input_shapes)
> /Users/yang.y/go/src/github.com/sql-machine-learning/sqlflow/sql/python/models/venv/lib/python3.7/site-packages/tensorflow/python/keras/layers/core.py(961)build()
-> dtype = dtypes.as_dtype(self.dtype or K.floatx())
Main logics at Model._set_inputs
. It creates symbolic tensors and passes them to the layer.__call__
. Inside layer.__call__
, if the input is symbolic tensors, layer.__call__
will invoke layer.build
.