Shortcuts

Source code for sentarget.tuner.functional

"""
Optimization functions used for hyperparameters tuning.
"""

import inspect
from sentarget.utils import rgetattr, rsetattr


[docs]def tune(model, config): r""" .. note:: If the key is separated with a '.', it means the first index is the module to change, then the attribute ``key = 'LSTM.dropout'`` will modify only the dropout corresponding to ``LSTM`` layers The double underscore ``__`` is used to modify a specific attribute by its name (and not its type), like ``key = 'linear__in_features'`` will modify only the ``in_features`` attribute from the ``Linear`` layer saved under the attribute ``linear`` of the custom model. .. warning:: The operation modify the model inplace. Args: model (Model): the model to tune its hyperparameters. config (dict): dictionary of parameters to change. Returns: dict: the configuration to apply to a model. Examples:: >>> from sentarget.nn.models.lstm import BiLSTM >>> # Defines the shape of the models >>> INPUT_DIM = len(TEXT.vocab) >>> EMBEDDING_DIM = 100 >>> HIDDEN_DIM = 128 >>> OUTPUT_DIM = len(LABEL.vocab) >>> N_LAYERS = 2 >>> BIDIRECTIONAL = True >>> DROPOUT = 0.25 >>> PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token] >>> model = BiLSTM(INPUT_DIM, ... EMBEDDING_DIM, ... HIDDEN_DIM, ... OUTPUT_DIM, ... N_LAYERS, ... BIDIRECTIONAL, ... DROPOUT, ... PAD_IDX) >>> config = {'LSTM.dropout': 0.2} >>> tune(model, config) """ for (key, value) in config.items(): attribute_list = key.split('__') attribute = attribute_list[0] module_path = key.split('__')[-1] # Change values from the attribute's key if len(attribute_list) == 2: attribute = getattr(model, attribute) try: rsetattr(attribute, module_path, value) except AttributeError: pass # Change values from module's type elif len(attribute_list) == 1: attribute = '.'.join(attribute.split('.')[1:]) for module in model.modules(): try: rsetattr(module, attribute, value) except AttributeError: pass else: raise KeyError(f'path to attribute {key} is ambiguous. Please separate objects with a `.` or `__`. \ More informations at https://pages.github.uio.no/arthurd/in5550-exam/source/package.html#sentarget-optim')
[docs]def init_cls(class_instance, config): r"""Initialize a class instance from a set of possible values. .. note:: More parameters can be added than the object need. They will just not be used. Args: class_instance (class): class to initialize. config (dict): possible values of init parameters. Returns: initialized object """ # Get the init parameters arguments = inspect.getargspec(class_instance.__init__).args # Remove the 'self' argument, which can't be changed. arguments.pop(0) init = {key: value for (key, value) in config.items() if key in arguments} return class_instance(**init)
[docs]def tune_optimizer(optimizer, config): r"""Tune te defaults parameters for an optimizer. .. warning:: The operation modify directly the ``defaults`` optimizer's dictionary. Args: optimizer (Optimizer): optimizer to tune. config (dict): dictionary of new parameters to set. """ for (key, value) in config.items(): if key in optimizer.defaults: optimizer.defaults[key] = value

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Check the GitHub page and contribute to the project

View GitHub