Shortcuts

Source code for sentarget.utils.functions

"""
Utility functions.
"""

import functools
import itertools

import torch


[docs]def append2dict(main_dict, *dicts): """ Append key values to another dict with the same keys. Args: main_dict (dict): dictionary where values will be added. *dicts (dict): dictionaries to extract values and append to another one. These dictionaries should have the same keys as dict. Examples:: >>> dict1 = {"key1": [], "key2": []} >>> dict2 = {"key1": 0, "key2": 1} >>> append2dict(dict1, dict2) >>> dict1 {"key1": [0], "key2": [1]} >>> dict3 = {"key1": 2, "key2": 3} >>> dict4 = {"key1": 4, "key2": 5} >>> append2dict(dict1, dict3, dict4) >>> dict1 {"key1": [0, 2, 4], "key2": [1, 3, 5]} """ # Multiples dictionaries to merge for d in dicts: for (key, value) in d.items(): # Test if the dictionary to append have the key try: main_dict[key].append(value) # If not, create the key and merge the value except: main_dict[key] = [value]
[docs]def permutation_dict(params): r"""Generate a list of configuration files used to tune a model. Returns: list Examples:: >>> hyper_params = {'dropout': [0, 0.1, 0.2, 0.3], ... 'in_features': [10, 20, 30, 40], ... 'out_features': [20, 30, 40, 50]} >>> permutation_dict(hyper_params) [{'dropout': 0, 'in_features': 10, 'out_features': 20}, {'dropout': 0, 'in_features': 10, 'out_features': 30}, {'dropout': 0, 'in_features': 10, 'out_features': 40}, {'dropout': 0, 'in_features': 10, 'out_features': 50}, {'dropout': 0, 'in_features': 20, 'out_features': 20}, {'dropout': 0, 'in_features': 20, 'out_features': 30}, ... ] """ params_list = {key: value for (key, value) in params.items() if isinstance(value, list)} params_single = {key: value for (key, value) in params.items() if not isinstance(value, list)} keys, values = zip(*params_list.items()) permutations = [dict(zip(keys, v), **params_single) for v in itertools.product(*values)] return permutations
[docs]def serialize_dict(data): r"""Serialize recursively a dict to another dict composed of basic python object (list, dict, int, float, str...) Args: data (dict): dict to serialize Returns: dict Examples:: >>> data = {'tensor': torch.tensor([0, 1, 2, 3, 4]), ... 'sub_tensor': [torch.tensor([1, 2, 3, 4, 5])], ... 'data': [1, 2, 3, 4, 5], ... 'num': 1} >>> serialize_dict(data) {'tensor': None, 'sub_tensor': [], 'data': [1, 2, 3, 4, 5], 'num': 1} """ new_data = {} for (key, value) in data.items(): if isinstance(value, dict): new_data[key] = serialize_dict(value) elif isinstance(value, list): new_data[key] = serialize_list(value) elif isinstance(value, int) or isinstance(value, float) or isinstance(value, str) or isinstance(value, bool): new_data[key] = value else: new_data[str(key)] = None return new_data
[docs]def serialize_list(data): """Serialize recursively a list to another list composed of basic python object (list, dict, int, float, str...) Args: data (list): list to serialize Returns: list Examples:: >>> data = [1, 2, 3, 4] >>> serialize_list(data) [1, 2, 3, 4] >>> data = [torch.tensor([1, 2, 3, 4])] >>> serialize_list(data) [] >>> data = [1, 2, 3, 4, torch.tensor([1, 2, 3, 4])] >>> serialize_list(data) [1, 2, 3, 4] """ new_data = [] for value in data: if isinstance(value, list): new_data.append(serialize_list(value)) elif isinstance(value, dict): new_data.append(serialize_dict(value)) elif isinstance(value, int) or isinstance(value, float) or isinstance(value, str) or isinstance(value, bool): new_data.append(value) else: return [] return new_data
[docs]def rsetattr(obj, attr, val): r"""Set an attribute recursively. ..note :: Attributes should be split with a dot ``.``. Args: obj (object): object to set the attribute. attr (string): path to the attribute. val (value): value to set. """ pre, _, post = attr.rpartition('.') return setattr(rgetattr(obj, pre) if pre else obj, post, val)
[docs]def rgetattr(obj, attr, *args): r"""Get an attribute recursively. Args: obj (object): object to get the attribute. attr (string): path to the attribute. *args: Returns: attribute """ def _getattr(obj, attr): return getattr(obj, attr, *args) return functools.reduce(_getattr, [obj] + attr.split('.'))

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Check the GitHub page and contribute to the project

View GitHub