Shortcuts

Source code for sentarget.tuner.tuner

r"""
Hyperparameters optimization using a grid search algorithm.
Basically, you need to provide a set of parameters that will be modified.
The grid search will run on all permutations from the set of parameters provided.
Usually, you modify the hyperparameters and models' modules (ex, dropout etc.).
In addition, if you are using custom losses or optimizer that needs additional arguments / parameters,
you can provide them through the specific dictionaries (see the documentation of ``Tuner``).


Examples:

.. code-block:: python

    # Hyper parameters to tune
    params_hyper = {
                        'epochs': [150],
                        'lr': np.arange(0.001, 0.3, 0.01).tolist(),     # Make sure to convert it to a list (for saving after)
                    }

    # Parameters affecting the models
    params_model = {
                        'model': [BiLSTM]
                        'hidden_dim': [100, 150, 200, 250],      # Model attribute
                        'n_layers': [1, 2, 3],                   # Model attribute
                        'bidirectional': [False, True],          # Model attribute
                        'LSTM.dropout': [0.2, 0.3, 0.4, 0.6],    # Modify all LSTM dropout
                        # ...
                    }

    params_loss = {
                        'criterion': [CrossEntropyLoss]
                    }

    params_optim = {
                        'criterion': [Adam]
                    }

    tuner = Tuner(params_hyper, params_loss=params_loss, params_optim=params_optim)

    # Grid Search
    tuner.fit(train_iterator, eval_iterator, verbose=True)

"""

import copy
import json
import os
from pathlib import Path
import torch

from sentarget.nn.models import BiLSTM
from .functional import tune, tune_optimizer, init_cls
from sentarget.utils import describe_dict, serialize_dict, permutation_dict


[docs]class Tuner: r""" The ``Tuner`` class is used for hyper parameters tuning. From a set of models and parameters to tune, this class will look at the best model's performance. .. note:: To facilitate the search and hyperameters tuning, it is recommended to use the ``sentarget.nn.models.Model`` abstract class as parent class for all of your models. * :attr:`hyper_params` (dict): dictionary of hyperparameters to tune. * :attr:`performance` (dict): dictionary of all models' performances. """ def __init__(self, params_hyper=None, params_model=None, params_loss=None, params_optim=None, options=None): # Hyper parameters with default values self.params_hyper = params_hyper if params_model is not None else {} self.params_model = params_model if params_model is not None else {} self.params_loss = params_loss if params_loss is not None else {} self.params_optim = params_optim if params_optim is not None else {} # General options self.options = {**self._init_options(), **options} if options is not None else self._init_options() # Keep track of all performances self.results = [] self._log = None self._log_conf = None self._log_perf = None self.best_model = None def _init_options(self): return { 'saves': True, 'dirsaves': '.saves', 'compare_on': 'accuracy', 'verbose': True, } def _init_hyper(self): return { 'batch_size': 64, 'epochs': 100 }
[docs] def reset(self): r"""Reset all parameters to their default values.""" self.results = [] self._log = None self._log_conf = None self._log_perf = None self.best_model = None
[docs] def fit(self, train_data, eval_data, verbose=True, saves=False, **kwargs): r"""Run the hyper parameters tuning. Args: train_data (iterator): training dataset. eval_data (iterator): dev dataset. verbose (bool): if ``True``, display a statistical log at each search. saves (bool): if ``True`` saves all trained models. dirsaves (string): path to the saving directory. Examples:: >>> from sentarget.metrics import Tuner >>> from sentarget.nn.models.lstm import BiLSTM >>> from sentarget.nn.models.gru import BiGRU >>> # Hyper parameters to tune >>> tuner = Tuner( ... params_hyper={ ... 'epochs': [2, 3], ... 'lr': [0.01], ... 'vectors': 'model.txt' ... } ... params_model={ ... 'model': [BiLSTM], ... } ... params_loss={ ... 'criterion': [torch.nn.CrossEntropyLoss], ... 'ignore_index': 0 ... } ... params_optim={ ... 'optimizer': [torch.optim.Adam] ... } ... ) >>> # train_iterator = torchtext data iterato >>> tuner.fit(train_iterator, valid_iterator) """ # Update the options dictionary self.options = {**self.options, **kwargs} dirsaves = self.options['dirsaves'] saves = self.options['saves'] compare_on = self.options['compare_on'] verbose = self.options['verbose'] configs_hyper = permutation_dict(self.params_hyper) configs_model = permutation_dict(self.params_model) configs_loss = permutation_dict(self.params_loss) configs_optim = permutation_dict(self.params_optim) # Number or permutations H = len(configs_hyper) M = len(configs_model) L = len(configs_loss) O = len(configs_optim) self._log = self.log_init(H, M, L, O) if verbose: print(self._log) num_search = 0 for config_hyper in configs_hyper: for config_model in configs_model: for config_loss in configs_loss: for config_optim in configs_optim: # General params num_search += 1 # Set a batch size to the data train_data.batch_size = config_hyper['batch_size'] eval_data.batch_size = config_hyper['batch_size'] # Initialize the model from arguments that are in config_model model = init_cls(config_model['model'], config_model) # Change modules values that were not saved as attributes tune(model, config_model) # Load the criterion and optimizer, with their parameters criterion = init_cls(config_loss['criterion'], config_loss) optimizer = init_cls(config_optim['optimizer'], {'params': model.parameters(), **config_optim}) # If the learning rate etc. were provided as hyper parameters tune_optimizer(optimizer, config_hyper) # Update the configuration log self._log_conf = f"Search n°{num_search}: {model.__class__.__name__}\n" self._log_conf += self.log_conf(config_hyper=config_hyper, config_model=config_model, config_loss=config_loss, config_optim=config_optim) self._log_conf += f"\n{model.__repr__()}" self._log += f"\n\n{self._log_conf}" if verbose: print(f"\n{self._log_conf}") # Train the model best_model = model.fit(train_data, eval_data, criterion=criterion, optimizer=optimizer, epochs=config_hyper['epochs'], verbose=False, compare_on=compare_on, **kwargs) results = {'performance': model.performance, 'hyper': config_hyper, 'model': config_model, 'optimizer': self.params_optim, 'criterion': self.params_loss} self.results.append(serialize_dict(results)) # Update the current best model if (self.best_model is None or best_model.performance['eval'][compare_on] > self.best_model.performance['eval'][compare_on]): self.best_model = copy.deepcopy(best_model) # Update the current performance log self._log_perf = model.log_perf() self._log += "\n" + self._log_perf if verbose: print(self._log_perf) # Save the current checkpoint if saves: dirpath = os.path.join(dirsaves, 'gridsearch', f"search_{num_search}") name = f"{model.__class__.__name__}{num_search}_best" # Save the best model and its checkpoint best_model.save(name=name, dirpath=dirpath) # Save the last model's state name = f"{model.__class__.__name__}{num_search}" model.save(name=name, dirpath=dirpath, checkpoint=False) # Save the model's results and configs self._save_current_results(dirpath=dirpath) self._save_current_log(dirpath=dirpath)
[docs] def log_init(self, hyper, model, loss, optim): """Generate a general configuration log. Args: hyper (int): number of hyper parameters permutations. model (int): number of model parameters permutations. loss (int): number of loss parameters permutations. optim (int): number of optimizer parameters permutations. Returns: string: general log. """ log = "GridSearch(\n" log += f" (options): Parameters(" \ f"{describe_dict(self.options, )})\n" log += f" (session): Permutations(hyper={hyper}, model={model}, " \ f"loss={loss}, optim={optim}, total={hyper * model * loss * optim})\n" log += ")" return log
[docs] def log_conf(self, config_hyper={}, config_model={}, config_loss={}, config_optim={}, *args, **kwargs): """Generate a configuration log from the generated set of configurations files. Args: config_hyper (dict): hyper parameters configuration file. config_model (dict): model parameters configuration file. config_loss (dict): loss parameters configuration file. config_optim (dict): optimizer parameters configuration file. Returns: string: configuration file representation. """ log = f"Configuration(\n" log += f" (hyper): Variables({describe_dict(config_hyper, *args, **kwargs)})\n" log += f" (model): Parameters({describe_dict(config_model, *args, **kwargs)})\n" log += f" (criterion): {config_loss['criterion'].__name__}({describe_dict(config_loss, *args, **kwargs)})\n" log += f" (optimizer): {config_optim['optimizer'].__name__}({describe_dict(config_optim, *args, **kwargs)})\n" log += ')' return log
def _save_current_results(self, filename='results.json', dirpath='.saves'): Path(dirpath).mkdir(parents=True, exist_ok=True) with open(os.path.join(dirpath, filename), 'w') as outfile: json.dump(serialize_dict(self.results[-1]), outfile) def _save_current_log(self, filename='log.txt', dirpath='.saves'): Path(dirpath).mkdir(parents=True, exist_ok=True) with open(os.path.join(dirpath, filename), 'w') as outfile: outfile.write(self._log_conf + "\n" + self._log_perf)
[docs] def save(self, filename='gridsearch.json', dirpath='.saves'): r"""Save the performances as a json file, by default. Args: filename (string): name of the file to save. dirpath (string): path to the saving directory. """ Path(dirpath).mkdir(parents=True, exist_ok=True) data = {'results': self.results} path = os.path.join(dirpath, filename) with open(path, 'w') as outfile: json.dump(data, outfile) with open(os.path.join(dirpath, f"log.txt"), 'w') as outfile: outfile.write(self._log) # Saving the best model filename = f"best_{self.best_model.__class__.__name__}.pt" self.best_model.save(filename=filename, dirpath=dirpath, checkpoint=True) # And its log / performances self._save_current_results(filename='best_results.json', dirpath=dirpath) self._save_current_log(filename='best_log.txt', dirpath=dirpath)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Check the GitHub page and contribute to the project

View GitHub