#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Keras utils
===========
Utility classes related to Keras.
KerasMixin
^^^^^^^^^^
.. autosummary::
    :toctree: generated/
    KerasMixin.create_model
    KerasMixin.create_callback_list
    KerasMixin.create_external_metric_evaluators
    KerasMixin.prepare_data
    KerasMixin.prepare_activity
    KerasMixin.keras_model_exists
    KerasMixin.log_model_summary
    KerasMixin.plot_model
    KerasMixin.get_processing_interval
BaseCallback
^^^^^^^^^^^^
.. autosummary::
    :toctree: generated/
    BaseCallback
ProgressLoggerCallback
^^^^^^^^^^^^^^^^^^^^^^
Keras callback to store metrics with tqdm progress bar or logging interface. Implements Keras Callback API.
This callback is very similar to standard ``ProgbarLogger`` Keras callback, however it adds support for
logging interface and tqdm based progress bars, and external metrics
(metrics calculated outside Keras training process).
.. autosummary::
    :toctree: generated/
    ProgressLoggerCallback
ProgressPlotterCallback
^^^^^^^^^^^^^^^^^^^^^^^
Keras callback to plot progress during the training process and save final progress into figure.
Implements Keras Callback API.
.. autosummary::
    :toctree: generated/
    ProgressPlotterCallback
StopperCallback
^^^^^^^^^^^^^^^
Keras callback to stop training when improvement has not seen in specified amount of epochs.
Implements Keras Callback API.
This Callback is very similar to standard ``EarlyStopping`` Keras callback, however it adds support for
external metrics (metrics calculated outside Keras training process).
.. autosummary::
    :toctree: generated/
    StopperCallback
StasherCallback
^^^^^^^^^^^^^^^
Keras callback to monitor training process and store best model. Implements Keras Callback API.
This callback is very similar to standard ``ModelCheckpoint`` Keras callback, however it adds support for
external metrics (metrics calculated outside Keras training process).
.. autosummary::
    :toctree: generated/
    StasherCallback
BaseDataGenerator
^^^^^^^^^^^^^^^^^
.. autosummary::
    :toctree: generated/
    BaseDataGenerator
    BaseDataGenerator.input_size
    BaseDataGenerator.data_size
    BaseDataGenerator.steps_count
    BaseDataGenerator.info
FeatureGenerator
^^^^^^^^^^^^^^^^
.. autosummary::
    :toctree: generated/
    FeatureGenerator
    FeatureGenerator.generator
"""
import os
import sys
import logging
import numpy
import copy
import importlib
import collections
from tqdm import tqdm
from six import iteritems
from .containers import DottedDict
from .utils import SuppressStdoutAndStderr, Timer, SimpleMathStringEvaluator, get_parameter_hash
from .features import FeatureContainer
from .metadata import EventRoll
from .data import DataBuffer
class KerasMixin(object):
    """Class Mixin for Keras based learner containers.
    """
    def __getstate__(self):
        data = {}
        excluded_fields = ['model']
        for item in self:
            if item not in excluded_fields and self.get(item):
                data[item] = copy.deepcopy(self.get(item))
        data['model'] = os.path.splitext(self.filename)[0] + '.model.hdf5'
        return data
[docs]    def keras_model_exists(self):
        """Check that keras model exists on disk
        Returns
        -------
        bool
        """
        keras_model_filename = os.path.splitext(self.filename)[0] + '.model.hdf5'
        return os.path.isfile(self.filename) and os.path.isfile(keras_model_filename) 
[docs]    def log_model_summary(self):
        """Prints model summary to the logging interface.
        Similar to Keras model summary
        """
        layer_name_map = {
            'BatchNormalization': 'BatchNorm',
        }
        import keras
        from distutils.version import LooseVersion
        import keras.backend as keras_backend
        self.logger.debug('  Model summary')
        self.logger.debug(
            '    {type:<15s} | {out:20s} | {param:6s}  | {name:21s}  | {conn:27s} | {act:7s} | {init:7s}'.format(
                type='Layer type',
                out='Output',
                param='Param',
                name='Name',
                conn='Connected to',
                act='Activ.',
                init='Init')
        )
        self.logger.debug(
            '    {type:<15s} + {out:20s} + {param:6s}  + {name:21s}  + {conn:27s} + {act:7s} + {init:6s}'.format(
                type='-' * 15,
                out='-' * 20,
                param='-' * 6,
                name='-' * 21,
                conn='-' * 27,
                act='-' * 7,
                init='-' * 6)
        )
        for layer in self.model.layers:
            connections = []
            if LooseVersion(keras.__version__) >= LooseVersion('2.1.3'):
                for node_index, node in enumerate(layer._inbound_nodes):
                    for i in range(len(node.inbound_layers)):
                        inbound_layer = node.inbound_layers[i].name
                        inbound_node_index = node.node_indices[i]
                        inbound_tensor_index = node.tensor_indices[i]
                        connections.append(inbound_layer + '[' + str(inbound_node_index) +
                                           '][' + str(inbound_tensor_index) + ']')
            else:
                for node_index, node in enumerate(layer.inbound_nodes):
                    for i in range(len(node.inbound_layers)):
                        inbound_layer = node.inbound_layers[i].name
                        inbound_node_index = node.node_indices[i]
                        inbound_tensor_index = node.tensor_indices[i]
                        connections.append(inbound_layer + '[' + str(inbound_node_index) +
                                           '][' + str(inbound_tensor_index) + ']')
            config = DottedDict(layer.get_config())
            layer_name = layer.__class__.__name__
            if layer_name in layer_name_map:
                layer_name = layer_name_map[layer_name]
            if config.get_path('kernel_initializer.class_name') == 'VarianceScaling':
                init = str(config.get_path('kernel_initializer.config.distribution', '---'))
            elif config.get_path('kernel_initializer.class_name') == 'RandomUniform':
                init = 'uniform'
            else:
                init = '---'
            self.logger.debug(
                '    {type:<15s} | {shape:20s} | {params:6s}  | {name:21s}  | {connected:27s} | {activation:7s} | {init:7s}'.format(
                    type=layer_name,
                    shape=str(layer.output_shape),
                    params=str(layer.count_params()),
                    name=str(layer.name),
                    connected=str(connections[0]) if len(connections) > 0 else '---',
                    activation=str(config.get('activation', '---')),
                    init=init,
                )
            )
        trainable_count = int(
            numpy.sum([keras_backend.count_params(p) for p in set(self.model.trainable_weights)])
        )
        non_trainable_count = int(
            numpy.sum([keras_backend.count_params(p) for p in set(self.model.non_trainable_weights)])
        )
        self.logger.debug('  ')
        self.logger.debug('  Parameters')
        self.logger.debug('    Trainable\t[{param_count:,}]'.format(param_count=int(trainable_count)))
        self.logger.debug('    Non-Trainable\t[{param_count:,}]'.format(param_count=int(non_trainable_count)))
        self.logger.debug(
            '    Total\t\t[{param_count:,}]'.format(param_count=int(trainable_count + non_trainable_count)))
        self.logger.debug('  ') 
[docs]    def plot_model(self, filename='model.png', show_shapes=True, show_layer_names=True):
        """Plots model topology
        """
        from keras.utils.visualize_util import plot
        plot(self.model, to_file=filename, show_shapes=show_shapes, show_layer_names=show_layer_names) 
[docs]    def prepare_data(self, data, files, processor='default'):
        """Concatenate feature data into one feature matrix
        Parameters
        ----------
        data : dict of FeatureContainers
            Feature data
        files : list of str
            List of filenames
        processor : str ('default', 'training')
            Data processor selector
            Default value 'default'
        Returns
        -------
        numpy.ndarray
            Features concatenated
        """
        if self.learner_params.get_path('input_sequencer.enable'):
            processed_data = []
            for item in files:
                if processor == 'training':
                    processed_data.append(
                        self.data_processor_training.process_data(
                            data=data[item].feat[0]
                        )
                    )
                else:
                    processed_data.append(
                        self.data_processor.process_data(
                            data=data[item].feat[0]
                        )
                    )
            return numpy.concatenate(processed_data)
        else:
            return numpy.vstack([data[x].feat[0] for x in files]) 
[docs]    def prepare_activity(self, activity_matrix_dict, files, processor='default'):
        """Concatenate activity matrices into one activity matrix
        Parameters
        ----------
        activity_matrix_dict : dict of binary matrices
            Meta data
        files : list of str
            List of filenames
        processor : str ('default', 'training')
            Data processor selector
            Default value 'default'
        Returns
        -------
        numpy.ndarray
            Activity matrix
        """
        if self.learner_params.get_path('input_sequencer.enable'):
            processed_activity = []
            for item in files:
                if processor == 'training':
                    processed_activity.append(
                        self.data_processor_training.process_activity_data(
                            activity_data=activity_matrix_dict[item]
                        )
                    )
                else:
                    processed_activity.append(
                        self.data_processor.process_activity_data(
                            activity_data=activity_matrix_dict[item]
                        )
                    )
            return numpy.concatenate(processed_activity)
        else:
            return numpy.vstack([activity_matrix_dict[x] for x in files]) 
[docs]    def create_model(self, input_shape):
        """Create sequential Keras model
        """
        from keras.models import Sequential
        self.model = Sequential()
        tuple_fields = [
            'input_shape',
            'kernel_size',
            'pool_size',
            'dims',
            'target_shape'
        ]
        # Get model config parameters
        model_params = copy.deepcopy(self.learner_params.get_path('model.config'))
        # Get constants for model
        constants = copy.deepcopy(self.learner_params.get_path('model.constants', {}))
        constants['CLASS_COUNT'] = len(self.class_labels)
        constants['FEATURE_VECTOR_LENGTH'] = input_shape
        if self.learner_params.get_path('input_sequencer.frames'):
            constants['INPUT_SEQUENCE_LENGTH'] = self.learner_params.get_path('input_sequencer.frames')
        def process_field(value, constants_dict):
            math_eval = SimpleMathStringEvaluator()
            if isinstance(value, str):
                # String field
                if value in constants_dict:
                    return constants_dict[value]
                elif len(value.split()) > 1:
                    sub_fields = value.split()
                    for subfield_id, subfield in enumerate(sub_fields):
                        if subfield in constants_dict:
                            sub_fields[subfield_id] = str(constants_dict[subfield])
                    return math_eval.eval(''.join(sub_fields))
                else:
                    return value
            elif isinstance(value, list):
                processed_value_list = []
                for item_id, item in enumerate(value):
                    processed_value_list.append(process_field(value=item, constants_dict=constants_dict))
                return processed_value_list
            else:
                return value
        # Inject constant into constants with equations
        for field in list(constants.keys()):
            constants[field] = process_field(value=constants[field], constants_dict=constants)
        # Setup layers
        for layer_id, layer_setup in enumerate(model_params):
            # Get layer parameters
            layer_setup = DottedDict(layer_setup)
            if 'config' not in layer_setup:
                layer_setup['config'] = {}
            # Get layer class
            try:
                layer_class = getattr(
                    importlib.import_module("keras.layers"),
                    layer_setup['class_name']
                )
            except AttributeError:
                message = '{name}: Invalid Keras layer type [{type}].'.format(
                    name=self.__class__.__name__,
                    type=layer_setup['class_name']
                )
                self.logger.exception(message)
                raise AttributeError(message)
            # Inject constants
            for config_field in list(layer_setup['config'].keys()):
                layer_setup['config'][config_field] = process_field(
                    value=layer_setup['config'][config_field],
                    constants_dict=constants
                )
            # Convert lists into tuples
            for field in tuple_fields:
                if field in layer_setup['config']:
                    layer_setup['config'][field] = tuple(layer_setup['config'][field])
            # Inject input shape for Input layer if not given
            if layer_id == 0 and layer_setup.get_path('config.input_shape') is None:
                # Set input layer dimension for the first layer if not set
                layer_setup['config']['input_shape'] = (input_shape,)
            if 'wrapper' in layer_setup:
                # Get layer wrapper class
                try:
                    wrapper_class = getattr(
                        importlib.import_module("keras.layers"),
                        layer_setup['wrapper']
                    )
                except AttributeError:
                    message = '{name}: Invalid Keras layer wrapper type [{type}].'.format(
                        name=self.__class__.__name__,
                        type=layer_setup['wrapper']
                    )
                    self.logger.exception(message)
                    raise AttributeError(message)
                wrapper_parameters = layer_setup.get('config_wrapper', {})
                if layer_setup.get('config'):
                    self.model.add(
                        wrapper_class(layer_class(**dict(layer_setup.get('config'))), **dict(wrapper_parameters)))
                else:
                    self.model.add(wrapper_class(layer_class(), **dict(wrapper_parameters)))
            else:
                if layer_setup.get('config'):
                    self.model.add(layer_class(**dict(layer_setup.get('config'))))
                else:
                    self.model.add(layer_class())
        # Get Optimizer class
        try:
            optimizer_class = getattr(
                importlib.import_module("keras.optimizers"),
                self.learner_params.get_path('model.optimizer.type')
            )
        except AttributeError:
            message = '{name}: Invalid Keras optimizer type [{type}].'.format(
                name=self.__class__.__name__,
                type=self.learner_params.get_path('model.optimizer.type')
            )
            self.logger.exception(message)
            raise AttributeError(message)
        # Compile the model
        self.model.compile(
            loss=self.learner_params.get_path('model.loss'),
            optimizer=optimizer_class(**dict(self.learner_params.get_path('model.optimizer.parameters', {}))),
            metrics=self.learner_params.get_path('model.metrics')
        ) 
[docs]    def create_callback_list(self):
        """Create list of Keras callbacks
        """
        callbacks = []
        # Fetch processing interval
        processing_interval = self.get_processing_interval()
        # Collect all external metrics
        external_metrics = collections.OrderedDict()
        if self.learner_params.get_path('training.epoch_processing.enable'):
            if self.learner_params.get_path('validation.enable') and self.learner_params.get_path(
                    'training.epoch_processing.external_metrics.enable'):
                for metric in self.learner_params.get_path('training.epoch_processing.external_metrics.metrics'):
                    current_metric_name = metric.get('name')
                    current_metric_label = metric.get('label', current_metric_name.split('.')[-1])
                    external_metrics[current_metric_label] = current_metric_name
        # ProgressLoggerCallback
        from dcase_framework.keras_utils import ProgressLoggerCallback
        callbacks.append(
            ProgressLoggerCallback(
                metric=self.learner_params.get_path('model.metrics')[0],
                loss=self.learner_params.get_path('model.loss'),
                disable_progress_bar=self.disable_progress_bar,
                log_progress=self.log_progress,
                epochs=self.learner_params.get_path('training.epochs'),
                close_progress_bar=not self.learner_params.get_path('training.epoch_processing.enable'),
                manual_update=self.learner_params.get_path('training.epoch_processing.enable'),
                manual_update_interval=processing_interval,
                external_metric_labels=external_metrics
            )
        )
        # Add model callbacks
        for cp in self.learner_params.get_path('training.callbacks', []):
            cp_params = DottedDict(cp.get('parameters', {}))
            if cp['type'] == 'Plotter':
                from dcase_framework.keras_utils import ProgressPlotterCallback
                callbacks.append(
                    ProgressPlotterCallback(
                        filename=os.path.splitext(self.filename)[0] + '.' + cp_params.get('output_format', 'pdf'),
                        metric=self.learner_params.get_path('model.metrics')[0],
                        loss=self.learner_params.get_path('model.loss'),
                        disable_progress_bar=self.disable_progress_bar,
                        log_progress=self.log_progress,
                        epochs=self.learner_params.get_path('training.epochs'),
                        close_progress_bar=not self.learner_params.get_path('training.epoch_processing.enable'),
                        manual_update=self.learner_params.get_path('training.epoch_processing.enable'),
                        interactive=cp_params.get('interactive', True),
                        save=cp_params.get('save', True),
                        focus_span=cp_params.get('focus_span'),
                        plotting_rate=cp_params.get('plotting_rate'),
                        external_metric_labels=external_metrics
                    )
                )
            elif cp['type'] == 'Stopper':
                from dcase_framework.keras_utils import StopperCallback
                callbacks.append(
                    StopperCallback(
                        epochs=self.learner_params.get_path('training.epochs'),
                        manual_update=self.learner_params.get_path('training.epoch_processing.enable'),
                        external_metric_labels=external_metrics,
                        **cp_params
                    )
                )
            elif cp['type'] == 'Stasher':
                from dcase_framework.keras_utils import StasherCallback
                callbacks.append(
                    StasherCallback(
                        epochs=self.learner_params.get_path('training.epochs'),
                        manual_update=self.learner_params.get_path('training.epoch_processing.enable'),
                        external_metric_labels=external_metrics,
                        **cp_params
                    )
                )
            else:
                # Keras standard callbacks
                if cp['type'] == 'ModelCheckpoint' and not cp['parameters'].get('filepath'):
                    cp['parameters']['filepath'] = os.path.splitext(self.filename)[0] + \
                                                   
'.weights.{epoch:02d}-{val_loss:.2f}.hdf5'
                if cp['type'] == 'EarlyStopping' and cp.get('parameters').get('monitor').startswith('val_') \
                        
and not self.learner_params.get_path('validation.enable', False):
                    message = '{name}: Cannot use callback type [{type}] with monitor parameter [{monitor}] ' \
                              
'as there is no validation set.'.format(name=self.__class__.__name__,
                                                                      type=cp['type'],
                                                                      monitor=cp.get('parameters').get('monitor')
                                                                      )
                    self.logger.exception(message)
                    raise AttributeError(message)
                try:
                    callback_class = getattr(importlib.import_module("keras.callbacks"), cp['type'])
                    callbacks.append(callback_class(**cp_params))
                except AttributeError:
                    message = '{name}: Invalid Keras callback type [{type}]'.format(
                        name=self.__class__.__name__,
                        type=cp['type']
                    )
                    self.logger.exception(message)
                    raise AttributeError(message)
        return callbacks 
[docs]    def create_external_metric_evaluators(self):
        """Create external metric evaluators
        """
        # Initialize external metrics
        external_metric_evaluators = collections.OrderedDict()
        if self.learner_params.get_path('training.epoch_processing.enable'):
            if self.learner_params.get_path('validation.enable') and self.learner_params.get_path(
                    'training.epoch_processing.external_metrics.enable'):
                import sed_eval
                for metric in self.learner_params.get_path('training.epoch_processing.external_metrics.metrics'):
                    # Current metric info
                    current_metric_evaluator = metric.get('evaluator')
                    current_metric_name = metric.get('name')
                    current_metric_params = metric.get('parameters', {})
                    current_metric_label = metric.get('label', current_metric_name.split('.')[-1])
                    # Initialize sed_eval evaluators
                    if current_metric_evaluator == 'sed_eval.scene':
                        evaluator = sed_eval.scene.SceneClassificationMetrics(
                            scene_labels=self.class_labels,
                            **current_metric_params
                        )
                    elif (current_metric_evaluator == 'sed_eval.segment_based' or
                          current_metric_evaluator == 'sed_eval.sound_event.segment_based'):
                        evaluator = sed_eval.sound_event.SegmentBasedMetrics(
                            event_label_list=self.class_labels,
                            **current_metric_params
                        )
                    elif (current_metric_evaluator == 'sed_eval.event_based' or
                          current_metric_evaluator == 'sed_eval.sound_event.event_based'):
                        evaluator = sed_eval.sound_event.EventBasedMetrics(
                            event_label_list=self.class_labels,
                            **current_metric_params
                        )
                    else:
                        message = '{name}: Unknown target metric [{metric}].'.format(
                            name=self.__class__.__name__,
                            metric=current_metric_name
                        )
                        self.logger.exception(message)
                        raise AssertionError(message)
                    # Check evaluator API
                    if (not hasattr(evaluator, 'reset') or
                       not hasattr(evaluator, 'evaluate') or
                       not hasattr(evaluator, 'results')):
                        if current_metric_evaluator.startswith('sed_eval'):
                            message = '{name}: wrong version of sed_eval for [{current_metric_evaluator}::{current_metric_name}], update sed_eval to latest version'.format(
                                name=self.__class__.__name__,
                                current_metric_evaluator=current_metric_evaluator,
                                current_metric_name=current_metric_name
                            )
                            self.logger.exception(message)
                            raise ValueError(message)
                        else:
                            message = '{name}: Evaluator has invalid API [{current_metric_evaluator}::{current_metric_name}]'.format(
                                name=self.__class__.__name__,
                                current_metric_evaluator=current_metric_evaluator,
                                current_metric_name=current_metric_name
                            )
                            self.logger.exception(message)
                            raise ValueError(message)
                    # Form unique name for metric, to allow multiple similar metrics with different parameters
                    metric_id = get_parameter_hash(metric)
                    # Metric data container
                    metric_data = {
                        'evaluator_name': current_metric_evaluator,
                        'name': current_metric_name,
                        'params': current_metric_params,
                        'label': current_metric_label,
                        'path': current_metric_name,
                        'evaluator': evaluator,
                    }
                    external_metric_evaluators[metric_id] = metric_data
        return external_metric_evaluators 
[docs]    def get_processing_interval(self):
        """Processing interval
        """
        processing_interval = 1
        if self.learner_params.get_path('training.epoch_processing.enable'):
            if self.learner_params.get_path('training.epoch_processing.external_metrics.enable'):
                processing_interval = self.learner_params.get_path(
                    'training.epoch_processing.external_metrics.evaluation_interval', 1)
        return processing_interval 
    def _after_load(self, to_return=None):
        with SuppressStdoutAndStderr():
            # Setup Keras if not yet set up. This is needed as keras has tensorflow as default backend, and this will
            # give error if it is not installed and theano is not set up as backend.
            self._setup_keras()
            from keras.models import load_model
        keras_model_filename = os.path.splitext(self.filename)[0] + '.model.hdf5'
        if os.path.isfile(keras_model_filename):
            with SuppressStdoutAndStderr():
                self.model = load_model(keras_model_filename)
        else:
            message = '{name}: Keras model not found [{filename}]'.format(
                name=self.__class__.__name__,
                filename=keras_model_filename
            )
            self.logger.exception(message)
            raise IOError(message)
    def _after_save(self, to_return=None):
        # Save keras model and weight
        keras_model_filename = os.path.splitext(self.filename)[0] + '.model.hdf5'
        model_weights_filename = os.path.splitext(self.filename)[0] + '.weights.hdf5'
        self.model.save(keras_model_filename)
        self.model.save_weights(model_weights_filename)
    def _setup_keras(self):
        """Setup keras backend and parameters
        """
        if not hasattr(self, 'keras_setup_done') or not self.keras_setup_done:
            # Get BLAS library associated to numpy
            if numpy.__config__.blas_opt_info and 'libraries' in numpy.__config__.blas_opt_info:
                blas_libraries = numpy.__config__.blas_opt_info['libraries']
            else:
                blas_libraries = ['']
            blas_extra_info = []
            # Set backend and parameters before importing keras
            if self.show_extra_debug:
                self.logger.debug('  Keras')
                self.logger.debug('    Backend \t[{backend}]'.format(
                    backend=self.learner_params.get_path('keras.backend', 'theano'))
                )
            # Threading
            if self.learner_params.get_path('keras.backend_parameters.threads'):
                thread_count = self.learner_params.get_path('keras.backend_parameters.threads', 1)
                os.environ['GOTO_NUM_THREADS'] = str(thread_count)
                os.environ['OMP_NUM_THREADS'] = str(thread_count)
                os.environ['MKL_NUM_THREADS'] = str(thread_count)
                blas_extra_info.append('Threads[{threads}]'.format(threads=thread_count))
                if thread_count > 1:
                    os.environ['OMP_DYNAMIC'] = 'False'
                    os.environ['MKL_DYNAMIC'] = 'False'
                else:
                    os.environ['OMP_DYNAMIC'] = 'True'
                    os.environ['MKL_DYNAMIC'] = 'True'
            # Conditional Numerical Reproducibility (CNR) for MKL BLAS library
            if self.learner_params.get_path('keras.backend_parameters.CNR', True) and blas_libraries[0].startswith('mkl'):
                os.environ['MKL_CBWR'] = 'COMPATIBLE'
                blas_extra_info.append('MKL_CBWR[{mode}]'.format(mode='COMPATIBLE'))
            # Show BLAS info
            if self.show_extra_debug:
                if numpy.__config__.blas_opt_info and 'libraries' in numpy.__config__.blas_opt_info:
                    blas_libraries = numpy.__config__.blas_opt_info['libraries']
                    if blas_libraries[0].startswith('openblas'):
                        self.logger.debug('    BLAS library\t[OpenBLAS]\t\t({info})'.format(
                            info=', '.join(blas_extra_info))
                        )
                    elif blas_libraries[0].startswith('blas'):
                        self.logger.debug(
                            '  BLAS library\t[BLAS/Atlas]\t\t({info})'.format(
                                info=', '.join(blas_extra_info))
                        )
                    elif blas_libraries[0].startswith('mkl'):
                        self.logger.debug('    BLAS library\t[MKL]\t\t({info})'.format(
                            info=', '.join(blas_extra_info))
                        )
            # Select Keras backend
            os.environ["KERAS_BACKEND"] = self.learner_params.get_path('keras.backend', 'theano')
            if self.learner_params.get_path('keras.backend', 'theano') == 'theano':
                # Theano setup
                if self.show_extra_debug:
                    self.logger.debug('  Theano')
                # Default flags
                flags = [
                    # 'ldflags=',
                    'warn.round=False',
                ]
                # Set device
                if self.learner_params.get_path('keras.backend_parameters.device'):
                    flags.append('device=' + self.learner_params.get_path('keras.backend_parameters.device', 'cpu'))
                    if self.show_extra_debug:
                        self.logger.debug('    Device \t\t[{device}]'.format(
                            device=self.learner_params.get_path('keras.backend_parameters.device', 'cpu'))
                        )
                # Set floatX
                if self.learner_params.get_path('keras.backend_parameters.floatX'):
                    flags.append('floatX=' + self.learner_params.get_path('keras.backend_parameters.floatX', 'float32'))
                    if self.show_extra_debug:
                        self.logger.debug('    floatX \t\t[{float}]'.format(
                            float=self.learner_params.get_path('keras.backend_parameters.floatX', 'float32'))
                        )
                # Set optimizer
                if self.learner_params.get_path('keras.backend_parameters.optimizer') is not None:
                    if self.learner_params.get_path('keras.backend_parameters.optimizer') in ['fast_run', 'merge',
                                                                                              'fast_compile', 'None']:
                        flags.append('optimizer=' + self.learner_params.get_path('keras.backend_parameters.optimizer'))
                if self.show_extra_debug:
                    self.logger.debug('    Optimizer \t[{optimizer}]'.format(
                        optimizer=self.learner_params.get_path('keras.backend_parameters.optimizer', 'None'))
                    )
                # Set fastmath for GPU mode only
                if self.learner_params.get_path('keras.backend_parameters.fastmath') and self.learner_params.get_path(
                        'keras.backend_parameters.device', 'cpu') != 'cpu':
                    if self.learner_params.get_path('keras.backend_parameters.fastmath', False):
                        flags.append('nvcc.fastmath=True')
                    else:
                        flags.append('nvcc.fastmath=False')
                    if self.show_extra_debug:
                        self.logger.debug('    NVCC fastmath \t[{flag}]'.format(
                            flag=str(self.learner_params.get_path('keras.backend_parameters.fastmath', False)))
                        )
                # Set OpenMP
                if self.learner_params.get_path('keras.backend_parameters.openmp') is not None:
                    if self.learner_params.get_path('keras.backend_parameters.openmp', False):
                        flags.append('openmp=True')
                    else:
                        flags.append('openmp=False')
                    if self.show_extra_debug:
                        self.logger.debug('    OpenMP\t\t[{flag}]'.format(
                            flag=str(self.learner_params.get_path('keras.backend_parameters.openmp', False)))
                        )
                # Set environmental variable for Theano
                os.environ["THEANO_FLAGS"] = ','.join(flags)
            elif self.learner_params.get_path('keras.backend', 'tensorflow') == 'tensorflow':
                # Tensorflow setup
                if self.show_extra_debug:
                    self.logger.debug('  Tensorflow')
                # Set device
                if self.learner_params.get_path('keras.backend_parameters.device', 'cpu'):
                    # In case of CPU disable visible GPU.
                    if self.learner_params.get_path('keras.backend_parameters.device', 'cpu') == 'cpu':
                        os.environ["CUDA_VISIBLE_DEVICES"] = ''
                    if self.show_extra_debug:
                        self.logger.debug('    Device \t\t[{device}]'.format(
                            device=self.learner_params.get_path('keras.backend_parameters.device', 'cpu')))
            else:
                message = '{name}: Keras backend not supported [backend].'.format(
                    name=self.__class__.__name__,
                    backend=self.learner_params.get_path('keras.backend')
                )
                self.logger.exception(message)
                raise AssertionError(message)
            if self.show_extra_debug:
                self.logger.debug('  ')
        self.keras_setup_done = True
[docs]class BaseCallback(object):
    """Base class for Callbacks
    """
[docs]    def __init__(self, *args, **kwargs):
        self.params = None
        self.model = None
        self.verbose = kwargs.get('verbose', True)
        self.manual_update = kwargs.get('manual_update', False)
        self.epochs = kwargs.get('epochs')
        self.epoch = 0
        self.external_metric_labels = kwargs.get('external_metric_labels', collections.OrderedDict())
        self.external_metric = collections.OrderedDict()
        self.keras_metrics = [
            'binary_accuracy',
            'categorical_accuracy',
            'sparse_categorical_accuracy',
            'top_k_categorical_accuracy'
        ]
        self.logger = logging.getLogger(__name__) 
    def set_model(self, model):
        self.model = model
    def set_params(self, params):
        self.params = params
    def on_train_begin(self, logs=None):
        pass
    def on_train_end(self, logs=None):
        pass
    def on_epoch_begin(self, epoch, logs=None):
        pass
    def on_batch_begin(self, batch, logs=None):
        pass
    def on_batch_end(self, batch, logs=None):
        pass
    def on_epoch_end(self, epoch, logs=None):
        pass
    def update(self):
        pass
    def add_external_metric(self, metric_label):
        pass
    def set_external_metric_value(self, metric_label, metric_value):
        pass
    def get_operator(self, metric):
        metric = metric.lower()
        if metric.endswith('error_rate') or metric.endswith('er'):
            return numpy.less
        elif (metric.endswith('f_measure') or
             metric.endswith('fmeasure') or
             metric.endswith('fscore') or
             metric.endswith('f-score')):
            return numpy.greater
        elif metric.endswith('accuracy') or metric.endswith('acc'):
            return numpy.greater
        else:
            return numpy.less 
[docs]class ProgressLoggerCallback(BaseCallback):
    """Keras callback to store metrics with tqdm progress bar or logging interface. Implements Keras Callback API.
    This callback is very similar to standard ``ProgbarLogger`` Keras callback, however it adds support for logging
    interface and tqdm based progress bars, and external metrics (metrics calculated outside Keras training process).
    """
[docs]    def __init__(self, *args, **kwargs):
        """Constructor
        Parameters
        ----------
        epochs : int
            Total amount of epochs
        metric : str
            Metric name
        manual_update : bool
            Manually update callback, use this to when injecting external metrics
            Default value True
        manual_update_interval : int
            Epoch interval for manual update, used anticipate updates
            Default value 1
        disable_progress_bar : bool
            Disable tqdm based progress bar
            Default value False
        close_progress_bar : bool
            Close tqdm progress bar on training end
            Default value True
        log_progress : bool
            Print progress into logging interface
            Default value False
        external_metric_labels : dict or OrderedDict
            Dictionary with
        """
        super(ProgressLoggerCallback, self).__init__(*args, **kwargs)
        if isinstance(kwargs.get('metric'), str):
            self.metric = kwargs.get('metric')
        elif callable(kwargs.get('metric')):
            self.metric = kwargs.get('metric').__name__
        self.loss = kwargs.get('loss')
        self.disable_progress_bar = kwargs.get('disable_progress_bar', False)
        self.close_progress_bar = kwargs.get('close_progress_bar', True)
        self.manual_update_interval = kwargs.get('manual_update_interval', 1)
        self.log_progress = kwargs.get('log_progress', False)
        self.timer = Timer()
        self.progress_bar = None
        self.validation_data = None
        self.seen = 0
        self.log_values = []
        self.logger = logging.getLogger(__name__)
        self.postfix = collections.OrderedDict()
        self.postfix['l_tra'] = None
        self.postfix['l_val'] = None
        self.postfix['m_tra'] = None
        self.postfix['m_val'] = None
        self.data = {
            'l_tra': numpy.empty((self.epochs,)),
            'l_val': numpy.empty((self.epochs,)),
            'm_tra': numpy.empty((self.epochs,)),
            'm_val': numpy.empty((self.epochs,)),
        }
        self.data['l_tra'][:] = numpy.nan
        self.data['l_val'][:] = numpy.nan
        self.data['m_tra'][:] = numpy.nan
        self.data['m_val'][:] = numpy.nan
        for metric_label in self.external_metric_labels:
            self.data[metric_label] = numpy.empty((self.epochs,))
            self.data[metric_label][:] = numpy.nan
        self.header_show = False
        self.last_update_epoch = 0
        self.target = None 
    def on_train_begin(self, logs=None):
        if self.epochs is None:
            self.epochs = self.params['epochs']
        if self.log_progress and not self.header_show:
            # Show header only once
            self.header_show = True
            self.logger.info('  Training')
            header_extra1 = '    {epoch:<5s} | {loss:<19s} | {metric:<19s} | '.format(
                epoch=' '*5,
                loss='Loss',
                metric='Metric',
                validation=' '*8,
            )
            if self.external_metric_labels:
                line = '{external_value:<'+str(12*len(self.external_metric_labels))+'s} '
                header_extra1 += line.format(
                    external_value='External metric',
                )
            header_extra1 += '{time:<15s}'.format(
                time=' '*15,
            )
            loss_label = self.loss
            if len(loss_label) > 19:
                loss_label = loss_label[0:17]+'..'
            metric_label = self.metric
            if len(metric_label) > 19:
                metric_label = metric_label[0:17]+'..'
            header_extra2 = '    {epoch:<5s} | {loss:<19s} | {metric:<19s} | '.format(
                epoch=' '*5,
                loss=loss_label,
                metric=metric_label,
                validation=' '*8,
            )
            if self.external_metric_labels:
                for metric_label in self.external_metric_labels:
                    header_extra2 += '{label:<10s} | '.format(label=metric_label)
            header_extra2 += '{time:<15s}'.format(
                time=' '*15,
            )
            header_main = '    {epoch:<5s} | {loss:<8s} | {val_loss:<8s} | {train:<8s} | {validation:<8s} | '.format(
                epoch='Epoch',
                loss='Train',
                val_loss='Val',
                train='Train',
                validation='Val',
            )
            if self.external_metric_labels:
                for metric_label in self.external_metric_labels:
                    header_main += '{label:<10s} | '.format(label='Val')
            header_main += '{time:<15s}'.format(
                time='Time'
            )
            sep = '    {epoch:<5s} + {loss:<8s} + {val_loss:<8s} + {train:<8s} + {validation:<8s} + '.format(
                epoch='-'*5,
                loss='-'*8,
                val_loss='-' * 8,
                train='-'*8,
                validation='-'*8,
            )
            if self.external_metric_labels:
                for metric_label in self.external_metric_labels:
                    sep += '{external_value:<10s} + '.format(
                        external_value='-'*10,
                    )
            sep += '{time:<15s}'.format(
                time='-'*15,
            )
            self.logger.info(header_extra1)
            self.logger.info(header_extra2)
            self.logger.info(header_main)
            self.logger.info(sep)
        elif self.progress_bar is None:
            self.progress_bar = tqdm(total=self.epochs,
                                     initial=self.epoch,
                                     file=sys.stdout,
                                     desc='    {0:>6s}'.format('Learn'),
                                     leave=False,
                                     miniters=1,
                                     disable=self.disable_progress_bar
                                     )
    def on_train_end(self, logs=None):
        if not self.log_progress and self.close_progress_bar:
            self.progress_bar.close()
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1
        if 'steps' in self.params:
            self.target = self.params['steps']
        elif 'samples' in self.params:
            self.target = self.params['samples']
        self.seen = 0
        self.timer.start()
    def on_batch_begin(self, batch, logs=None):
        if self.target and self.seen < self.target:
            self.log_values = []
    def on_batch_end(self, batch, logs=None):
        logs = logs or {}
        batch_size = logs.get('size', 0)
        self.seen += batch_size
        for k in self.params['metrics']:
            if k in logs:
                self.log_values.append((k, logs[k]))
    def on_epoch_end(self, epoch, logs=None):
        self.timer.stop()
        self.epoch = epoch
        logs = logs or {}
        # Reset values
        self.postfix['l_tra'] = None
        self.postfix['l_val'] = None
        self.postfix['m_tra'] = None
        self.postfix['m_val'] = None
        # Collect values
        for k in self.params['metrics']:
            if k in logs:
                self.log_values.append((k, logs[k]))
                if k == 'loss':
                    self.data['l_tra'][self.epoch] = logs[k]
                    self.postfix['l_tra'] = '{:4.3f}'.format(logs[k])
                elif k == 'val_loss':
                    self.data['l_val'][self.epoch] = logs[k]
                    self.postfix['l_val'] = '{:4.3f}'.format(logs[k])
                elif self.metric and k.endswith(self.metric):
                    if k.startswith('val_'):
                        self.data['m_val'][self.epoch] = logs[k]
                        self.postfix['m_val'] = '{:4.3f}'.format(logs[k])
                    else:
                        self.data['m_tra'][self.epoch] = logs[k]
                        self.postfix['m_tra'] = '{:4.3f}'.format(logs[k])
        for metric_label in self.external_metric_labels:
            if metric_label in self.external_metric:
                metric_name = self.external_metric_labels[metric_label]
                value = self.external_metric[metric_label]
                if metric_name.endswith('f_measure') or metric_name.endswith('f_score'):
                    self.postfix[metric_label] = '{:3.1f}'.format(value*100)
                else:
                    self.postfix[metric_label] = '{:4.3f}'.format(value)
        if (not self.manual_update or
           (self.epoch - self.last_update_epoch > 0 and (self.epoch+1) % self.manual_update_interval)):
            # Update logged progress
            if self.log_progress:
                self.update_progress_log()
        # Increase iteration count and update progress bar
        if not self.log_progress:
            self.update_progress_bar(increase=1)
    def update(self):
        """Update
        """
        if self.log_progress:
            self.update_progress_log()
        else:
            self.update_progress_bar()
        self.last_update_epoch = self.epoch
    def update_progress_log(self):
        """Update progress to logging interface
        """
        if self.log_progress and self.epoch - self.last_update_epoch:
            output = '    '
            output += '{epoch:<5s} |'.format(epoch='{:d}'.format(self.epoch))
            output += ' {loss:<8s} |'.format(loss='{:4.6f}'.format(self.data['l_tra'][self.epoch]))
            if 'l_val' in self.postfix:
                output += ' {val_loss:<8s} |'.format(val_loss='{:4.6f}'.format(self.data['l_val'][self.epoch]))
            else:
                output += ' {val_loss:<8s} |'.format(val_loss=' '*8)
            output += ' {train:<8s} |'.format(train='{:4.6f}'.format(self.data['m_tra'][self.epoch]))
            if self.postfix['m_val']:
                output += ' {validation:<8s} |'.format(validation='{:4.6f}'.format(self.data['m_val'][self.epoch]))
            else:
                output += ' {validation:<8s} |'.format(validation=' '*8)
            for metric_label in self.external_metric_labels:
                if metric_label in self.external_metric:
                    value = self.data[metric_label][self.epoch]
                    if numpy.isnan(value):
                        value = ' '*10
                    else:
                        if(self.external_metric_labels[metric_label].endswith('f_measure') or
                           self.external_metric_labels[metric_label].endswith('f_score')):
                            value = '{:3.3f}'.format(float(value)*100)
                        else:
                            value = '{:4.3f}'.format(float(value))
                    output += ' {external_value:<10s} |'.format(
                        external_value=value
                    )
                else:
                    output += ' {external_value:<10s} |'.format(
                        external_value=' '*10
                    )
            output += ' {time:<15s}'.format(
                time=self.timer.get_string()
            )
            self.logger.info(output)
    def update_progress_bar(self, increase=0):
        """Update progress to tqdm progress bar
        """
        self.progress_bar.set_postfix(self.postfix)
        self.progress_bar.update(increase)
    def add_external_metric(self, metric_id):
        """Add external metric to be monitored
        Parameters
        ----------
        metric_id : str
            Metric name
        """
        if metric_id not in self.external_metric_labels:
            self.external_metric_labels[metric_id] = metric_id
        if metric_id not in self.data:
            self.data[metric_id] = numpy.empty((self.epochs,))
            self.data[metric_id][:] = numpy.nan
    def set_external_metric_value(self, metric_label, metric_value):
        """Add external metric value
        Parameters
        ----------
        metric_label : str
            Metric label
        metric_value : numeric
            Metric value
        """
        self.external_metric[metric_label] = metric_value
        self.data[metric_label][self.epoch] = metric_value
    def close(self):
        """Manually close progress logging
        """
        if not self.log_progress and self.close_progress_bar:
            self.progress_bar.close() 
[docs]class ProgressPlotterCallback(ProgressLoggerCallback):
    """Keras callback to plot progress during the training process and save final progress into figure.
    Implements Keras Callback API.
    """
[docs]    def __init__(self, *args, **kwargs):
        """Constructor
        Parameters
        ----------
        epochs : int
            Total amount of epochs
        metric : str
            Metric name
        manual_update : bool
            Manually update callback, use this to when injecting external metrics
            Default value True
        interactive : bool
            Show plot during the training and update with plotting rate
            Default value True
        plotting_rate : int
            Plot update rate in seconds
            Default value 10
        save : bool
            Save plot on disk, plotting rate applies
        filename : str
            Filename of figure
            Default value 1
        focus_span : int
            Epoch amount to highlight, and show separately in the plot.
            Default value 10
        """
        super(ProgressPlotterCallback, self).__init__(*args, **kwargs)
        self.filename = kwargs.get('filename')
        # Get file format for the output plot
        file_extension = os.path.splitext(self.filename)[1]
        if file_extension == '.eps':
            self.format = 'eps'
        elif file_extension == '.svg':
            self.format = 'svg'
        elif file_extension == '.pdf':
            self.format = 'pdf'
        elif file_extension == '.png':
            self.format = 'png'
        self.plotting_rate = kwargs.get('plotting_rate', 10)
        self.interactive = kwargs.get('interactive', True)
        self.save = kwargs.get('save', True)
        self.focus_span = kwargs.get('focus_span', 10)
        if self.focus_span > self.epochs:
            self.focus_span = self.epochs
        self.timer.start()
        self.data = {
            'l_tra': numpy.empty((self.epochs,)),
            'l_val': numpy.empty((self.epochs,)),
            'm_tra': numpy.empty((self.epochs,)),
            'm_val': numpy.empty((self.epochs,)),
        }
        self.data['l_tra'][:] = numpy.nan
        self.data['l_val'][:] = numpy.nan
        self.data['m_tra'][:] = numpy.nan
        self.data['m_val'][:] = numpy.nan
        for metric_label in self.external_metric_labels:
            self.data[metric_label] = numpy.empty((self.epochs,))
            self.data[metric_label][:] = numpy.nan
        self.ax1_1 = None
        self.ax1_2 = None
        self.ax2_1 = None
        self.ax2_2 = None
        self.extra_main = {}
        self.extra_highlight = {}
        import matplotlib.pyplot as plt
        import warnings
        import matplotlib.cbook
        warnings.filterwarnings("ignore", category=matplotlib.cbook.mplDeprecation)
        figure_height = 8
        if len(self.external_metric_labels) > 2:
            figure_height = 8 + len(self.external_metric_labels)
        self.figure = plt.figure(num=None, figsize=(18, figure_height), dpi=80, facecolor='w', edgecolor='k')
        self.draw()
        if self.interactive:
            plt.show(block=False)
            plt.pause(0.1) 
    def draw(self):
        """Draw plot
        """
        import matplotlib.patches as patches
        import matplotlib.pyplot as plt
        plt.figure(self.figure.number)
        row_count = 2+len(self.external_metric_labels)
        self.ax1_1 = plt.subplot2grid((row_count, 4), (0, 0), rowspan=1, colspan=3)
        self.ax1_2 = plt.subplot2grid((row_count, 4), (0, 3), rowspan=1, colspan=1)
        self.ax2_1 = plt.subplot2grid((row_count, 4), (1, 0), rowspan=1, colspan=3)
        self.ax2_2 = plt.subplot2grid((row_count, 4), (1, 3), rowspan=1, colspan=1)
        self.extra_main = {}
        self.extra_highlight = {}
        row_id = 2
        for metric_label in self.external_metric_labels:
            self.extra_main[metric_label] = plt.subplot2grid((row_count, 4), (row_id, 0), rowspan=1, colspan=3)
            self.extra_highlight[metric_label] = plt.subplot2grid((row_count, 4), (row_id, 3), rowspan=1, colspan=1)
            row_id += 1
        span = [self.epoch - self.focus_span, self.epoch]
        if span[0] < 0:
            span[0] = 0
        # PLOT 1 / Main
        self.ax1_1.cla()
        self.ax1_1.set_title('Loss')
        self.ax1_1.set_ylabel('Model Loss')
        self.ax1_1.plot(
            numpy.arange(self.epochs),
            self.data['l_tra'],
            lw=3,
            color='red',
        )
        self.ax1_1.plot(
            numpy.arange(self.epochs),
            self.data['l_val'],
            lw=3,
            color='green',
        )
        self.ax1_1.add_patch(
            patches.Rectangle(
                (span[0], self.ax1_1.get_ylim()[0]),  # (x,y)
                width=span[1]-span[0],
                height=self.ax1_1.get_ylim()[1],
                facecolor="#000000",
                alpha=0.05
            )
        )
        # Horizontal lines
        if not numpy.all(numpy.isnan(self.data['l_tra'])):
            self.ax1_1.axhline(y=numpy.nanmin(self.data['l_tra']), lw=1, color='red', linestyle='--')
            self.ax1_1.axhline(y=numpy.nanmin(self.data['l_val']), lw=1, color='green', linestyle='--')
        self.ax1_1.legend(['Train', 'Validation'], loc='upper right')
        self.ax1_1.set_xlim([0, self.epochs - 1])
        self.ax1_1.set_xticklabels([])
        self.ax1_1.grid(True)
        # PLOT 1 / Highlighted area
        self.ax1_2.cla()
        self.ax1_2.set_title('Loss / Highlighted area')
        self.ax1_2.set_ylabel('Model Loss')
        self.ax1_2.plot(
            numpy.arange(span[0], span[1]),
            self.data['l_tra'][span[0]:span[1]],
            lw=3,
            color='red',
        )
        self.ax1_2.plot(
            numpy.arange(span[0], span[1]),
            self.data['l_val'][span[0]:span[1]],
            lw=3,
            color='green',
        )
        self.ax1_2.set_xticklabels([])
        self.ax1_2.grid(True)
        self.ax1_2.yaxis.tick_right()
        self.ax1_2.yaxis.set_label_position("right")
        # PLOT 2 / Main
        self.ax2_1.cla()
        self.ax2_1.set_title('Metric')
        self.ax2_1.set_ylabel(self.metric)
        # Plots
        self.ax2_1.plot(
            numpy.arange(self.epochs),
            self.data['m_tra'],
            lw=3,
            color='red',
        )
        self.ax2_1.plot(
            numpy.arange(self.epochs),
            self.data['m_val'],
            lw=3,
            color='green',
        )
        # Horizontal lines
        if not numpy.all(numpy.isnan(self.data['m_tra'])):
            if self.get_operator(metric=self.metric) == numpy.greater:
                h_tra_line_y = numpy.nanmax(self.data['m_tra'])
                h_val_line_y = numpy.nanmax(self.data['m_val'])
            else:
                h_tra_line_y = numpy.nanmin(self.data['m_tra'])
                h_val_line_y = numpy.nanmin(self.data['m_tra'])
            self.ax2_1.axhline(y=h_tra_line_y, lw=1, color='red', linestyle='--')
            self.ax2_1.axhline(y=h_val_line_y, lw=1, color='green', linestyle='--')
        self.ax2_1.add_patch(
            patches.Rectangle(
                (span[0], self.ax2_1.get_ylim()[0]),  # (x,y)
                width=span[1]-span[0],
                height=self.ax2_1.get_ylim()[1],
                facecolor="#000000",
                alpha=0.05
            )
        )
        if self.get_operator(metric=self.metric) == numpy.greater:
            legend_location = 'lower right'
        else:
            legend_location = 'upper right'
        self.ax2_1.legend(['Train', 'Validation'], loc=legend_location)
        self.ax2_1.set_xlim([0, self.epochs - 1])
        if self.external_metric_labels:
            self.ax2_1.set_xticklabels([])
        self.ax2_1.grid(True)
        # PLOT 2 / Highlighted area
        self.ax2_2.cla()
        self.ax2_2.set_title('Metric / Highlighted area')
        self.ax2_2.set_ylabel(self.metric)
        self.ax2_2.plot(
            numpy.arange(span[0], span[1]),
            self.data['m_tra'][span[0]:span[1]],
            lw=3,
            color='red',
        )
        self.ax2_2.plot(
            numpy.arange(span[0], span[1]),
            self.data['m_val'][span[0]:span[1]],
            lw=3,
            color='green',
        )
        self.ax2_2.set_xticklabels([])
        self.ax2_2.grid(True)
        self.ax2_2.yaxis.tick_right()
        self.ax2_2.yaxis.set_label_position("right")
        for mid, metric_label in enumerate(self.external_metric_labels):
            metric_name = self.external_metric_labels[metric_label]
            if metric_name.endswith('f_measure') or metric_name.endswith('f_score'):
                factor = 100
            else:
                factor = 1
            # PLOT 3 / Main
            self.extra_main[metric_label].cla()
            self.extra_main[metric_label].set_title('External metric')
            self.extra_main[metric_label].set_ylabel(str(metric_label))
            mask = numpy.isfinite(self.data[metric_label])
            self.extra_main[metric_label].plot(
                numpy.arange(self.epochs)[mask],
                self.data[metric_label][mask]*factor,
                lw=3,
                color='green',
                marker='o',
            )
            self.extra_main[metric_label].add_patch(
                patches.Rectangle(
                    (span[0], self.extra_main[metric_label].get_ylim()[0]),  # (x,y)
                    width=span[1]-span[0],
                    height=self.extra_main[metric_label].get_ylim()[1],
                    facecolor="#000000",
                    alpha=0.05
                )
            )
            # Horizontal lines
            if not numpy.all(numpy.isnan(self.data[metric_label][mask])):
                if self.get_operator(metric=str(metric_label)) == numpy.greater:
                    h_extra_line_y = numpy.nanmax((self.data[metric_label][mask]*factor))
                else:
                    h_extra_line_y = numpy.nanmin((self.data[metric_label][mask]*factor))
                self.extra_main[metric_label].axhline(y=h_extra_line_y, lw=1, color='blue', linestyle='--')
            if self.get_operator(metric=self.metric) == numpy.greater:
                legend_location = 'lower right'
            else:
                legend_location = 'upper right'
            self.extra_main[metric_label].legend(['Validation'], loc=legend_location)
            self.extra_main[metric_label].set_xlim([0, self.epochs - 1])
            if (mid + 1) < len(self.external_metric_labels):
                self.extra_main[metric_label].set_xticklabels([])
            else:
                self.extra_main[metric_label].set_xlabel('Epochs')
            self.extra_main[metric_label].grid(True)
            # PLOT 3 / Highlighted area
            self.extra_highlight[metric_label].cla()
            self.extra_highlight[metric_label].set_title('External metric / Highlighted area')
            self.extra_highlight[metric_label].set_ylabel(str(metric_label))
            highlight_data = self.data[metric_label][span[0]:span[1]]*factor
            mask = numpy.isfinite(highlight_data)
            self.extra_highlight[metric_label].plot(
                numpy.arange(span[0], span[1])[mask],
                highlight_data[mask],
                lw=3,
                color='green',
                marker='o',
            )
            if (mid + 1) < len(self.external_metric_labels):
                self.extra_highlight[metric_label].set_xticklabels([])
            else:
                self.extra_highlight[metric_label].set_xlabel('Epochs')
            self.extra_highlight[metric_label].yaxis.tick_right()
            self.extra_highlight[metric_label].yaxis.set_label_position("right")
            self.extra_highlight[metric_label].grid(True)
        plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0.1, wspace=0.02, hspace=0.2)
    def on_train_begin(self, logs=None):
        if self.epochs is None:
            self.epochs = self.params['epochs']
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1
        self.seen = 0
    def on_epoch_end(self, epoch, logs=None):
        self.epoch = epoch
        logs = logs or {}
        # Collect values
        for k in self.params['metrics']:
            if k in logs:
                self.log_values.append((k, logs[k]))
                if k == 'loss':
                    self.data['l_tra'][self.epoch] = logs[k]
                elif k == 'val_loss':
                    self.data['l_val'][self.epoch] = logs[k]
                elif self.metric and k.endswith(self.metric):
                    if k.startswith('val_'):
                        self.data['m_val'][self.epoch] = logs[k]
                    else:
                        self.data['m_tra'][self.epoch] = logs[k]
        if not self.manual_update:
            # Update logged progress
            self.update()
    def update(self):
        """Update
        """
        import matplotlib.pyplot as plt
        if self.timer.elapsed() > self.plotting_rate:
            self.draw()
            self.figure.canvas.flush_events()
            if self.interactive:
                plt.pause(0.01)
            if self.save:
                plt.savefig(self.filename, bbox_inches='tight', format=self.format, dpi=1000)
            self.timer.start()
    def add_external_metric(self, metric_label):
        """Add external metric to be monitored
        Parameters
        ----------
        metric_label : str
            Metric label
        """
        if metric_label not in self.external_metric_labels:
            self.external_metric_labels[metric_label] = metric_label
        if metric_label not in self.data:
            self.data[metric_label] = numpy.empty((self.epochs,))
            self.data[metric_label][:] = numpy.nan
    def set_external_metric_value(self, metric_label, metric_value):
        """Add external metric value
        Parameters
        ----------
        metric_label : str
            Metric label
        metric_value : numeric
            Metric value
        """
        self.external_metric[metric_label] = metric_value
        self.data[metric_label][self.epoch] = metric_value
    def close(self):
        """Manually close progress logging
        """
        import matplotlib.pyplot as plt
        if self.save:
            self.draw()
            plt.savefig(self.filename, bbox_inches='tight', format=self.format, dpi=1000)
        plt.close(self.figure) 
[docs]class StopperCallback(BaseCallback):
    """Keras callback to stop training when improvement has not seen in specified amount of epochs.
    Implements Keras Callback API.
    Callback is very similar to standard ``EarlyStopping`` Keras callback, however it adds support for external metrics
    (calculated outside Keras training process).
    """
[docs]    def __init__(self, *args, **kwargs):
        """Constructor
        Parameters
        ----------
        epochs : int
            Total amount of epochs
        manual_update : bool
            Manually update callback, use this to when injecting external metrics
            Default value True
        monitor : str
            Metric value to be monitored
            Default value "val_loss"
        patience : int
            Number of epochs with no improvement after which training will be stopped.
            Default value 0
        min_delta : float
            Minimum change in the monitored quantity to qualify as an improvement.
            Default value 0
        initial_delay : int
            Amount of epochs to wait at the beginning before quantity is monitored.
            Default value 10
        """
        super(StopperCallback, self).__init__(*args, **kwargs)
        self.monitor = kwargs.get('monitor', 'val_loss')
        self.patience = kwargs.get('patience', 0)
        self.min_delta = kwargs.get('min_delta', 0)
        self.initial_delay = kwargs.get('initial_delay', 10)
        self.wait = None
        self.stopped_epoch = None
        self.model = None
        self.params = None
        self.last_update_epoch = 0
        self.stopped = False
        self.logger = logging.getLogger(__name__)
        self.metric_data = {
            self.monitor: numpy.empty((self.epochs,))
        }
        self.metric_data[self.monitor][:] = numpy.nan
        mode = kwargs.get('mode', 'auto')
        if mode not in ['min', 'max', 'auto']:
            mode = 'auto'
        if mode == 'min':
            self.monitor_op = numpy.less
        elif mode == 'max':
            self.monitor_op = numpy.greater
        else:
            self.monitor_op = self.get_operator(metric=self.monitor)
        self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf
        if self.monitor_op == numpy.greater:
            self.min_delta *= 1
        else:
            self.min_delta *= -1 
    def on_train_begin(self, logs=None):
        if self.epochs is None:
            self.epochs = self.params['epochs']
        if self.wait is None:
            self.wait = 0
        if self.stopped_epoch is None:
            self.stopped_epoch = 0
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1
    def on_epoch_end(self, epoch, logs=None):
        self.epoch = epoch
        if self.monitor in logs:
            self.metric_data[self.monitor][self.epoch] = logs.get(self.monitor)
        if not self.manual_update:
            self.update()
    def set_external_metric_value(self, metric_label, metric_value):
        """Add external metric value
        Parameters
        ----------
        metric_label : str
            Metric label
        metric_value : numeric
            Metric value
        """
        if metric_label not in self.metric_data:
            self.metric_data[metric_label] = numpy.empty((self.epochs,))
            self.metric_data[metric_label][:] = numpy.nan
        self.metric_data[metric_label][self.epoch] = metric_value
    def stop(self):
        return self.stopped
    def update(self):
        if self.epoch > self.initial_delay:
            # get current metric value
            current = self.metric_data[self.monitor][self.epoch]
            if numpy.isnan(current):
                message = '{name}: Metric to monitor is Nan, metric:[{metric}]'.format(
                    name=self.__class__.__name__,
                    metric=self.monitor
                )
                self.logger.exception(message)
                raise ValueError(message)
            if self.monitor_op(current - self.min_delta, self.best):
                # New best value found
                self.best = current
                self.wait = 0
            else:
                if self.wait >= self.patience:
                    # Stopping criteria met => return false
                    self.stopped_epoch = self.epoch
                    self.model.stop_training = True
                    self.logger.info('  Stopping criteria met at epoch[{epoch:d}]'.format(
                        epoch=self.epoch,
                    ))
                    self.logger.info('    metric[{metric}], patience[{patience:d}]'.format(
                        metric=self.monitor,
                        current='{:4.4f}'.format(current),
                        patience=self.patience
                    ))
                    self.logger.info('  ')
                    self.stopped = True
                    return self.stopped
                # Increase waiting counter
                self.wait += self.epoch - self.last_update_epoch
        self.last_update_epoch = self.epoch
        return self.stopped 
[docs]class StasherCallback(BaseCallback):
    """Keras callback to monitor training process and store best model. Implements Keras Callback API.
    This callback is very similar to standard ``ModelCheckpoint`` Keras callback, however it adds support for external
    metrics (metrics calculated outside Keras training process).
    """
[docs]    def __init__(self, *args, **kwargs):
        """Constructor
        Parameters
        ----------
        epochs : int
            Total amount of epochs
        manual_update : bool
            Manually update callback, use this to when injecting external metrics
            Default value True
        monitor : str
            Metric to monitor
            Default value 'val_loss'
        mode : str
            Which way metric is interpreted, values {auto, min, max}
            Default value 'auto'
        period : int
            Disable tqdm based progress bar
            Default value 1
        initial_delay : int
            Amount of epochs to wait at the beginning before quantity is monitored.
            Default value 10
        save_weights : bool
            Save weight to the disk
            Default value False
        file_path : str
            File name for model weight
            Default value None
        """
        super(StasherCallback, self).__init__(*args, **kwargs)
        self.monitor = kwargs.get('monitor', 'val_loss')
        self.period = kwargs.get('period', 1)
        self.initial_delay = kwargs.get('initial_delay', 10)
        self.save_weights = kwargs.get('save_weights', False)
        self.file_path = kwargs.get('file_path', None)
        self.epochs_since_last_save = 0
        self.logger = logging.getLogger(__name__)
        mode = kwargs.get('mode', 'auto')
        if mode not in ['auto', 'min', 'max']:
            mode = 'auto'
        if mode == 'min':
            self.monitor_op = numpy.less
        elif mode == 'max':
            self.monitor_op = numpy.greater
        else:
            self.monitor_op = self.get_operator(metric=self.monitor)
        self.best = numpy.Inf if self.monitor_op == numpy.less else -numpy.Inf
        self.metric_data = {
            self.monitor: numpy.empty((self.epochs,))
        }
        self.metric_data[self.monitor][:] = numpy.nan
        self.best_model_weights = None
        self.best_model_epoch = 0
        self.last_logs = None 
    def on_epoch_begin(self, epoch, logs=None):
        self.epoch = epoch + 1
    def on_epoch_end(self, epoch, logs=None):
        self.epoch = epoch
        if self.monitor in logs:
            self.metric_data[self.monitor][self.epoch] = logs.get(self.monitor)
        self.last_logs = logs
        if not self.manual_update:
            self.update()
    def update(self):
        if self.epoch > self.initial_delay:
            self.epochs_since_last_save += 1
            if self.epochs_since_last_save >= self.period:
                self.epochs_since_last_save = 0
                current = self.metric_data[self.monitor][self.epoch]
                if numpy.isnan(current):
                    message = '{name}: Metric to monitor is Nan, metric:[{metric}]'.format(
                        name=self.__class__.__name__,
                        metric=self.monitor
                    )
                    self.logger.exception(message)
                    raise ValueError(message)
                else:
                    if self.monitor_op(current, self.best):
                        # Store the best
                        self.best = current
                        self.best_model_weights = self.model.get_weights()
                        self.best_model_epoch = self.epoch
                        if self.save_weights and self.file_path:
                            # Save weight on disk
                            logs = self.last_logs
                            if self.monitor not in logs:
                                logs[self.monitor] = current
                            file_path = self.file_path.format(epoch=self.epoch, **self.last_logs)
                            self.model.save_weights(file_path, overwrite=True)
    def set_external_metric_value(self, metric_label, metric_value):
        """Add external metric value
        Parameters
        ----------
        metric_label : str
            Metric label
        metric_value : numeric
            Metric value
        """
        if metric_label not in self.metric_data:
            self.metric_data[metric_label] = numpy.empty((self.epochs,))
            self.metric_data[metric_label][:] = numpy.nan
        self.metric_data[metric_label][self.epoch] = metric_value
    def get_best(self):
        """Return best model seen
        Returns
        -------
        dict
            Dictionary with keys 'weights', 'epoch', 'metric_value'
        """
        return {
            'epoch': self.best_model_epoch,
            'weights': self.best_model_weights,
            'metric_value': self.best,
        }
    def log(self):
        """Print information about the best model into logging interface
        """
        self.logger.info('  Best model weights at epoch[{epoch:d}]'.format(epoch=self.best_model_epoch))
        self.logger.info('    metric[{metric}]={best}'.format(
                metric=self.monitor,
                best='{:4.4f}'.format(self.best)
            )
        )
        self.logger.info('  ') 
[docs]class BaseDataGenerator(object):
    """Base class for data generator.
    """
[docs]    def __init__(self, *args, **kwargs):
        """Constructor
        Parameters
        ----------
        files : list
        data_filenames : dict
        annotations : dict
        class_labels : list of str
        hop_length_seconds : float
            Default value 0.2
        shuffle : bool
            Default value True
        batch_size : int
            Default value 64
        buffer_size : int
            Default value 256
        """
        self.method = 'base_generator'
        # Data
        self.item_list = copy.copy(kwargs.get('files', []))
        self.data_filenames = kwargs.get('data_filenames', {})
        self.annotations = kwargs.get('annotations', {})
        # Activity matrix
        self.class_labels = kwargs.get('class_labels', [])
        self.hop_length_seconds = kwargs.get('hop_length_seconds', 0.2)
        self.shuffle = kwargs.get('shuffle', True)
        self.batch_size = kwargs.get('batch_size', 64)
        self.buffer_size = kwargs.get('buffer_size', 256)
        self.logger = kwargs.get('logger', logging.getLogger(__name__))
        # Internal state variables
        self.batch_index = 0
        self.item_index = 0
        self.data_position = 0
        # Initialize data buffer
        self.data_buffer = DataBuffer(size=self.buffer_size)
        if self.buffer_size >= len(self.item_list):
            # Fill data buffer at initialization if it fits fully to the buffer
            for current_item in self.item_list:
                self.process_item(item=current_item)
                if self.data_buffer.full():
                    break
        self._data_size = None
        self._input_size = None 
    @property
    def steps_count(self):
        """Number of batches in one epoch
        """
        num_batches = int(numpy.ceil(self.data_size / float(self.batch_size)))
        if num_batches > 0:
            return num_batches
        else:
            return 1
    @property
    def input_size(self):
        """Length of input feature vector
        """
        if self._input_size is None:
            # Load first item
            first_item = list(self.data_filenames.keys())[0]
            self.process_item(item=first_item)
            # Get Feature vector length
            self._input_size = self.data_buffer.get(key=first_item)[0].shape[-1]
        return self._input_size
    @property
    def data_size(self):
        """Total data amount
        """
        if self._data_size is None:
            self._data_size = 0
            for current_item in self.item_list:
                self.process_item(item=current_item)
                data, meta = self.data_buffer.get(key=current_item)
                # Accumulate feature matrix length
                self._data_size += data.shape[0]
        return self._data_size
[docs]    def info(self):
        """Information logging
        """
        info = [
            '  Generator',
            '    Shuffle \t[{shuffle}]'.format(shuffle='True' if self.shuffle else 'False'),
            '    Epoch size\t[{steps:d} batches]'.format(steps=self.steps_count),
            '    Buffer size \t[{buffer_size:d} files]'.format(buffer_size=self.buffer_size),
            ' '
        ]
        return info 
    def process_item(self, item):
        pass
    def on_epoch_start(self):
        pass
    def on_epoch_end(self):
        pass 
[docs]class FeatureGenerator(BaseDataGenerator):
    """Feature data generator
    """
[docs]    def __init__(self, *args, **kwargs):
        """Constructor
        Parameters
        ----------
        files : list of str
            List of active item identifies, usually filenames
        data_filenames : dict of dicts
            Data structure keyed with item identifiers (defined with files parameter), data dict feature extractor
            labels as keys and values the filename on disk.
        annotations : dict of MetaDataContainers or MetaDataItems
            Annotations for all items keyed with item identifiers
        class_labels : list of str
            Class labels in a list
        hop_length_seconds : float
            Analysis frame hop length in seconds
            Default value 0.2
        shuffle : bool
            Shuffle data before each epoch
            Default value True
        batch_size : int
            Batch size to generate
            Default value 64
        buffer_size : int
            Internal item buffer size, set large enough for smaller dataset to avoid loading
            Default value 256
        data_processor : class
            Data processor class used to process load features
        data_refresh_on_each_epoch : bool
            Internal data buffer reset at the start of each epoch
            Default value False
        label_mode : str ('event', 'scene')
            Activity matrix forming mode.
            Default value "event"
        """
        self.data_processor = kwargs.get('data_processor')
        self.data_refresh_on_each_epoch = kwargs.get('data_refresh_on_each_epoch', False)
        self.label_mode = kwargs.get('label_mode', 'event')
        super(FeatureGenerator, self).__init__(*args, **kwargs)
        self.method = 'feature'
        self.logger = kwargs.get('logger', logging.getLogger(__name__))
        if self.label_mode not in ['event', 'scene']:
            message = '{name}: Label mode unknown [{label_mode}]'.format(
                name=self.__class__.__name__,
                metric=self.label_mode
            )
            self.logger.exception(message)
            raise ValueError(message) 
    def process_item(self, item):
        if not self.data_buffer.key_exists(key=item):
            current_data, current_length = self.data_processor.load(
                feature_filename_dict=self.data_filenames[item]
            )
            current_activity_matrix = self.get_activity_matrix(
                annotation=self.annotations[item],
                data_length=current_length
            )
            self.data_buffer.set(key=item, data=current_data, meta=current_activity_matrix)
    def on_epoch_start(self):
        self.batch_index = 0
        if self.shuffle:
            # Shuffle item list order
            numpy.random.shuffle(self.item_list)
        if self.data_refresh_on_each_epoch:
            # Force reload of data
            self.data_buffer.clear()
[docs]    def generator(self):
        """Generator method
        Returns
        -------
        ndarray
            data batches
        """
        while True:
            # Start of epoch
            self.on_epoch_start()
            batch_buffer_data = []
            batch_buffer_meta = []
            # Go through items
            for item in self.item_list:
                # Load item data into buffer
                self.process_item(item=item)
                # Fetch item from buffer
                data, meta = self.data_buffer.get(key=item)
                # Data indexing
                data_ids = numpy.arange(data.shape[0])
                # Shuffle data order
                if self.shuffle:
                    numpy.random.shuffle(data_ids)
                for data_id in data_ids:
                    if len(batch_buffer_data) == self.batch_size:
                        # Batch buffer full, yield data
                        yield (
                            numpy.concatenate(
                                numpy.expand_dims(batch_buffer_data, axis=0)
                            ),
                            numpy.concatenate(
                                numpy.expand_dims(batch_buffer_meta, axis=0)
                            )
                        )
                        # Empty batch buffers
                        batch_buffer_data = []
                        batch_buffer_meta = []
                        # Increase batch counter
                        self.batch_index += 1
                    # Collect data fro the batch
                    batch_buffer_data.append(data[data_id])
                    batch_buffer_meta.append(meta[data_id])
            if len(batch_buffer_data):
                # Last batch, usually not full
                yield (
                    numpy.concatenate(
                        numpy.expand_dims(batch_buffer_data, axis=0)
                    ),
                    numpy.concatenate(
                        numpy.expand_dims(batch_buffer_meta, axis=0)
                    ),
                )
                # Increase batch counter
                self.batch_index += 1
            # End of epoch
            self.on_epoch_end() 
    def get_activity_matrix(self, annotation, data_length):
        """Convert annotation into activity matrix and run it through data processor.
        """
        event_roll = None
        if self.label_mode == 'event':
            # Event activity, event onset and offset specified
            event_roll = EventRoll(metadata_container=annotation,
                                   label_list=self.class_labels,
                                   time_resolution=self.hop_length_seconds
                                   )
            event_roll = event_roll.pad(length=data_length)
        elif self.label_mode == 'scene':
            # Scene activity, one-hot activity throughout whole file
            pos = self.class_labels.index(annotation.scene_label)
            event_roll = numpy.zeros((data_length, len(self.class_labels)))
            event_roll[:, pos] = 1
        if event_roll is not None:
            return self.data_processor.process_activity_data(
                activity_data=event_roll
            )
        else:
            return None