Skip to content

Nested outputs structures and corresponding nested losses are supported, but not nested metrics #21700

@hertschuh

Description

@hertschuh

This is a follow up to #21675 and #21694

  • nested output structures are supported
  • nested losses matching the nested output structures are supported
  • however, the same is not true for metrics, the code within the try/except block fails
import numpy as np
import keras

print('keras version:', keras.__version__)

X = np.random.rand(100, 32)
Y1 = np.random.rand(100, 10)
Y2 = np.random.rand(100, 10)
Y3 = np.random.rand(100, 10)

def create_model():
    x = keras.Input(shape=(32,))
    y1 = keras.layers.Dense(10)(x)
    y2 = keras.layers.Dense(10)(x)
    y3 = keras.layers.Dense(10)(x)
    return keras.Model(inputs = x, outputs = {'a': y1, 'b': {'c': y2, 'd': y3}})

model = create_model()
model.compile(optimizer = 'adam', loss = {'a': 'bce', 'b': {'c': 'mse', 'd': 'mse'}})
model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}})
print('Successful training not using metrics\n')

try:
    model = create_model()
    model.compile(
        optimizer = 'adam',
        loss = {'a': 'bce', 'b': {'c': 'mse', 'd': 'mse'}},
        metrics = {'a': ['mae', 'acc'], 'b':{'c': 'mse', 'd': 'mse'}}
    )
    model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}})
    print('Successful training using metrics')
except ValueError as ex:
    print('Unsuccessful training with metrics')
    print(ex)

model = create_model()
model.compile(
    optimizer = 'adam',
    loss = {'a': 'bce', 'b': {'c': 'mse', 'd': 'mse'}},
    loss_weights = {'a': 1., 'b': {'c': 0.5, 'd': 0.4}}
)
model.train_on_batch(X, {'a': Y1, 'b': {'c': Y2, 'd': Y3}})
print('Successful training using loss_weights\n')

Ideally, metrics and losses should be handled the same. Part of the complication is that there is a mechanism to use the output names when not using the output structure.

Another complication is that metrics can be either a single metric or a list of metrics for each output. The logic that turns single metrics to a list must handle nested structures correctly (with keras.tree.map_structure_up_to).

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions