Source code for dcase_framework.containers

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Features
==================
Classes for data containers



"""

from __future__ import print_function, absolute_import
from six import iteritems
import numpy
import logging
import hashlib
import json
import copy


class ContainerMixin(object):
    def get_path(self, dotted_path, default=None, data=None):
        """Get value from nested dict with dotted path

        Parameters
        ----------
        dotted_path : str
            String in form of "field1.field2.field3"
        default : str, int, float
            Default value returned if path does not exists
            Default value "None"
        data : dict, optional
            Dict for which path search is done, if None given self is used. Used for recursive path search.
            Default value "None"

        Returns
        -------

        """

        if data is None:
            data = self
        fields = dotted_path.split('.')

        if '*' == fields[0]:
            # Magic field to return all childes in a list
            sub_list = []
            for key, value in iteritems(data):
                if len(fields) > 1:
                    sub_list.append(self.get_path(data=value, dotted_path='.'.join(fields[1:]), default=default))
                else:
                    sub_list.append(value)
            return sub_list
        else:
            if fields[0] in data and len(fields) > 1:
                # Go deeper
                return self.get_path(data=data[fields[0]], dotted_path='.'.join(fields[1:]), default=default)

            elif fields[0] in data and len(fields) == 1:
                # We reached to the node
                return data[fields[0]]

            else:
                return default

    def set_path(self, dotted_path, new_value, data=None):
        """Set value in nested dict with dotted path

        Parameters
        ----------
        dotted_path : str
            String in form of "field1.field2.field3"
        new_value :
            new value to be placed
        data : dict, optional
            Dict for which path search is done, if None given self is used. Used for recursive path search.
            Default value "None"

        Returns
        -------

        """

        if data is None:
            data = self
        fields = dotted_path.split('.')

        if '*' == fields[0]:
            # Magic field to set all childes in a list
            for key, value in iteritems(data):
                if len(fields) > 1:
                    self.set_path(new_value=new_value, data=value, dotted_path='.'.join(fields[1:]))
                else:
                    data[key] = new_value

        else:
            print(fields[0])
            if len(fields) == 1:
                # We reached to the node
                data[fields[0]] = new_value
            else:
                if fields[0] not in data:
                    data[fields[0]] = {}
                elif not isinstance(data[fields[0]], dict):
                    # Overwrite path
                    data[fields[0]] = {}
                self.set_path(new_value=new_value, data=data[fields[0]], dotted_path='.'.join(fields[1:]))

    def _walk(self, d, depth=0):
        """Recursive dict walk to get string of the content nicely formatted

        Parameters
        ----------
        d : dict
            Dict for walking
        depth : int
            Depth of walk, string is indented with this
            Default value 0

        Returns
        -------
            str

        """

        output = ''
        indent = 3
        header_width = 35 - depth*indent

        for k, v in sorted(d.items(), key=lambda x: x[0]):
            if isinstance(v, dict):
                output += "".ljust(depth * indent)+k+'\n'
                output += self._walk(v, depth + 1)
            else:
                if isinstance(v, numpy.ndarray):
                    # Numpy array or matrix
                    shape = v.shape
                    if len(shape) == 1:
                        output += "".ljust(depth * indent)
                        output += k.ljust(header_width) + " : " + "array (%d)" % (v.shape[0]) + '\n'

                    elif len(shape) == 2:
                        output += "".ljust(depth * indent)
                        output += k.ljust(header_width) + " : " + "matrix (%d,%d)" % (v.shape[0], v.shape[1]) + '\n'

                elif isinstance(v, list) and len(v) and isinstance(v[0], str):
                    output += "".ljust(depth * indent) + k.ljust(header_width) + " : list (%d)\n" % len(v)
                    for item_id, item in enumerate(v):
                        output += "".ljust((depth + 1) * indent)
                        output += ("["+str(item_id)+"]").ljust(header_width-3) + " : " + str(item) + '\n'

                elif isinstance(v, list) and len(v) and isinstance(v[0], numpy.ndarray):
                    # List of arrays
                    output += "".ljust(depth * indent) + k.ljust(header_width) + " : list (%d)\n" % len(v)
                    for item_id, item in enumerate(v):
                        if len(item.shape) == 1:
                            output += "".ljust((depth+1) * indent)
                            output += ("["+str(item_id)+"]").ljust(header_width-3) + " : array (%d)" % (item.shape[0]) + '\n'

                        elif len(item.shape) == 2:
                            output += "".ljust((depth+1) * indent)
                            output += ("["+str(item_id)+"]").ljust(header_width-3) + " : matrix (%d,%d)" % (item.shape[0], item.shape[1]) + '\n'

                elif isinstance(v, list) and len(v) and isinstance(v[0], dict):
                    output += "".ljust(depth * indent)
                    output += k.ljust(header_width) + " : list (%d)\n" % len(v)

                    for item_id, item in enumerate(v):
                        output += "".ljust((depth + 1) * indent) + "["+str(item_id)+"]" + '\n'
                        output += self._walk(item, depth + 2)

                else:
                    output += "".ljust(depth * indent) + k.ljust(header_width) + " : " + str(v) + '\n'

        return output

    def __str__(self):
        return self._walk(self, depth=1)

    def show(self):
        """Print container content

        Returns
        -------
            Nothing

        """

        print(self._walk(self, depth=1))

    def log(self, level='info'):
        """Log container content

        Parameters
        ----------
        level : str
            Logging level, possible values [info, debug, warn, warning, error, critical]
            Default value "info"

        Returns
        -------
            Nothing

        """

        lines = str(self).split('\n')
        logger = logging.getLogger(__name__)
        for line in lines:
            if level.lower() == 'debug':
                logger.debug(line)
            elif level.lower() == 'info':
                logger.info(line)
            elif level.lower() == 'warn' or level.lower() == 'warning':
                logger.warn(line)
            elif level.lower() == 'error':
                logger.error(line)
            elif level.lower() == 'critical':
                logger.critical(line)

    @staticmethod
    def _search_list_of_dictionaries(key, value, list_of_dictionaries):
        """Search in the list of dictionaries

        Parameters
        ----------
        key : str
            Dict key for the search
        value :
            Value for the key to match
        list_of_dictionaries : list of dicts
            List to search

        Returns
        -------
            Dict or None

        """

        for element in list_of_dictionaries:
            if element.get(key) == value:
                return element
        return None

    def merge(self, override, target=None):
        """ Recursive dict merge

        Parameters
        ----------
        target : dict
            target parameter dict

        override : dict
            override parameter dict

        Returns
        -------
        None

        """

        if not target:
            target = self

        for k, v in iteritems(override):
            if k in target and isinstance(target[k], dict) and isinstance(override[k], dict):
                self.merge(target=target[k], override=override[k])
            else:
                target[k] = override[k]

    def get_hash_for_path(self, dotted_path=None):
        if dotted_path:
            data = self.get_path(dotted_path=dotted_path)
            if data is not None:
                return self.get_hash(data)
            else:
                return None
        else:
            return self.get_hash(self)

    def get_hash(self, data=None):
        """Get unique hash string (md5) for given parameter dict

        Parameters
        ----------
        data : dict
            Input parameters

        Returns
        -------
        md5_hash : str
            Unique hash for parameter dict

        """

        if data is None:
            data = dict(self)

        md5 = hashlib.md5()
        md5.update(str(json.dumps(self._clean_for_hashing(copy.deepcopy(data)), sort_keys=True)).encode('utf-8'))
        return md5.hexdigest()

    def _clean_for_hashing(self, data, non_hashable_fields=None):
        # Recursively remove keys with value set to False, or non hashable fields
        if non_hashable_fields is None and hasattr(self, 'non_hashable_fields'):
            non_hashable_fields = self.non_hashable_fields
        elif non_hashable_fields is None:
            non_hashable_fields = []

        if data:
            if 'enable' in data and not data['enable']:
                return {
                    'enable': False,
                }
            else:
                if isinstance(data, dict):
                    for key in list(data.keys()):
                        value = data[key]
                        if isinstance(value, bool) and value is False:
                            # Remove fields marked False
                            del data[key]
                        elif key in non_hashable_fields:
                            # Remove fields marked in non_hashable_fields list
                            del data[key]
                        elif isinstance(value, dict):
                            if 'enable' in value and not value['enable']:
                                # Remove dict block which is disabled
                                del data[key]
                            else:
                                # Proceed recursively
                                data[key] = self._clean_for_hashing(value)
                    return data
                else:
                    return data
        else:
            return data


class DottedDict(dict, ContainerMixin):
    def __init__(self, *args, **kwargs):
        super(DottedDict, self).__init__(*args, **kwargs)

        self.non_hashable_fields = [
            '_hash',
            'verbose',
        ]
        if kwargs.get('non_hashable_fields'):
            self.non_hashable_fields.update(kwargs.get('non_hashable_fields'))

    def __getstate__(self):
        return dict(self)

    def __setstate__(self, state):
        self.__dict__ = state