Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for stateful metrics. #9253

Merged
merged 7 commits into from
Feb 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,8 +2039,8 @@ def batch_get_value(xs):
def set_value(x, value):
if (isinstance(x, C.variables.Parameter) or
isinstance(x, C.variables.Constant)):
if isinstance(value, float):
value = np.full(x.shape, value)
if isinstance(value, (float, int)):
value = np.full(x.shape, value, dtype=floatx())
x.value = value
else:
raise NotImplementedError
Expand Down
41 changes: 35 additions & 6 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from collections import Iterable
from .utils.generic_utils import Progbar
from . import backend as K
from .engine.topology import Layer

try:
import requests
Expand Down Expand Up @@ -202,8 +203,20 @@ class BaseLogger(Callback):
"""Callback that accumulates epoch averages of metrics.

This callback is automatically applied to every Keras model.

# Arguments
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is in `on_epoch_end`.
All others will be averaged in `on_epoch_end`.
"""

def __init__(self, stateful_metrics=None):
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()

def on_epoch_begin(self, epoch, logs=None):
self.seen = 0
self.totals = {}
Expand All @@ -214,17 +227,23 @@ def on_batch_end(self, batch, logs=None):
self.seen += batch_size

for k, v in logs.items():
if k in self.totals:
self.totals[k] += v * batch_size
if k in self.stateful_metrics:
self.totals[k] = v
else:
self.totals[k] = v * batch_size
if k in self.totals:
self.totals[k] += v * batch_size
else:
self.totals[k] = v * batch_size

def on_epoch_end(self, epoch, logs=None):
if logs is not None:
for k in self.params['metrics']:
if k in self.totals:
# Make value available to next callbacks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"make value available to future callbacks"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Callbacks are processed sequentially, future is more vague than next IMO.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the grammar is right as-is with "next".

Alternate: "make value available to each of the following callbacks"

logs[k] = self.totals[k] / self.seen
if k in self.stateful_metrics:
logs[k] = self.totals[k]
else:
logs[k] = self.totals[k] / self.seen


class TerminateOnNaN(Callback):
Expand All @@ -250,19 +269,28 @@ class ProgbarLogger(Callback):
count_mode: One of "steps" or "samples".
Whether the progress bar should
count samples seen or steps (batches) seen.
stateful_metrics: Iterable of string names of metrics that
should *not* be averaged over an epoch.
Metrics in this list will be logged as-is.
All others will be averaged over time (e.g. loss, etc).

# Raises
ValueError: In case of invalid `count_mode`.
"""

def __init__(self, count_mode='samples'):
def __init__(self, count_mode='samples',
stateful_metrics=None):
super(ProgbarLogger, self).__init__()
if count_mode == 'samples':
self.use_steps = False
elif count_mode == 'steps':
self.use_steps = True
else:
raise ValueError('Unknown `count_mode`: ' + str(count_mode))
if stateful_metrics:
self.stateful_metrics = set(stateful_metrics)
else:
self.stateful_metrics = set()

def on_train_begin(self, logs=None):
self.verbose = self.params['verbose']
Expand All @@ -277,7 +305,8 @@ def on_epoch_begin(self, epoch, logs=None):
target = self.params['samples']
self.target = target
self.progbar = Progbar(target=self.target,
verbose=self.verbose)
verbose=self.verbose,
stateful_metrics=self.stateful_metrics)
self.seen = 0

def on_batch_begin(self, batch, logs=None):
Expand Down
Loading