-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Open
Labels
stat:contributions welcomeA pull request to fix this issue would be welcome.A pull request to fix this issue would be welcome.
Description
This is a follow up to #21675 and #21694
- nested output structures are supported
- nested losses matching the nested output structures are supported
- however, the same is not true for metrics, the code within the
try/exceptblock fails
import numpy as np
import keras
print('keras version:', keras.__version__)
X = np.random.rand(100, 32)
Y1 = np.random.rand(100, 10)
Y2 = np.random.rand(100, 10)
Y3 = np.random.rand(100, 10)
def create_model():
x = keras.Input(shape=(32,))
y1 = keras.layers.Dense(10)(x)
y2 = keras.layers.Dense(10)(x)
y3 = keras.layers.Dense(10)(x)
return keras.Model(inputs = x, outputs = {'a': y1, 'b': {'c': y2, 'd': y3}})
model = create_model()
model.compile(optimizer = 'adam', loss = {'a': 'bce', 'b': {'c': 'mse', 'd': 'mse'}})
model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}})
print('Successful training not using metrics\n')
try:
model = create_model()
model.compile(
optimizer = 'adam',
loss = {'a': 'bce', 'b': {'c': 'mse', 'd': 'mse'}},
metrics = {'a': ['mae', 'acc'], 'b':{'c': 'mse', 'd': 'mse'}}
)
model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}})
print('Successful training using metrics')
except ValueError as ex:
print('Unsuccessful training with metrics')
print(ex)
model = create_model()
model.compile(
optimizer = 'adam',
loss = {'a': 'bce', 'b': {'c': 'mse', 'd': 'mse'}},
loss_weights = {'a': 1., 'b': {'c': 0.5, 'd': 0.4}}
)
model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}})
print('Successful training using loss_weights\n')Ideally, metrics and losses should be handled the same. Part of the complication is that there is a mechanism to use the output names when not using the output structure.
Another complication is that metrics can be either a single metric or a list of metrics for each output. The logic that turns single metrics to a list must handle nested structures correctly (with keras.tree.map_structure_up_to).
Metadata
Metadata
Assignees
Labels
stat:contributions welcomeA pull request to fix this issue would be welcome.A pull request to fix this issue would be welcome.