Shortcuts

Source code for sentarget.nn.solver

r"""
A ``Solver`` is an object used for training, evaluating and testing a model.
The performance is stored in a dictionary, both for training and testing.
In addition, the best model occurred during training is stored,
as well as it's checkpoint to re-load a model at a specific epoch.

Example:

.. code-block:: python

    import torch.nn as nn
    import torch.optim as optim

    model = nn.Sequential(nn.Linear(10, 100), nn.Sigmoid(), nn.Linear(100, 5), nn.ReLU())
    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss(ignore_index = LABEL_PAD_IDX)

    solver = BiLSTMSolver(model, optimizer=optimizer, criterion=criterion)

    # epochs = number of training loops
    # train_iterator = Iterator, DataLoader... Training data
    # eval_iterator = Iterator, DataLoader... Eval data
    solver.fit(train_iterator, eval_iterator, epochs=epochs)

"""

from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import torch.optim as optim

# Data science
import os
from pathlib import Path
import time
import copy
from sentarget.utils import append2dict, describe_dict


[docs]class Solver(ABC): r"""Train and evaluate model. * :attr:`model` (Module): model to optimize or test. * :attr:`checkpoint` (dict): checkpoint of the best model tested. * :attr:`criterion` (Loss): loss function. * :attr:`optimizer` (Optimizer): optimizer for weights and biases. * :attr:`performance` (dict): dictionary where performances are stored. * ``'train'`` (dict): training dictionary. * ``'eval'`` (dict): testing dictionary. Args: model (Module): model to optimize or test. criterion (Loss): loss function. optimizer (Optimizer): optimizer for weights and biases. """ def __init__(self, model, criterion=None, optimizer=None): # Defaults attributes self.model = model self.criterion = nn.CrossEntropyLoss() if criterion is None else criterion self.optimizer = optim.Adam(model.parameters()) if optimizer is None else optimizer # Performances self.best_model = None self.performance = None self.checkpoint = None self.reset()
[docs] def reset(self): """Reset the performance dictionary.""" self.best_model = None self.checkpoint = {'epoch': None, 'model_name': None, 'model_state_dict': None, 'optimizer_name': None, 'criterion_name': None, 'optimizer_state_dict': None, 'train': None, 'eval': None } self.performance = { "train": {}, "eval": {} }
[docs] @abstractmethod def train(self, iterator, *args, **kwargs): r"""Train one time the model on iterator data. Args: iterator (Iterator): iterator containing batch samples of data. Returns: dict: the performance and metrics of the training session. """ raise NotImplementedError
[docs] @abstractmethod def evaluate(self, iterator, *args, **kwargs): r"""Evaluate one time the model on iterator data. Args: iterator (Iterator): iterator containing batch samples of data. Returns: dict: the performance and metrics of the training session. """ raise NotImplementedError
def _update_checkpoint(self, epoch, results_train=None, results_eval=None): r"""Update the model's checkpoint. Keep track of its epoch, state, optimizer, and performances. In addition, it saves the current model in `best_model`. Args: epoch (int): epoch at the current training state. results_train (dict, optional): metrics for the training session at epoch. The default is None. results_eval (dict, optional): metrics for the evaluation session at epoch. The default is None. """ self.best_model = copy.deepcopy(self.model) self.checkpoint = {'epoch': epoch, 'model_name': self.best_model.__class__.__name__, 'model_state_dict': self.best_model.state_dict(), 'optimizer_name': self.optimizer.__class__.__name__, 'criterion_name': self.criterion.__class__.__name__, 'train': results_train, 'eval': results_eval }
[docs] def save(self, filename=None, dirpath=".", checkpoint=True): r"""Save the best torch model. Args: filename (str, optional): name of the model. The default is "model.pt". dirpath (str, optional): path to the desired foldre location. The default is ".". checkpoint (bool, optional): ``True`` to save the model at the best checkpoint during training. """ if checkpoint: # Get the name and other relevant information model_name = self.checkpoint['model_name'] epoch = self.checkpoint['epoch'] filename = f"model_{model_name}_epoch{epoch}.pt" if filename is None else filename # Save in the appropriate directory, and create it if it doesn't exists Path(dirpath).mkdir(parents=True, exist_ok=True) # Save the best model path = os.path.join(dirpath, filename) torch.save(self.best_model, path) # Save its checkpoint checkname = f"checkpoint_{filename.split('.')[-2].split('_')[1]}_epoch{epoch}.pt" checkpath = os.path.join(dirpath, checkname) torch.save(self.checkpoint, checkpath) else: model_name = self.checkpoint['model_name'] filename = f"model_{model_name}.pt" if filename is None else filename torch.save(self.model, filename)
[docs] def get_accuracy(self, y_tilde, y): r"""Compute accuracy from predicted classes and gold labels. Args: y_tilde (Tensor): 1D tensor containing the predicted classes for each predictions in the batch. This tensor should be computed through `get_predicted_classes(y_hat)` method. y (Tensor): gold labels. Note that y_tilde an y must have the same shape. Returns: float: the mean of correct answers. Examples:: >>> y = torch.tensor([0, 1, 4, 2, 1, 3, 2, 1, 1, 3]) >>> y_tilde = torch.tensor([0, 1, 2, 2, 1, 3, 2, 4, 4, 3]) >>> solver.get_accuracy(y_tilde, y) 0.7 """ assert y_tilde.shape == y.shape, "predicted classes and gold labels should have the same shape" correct = (y_tilde == y).astype(float) # convert into float for division acc = correct.sum() / len(correct) return acc
[docs] def fit(self, train_iterator, eval_iterator, *args, epochs=10, **kwargs): r"""Train and evaluate a model X times. During the training, both training and evaluation results are saved under the `performance` attribute. Args: train_iterator (Iterator): iterator containing batch samples of data. eval_iterator (Iterator): iterator containing batch samples of data. epochs (int): number of times the model will be trained. verbose (bool, optional): if ``True`` display a progress bar and metrics at each epoch. The default is ``True``. Examples:: >>> solver = MySolver(model, criterion=criterion, optimizer=optimizer) >>> # Train & eval EPOCHS times >>> EPOCHS = 10 >>> solver.fit(train_iterator, eval_iterator, epochs=EPOCHS, verbose=True) Epoch: 1/10 Training: 100% | [==================================================] Evaluation: 100% | [==================================================] Stats Training: | Loss: 0.349 | Acc: 84.33% | Prec.: 84.26% Stats Evaluation: | Loss: 0.627 | Acc: 72.04% | Prec.: 72.22% >>> # ... """ # By default, print a log each epoch verbose = True if 'verbose' not in {*kwargs} else kwargs['verbose'] # Keep track of the best model best_eval_accuracy = 0 start_time = time.time() # Train and evaluate the model epochs times for epoch in range(epochs): if verbose: print("Epoch:\t{0:3d}/{1}".format(epoch + 1, epochs)) # Train and evaluate the model results_train = self.train(train_iterator, *args, **kwargs) results_eval = self.evaluate(eval_iterator, *args, **kwargs) # Update the eval dictionary by adding the results at the # current epoch append2dict(self.performance["train"], results_train) append2dict(self.performance["eval"], results_eval) if verbose: print("\t Stats Train: | " + describe_dict(results_train)) print("\t Stats Eval: | " + describe_dict(results_eval)) print() # We copy in memory the best model if best_eval_accuracy < self.performance["eval"]["accuracy"][-1]: best_eval_accuracy = self.performance["eval"]["accuracy"][-1] self._update_checkpoint(epoch + 1, results_train, results_eval) self.performance['time'] = time.time() - start_time

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Check the GitHub page and contribute to the project

View GitHub