Source code for deepcpg.callbacks
"""Keras callback classes used by `dcpg_train.py`."""
from __future__ import division
from __future__ import print_function
from collections import OrderedDict
import os
from time import time
from keras.callbacks import Callback
import numpy as np
import six
from .utils import format_table
[docs]class PerformanceLogger(Callback):
"""Logs performance metrics during training.
Stores and prints performance metrics for each batch, epoch, and output.
Parameters
----------
metrics: list
Name of metrics to be logged.
log_freq: float
Logging frequency as the percentage of training samples per epoch.
precision: int
Floating point precision.
callbacks: list
List of functions with parameters `epoch`, `epoch_logs`, and
`val_epoch_logs` that are called at the end of each epoch.
verbose: bool
If `True`, log performance metrics of individual outputs.
logger: function
Logging function.
"""
def __init__(self, metrics=['loss', 'acc'], log_freq=0.1,
precision=4, callbacks=[], verbose=bool, logger=print):
self.metrics = metrics
self.log_freq = log_freq
self.precision = precision
self.callbacks = callbacks
self.verbose = verbose
self.logger = logger
self._line = '=' * 100
self.epoch_logs = None
self.val_epoch_logs = None
self.batch_logs = []
def _log(self, x):
if self.logger:
self.logger(x)
def _init_logs(self, logs, train=True):
"""Extracts metric names from `logs` and initializes table to store
epoch or batch logs.
Returns
-------
tuple
Tuple (`metrics`, `logs_dict`). `metrics` maps metrics, e.g.
`metrics['acc'] = ['acc', 'output_acc1']`. `logs_dict` is a dict of
lists to store logs, e.g. `logs_dict['acc'] = []`.
"""
logs = list(logs)
# Select either only training or validation logs
if train:
logs = [log for log in logs if not log.startswith('val_')]
else:
logs = [log[4:] for log in logs if log.startswith('val_')]
# `metrics` stores for each metric in self.metrics that exists in logs
# the name for the metric itself, followed by all output metrics:
# metrics['acc'] = ['acc', 'output1_acc', 'output2_acc']
metrics = OrderedDict()
for name in self.metrics:
if name in logs:
metrics[name] = [name]
output_logs = [log for log in logs if log.endswith('_' + name)]
if len(output_logs):
if name not in metrics:
# mean 'acc' does not exist in logs, but is added here to
# compute it later over all outputs with `_udpate_means`
metrics[name] = [name]
metrics[name].extend(output_logs)
# `logs_dict` stored the actual logs for each metric in `metrics`
logs_dict = OrderedDict()
# Show mean metrics first
for mean_name in metrics:
logs_dict[mean_name] = []
# Followed by all output metrics
for mean_name, names in six.iteritems(metrics):
for name in names:
logs_dict[name] = []
return metrics, logs_dict
def _update_means(self, logs, metrics):
"""Computes the mean over all outputs, if it does not exist yet."""
for mean_name, names in six.iteritems(metrics):
# Skip, if mean already exists, e.g. loss.
if logs[mean_name][-1] is not None:
continue
mean = 0
count = 0
for name in names:
if name in logs:
value = logs[name][-1]
if value is not None and not np.isnan(value):
mean += value
count += 1
if count:
mean /= count
else:
mean = np.nan
logs[mean_name][-1] = mean
def on_train_begin(self, logs={}):
self._time_start = time()
s = []
s.append('Epochs: %d' % (self.params['epochs']))
s = '\n'.join(s)
self._log(s)
def on_train_end(self, logs={}):
self._log(self._line)
def on_epoch_begin(self, epoch, logs={}):
self._log(self._line)
s = 'Epoch %d/%d' % (epoch + 1, self.params['epochs'])
self._log(s)
self._log(self._line)
self._step = 0
self._steps = self.params['steps']
self._log_freq = int(np.ceil(self.log_freq * self._steps))
self._batch_logs = None
self._totals = None
def on_epoch_end(self, epoch, logs={}):
if self._batch_logs:
self.batch_logs.append(self._batch_logs)
if not self.epoch_logs:
# Initialize epoch metrics and logs
self._epoch_metrics, self.epoch_logs = self._init_logs(logs)
tmp = self._init_logs(logs, False)
self._val_epoch_metrics, self.val_epoch_logs = tmp
# Add new epoch logs to logs table
for metric, metric_logs in six.iteritems(self.epoch_logs):
if metric in logs:
metric_logs.append(logs[metric])
else:
# Add `None` if log value missing
metric_logs.append(None)
self._update_means(self.epoch_logs, self._epoch_metrics)
# Add new validation epoch logs to logs table
for metric, metric_logs in six.iteritems(self.val_epoch_logs):
metric_val = 'val_' + metric
if metric_val in logs:
metric_logs.append(logs[metric_val])
else:
metric_logs.append(None)
self._update_means(self.val_epoch_logs, self._val_epoch_metrics)
# Show table
table = OrderedDict()
table['split'] = ['train']
# Show mean logs first
for mean_name in self._epoch_metrics:
table[mean_name] = []
# Show output logs
if self.verbose:
for mean_name, names in six.iteritems(self._epoch_metrics):
for name in names:
table[name] = []
for name, logs in six.iteritems(self.epoch_logs):
if name in table:
table[name].append(logs[-1])
if self.val_epoch_logs:
table['split'].append('val')
for name, logs in six.iteritems(self.val_epoch_logs):
if name in table:
table[name].append(logs[-1])
self._log('')
self._log(format_table(table, precision=self.precision))
# Trigger callbacks
for callback in self.callbacks:
callback(epoch, self.epoch_logs, self.val_epoch_logs)
def on_batch_end(self, batch, logs={}):
self._step += 1
batch_size = logs.get('size', 0)
if not self._batch_logs:
# Initialize batch metrics and logs table
self._batch_metrics, self._batch_logs = self._init_logs(logs.keys())
# Sum of logs up to the current batch
self._totals = OrderedDict()
# Number of samples up to the current batch
self._nb_totals = OrderedDict()
for name in self._batch_logs:
if name in logs:
self._totals[name] = 0
self._nb_totals[name] = 0
for name, value in six.iteritems(logs):
# Skip value if nan, which can occur if the batch size is small.
if np.isnan(value):
continue
if name in self._totals:
self._totals[name] += value * batch_size
self._nb_totals[name] += batch_size
# Compute the accumulative mean over logs and store it in `_batch_logs`.
for name in self._batch_logs:
if name in self._totals:
if self._nb_totals[name]:
tmp = self._totals[name] / self._nb_totals[name]
else:
tmp = np.nan
else:
tmp = None
self._batch_logs[name].append(tmp)
self._update_means(self._batch_logs, self._batch_metrics)
# Show logs table at a certain frequency
do_log = False
if self._step % self._log_freq == 0:
do_log = True
do_log |= self._step == 1 or self._step == self._steps
if do_log:
table = OrderedDict()
prog = self._step / self._steps
prog *= 100
precision = []
table['done (%)'] = [prog]
precision.append(1)
table['time'] = [(time() - self._time_start) / 60]
precision.append(1)
for mean_name in self._batch_metrics:
table[mean_name] = []
if self.verbose:
for mean_name, names in six.iteritems(self._batch_metrics):
for name in names:
table[name] = []
precision.append(self.precision)
for name, logs in six.iteritems(self._batch_logs):
if name in table:
table[name].append(logs[-1])
precision.append(self.precision)
self._log(format_table(table, precision=precision,
header=self._step == 1))
[docs]class TrainingStopper(Callback):
"""Stop training after certain time or when file is detected.
Parameters
----------
max_time: int
Maximum training time in seconds.
stop_file: str
Name of stop file that triggers the end of training when existing.
verbose: bool
If `True`, log message when training is stopped.
"""
def __init__(self, max_time=None, stop_file=None,
verbose=1, logger=print):
"""max_time in seconds."""
self.max_time = max_time
self.stop_file = stop_file
self.verbose = verbose
self.logger = logger
def on_train_begin(self, logs={}):
self._time_start = time()
def log(self, msg):
if self.verbose:
self.logger(msg)
def on_epoch_end(self, batch, logs={}):
if self.max_time is not None:
elapsed = time() - self._time_start
if elapsed > self.max_time:
self.log('Stopping training after %.2fh' % (elapsed / 3600))
self.model.stop_training = True
if self.stop_file:
if os.path.isfile(self.stop_file):
self.log('Stopping training due to stop file!')
self.model.stop_training = True