#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Keras utils
===========
Utility classes related to Keras.
KerasMixin
^^^^^^^^^^
.. autosummary::
:toctree: generated/
KerasMixin.create_model
KerasMixin.create_callback_list
KerasMixin.create_external_metric_evaluators
KerasMixin.prepare_data
KerasMixin.prepare_activity
KerasMixin.keras_model_exists
KerasMixin.log_model_summary
KerasMixin.plot_model
KerasMixin.get_processing_interval
BaseCallback
^^^^^^^^^^^^
.. autosummary::
:toctree: generated/
BaseCallback
ProgressLoggerCallback
^^^^^^^^^^^^^^^^^^^^^^
Keras callback to store metrics with tqdm progress bar or logging interface. Implements Keras Callback API.
This callback is very similar to standard ``ProgbarLogger`` Keras callback, however it adds support for
logging interface and tqdm based progress bars, and external metrics
(metrics calculated outside Keras training process).
.. autosummary::
:toctree: generated/
ProgressLoggerCallback
ProgressPlotterCallback
^^^^^^^^^^^^^^^^^^^^^^^
Keras callback to plot progress during the training process and save final progress into figure.
Implements Keras Callback API.
.. autosummary::
:toctree: generated/
ProgressPlotterCallback
StopperCallback
^^^^^^^^^^^^^^^
Keras callback to stop training when improvement has not seen in specified amount of epochs.
Implements Keras Callback API.
This Callback is very similar to standard ``EarlyStopping`` Keras callback, however it adds support for
external metrics (metrics calculated outside Keras training process).
.. autosummary::
:toctree: generated/
StopperCallback
StasherCallback
^^^^^^^^^^^^^^^
Keras callback to monitor training process and store best model. Implements Keras Callback API.
This callback is very similar to standard ``ModelCheckpoint`` Keras callback, however it adds support for
external metrics (metrics calculated outside Keras training process).
.. autosummary::
:toctree: generated/
StasherCallback
BaseDataGenerator
^^^^^^^^^^^^^^^^^
.. autosummary::
:toctree: generated/
BaseDataGenerator
BaseDataGenerator.input_size
BaseDataGenerator.data_size
BaseDataGenerator.steps_count
BaseDataGenerator.info
FeatureGenerator
^^^^^^^^^^^^^^^^
.. autosummary::
:toctree: generated/
FeatureGenerator
FeatureGenerator.generator
"""
import os
import sys
import logging
import numpy
import copy
import importlib
import collections
from tqdm import tqdm
from six import iteritems
from .containers import DottedDict
from .utils import SuppressStdoutAndStderr, Timer, SimpleMathStringEvaluator, get_parameter_hash
from .features import FeatureContainer
from .metadata import EventRoll
from .data import DataBuffer
class KerasMixin(object):
"""Class Mixin for Keras based learner containers.
"""
def __getstate__(self):
data = {}
excluded_fields = ['model']
for item in self:
if item not in excluded_fields and self.get(item):
data[item] = copy.deepcopy(self.get(item))
data['model'] = os.path.splitext(self.filename)[0] + '.model.hdf5'
return data
[docs] def keras_model_exists(self):
"""Check that keras model exists on disk
Returns
-------
bool
"""
keras_model_filename = os.path.splitext(self.filename)[0] + '.model.hdf5'
return os.path.isfile(self.filename) and os.path.isfile(keras_model_filename)
[docs] def log_model_summary(self):
"""Prints model summary to the logging interface.
Similar to Keras model summary
"""
layer_name_map = {
'BatchNormalization': 'BatchNorm',
}
import keras
from distutils.version import LooseVersion
import keras.backend as keras_backend
self.logger.debug(' Model summary')
self.logger.debug(
' {type:<15s} | {out:20s} | {param:6s} | {name:21s} | {conn:27s} | {act:7s} | {init:7s}'.format(
type='Layer type',
out='Output',
param='Param',
name='Name',
conn='Connected to',
act='Activ.',
init='Init')
)
self.logger.debug(
' {type:<15s} + {out:20s} + {param:6s} + {name:21s} + {conn:27s} + {act:7s} + {init:6s}'.format(
type='-' * 15,
out='-' * 20,
param='-' * 6,
name='-' * 21,
conn='-' * 27,
act='-' * 7,
init='-' * 6)
)
for layer in self.model.layers:
connections = []
if LooseVersion(keras.__version__) >= LooseVersion('2.1.3'):
for node_index, node in enumerate(layer._inbound_nodes):
for i in range(len(node.inbound_layers)):
inbound_layer = node.inbound_layers[i].name
inbound_node_index = node.node_indices[i]
inbound_tensor_index = node.tensor_indices[i]
connections.append(inbound_layer + '[' + str(inbound_node_index) +
'][' + str(inbound_tensor_index) + ']')
else:
for node_index, node in enumerate(layer.inbound_nodes):
for i in range(len(node.inbound_layers)):
inbound_layer = node.inbound_layers[i].name
inbound_node_index = node.node_indices[i]
inbound_tensor_index = node.tensor_indices[i]
connections.append(inbound_layer + '[' + str(inbound_node_index) +
'][' + str(inbound_tensor_index) + ']')
config = DottedDict(layer.get_config())
layer_name = layer.__class__.__name__
if layer_name in layer_name_map:
layer_name = layer_name_map[layer_name]
if config.get_path('kernel_initializer.class_name') == 'VarianceScaling':
init = str(config.get_path('kernel_initializer.config.distribution', '---'))
elif config.get_path('kernel_initializer.class_name') == 'RandomUniform':
init = 'uniform'
else:
init = '---'
self.logger.debug(
' {type:<15s} | {shape:20s} | {params:6s} | {name:21s} | {connected:27s} | {activation:7s} | {init:7s}'.format(
type=layer_name,
shape=str(layer.output_shape),
params=str(layer.count_params()),
name=str(layer.name),
connected=str(connections[0]) if len(connections) > 0 else '---',
activation=str(config.get('activation', '---')),
init=init,
)
)
trainable_count = int(
numpy.sum([keras_backend.count_params(p) for p in set(self.model.trainable_weights)])
)
non_trainable_count = int(
numpy.sum([keras_backend.count_params(p) for p in set(self.model.non_trainable_weights)])
)
self.logger.debug(' ')
self.logger.debug(' Parameters')
self.logger.debug(' Trainable\t[{param_count:,}]'.format(param_count=int(trainable_count)))
self.logger.debug(' Non-Trainable\t[{param_count:,}]'.format(param_count=int(non_trainable_count)))
self.logger.debug(
' Total\t\t[{param_count:,}]'.format(param_count=int(trainable_count + non_trainable_count)))
self.logger.debug(' ')
[docs] def plot_model(self, filename='model.png', show_shapes=True, show_layer_names=True):
"""Plots model topology
"""
from keras.utils.visualize_util import plot
plot(self.model, to_file=filename, show_shapes=show_shapes, show_layer_names=show_layer_names)
[docs] def prepare_data(self, data, files, processor='default'):
"""Concatenate feature data into one feature matrix
Parameters
----------
data : dict of FeatureContainers
Feature data
files : list of str
List of filenames
processor : str ('default', 'training')
Data processor selector
Default value 'default'
Returns
-------
numpy.ndarray
Features concatenated
"""
if self.learner_params.get_path('input_sequencer.enable'):
processed_data = []
for item in files:
if processor == 'training':
processed_data.append(
self.data_processor_training.process_data(
data=data[item].feat[0]
)
)
else:
processed_data.append(
self.data_processor.process_data(
data=data[item].feat[0]
)
)
return numpy.concatenate(processed_data)
else:
return numpy.vstack([data[x].feat[0] for x in files])
[docs] def prepare_activity(self, activity_matrix_dict, files, processor='default'):
"""Concatenate activity matrices into one activity matrix
Parameters
----------
activity_matrix_dict : dict of binary matrices
Meta data
files : list of str
List of filenames
processor : str ('default', 'training')
Data processor selector
Default value 'default'
Returns
-------
numpy.ndarray
Activity matrix
"""
if self.learner_params.get_path('input_sequencer.enable'):
processed_activity = []
for item in files:
if processor == 'training':
processed_activity.append(
self.data_processor_training.process_activity_data(
activity_data=activity_matrix_dict[item]
)
)
else:
processed_activity.append(
self.data_processor.process_activity_data(
activity_data=activity_matrix_dict[item]
)
)
return numpy.concatenate(processed_activity)
else:
return numpy.vstack([activity_matrix_dict[x] for x in files])
[docs] def create_model(self, input_shape):
"""Create sequential Keras model
"""
from keras.models import Sequential
self.model = Sequential()
tuple_fields = [
'input_shape',
'kernel_size',
'pool_size',
'dims',
'target_shape'
]
# Get model config parameters
model_params = copy.deepcopy(self.learner_params.get_path('model.config'))
# Get constants for model
constants = copy.deepcopy(self.learner_params.get_path('model.constants', {}))
constants['CLASS_COUNT'] = len(self.class_labels)
constants['FEATURE_VECTOR_LENGTH'] = input_shape
if self.learner_params.get_path('input_sequencer.frames'):
constants['INPUT_SEQUENCE_LENGTH'] = self.learner_params.get_path('input_sequencer.frames')
def process_field(value, constants_dict):
math_eval = SimpleMathStringEvaluator()
if isinstance(value, str):
# String field
if value in constants_dict:
return constants_dict[value]
elif len(value.split()) > 1:
sub_fields = value.split()
for subfield_id, subfield in enumerate(sub_fields):
if subfield in constants_dict:
sub_fields[subfield_id] = str(constants_dict[subfield])
return math_eval.eval(''.join(sub_fields))
else:
return value
elif isinstance(value, list):
processed_value_list = []
for item_id, item in enumerate(value):
processed_value_list.append(process_field(value=item, constants_dict=constants_dict))
return processed_value_list
else:
return value
# Inject constant into constants with equations
for field in list(constants.keys()):
constants[field] = process_field(value=constants[field], constants_dict=constants)
# Setup layers
for layer_id, layer_setup in enumerate(model_params):
# Get layer parameters
layer_setup = DottedDict(layer_setup)
if 'config' not in layer_setup:
layer_setup['config'] = {}
# Get layer class
try:
layer_class = getattr(
importlib.import_module("keras.layers"),
layer_setup['class_name']
)
except AttributeError:
message = '{name}: Invalid Keras layer type [{type}].'.format(
name=self.__class__.__name__,
type=layer_setup['class_name']
)
self.logger.exception(message)
raise AttributeError(message)
# Inject constants
for config_field in list(layer_setup['config'].keys()):
layer_setup['config'][config_field] = process_field(
value=layer_setup['config'][config_field],
constants_dict=constants
)
# Convert lists into tuples
for field in tuple_fields:
if field in layer_setup['config']:
layer_setup['config'][field] = tuple(layer_setup['config'][field])
# Inject input shape for Input layer if not given
if layer_id == 0 and layer_setup.get_path('config.input_shape') is None:
# Set input layer dimension for the first layer if not set
layer_setup['config']['input_shape'] = (input_shape,)
if 'wrapper' in layer_setup:
# Get layer wrapper class
try:
wrapper_class = getattr(
importlib.import_module("keras.layers"),
layer_setup['wrapper']
)
except AttributeError:
message = '{name}: Invalid Keras layer wrapper type [{type}].'.format(
name=self.__class__.__name__,
type=layer_setup['wrapper']
)
self.logger.exception(message)
raise AttributeError(message)
wrapper_parameters = layer_setup.get('config_wrapper', {})
if layer_setup.get('config'):
self.model.add(
wrapper_class(layer_class(**dict(layer_setup.get('config'))), **dict(wrapper_parameters)))
else:
self.model.add(wrapper_class(layer_class(), **dict(wrapper_parameters)))
else:
if layer_setup.get('config'):
self.model.add(layer_class(**dict(layer_setup.get('config'))))
else:
self.model.add(layer_class())
# Get Optimizer class
try:
optimizer_class = getattr(
importlib.import_module("keras.optimizers"),
self.learner_params.get_path('model.optimizer.type')
)
except AttributeError:
message = '{name}: Invalid Keras optimizer type [{type}].'.format(
name=self.__class__.__name__,
type=self.learner_params.get_path('model.optimizer.type')
)
self.logger.exception(message)
raise AttributeError(message)
# Compile the model
self.model.compile(
loss=self.learner_params.get_path('model.loss'),
optimizer=optimizer_class(**dict(self.learner_params.get_path('model.optimizer.parameters', {}))),
metrics=self.learner_params.get_path('model.metrics')
)
[docs] def create_callback_list(self):
"""Create list of Keras callbacks
"""
callbacks = []
# Fetch processing interval
processing_interval = self.get_processing_interval()
# Collect all external metrics
external_metrics = collections.OrderedDict()
if self.learner_params.get_path('training.epoch_processing.enable'):
if self.learner_params.get_path('validation.enable') and self.learner_params.get_path(
'training.epoch_processing.external_metrics.enable'):
for metric in self.learner_params.get_path('training.epoch_processing.external_metrics.metrics'):
current_metric_name = metric.get('name')
current_metric_label = metric.get('label', current_metric_name.split('.')[-1])
external_metrics[current_metric_label] = current_metric_name
# ProgressLoggerCallback
from dcase_framework.keras_utils import ProgressLoggerCallback
callbacks.append(
ProgressLoggerCallback(
metric=self.learner_params.get_path('model.metrics')[0],
loss=self.learner_params.get_path('model.loss'),
disable_progress_bar=self.disable_progress_bar,
log_progress=self.log_progress,
epochs=self.learner_params.get_path('training.epochs'),
close_progress_bar=not self.learner_params.get_path('training.epoch_processing.enable'),
manual_update=self.learner_params.get_path('training.epoch_processing.enable'),
manual_update_interval=processing_interval,
external_metric_labels=external_metrics
)
)
# Add model callbacks
for cp in self.learner_params.get_path('training.callbacks', []):
cp_params = DottedDict(cp.get('parameters', {}))
if cp['type'] == 'Plotter':
from dcase_framework.keras_utils import ProgressPlotterCallback
callbacks.append(
ProgressPlotterCallback(
filename=os.path.splitext(self.filename)[0] + '.' + cp_params.get('output_format', 'pdf'),
metric=self.learner_params.get_path('model.metrics')[0],
loss=self.learner_params.get_path('model.loss'),
disable_progress_bar=self.disable_progress_bar,
log_progress=self.log_progress,
epochs=self.learner_params.get_path('training.epochs'),
close_progress_bar=not self.learner_params.get_path('training.epoch_processing.enable'),
manual_update=self.learner_params.get_path('training.epoch_processing.enable'),
interactive=cp_params.get('interactive', True),
save=cp_params.get('save', True),
focus_span=cp_params.get('focus_span'),
plotting_rate=cp_params.get('plotting_rate'),
external_metric_labels=external_metrics
)
)
elif cp['type'] == 'Stopper':
from dcase_framework.keras_utils import StopperCallback
callbacks.append(
StopperCallback(
epochs=self.learner_params.get_path('training.epochs'),
manual_update=self.learner_params.get_path('training.epoch_processing.enable'),
external_metric_labels=external_metrics,
**cp_params
)
)
elif cp['type'] == 'Stasher':
from dcase_framework.keras_utils import StasherCallback
callbacks.append(
StasherCallback(
epochs=self.learner_params.get_path('training.epochs'),
manual_update=self.learner_params.get_path('training.epoch_processing.enable'),
external_metric_labels=external_metrics,
**cp_params
)
)
else:
# Keras standard callbacks
if cp['type'] == 'ModelCheckpoint' and not cp['parameters'].get('filepath'):
cp['parameters']['filepath'] = os.path.splitext(self.filename)[0] + \
'.weights.{epoch:02d}-{val_loss:.2f}.hdf5'
if cp['type'] == 'EarlyStopping' and cp.get('parameters').get('monitor').startswith('val_') \
and not self.learner_params.get_path('validation.enable', False):
message = '{name}: Cannot use callback type [{type}] with monitor parameter [{monitor}] ' \
'as there is no validation set.'.format(name=self.__class__.__name__,
type=cp['type'],
monitor=cp.get('parameters').get('monitor')
)
self.logger.exception(message)
raise AttributeError(message)
try:
callback_class = getattr(importlib.import_module("keras.callbacks"), cp['type'])
callbacks.append(callback_class(**cp_params))
except AttributeError:
message = '{name}: Invalid Keras callback type [{type}]'.format(
name=self.__class__.__name__,
type=cp['type']
)
self.logger.exception(message)
raise AttributeError(message)
return callbacks
[docs] def create_external_metric_evaluators(self):
"""Create external metric evaluators
"""
# Initialize external metrics
external_metric_evaluators = collections.OrderedDict()
if self.learner_params.get_path('training.epoch_processing.enable'):
if self.learner_params.get_path('validation.enable') and self.learner_params.get_path(
'training.epoch_processing.external_metrics.enable'):
import sed_eval
for metric in self.learner_params.get_path('training.epoch_processing.external_metrics.metrics'):
# Current metric info
current_metric_evaluator = metric.get('evaluator')
current_metric_name = metric.get('name')
current_metric_params = metric.get('parameters', {})
current_metric_label = metric.get('label', current_metric_name.split('.')[-1])
# Initialize sed_eval evaluators
if current_metric_evaluator == 'sed_eval.scene':
evaluator = sed_eval.scene.SceneClassificationMetrics(
scene_labels=self.class_labels,
**current_metric_params
)
elif (current_metric_evaluator == 'sed_eval.segment_based' or
current_metric_evaluator == 'sed_eval.sound_event.segment_based'):
evaluator = sed_eval.sound_event.SegmentBasedMetrics(
event_label_list=self.class_labels,
**current_metric_params
)
elif (current_metric_evaluator == 'sed_eval.event_based' or
current_metric_evaluator == 'sed_eval.sound_event.event_based'):
evaluator = sed_eval.sound_event.EventBasedMetrics(
event_label_list=self.class_labels,
**current_metric_params
)
else:
message = '{name}: Unknown target metric [{metric}].'.format(
name=self.__class__.__name__,
metric=current_metric_name
)
self.logger.exception(message)
raise AssertionError(message)
# Check evaluator API
if (not hasattr(evaluator, 'reset') or
not hasattr(evaluator, 'evaluate') or
not hasattr(evaluator, 'results')):
if current_metric_evaluator.startswith('sed_eval'):
message = '{name}: wrong version of sed_eval for [{current_metric_evaluator}::{current_metric_name}], update sed_eval to latest version'.format(
name=self.__class__.__name__,
current_metric_evaluator=current_metric_evaluator,
current_metric_name=current_metric_name
)
self.logger.exception(message)
raise ValueError(message)
else:
message = '{name}: Evaluator has invalid API [{current_metric_evaluator}::{current_metric_name}]'.format(
name=self.__class__.__name__,
current_metric_evaluator=current_metric_evaluator,
current_metric_name=current_metric_name
)
self.logger.exception(message)
raise ValueError(message)
# Form unique name for metric, to allow multiple similar metrics with different parameters
metric_id = get_parameter_hash(metric)
# Metric data container
metric_data = {
'evaluator_name': current_metric_evaluator,
'name': current_metric_name,
'params': current_metric_params,
'label': current_metric_label,
'path': current_metric_name,
'evaluator': evaluator,
}
external_metric_evaluators[metric_id] = metric_data
return external_metric_evaluators
[docs] def get_processing_interval(self):
"""Processing interval
"""
processing_interval = 1
if self.learner_params.get_path('training.epoch_processing.enable'):
if self.learner_params.get_path('training.epoch_processing.external_metrics.enable'):
processing_interval = self.learner_params.get_path(
'training.epoch_processing.external_metrics.evaluation_interval', 1)
return processing_interval
def _after_load(self, to_return=None):
with SuppressStdoutAndStderr():
# Setup Keras if not yet set up. This is needed as keras has tensorflow as default backend, and this will
# give error if it is not installed and theano is not set up as backend.
self._setup_keras()
from keras.models import load_model
keras_model_filename = os.path.splitext(self.filename)[0] + '.model.hdf5'
if os.path.isfile(keras_model_filename):
with SuppressStdoutAndStderr():
self.model = load_model(keras_model_filename)
else:
message = '{name}: Keras model not found [{filename}]'.format(
name=self.__class__.__name__,
filename=keras_model_filename
)
self.logger.exception(message)
raise IOError(message)
def _after_save(self, to_return=None):
# Save keras model and weight
keras_model_filename = os.path.splitext(self.filename)[0] + '.model.hdf5'
model_weights_filename = os.path.splitext(self.filename)[0] + '.weights.hdf5'
self.model.save(keras_model_filename)
self.model.save_weights(model_weights_filename)
def _setup_keras(self):
"""Setup keras backend and parameters
"""
if not hasattr(self, 'keras_setup_done') or not self.keras_setup_done:
# Get BLAS library associated to numpy
if numpy.__config__.blas_opt_info and 'libraries' in numpy.__config__.blas_opt_info:
blas_libraries = numpy.__config__.blas_opt_info['libraries']
else:
blas_libraries = ['']
blas_extra_info = []
# Set backend and parameters before importing keras
if self.show_extra_debug:
self.logger.debug(' Keras')
self.logger.debug(' Backend \t[{backend}]'.format(
backend=self.learner_params.get_path('keras.backend', 'theano'))
)
# Threading
if self.learner_params.get_path('keras.backend_parameters.threads'):
thread_count = self.learner_params.get_path('keras.backend_parameters.threads', 1)
os.environ['GOTO_NUM_THREADS'] = str(thread_count)
os.environ['OMP_NUM_THREADS'] = str(thread_count)
os.environ['MKL_NUM_THREADS'] = str(thread_count)
blas_extra_info.append('Threads[{threads}]'.format(threads=thread_count))
if thread_count > 1:
os.environ['OMP_DYNAMIC'] = 'False'
os.environ['MKL_DYNAMIC'] = 'False'
else:
os.environ['OMP_DYNAMIC'] = 'True'
os.environ['MKL_DYNAMIC'] = 'True'
# Conditional Numerical Reproducibility (CNR) for MKL BLAS library
if self.learner_params.get_path('keras.backend_parameters.CNR', True) and blas_libraries[0].startswith('mkl'):
os.environ['MKL_CBWR'] = 'COMPATIBLE'
blas_extra_info.append('MKL_CBWR[{mode}]'.format(mode='COMPATIBLE'))
# Show BLAS info
if self.show_extra_debug:
if numpy.__config__.blas_opt_info and 'libraries' in numpy.__config__.blas_opt_info:
blas_libraries = numpy.__config__.blas_opt_info['libraries']
if blas_libraries[0].startswith('openblas'):
self.logger.debug(' BLAS library\t[OpenBLAS]\t\t({info})'.format(
info=', '.join(blas_extra_info))
)
elif blas_libraries[0].startswith('blas'):
self.logger.debug(
' BLAS library\t[BLAS/Atlas]\t\t({info})'.format(
info=', '.join(blas_extra_info))
)
elif blas_libraries[0].startswith('mkl'):
self.logger.debug(' BLAS library\t[MKL]\t\t({info})'.format(
info=', '.join(blas_extra_info))
)
# Select Keras backend
os.environ["KERAS_BACKEND"] = self.learner_params.get_path('keras.backend', 'theano')
if self.learner_params.get_path('keras.backend', 'theano') == 'theano':
# Theano setup
if self.show_extra_debug:
self.logger.debug(' Theano')
# Default flags
flags = [
# 'ldflags=',
'warn.round=False',
]
# Set device
if self.learner_params.get_path('keras.backend_parameters.device'):
flags.append('device=' + self.learner_params.get_path('keras.backend_parameters.device', 'cpu'))
if self.show_extra_debug:
self.logger.debug(' Device \t\t[{device}]'.format(
device=self.learner_params.get_path('keras.backend_parameters.device', 'cpu'))
)
# Set floatX
if self.learner_params.get_path('keras.backend_parameters.floatX'):
flags.append('floatX=' + self.learner_params.get_path('keras.backend_parameters.floatX', 'float32'))
if self.show_extra_debug:
self.logger.debug(' floatX \t\t[{float}]'.format(
float=self.learner_params.get_path('keras.backend_parameters.floatX', 'float32'))
)
# Set optimizer
if self.learner_params.get_path('keras.backend_parameters.optimizer') is not None:
if self.learner_params.get_path('keras.backend_parameters.optimizer') in ['fast_run', 'merge',
'fast_compile', 'None']:
flags.append('optimizer=' + self.learner_params.get_path('keras.backend_parameters.optimizer'))
if self.show_extra_debug:
self.logger.debug(' Optimizer \t[{optimizer}]'.format(
optimizer=self.learner_params.get_path('keras.backend_parameters.optimizer', 'None'))
)
# Set fastmath for GPU mode only
if self.learner_params.get_path('keras.backend_parameters.fastmath') and self.learner_params.get_path(
'keras.backend_parameters.device', 'cpu') != 'cpu':
if self.learner_params.get_path('keras.backend_parameters.fastmath', False):
flags.append('nvcc.fastmath=True')
else:
flags.append('nvcc.fastmath=False')
if self.show_extra_debug:
self.logger.debug(' NVCC fastmath \t[{flag}]'.format(
flag=str(self.learner_params.get_path('keras.backend_parameters.fastmath', False)))
)
# Set OpenMP
if self.learner_params.get_path('keras.backend_parameters.openmp') is not None:
if self.learner_params.get_path('keras.backend_parameters.openmp', False):
flags.append('openmp=True')
else:
flags.append('openmp=False')
if self.show_extra_debug:
self.logger.debug(' OpenMP\t\t[{flag}]'.format(
flag=str(self.learner_params.get_path('keras.backend_parameters.openmp', False)))
)
# Set environmental variable for Theano
os.environ["THEANO_FLAGS"] = ','.join(flags)
elif self.learner_params.get_path('keras.backend', 'tensorflow') == 'tensorflow':
# Tensorflow setup
if self.show_extra_debug:
self.logger.debug(' Tensorflow')
# Set device
if self.learner_params.get_path('keras.backend_parameters.device', 'cpu'):
# In case of CPU disable visible GPU.
if self.learner_params.get_path('keras.backend_parameters.device', 'cpu') == 'cpu':
os.environ["CUDA_VISIBLE_DEVICES"] = ''
if self.show_extra_debug:
self.logger.debug(' Device \t\t[{device}]'.format(
device=self.learner_params.get_path('keras.backend_parameters.device', 'cpu')))
else:
message = '{name}: Keras backend not supported [backend].'.format(
name=self.__class__.__name__,
backend=self.learner_params.get_path('keras.backend')
)
self.logger.exception(message)
raise AssertionError(message)
if self.show_extra_debug:
self.logger.debug(' ')
self.keras_setup_done = True
[docs]class BaseCallback(object):
"""Base class for Callbacks
"""
[docs] def __init__(self, *args, **kwargs):
self.params = None
self.model = None
self.verbose = kwargs.get('verbose', True)
self.manual_update = kwargs.get('manual_update', False)
self.epochs = kwargs.get('epochs')
self.epoch = 0
self.external_metric_labels = kwargs.get('external_metric_labels', collections.OrderedDict())
self.external_metric = collections.OrderedDict()
self.keras_metrics = [
'binary_accuracy',
'categorical_accuracy',
'sparse_categorical_accuracy',
'top_k_categorical_accuracy'
]
self.logger = logging.getLogger(__name__)
def set_model(self, model):
self.model = model
def set_params(self, params):
self.params = params
def on_train_begin(self, logs=None):
pass
def on_train_end(self, logs=None):
pass
def on_epoch_begin(self, epoch, logs=None):
pass
def on_batch_begin(self, batch, logs=None):
pass
def on_batch_end(self, batch, logs=None):
pass
def on_epoch_end(self, epoch, logs=None):
pass
def update(self):
pass
def add_external_metric(self, metric_label):
pass
def set_external_metric_value(self, metric_label, metric_value):
pass
def get_operator(self, metric):
metric = metric.lower()
if metric.endswith('error_rate') or metric.endswith('er'):
return numpy.less
elif (metric.endswith('f_measure') or
metric.endswith('fmeasure') or
metric.endswith('fscore') or
metric.endswith('f-score')):
return numpy.greater
elif metric.endswith('accuracy') or metric.endswith('acc'):
return numpy.greater
else:
return numpy.less
[docs]class ProgressLoggerCallback(BaseCallback):
"""Keras callback to store metrics with tqdm progress bar or logging interface. Implements Keras Callback API.
This callback is very similar to standard ``ProgbarLogger`` Keras callback, however it adds support for logging
interface and tqdm based progress bars, and external metrics (metrics calculated outside Keras training process).
"""
[docs] def __init__(self, *args, **kwargs):
"""Constructor
Parameters
----------
epochs : int
Total amount of epochs
metric : str
Metric name
manual_update : bool
Manually update callback, use this to when injecting external metrics
Default value True
manual_update_interval : int
Epoch interval for manual update, used anticipate updates
Default value 1
disable_progress_bar : bool
Disable tqdm based progress bar
Default value False
close_progress_bar : bool
Close tqdm progress bar on training end
Default value True
log_progress : bool
Print progress into logging interface
Default value False
external_metric_labels : dict or OrderedDict
Dictionary with
"""
super(ProgressLoggerCallback, self).__init__(*args, **kwargs)
if isinstance(kwargs.get('metric'), str):
self.metric = kwargs.get('metric')
elif callable(kwargs.get('metric')):
self.metric = kwargs.get('metric').__name__
self.loss = kwargs.get('loss')
self.disable_progress_bar = kwargs.get('disable_progress_bar', False)
self.close_progress_bar = kwargs.get('close_progress_bar', True)
self.manual_update_interval = kwargs.get('manual_update_interval', 1)
self.log_progress = kwargs.get('log_progress', False)
self.timer = Timer()
self.progress_bar = None
self.validation_data = None
self.seen = 0
self.log_values = []
self.logger = logging.getLogger(__name__)
self.postfix = collections.OrderedDict()
self.postfix['l_tra'] = None
self.postfix['l_val'] = None
self.postfix['m_tra'] = None
self.postfix['m_val'] = None
self.data = {
'l_tra': numpy.empty((self.epochs,)),
'l_val': numpy.empty((self.epochs,)),
'm_tra': numpy.empty((self.epochs,)),
'm_val': numpy.empty((self.epochs,)),
}
self.data['l_tra'][:] = numpy.nan
self.data['l_val'][:] = numpy.nan
self.data['m_tra'][:] = numpy.nan
self.data['m_val'][:] = numpy.nan
for metric_label in self.external_metric_labels:
self.data[metric_label] = numpy.empty((self.epochs,))
self.data[metric_label][:] = numpy.nan
self.header_show = False
self.last_update_epoch = 0
self.target = None
def on_train_begin(self, logs=None):
if self.epochs is None:
self.epochs = self.params['epochs']
if self.log_progress and not self.header_show:
# Show header only once
self.header_show = True
self.logger.info(' Training')
header_extra1 = ' {epoch:<5s} | {loss:<19s} | {metric:<19s} | '.format(
epoch=' '*5,
loss='Loss',
metric='Metric',
validation=' '*8,
)
if self.external_metric_labels:
line = '{external_value:<'+str(12*len(self.external_metric_labels))+'s} '
header_extra1 += line.format(
external_value='External metric',
)
header_extra1 += '{time:<15s}'.format(
time=' '*15,
)
loss_label = self.loss
if len(loss_label) > 19:
loss_label = loss_label[0:17]+'..'
metric_label = self.metric
if len(metric_label) > 19:
metric_label = metric_label[0:17]+'..'
header_extra2 = ' {epoch:<5s} | {loss:<19s} | {metric:<19s} | '.format(
epoch=' '*5,
loss=loss_label,
metric=metric_label,
validation=' '*8,
)
if self.external_metric_labels:
for metric_label in self.external_metric_labels:
header_extra2 += '{label:<10s} | '.format(label=metric_label)
header_extra2 += '{time:<15s}'.format(
time=' '*15,
)
header_main = ' {epoch:<5s} | {loss:<8s} | {val_loss:<8s} | {train:<8s} | {validation:<8s} | '.format(
epoch='Epoch',
loss='Train',
val_loss='Val',
train='Train',
validation='Val',
)
if self.external_metric_labels:
for metric_label in self.external_metric_labels:
header_main += '{label:<10s} | '.format(label='Val')
header_main += '{time:<15s}'.format(
time='Time'
)
sep = ' {epoch:<5s} + {loss:<8s} + {val_loss:<8s} + {train:<8s} + {validation:<8s} + '.format(
epoch='-'*5,
loss='-'*8,
val_loss='-' * 8,
train='-'*8,
validation='-'*8,
)
if self.external_metric_labels:
for metric_label in self.external_metric_labels:
sep += '{external_value:<10s} + '.format(
external_value='-'*10,
)
sep += '{time:<15s}'.format(
time='-'*15,
)
self.logger.info(header_extra1)
self.logger.info(header_extra2)
self.logger.info(header_main)
self.logger.info(sep)
elif self.progress_bar is None:
self.progress_bar = tqdm(total=self.epochs,
initial=self.epoch,
file=sys.stdout,
desc=' {0:>6s}'.format('Learn'),
leave=False,
miniters=1,
disable=self.disable_progress_bar
)
def on_train_end(self, logs=None):
if not self.log_progress and self.close_progress_bar:
self.progress_bar.close()
def on_epoch_begin(self, epoch, logs=None):
self.epoch = epoch + 1
if 'steps' in self.params:
self.target = self.params['steps']
elif 'samples' in self.params:
self.target = self.params['samples']
self.seen = 0
self.timer.start()
def on_batch_begin(self, batch, logs=None):
if self.target and self.seen < self.target:
self.log_values = []
def on_batch_end(self, batch, logs=None):
logs = logs or {}
batch_size = logs.get('size', 0)
self.seen += batch_size
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
def on_epoch_end(self, epoch, logs=None):
self.timer.stop()
self.epoch = epoch
logs = logs or {}
# Reset values
self.postfix['l_tra'] = None
self.postfix['l_val'] = None
self.postfix['m_tra'] = None
self.postfix['m_val'] = None
# Collect values
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
if k == 'loss':
self.data['l_tra'][self.epoch] = logs[k]
self.postfix['l_tra'] = '{:4.3f}'.format(logs[k])
elif k == 'val_loss':
self.data['l_val'][self.epoch] = logs[k]
self.postfix['l_val'] = '{:4.3f}'.format(logs[k])
elif self.metric and k.endswith(self.metric):
if k.startswith('val_'):
self.data['m_val'][self.epoch] = logs[k]
self.postfix['m_val'] = '{:4.3f}'.format(logs[k])
else:
self.data['m_tra'][self.epoch] = logs[k]
self.postfix['m_tra'] = '{:4.3f}'.format(logs[k])
for metric_label in self.external_metric_labels:
if metric_label in self.external_metric:
metric_name = self.external_metric_labels[metric_label]
value = self.external_metric[metric_label]
if metric_name.endswith('f_measure') or metric_name.endswith('f_score'):
self.postfix[metric_label] = '{:3.1f}'.format(value*100)
else:
self.postfix[metric_label] = '{:4.3f}'.format(value)
if (not self.manual_update or
(self.epoch - self.last_update_epoch > 0 and (self.epoch+1) % self.manual_update_interval)):
# Update logged progress
if self.log_progress:
self.update_progress_log()
# Increase iteration count and update progress bar
if not self.log_progress:
self.update_progress_bar(increase=1)
def update(self):
"""Update
"""
if self.log_progress:
self.update_progress_log()
else:
self.update_progress_bar()
self.last_update_epoch = self.epoch
def update_progress_log(self):
"""Update progress to logging interface
"""
if self.log_progress and self.epoch - self.last_update_epoch:
output = ' '
output += '{epoch:<5s} |'.format(epoch='{:d}'.format(self.epoch))
output += ' {loss:<8s} |'.format(loss='{:4.6f}'.format(self.data['l_tra'][self.epoch]))
if 'l_val' in self.postfix:
output += ' {val_loss:<8s} |'.format(val_loss='{:4.6f}'.format(self.data['l_val'][self.epoch]))
else:
output += ' {val_loss:<8s} |'.format(val_loss=' '*8)
output += ' {train:<8s} |'.format(train='{:4.6f}'.format(self.data['m_tra'][self.epoch]))
if self.postfix['m_val']:
output += ' {validation:<8s} |'.format(validation='{:4.6f}'.format(self.data['m_val'][self.epoch]))
else:
output += ' {validation:<8s} |'.format(validation=' '*8)
for metric_label in self.external_metric_labels:
if metric_label in self.external_metric:
value = self.data[metric_label][self.epoch]
if numpy.isnan(value):
value = ' '*10
else:
if(self.external_metric_labels[metric_label].endswith('f_measure') or
self.external_metric_labels[metric_label].endswith('f_score')):
value = '{:3.3f}'.format(float(value)*100)
else:
value = '{:4.3f}'.format(float(value))
output += ' {external_value:<10s} |'.format(
external_value=value
)
else:
output += ' {external_value:<10s} |'.format(
external_value=' '*10
)
output += ' {time:<15s}'.format(
time=self.timer.get_string()
)
self.logger.info(output)
def update_progress_bar(self, increase=0):
"""Update progress to tqdm progress bar
"""
self.progress_bar.set_postfix(self.postfix)
self.progress_bar.update(increase)
def add_external_metric(self, metric_id):
"""Add external metric to be monitored
Parameters
----------
metric_id : str
Metric name
"""
if metric_id not in self.external_metric_labels:
self.external_metric_labels[metric_id] = metric_id
if metric_id not in self.data:
self.data[metric_id] = numpy.empty((self.epochs,))
self.data[metric_id][:] = numpy.nan
def set_external_metric_value(self, metric_label, metric_value):
"""Add external metric value
Parameters
----------
metric_label : str
Metric label
metric_value : numeric
Metric value
"""
self.external_metric[metric_label] = metric_value
self.data[metric_label][self.epoch] = metric_value
def close(self):
"""Manually close progress logging
"""
if not self.log_progress and self.close_progress_bar:
self.progress_bar.close()
[docs]class ProgressPlotterCallback(ProgressLoggerCallback):
"""Keras callback to plot progress during the training process and save final progress into figure.
Implements Keras Callback API.
"""
[docs] def __init__(self, *args, **kwargs):
"""Constructor
Parameters
----------
epochs : int
Total amount of epochs
metric : str
Metric name
manual_update : bool
Manually update callback, use this to when injecting external metrics
Default value True
interactive : bool
Show plot during the training and update with plotting rate
Default value True
plotting_rate : int
Plot update rate in seconds
Default value 10
save : bool
Save plot on disk, plotting rate applies
filename : str
Filename of figure
Default value 1
focus_span : int
Epoch amount to highlight, and show separately in the plot.
Default value 10
"""
super(ProgressPlotterCallback, self).__init__(*args, **kwargs)
self.filename = kwargs.get('filename')
# Get file format for the output plot
file_extension = os.path.splitext(self.filename)[1]
if file_extension == '.eps':
self.format = 'eps'
elif file_extension == '.svg':
self.format = 'svg'
elif file_extension == '.pdf':
self.format = 'pdf'
elif file_extension == '.png':
self.format = 'png'
self.plotting_rate = kwargs.get('plotting_rate', 10)
self.interactive = kwargs.get('interactive', True)
self.save = kwargs.get('save', True)
self.focus_span = kwargs.get('focus_span', 10)
if self.focus_span > self.epochs:
self.focus_span = self.epochs
self.timer.start()
self.data = {
'l_tra': numpy.empty((self.epochs,)),
'l_val': numpy.empty((self.epochs,)),
'm_tra': numpy.empty((self.epochs,)),
'm_val': numpy.empty((self.epochs,)),
}
self.data['l_tra'][:] = numpy.nan
self.data['l_val'][:] = numpy.nan
self.data['m_tra'][:] = numpy.nan
self.data['m_val'][:] = numpy.nan
for metric_label in self.external_metric_labels:
self.data[metric_label] = numpy.empty((self.epochs,))
self.data[metric_label][:] = numpy.nan
self.ax1_1 = None
self.ax1_2 = None
self.ax2_1 = None
self.ax2_2 = None
self.extra_main = {}
self.extra_highlight = {}
import matplotlib.pyplot as plt
import warnings
import matplotlib.cbook
warnings.filterwarnings("ignore", category=matplotlib.cbook.mplDeprecation)
figure_height = 8
if len(self.external_metric_labels) > 2:
figure_height = 8 + len(self.external_metric_labels)
self.figure = plt.figure(num=None, figsize=(18, figure_height), dpi=80, facecolor='w', edgecolor='k')
self.draw()
if self.interactive:
plt.show(block=False)
plt.pause(0.1)
def draw(self):
"""Draw plot
"""
import matplotlib.patches as patches
import matplotlib.pyplot as plt
plt.figure(self.figure.number)
row_count = 2+len(self.external_metric_labels)
self.ax1_1 = plt.subplot2grid((row_count, 4), (0, 0), rowspan=1, colspan=3)
self.ax1_2 = plt.subplot2grid((row_count, 4), (0, 3), rowspan=1, colspan=1)
self.ax2_1 = plt.subplot2grid((row_count, 4), (1, 0), rowspan=1, colspan=3)
self.ax2_2 = plt.subplot2grid((row_count, 4), (1, 3), rowspan=1, colspan=1)
self.extra_main = {}
self.extra_highlight = {}
row_id = 2
for metric_label in self.external_metric_labels:
self.extra_main[metric_label] = plt.subplot2grid((row_count, 4), (row_id, 0), rowspan=1, colspan=3)
self.extra_highlight[metric_label] = plt.subplot2grid((row_count, 4), (row_id, 3), rowspan=1, colspan=1)
row_id += 1
span = [self.epoch - self.focus_span, self.epoch]
if span[0] < 0:
span[0] = 0
# PLOT 1 / Main
self.ax1_1.cla()
self.ax1_1.set_title('Loss')
self.ax1_1.set_ylabel('Model Loss')
self.ax1_1.plot(
numpy.arange(self.epochs),
self.data['l_tra'],
lw=3,
color='red',
)
self.ax1_1.plot(
numpy.arange(self.epochs),
self.data['l_val'],
lw=3,
color='green',
)
self.ax1_1.add_patch(
patches.Rectangle(
(span[0], self.ax1_1.get_ylim()[0]), # (x,y)
width=span[1]-span[0],
height=self.ax1_1.get_ylim()[1],
facecolor="#000000",
alpha=0.05
)
)
# Horizontal lines
if not numpy.all(numpy.isnan(self.data['l_tra'])):
self.ax1_1.axhline(y=numpy.nanmin(self.data['l_tra']), lw=1, color='red', linestyle='--')
self.ax1_1.axhline(y=numpy.nanmin(self.data['l_val']), lw=1, color='green', linestyle='--')
self.ax1_1.legend(['Train', 'Validation'], loc='upper right')
self.ax1_1.set_xlim([0, self.epochs - 1])
self.ax1_1.set_xticklabels([])
self.ax1_1.grid(True)
# PLOT 1 / Highlighted area
self.ax1_2.cla()
self.ax1_2.set_title('Loss / Highlighted area')
self.ax1_2.set_ylabel('Model Loss')
self.ax1_2.plot(
numpy.arange(span[0], span[1]),
self.data['l_tra'][span[0]:span[1]],
lw=3,
color='red',
)
self.ax1_2.plot(
numpy.arange(span[0], span[1]),
self.data['l_val'][span[0]:span[1]],
lw=3,
color='green',
)
self.ax1_2.set_xticklabels([])
self.ax1_2.grid(True)
self.ax1_2.yaxis.tick_right()
self.ax1_2.yaxis.set_label_position("right")
# PLOT 2 / Main
self.ax2_1.cla()
self.ax2_1.set_title('Metric')
self.ax2_1.set_ylabel(self.metric)
# Plots
self.ax2_1.plot(
numpy.arange(self.epochs),
self.data['m_tra'],
lw=3,
color='red',
)
self.ax2_1.plot(
numpy.arange(self.epochs),
self.data['m_val'],
lw=3,
color='green',
)
# Horizontal lines
if not numpy.all(numpy.isnan(self.data['m_tra'])):
if self.get_operator(metric=self.metric) == numpy.greater:
h_tra_line_y = numpy.nanmax(self.data['m_tra'])
h_val_line_y = numpy.nanmax(self.data['m_val'])
else:
h_tra_line_y = numpy.nanmin(self.data['m_tra'])
h_val_line_y = numpy.nanmin(self.data['m_tra'])
self.ax2_1.axhline(y=h_tra_line_y, lw=1, color='red', linestyle='--')
self.ax2_1.axhline(y=h_val_line_y, lw=1, color='green', linestyle='--')
self.ax2_1.add_patch(
patches.Rectangle(
(span[0], self.ax2_1.get_ylim()[0]), # (x,y)
width=span[1]-span[0],
height=self.ax2_1.get_ylim()[1],
facecolor="#000000",
alpha=0.05
)
)
if self.get_operator(metric=self.metric) == numpy.greater:
legend_location = 'lower right'
else:
legend_location = 'upper right'
self.ax2_1.legend(['Train', 'Validation'], loc=legend_location)
self.ax2_1.set_xlim([0, self.epochs - 1])
if self.external_metric_labels:
self.ax2_1.set_xticklabels([])
self.ax2_1.grid(True)
# PLOT 2 / Highlighted area
self.ax2_2.cla()
self.ax2_2.set_title('Metric / Highlighted area')
self.ax2_2.set_ylabel(self.metric)
self.ax2_2.plot(
numpy.arange(span[0], span[1]),
self.data['m_tra'][span[0]:span[1]],
lw=3,
color='red',
)
self.ax2_2.plot(
numpy.arange(span[0], span[1]),
self.data['m_val'][span[0]:span[1]],
lw=3,
color='green',
)
self.ax2_2.set_xticklabels([])
self.ax2_2.grid(True)
self.ax2_2.yaxis.tick_right()
self.ax2_2.yaxis.set_label_position("right")
for mid, metric_label in enumerate(self.external_metric_labels):
metric_name = self.external_metric_labels[metric_label]
if metric_name.endswith('f_measure') or metric_name.endswith('f_score'):
factor = 100
else:
factor = 1
# PLOT 3 / Main
self.extra_main[metric_label].cla()
self.extra_main[metric_label].set_title('External metric')
self.extra_main[metric_label].set_ylabel(str(metric_label))
mask = numpy.isfinite(self.data[metric_label])
self.extra_main[metric_label].plot(
numpy.arange(self.epochs)[mask],
self.data[metric_label][mask]*factor,
lw=3,
color='green',
marker='o',
)
self.extra_main[metric_label].add_patch(
patches.Rectangle(
(span[0], self.extra_main[metric_label].get_ylim()[0]), # (x,y)
width=span[1]-span[0],
height=self.extra_main[metric_label].get_ylim()[1],
facecolor="#000000",
alpha=0.05
)
)
# Horizontal lines
if not numpy.all(numpy.isnan(self.data[metric_label][mask])):
if self.get_operator(metric=str(metric_label)) == numpy.greater:
h_extra_line_y = numpy.nanmax((self.data[metric_label][mask]*factor))
else:
h_extra_line_y = numpy.nanmin((self.data[metric_label][mask]*factor))
self.extra_main[metric_label].axhline(y=h_extra_line_y, lw=1, color='blue', linestyle='--')
if self.get_operator(metric=self.metric) == numpy.greater:
legend_location = 'lower right'
else:
legend_location = 'upper right'
self.extra_main[metric_label].legend(['Validation'], loc=legend_location)
self.extra_main[metric_label].set_xlim([0, self.epochs - 1])
if (mid + 1) < len(self.external_metric_labels):
self.extra_main[metric_label].set_xticklabels([])
else:
self.extra_main[metric_label].set_xlabel('Epochs')
self.extra_main[metric_label].grid(True)
# PLOT 3 / Highlighted area
self.extra_highlight[metric_label].cla()
self.extra_highlight[metric_label].set_title('External metric / Highlighted area')
self.extra_highlight[metric_label].set_ylabel(str(metric_label))
highlight_data = self.data[metric_label][span[0]:span[1]]*factor
mask = numpy.isfinite(highlight_data)
self.extra_highlight[metric_label].plot(
numpy.arange(span[0], span[1])[mask],
highlight_data[mask],
lw=3,
color='green',
marker='o',
)
if (mid + 1) < len(self.external_metric_labels):
self.extra_highlight[metric_label].set_xticklabels([])
else:
self.extra_highlight[metric_label].set_xlabel('Epochs')
self.extra_highlight[metric_label].yaxis.tick_right()
self.extra_highlight[metric_label].yaxis.set_label_position("right")
self.extra_highlight[metric_label].grid(True)
plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1, wspace=0.02, hspace=0.2)
def on_train_begin(self, logs=None):
if self.epochs is None:
self.epochs = self.params['epochs']
def on_epoch_begin(self, epoch, logs=None):
self.epoch = epoch + 1
self.seen = 0
def on_epoch_end(self, epoch, logs=None):
self.epoch = epoch
logs = logs or {}
# Collect values
for k in self.params['metrics']:
if k in logs:
self.log_values.append((k, logs[k]))
if k == 'loss':
self.data['l_tra'][self.epoch] = logs[k]
elif k == 'val_loss':
self.data['l_val'][self.epoch] = logs[k]
elif self.metric and k.endswith(self.metric):
if k.startswith('val_'):
self.data['m_val'][self.epoch] = logs[k]
else:
self.data['m_tra'][self.epoch] = logs[k]
if not self.manual_update:
# Update logged progress
self.update()
def update(self):
"""Update
"""
import matplotlib.pyplot as plt
if self.timer.elapsed() > self.plotting_rate:
self.draw()
self.figure.canvas.flush_events()
if self.interactive:
plt.pause(0.01)
if self.save:
plt.savefig(self.filename, bbox_inches='tight', format=self.format, dpi=1000)
self.timer.start()
def add_external_metric(self, metric_label):
"""Add external metric to be monitored
Parameters
----------
metric_label : str
Metric label
"""
if metric_label not in self.external_metric_labels:
self.external_metric_labels[metric_label] = metric_label
if metric_label not in self.data:
self.data[metric_label] = numpy.empty((self.epochs,))
self.data[metric_label][:] = numpy.nan
def set_external_metric_value(self, metric_label, metric_value):
"""Add external metric value
Parameters
----------
metric_label : str
Metric label
metric_value : numeric
Metric value
"""
self.external_metric[metric_label] = metric_value
self.data[metric_label][self.epoch] = metric_value
def close(self):
"""Manually close progress logging
"""
import matplotlib.pyplot as plt
if self.save:
self.draw()
plt.savefig(self.filename, bbox_inches='tight', format=self.format, dpi=1000)
plt.close(self.figure)
[docs]class StopperCallback(BaseCallback):
"""Keras callback to stop training when improvement has not seen in specified amount of epochs.
Implements Keras Callback API.
Callback is very similar to standard ``EarlyStopping`` Keras callback, however it adds support for external metrics
(calculated outside Keras training process).
"""
[docs] def __init__(self, *args, **kwargs):
"""Constructor
Parameters
----------
epochs : int
Total amount of epochs
manual_update : bool
Manually update callback, use this to when injecting external metrics
Default value True
monitor : str
Metric value to be monitored
Default value "val_loss"
patience : int
Number of epochs with no improvement after which training will be stopped.
Default value 0
min_delta : float
Minimum change in the monitored quantity to qualify as an improvement.
Default value 0
initial_delay : int
Amount of epochs to wait at the beginning before quantity is monitored.
Default value 10
"""
super(StopperCallback, self).__init__(*args, **kwargs)
self.monitor = kwargs.get('monitor', 'val_loss')
self.patience = kwargs.get('patience', 0)
self.min_delta = kwargs.get('min_delta', 0)
self.initial_delay = kwargs.get('initial_delay', 10)
self.wait = None
self.stopped_epoch = None
self.model = None
self.params = None
self.last_update_epoch = 0
self.stopped = False
self.logger = logging.getLogger(__name__)
self.metric_data = {
self.monitor: numpy.empty((self.epochs,))
}
self.metric_data[self.monitor][:] = numpy.nan
mode = kwargs.get('mode', 'auto')
if mode not in ['min', 'max', 'auto']:
mode = 'auto'
if mode == 'min':
self.monitor_op = numpy.less
elif mode == 'max':
self.monitor_op = numpy.greater
else:
self.monitor_op = self.get_operator(metric=self.monitor)
self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf
if self.monitor_op == numpy.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_train_begin(self, logs=None):
if self.epochs is None:
self.epochs = self.params['epochs']
if self.wait is None:
self.wait = 0
if self.stopped_epoch is None:
self.stopped_epoch = 0
def on_epoch_begin(self, epoch, logs=None):
self.epoch = epoch + 1
def on_epoch_end(self, epoch, logs=None):
self.epoch = epoch
if self.monitor in logs:
self.metric_data[self.monitor][self.epoch] = logs.get(self.monitor)
if not self.manual_update:
self.update()
def set_external_metric_value(self, metric_label, metric_value):
"""Add external metric value
Parameters
----------
metric_label : str
Metric label
metric_value : numeric
Metric value
"""
if metric_label not in self.metric_data:
self.metric_data[metric_label] = numpy.empty((self.epochs,))
self.metric_data[metric_label][:] = numpy.nan
self.metric_data[metric_label][self.epoch] = metric_value
def stop(self):
return self.stopped
def update(self):
if self.epoch > self.initial_delay:
# get current metric value
current = self.metric_data[self.monitor][self.epoch]
if numpy.isnan(current):
message = '{name}: Metric to monitor is Nan, metric:[{metric}]'.format(
name=self.__class__.__name__,
metric=self.monitor
)
self.logger.exception(message)
raise ValueError(message)
if self.monitor_op(current - self.min_delta, self.best):
# New best value found
self.best = current
self.wait = 0
else:
if self.wait >= self.patience:
# Stopping criteria met => return false
self.stopped_epoch = self.epoch
self.model.stop_training = True
self.logger.info(' Stopping criteria met at epoch[{epoch:d}]'.format(
epoch=self.epoch,
))
self.logger.info(' metric[{metric}], patience[{patience:d}]'.format(
metric=self.monitor,
current='{:4.4f}'.format(current),
patience=self.patience
))
self.logger.info(' ')
self.stopped = True
return self.stopped
# Increase waiting counter
self.wait += self.epoch - self.last_update_epoch
self.last_update_epoch = self.epoch
return self.stopped
[docs]class StasherCallback(BaseCallback):
"""Keras callback to monitor training process and store best model. Implements Keras Callback API.
This callback is very similar to standard ``ModelCheckpoint`` Keras callback, however it adds support for external
metrics (metrics calculated outside Keras training process).
"""
[docs] def __init__(self, *args, **kwargs):
"""Constructor
Parameters
----------
epochs : int
Total amount of epochs
manual_update : bool
Manually update callback, use this to when injecting external metrics
Default value True
monitor : str
Metric to monitor
Default value 'val_loss'
mode : str
Which way metric is interpreted, values {auto, min, max}
Default value 'auto'
period : int
Disable tqdm based progress bar
Default value 1
initial_delay : int
Amount of epochs to wait at the beginning before quantity is monitored.
Default value 10
save_weights : bool
Save weight to the disk
Default value False
file_path : str
File name for model weight
Default value None
"""
super(StasherCallback, self).__init__(*args, **kwargs)
self.monitor = kwargs.get('monitor', 'val_loss')
self.period = kwargs.get('period', 1)
self.initial_delay = kwargs.get('initial_delay', 10)
self.save_weights = kwargs.get('save_weights', False)
self.file_path = kwargs.get('file_path', None)
self.epochs_since_last_save = 0
self.logger = logging.getLogger(__name__)
mode = kwargs.get('mode', 'auto')
if mode not in ['auto', 'min', 'max']:
mode = 'auto'
if mode == 'min':
self.monitor_op = numpy.less
elif mode == 'max':
self.monitor_op = numpy.greater
else:
self.monitor_op = self.get_operator(metric=self.monitor)
self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf
self.metric_data = {
self.monitor: numpy.empty((self.epochs,))
}
self.metric_data[self.monitor][:] = numpy.nan
self.best_model_weights = None
self.best_model_epoch = 0
self.last_logs = None
def on_epoch_begin(self, epoch, logs=None):
self.epoch = epoch + 1
def on_epoch_end(self, epoch, logs=None):
self.epoch = epoch
if self.monitor in logs:
self.metric_data[self.monitor][self.epoch] = logs.get(self.monitor)
self.last_logs = logs
if not self.manual_update:
self.update()
def update(self):
if self.epoch > self.initial_delay:
self.epochs_since_last_save += 1
if self.epochs_since_last_save >= self.period:
self.epochs_since_last_save = 0
current = self.metric_data[self.monitor][self.epoch]
if numpy.isnan(current):
message = '{name}: Metric to monitor is Nan, metric:[{metric}]'.format(
name=self.__class__.__name__,
metric=self.monitor
)
self.logger.exception(message)
raise ValueError(message)
else:
if self.monitor_op(current, self.best):
# Store the best
self.best = current
self.best_model_weights = self.model.get_weights()
self.best_model_epoch = self.epoch
if self.save_weights and self.file_path:
# Save weight on disk
logs = self.last_logs
if self.monitor not in logs:
logs[self.monitor] = current
file_path = self.file_path.format(epoch=self.epoch, **self.last_logs)
self.model.save_weights(file_path, overwrite=True)
def set_external_metric_value(self, metric_label, metric_value):
"""Add external metric value
Parameters
----------
metric_label : str
Metric label
metric_value : numeric
Metric value
"""
if metric_label not in self.metric_data:
self.metric_data[metric_label] = numpy.empty((self.epochs,))
self.metric_data[metric_label][:] = numpy.nan
self.metric_data[metric_label][self.epoch] = metric_value
def get_best(self):
"""Return best model seen
Returns
-------
dict
Dictionary with keys 'weights', 'epoch', 'metric_value'
"""
return {
'epoch': self.best_model_epoch,
'weights': self.best_model_weights,
'metric_value': self.best,
}
def log(self):
"""Print information about the best model into logging interface
"""
self.logger.info(' Best model weights at epoch[{epoch:d}]'.format(epoch=self.best_model_epoch))
self.logger.info(' metric[{metric}]={best}'.format(
metric=self.monitor,
best='{:4.4f}'.format(self.best)
)
)
self.logger.info(' ')
[docs]class BaseDataGenerator(object):
"""Base class for data generator.
"""
[docs] def __init__(self, *args, **kwargs):
"""Constructor
Parameters
----------
files : list
data_filenames : dict
annotations : dict
class_labels : list of str
hop_length_seconds : float
Default value 0.2
shuffle : bool
Default value True
batch_size : int
Default value 64
buffer_size : int
Default value 256
"""
self.method = 'base_generator'
# Data
self.item_list = copy.copy(kwargs.get('files', []))
self.data_filenames = kwargs.get('data_filenames', {})
self.annotations = kwargs.get('annotations', {})
# Activity matrix
self.class_labels = kwargs.get('class_labels', [])
self.hop_length_seconds = kwargs.get('hop_length_seconds', 0.2)
self.shuffle = kwargs.get('shuffle', True)
self.batch_size = kwargs.get('batch_size', 64)
self.buffer_size = kwargs.get('buffer_size', 256)
self.logger = kwargs.get('logger', logging.getLogger(__name__))
# Internal state variables
self.batch_index = 0
self.item_index = 0
self.data_position = 0
# Initialize data buffer
self.data_buffer = DataBuffer(size=self.buffer_size)
if self.buffer_size >= len(self.item_list):
# Fill data buffer at initialization if it fits fully to the buffer
for current_item in self.item_list:
self.process_item(item=current_item)
if self.data_buffer.full():
break
self._data_size = None
self._input_size = None
@property
def steps_count(self):
"""Number of batches in one epoch
"""
num_batches = int(numpy.ceil(self.data_size / float(self.batch_size)))
if num_batches > 0:
return num_batches
else:
return 1
@property
def input_size(self):
"""Length of input feature vector
"""
if self._input_size is None:
# Load first item
first_item = list(self.data_filenames.keys())[0]
self.process_item(item=first_item)
# Get Feature vector length
self._input_size = self.data_buffer.get(key=first_item)[0].shape[-1]
return self._input_size
@property
def data_size(self):
"""Total data amount
"""
if self._data_size is None:
self._data_size = 0
for current_item in self.item_list:
self.process_item(item=current_item)
data, meta = self.data_buffer.get(key=current_item)
# Accumulate feature matrix length
self._data_size += data.shape[0]
return self._data_size
[docs] def info(self):
"""Information logging
"""
info = [
' Generator',
' Shuffle \t[{shuffle}]'.format(shuffle='True' if self.shuffle else 'False'),
' Epoch size\t[{steps:d} batches]'.format(steps=self.steps_count),
' Buffer size \t[{buffer_size:d} files]'.format(buffer_size=self.buffer_size),
' '
]
return info
def process_item(self, item):
pass
def on_epoch_start(self):
pass
def on_epoch_end(self):
pass
[docs]class FeatureGenerator(BaseDataGenerator):
"""Feature data generator
"""
[docs] def __init__(self, *args, **kwargs):
"""Constructor
Parameters
----------
files : list of str
List of active item identifies, usually filenames
data_filenames : dict of dicts
Data structure keyed with item identifiers (defined with files parameter), data dict feature extractor
labels as keys and values the filename on disk.
annotations : dict of MetaDataContainers or MetaDataItems
Annotations for all items keyed with item identifiers
class_labels : list of str
Class labels in a list
hop_length_seconds : float
Analysis frame hop length in seconds
Default value 0.2
shuffle : bool
Shuffle data before each epoch
Default value True
batch_size : int
Batch size to generate
Default value 64
buffer_size : int
Internal item buffer size, set large enough for smaller dataset to avoid loading
Default value 256
data_processor : class
Data processor class used to process load features
data_refresh_on_each_epoch : bool
Internal data buffer reset at the start of each epoch
Default value False
label_mode : str ('event', 'scene')
Activity matrix forming mode.
Default value "event"
"""
self.data_processor = kwargs.get('data_processor')
self.data_refresh_on_each_epoch = kwargs.get('data_refresh_on_each_epoch', False)
self.label_mode = kwargs.get('label_mode', 'event')
super(FeatureGenerator, self).__init__(*args, **kwargs)
self.method = 'feature'
self.logger = kwargs.get('logger', logging.getLogger(__name__))
if self.label_mode not in ['event', 'scene']:
message = '{name}: Label mode unknown [{label_mode}]'.format(
name=self.__class__.__name__,
metric=self.label_mode
)
self.logger.exception(message)
raise ValueError(message)
def process_item(self, item):
if not self.data_buffer.key_exists(key=item):
current_data, current_length = self.data_processor.load(
feature_filename_dict=self.data_filenames[item]
)
current_activity_matrix = self.get_activity_matrix(
annotation=self.annotations[item],
data_length=current_length
)
self.data_buffer.set(key=item, data=current_data, meta=current_activity_matrix)
def on_epoch_start(self):
self.batch_index = 0
if self.shuffle:
# Shuffle item list order
numpy.random.shuffle(self.item_list)
if self.data_refresh_on_each_epoch:
# Force reload of data
self.data_buffer.clear()
[docs] def generator(self):
"""Generator method
Returns
-------
ndarray
data batches
"""
while True:
# Start of epoch
self.on_epoch_start()
batch_buffer_data = []
batch_buffer_meta = []
# Go through items
for item in self.item_list:
# Load item data into buffer
self.process_item(item=item)
# Fetch item from buffer
data, meta = self.data_buffer.get(key=item)
# Data indexing
data_ids = numpy.arange(data.shape[0])
# Shuffle data order
if self.shuffle:
numpy.random.shuffle(data_ids)
for data_id in data_ids:
if len(batch_buffer_data) == self.batch_size:
# Batch buffer full, yield data
yield (
numpy.concatenate(
numpy.expand_dims(batch_buffer_data, axis=0)
),
numpy.concatenate(
numpy.expand_dims(batch_buffer_meta, axis=0)
)
)
# Empty batch buffers
batch_buffer_data = []
batch_buffer_meta = []
# Increase batch counter
self.batch_index += 1
# Collect data fro the batch
batch_buffer_data.append(data[data_id])
batch_buffer_meta.append(meta[data_id])
if len(batch_buffer_data):
# Last batch, usually not full
yield (
numpy.concatenate(
numpy.expand_dims(batch_buffer_data, axis=0)
),
numpy.concatenate(
numpy.expand_dims(batch_buffer_meta, axis=0)
),
)
# Increase batch counter
self.batch_index += 1
# End of epoch
self.on_epoch_end()
def get_activity_matrix(self, annotation, data_length):
"""Convert annotation into activity matrix and run it through data processor.
"""
event_roll = None
if self.label_mode == 'event':
# Event activity, event onset and offset specified
event_roll = EventRoll(metadata_container=annotation,
label_list=self.class_labels,
time_resolution=self.hop_length_seconds
)
event_roll = event_roll.pad(length=data_length)
elif self.label_mode == 'scene':
# Scene activity, one-hot activity throughout whole file
pos = self.class_labels.index(annotation.scene_label)
event_roll = numpy.zeros((data_length, len(self.class_labels)))
event_roll[:, pos] = 1
if event_roll is not None:
return self.data_processor.process_activity_data(
activity_data=event_roll
)
else:
return None