Transfer Learning

Transfer learning refers to using pre-trained deep learning models in similar datasets to which it was trained on. Pre-trained deep learning models have weights (or coefficients or parameters) determined beforehand by training and optimizing them on extremely huge datasets. These models have some of the highest test accuracy% for the specific dataset types they have been trained on and have been the de-facto standard, state-of-the-art (SOTA) models when they were introduced first.

These models are built using a variation of either Convolutional neural nets (CNNs) or Recurrent neural nets (RNNs). CNNs and RNNs work extremely well, thus, mostly used on datasets that have correlated patterns amongst nearby input values e.g. images, languages, music, videos, time-series and text data.

Few of the examples are - VGG16, ResNet50, InceptionV3 for images; BERT, GPT-2, XLNet, ELMO for NLP.

Typically these CNNs and RNNs consist of numerous training layers, with each layer consisting of multiple nodes (or channels), and each node a combination of numerous weights (or parameters). For instance, the VGG16 model has 16 layers, the first layer has 64 channels, and each of those channels has 28 parameters. In total, the model has ~134M parameters to be trained and calculated. More complex models can have over 10B parameters to be trained.

How to use transfer learning? The answer is based on two factors - 1. Dataset size and 2. Dataset similarity to that of the pre-trained model.

  • Small, similar dataset - no training of parameters required
  • Small, dissimilar dataset - train top layer parameters (will vary depending on data & model features)
  • Large, similar dataset - train top layer parameters (will vary depending on data & model features)
  • Large, dissimilar dataset - train top as well as inner layer parameters

Further understanding requires going through code implementations.

Python Implementation: 

  • tensorflow.keras.applications -> ResNet50, InceptionV3, VGG16
  • torchvision.models -> resnet50, vgg16, inception_v3
  • transformers.AutoModel -> from_pretrained('bert-base-uncased')
  • transformers.XLNetTokenizer -> from_pretrained('xlnet-base-cased'), from_pretrained('gpt2')
  • sentence_transformers -> SentenceTransformer('stsb-bert-base')


Comments

Popular posts from this blog

Precision Recall Curve

Principal Component Analysis (PCA)