Meta-Learning: Learning to Learn Quickly and Efficiently

A cornerstone of effective machine learning models is often a vast dataset for training. However, human learning contrasts sharply with this, showcasing remarkable speed and efficiency. Children can discern cats from birds after just a few encounters. Someone who knows how to ride a bicycle can quickly grasp the mechanics of a motorcycle, often with minimal instruction. This raises a pivotal question: Can we engineer machine learning models that mirror this human-like ability to rapidly acquire new concepts and skills from limited examples? This is the central challenge that Meta-learning, also known as “learning to learn,” endeavors to address.

The aspiration for meta-learning models is to enable them to adapt and generalize effectively to novel tasks and environments, ones they haven’t encountered during their training phase. This adaptation, a condensed learning process, occurs during testing but with only brief exposure to the new task’s specific parameters. Ultimately, the adapted model should be capable of executing these new tasks proficiently. This capability is why meta-learning is frequently referred to as “learning to learn.”

These “tasks” can encompass any well-defined family of machine learning problems, from supervised learning to reinforcement learning and beyond. Consider these illustrative examples of meta-learning tasks:

  • A classifier, initially trained on images excluding cats, can accurately identify cats in images after being shown only a few examples of cat pictures.
  • A game-playing AI can rapidly become proficient at a new game, even if it has never encountered it before.
  • A miniature robot, trained in a flat environment, can successfully perform its designated task on an uphill surface during testing.

Defining the Meta-Learning Problem

In this discussion, we will primarily focus on meta-learning in the context of supervised learning tasks, particularly image classification. While meta-learning with reinforcement learning problems, often termed “Meta Reinforcement Learning,” is a rich and fascinating area, it will not be covered in this article.

A Simplified Perspective

A robust meta-learning model is trained across a diverse array of learning tasks. Its optimization is geared towards achieving peak performance across a distribution of tasks, which may include tasks unseen during training. Each task is associated with a dataset $mathcal{D}$, which contains both feature vectors and their corresponding true labels. The ideal model parameters are defined as:

$$ theta^* = argmintheta mathbb{E}{mathcal{D}sim p(mathcal{D})} [mathcal{L}_theta(mathcal{D})] $$

This formulation bears a striking resemblance to standard learning tasks. However, the key distinction is that here, a dataset itself is treated as a single data sample.

Few-shot classification serves as a prime example of meta-learning within supervised learning. In this context, the dataset $mathcal{D}$ is typically divided into two subsets: a support set $S$ for the learning phase and a prediction set $B$ for training or evaluation, represented as $mathcal{D}=langle S, Brangle$. A common scenario is K-shot N-class classification, where the support set includes K labeled examples for each of the N classes.

Fig. 1. Example of 4-shot 2-class image classification. (Image thumbnails from Pinterest)

Training Mimicking Testing Conditions

A dataset $mathcal{D}$ is composed of pairs of feature vectors and labels, $mathcal{D} = {(mathbf{x}_i, y_i)}$, with each label belonging to a predefined label set $mathcal{L}^text{label}$. Let’s assume our classifier $f_theta$, parameterized by $theta$, outputs the probability $P_theta(yvertmathbf{x})$ of a data point with feature vector $mathbf{x}$ belonging to class $y$.

The optimal parameters should maximize the probability of true labels across multiple training batches $B subset mathcal{D}$:

$$ begin{aligned} theta^ &= {argmax}{theta} mathbb{E}{(mathbf{x}, y)in mathcal{D}}[P_theta(y vert mathbf{x})] & theta^ &= {argmax}{theta} mathbb{E}{Bsubset mathcal{D}}[sum{(mathbf{x}, y)in B}Ptheta(y vert mathbf{x})] & scriptstyle{text{; trained with mini-batches.}} end{aligned} $$

In few-shot classification, the primary objective is to minimize prediction errors on data samples with unknown labels, given a small support set for “fast learning.” To align the training process with the inference phase, we simulate datasets with a subset of labels to prevent the model from being exposed to all labels at once. We also adjust the optimization procedure to foster rapid learning:

  1. Label Subset Sampling: Randomly select a subset of labels, $Lsubsetmathcal{L}^text{label}$.
  2. Dataset Sampling: Sample a support set $S^L subset mathcal{D}$ and a training batch $B^L subset mathcal{D}$. Both sets exclusively contain data points with labels from the sampled label set $L$, i.e., $y in L, forall (x, y) in S^L, B^L$.
  3. Support Set as Input: The support set $S^L$ is incorporated as part of the model input.
  4. Optimization with Mini-Batch: Use the mini-batch $B^L$ to calculate the loss and update the model parameters via backpropagation, mirroring the standard supervised learning approach.

Each sampled pair of datasets $(S^L, B^L)$ can be viewed as a single data point. The model is trained to generalize effectively to different datasets. The red symbols in the formula below highlight the additions for meta-learning compared to standard supervised learning:

$$ theta = argmaxtheta color{red}{E{Lsubsetmathcal{L}}[} E{color{red}{S^L subsetmathcal{D}, }B^L subsetmathcal{D}} [sum{(x, y)in B^L} P_theta(x, ycolor{red}{, S^L})] color{red}{]} $$

This concept shares similarities with using pre-trained models in image classification (like ImageNet) or language modeling (using large text corpora) when only a limited number of task-specific data samples are available. Meta-learning takes this idea a step further: instead of fine-tuning for a single downstream task, it optimizes the model to excel at many, if not all, tasks within a distribution.

Learner and Meta-Learner

Another prevalent perspective on meta-learning involves breaking down the model update process into two distinct stages:

  • Learner: A classifier $f_theta$, parameterized by $theta$, is trained to perform a specific task.
  • Meta-Learner: An optimizer $g_phi$, parameterized by $phi$, learns how to update the learner model’s parameters using the support set $S$, resulting in updated parameters $theta’ = g_phi(theta, S)$.

In the final optimization step, both $theta$ and $phi$ are updated to maximize the objective:

$$ mathbb{E}{Lsubsetmathcal{L}}[ mathbb{E}{S^L subsetmathcal{D}, B^L subsetmathcal{D}} [sum{(mathbf{x}, y)in B^L} P{g_phi(theta, S^L)}(y vert mathbf{x})]] $$

Common Meta-Learning Approaches

There are three primary categories of meta-learning approaches: metric-based, model-based, and optimization-based. Oriol Vinyals provides an excellent summary in his presentation at the Meta-Learning Symposium @ NIPS 2018:

Approach Key Idea Model of $P_theta(y vert mathbf{x})$
Model-based RNN; Memory $f_theta(mathbf{x}, S)$
Metric-based Metric Learning $sum_{(mathbf{x}_i, y_i) in S} k_theta(mathbf{x}, mathbf{x}_i)y_i$ (*)
Optimization-based Gradient Descent $P_{g_phi(theta, S^L)}(y vert mathbf{x})$

(*) Here, $k_theta$ represents a kernel function that quantifies the similarity between $mathbf{x}_i$ and $mathbf{x}$.

In the following sections, we will delve into classic models within each of these approaches.

Metric-Based Meta-Learning

Metric-based meta-learning hinges on the principle of similarity, akin to nearest neighbors algorithms like k-NN classification, k-means clustering, and kernel density estimation. The predicted probability for a label $y$ is computed as a weighted sum of labels from support set samples, where the weights are determined by a kernel function $k_theta$. This kernel function measures the similarity between data samples.

$$ Ptheta(y vert mathbf{x}, S) = sum{(mathbf{x}_i, yi) in S} ktheta(mathbf{x}, mathbf{x}_i)y_i $$

The effectiveness of metric-based meta-learning largely depends on learning an appropriate kernel. Metric learning aligns perfectly with this goal, focusing on learning a metric or distance function between objects. A “good” metric is task-dependent and should effectively represent the relationships within the task space, thereby facilitating problem-solving.

The models discussed below explicitly learn embedding vectors for input data and utilize these embeddings to design suitable kernel functions.

Convolutional Siamese Neural Network

The Siamese Neural Network architecture comprises two identical twin networks that are jointly trained to learn the relationship between pairs of input data samples. These twin networks share weights and parameters, effectively using the same embedding network to learn efficient embeddings that reveal relationships between data point pairs.

Koch, Zemel & Salakhutdinov (2015) introduced a Siamese neural network approach for one-shot image classification. This network is initially trained on a verification task: determining whether two input images belong to the same class. It outputs the probability of two images being from the same class. During testing, the Siamese network evaluates all pairs formed by a test image and each image in the support set. The final prediction is the class of the support image that yields the highest probability.

Fig. 2. Architecture of a convolutional Siamese neural network for few-shot image classification.

  1. Convolutional Siamese networks encode images into feature vectors using an embedding function $f_theta$ with convolutional layers.
  2. The L1-distance between embeddings is calculated as $vert f_theta(mathbf{x}_i) – f_theta(mathbf{x}_j) vert$.
  3. This distance is transformed into a probability $p$ via a linear feedforward layer and a sigmoid function, representing the probability of images belonging to the same class.
  4. The loss function is typically cross-entropy due to the binary nature of the label.

$$ begin{aligned} p(mathbf{x}_i, mathbf{x}j) &= sigma(mathbf{W}vert ftheta(mathbf{x}i) – ftheta(mathbf{x}j) vert) mathcal{L}(B) &= sum{(mathbf{x}_i, mathbf{x}_j, y_i, yj)in B} mathbf{1}{y_i=y_j}log p(mathbf{x}_i, mathbf{x}j) + (1-mathbf{1}{y_i=y_j})log (1-p(mathbf{x}_i, mathbf{x}_j)) end{aligned} $$

Training batches $B$ can be augmented with image distortions to improve robustness. While L1 distance is used here, other distance metrics like L2 or cosine distance can also be employed, provided they are differentiable to maintain the functionality of backpropagation.

Given a support set $S$ and a test image $mathbf{x}$, the predicted class is determined by:

$$ hat{c}S(mathbf{x}) = c(argmax{mathbf{x}_i in S} P(mathbf{x}, mathbf{x}_i)) $$

where $c(mathbf{x})$ denotes the class label of image $mathbf{x}$, and $hat{c}(.)$ is the predicted label.

This approach assumes that the learned embedding generalizes effectively to measure distances between images of unseen categories. This is a similar assumption underlying transfer learning, where pre-trained models, like those trained on ImageNet, are expected to be beneficial for other image-related tasks. However, the effectiveness of a pre-trained model diminishes as the new task diverges from the task on which the model was originally trained.

Matching Networks

Matching Networks (Vinyals et al., 2016) are designed to learn a classifier $c_S$ for any given small support set $S={x_i, y_i}_{i=1}^k$ (for k-shot classification). This classifier defines a probability distribution over output labels $y$ for a test example $mathbf{x}$. Like other metric-based models, the classifier’s output is a sum of support sample labels, weighted by an attention kernel $a(mathbf{x}, mathbf{x}_i)$. This kernel should be proportional to the similarity between $mathbf{x}$ and $mathbf{x}_i$.

Fig. 3. Matching Networks architecture. (Image source: original paper)

$$ cS(mathbf{x}) = P(y vert mathbf{x}, S) = sum{i=1}^k a(mathbf{x}, mathbf{x}_i) y_i text{, where }S={(mathbf{x}_i, yi)}{i=1}^k $$

The attention kernel relies on two embedding functions, $f$ and $g$, to encode the test sample and support set samples, respectively. The attention weight between two data points is calculated using the cosine similarity, $text{cosine}(.)$, between their embedding vectors, normalized by softmax:

$$ a(mathbf{x}, mathbf{x}_i) = frac{exp(text{cosine}(f(mathbf{x}), g(mathbf{x}i))}{sum{j=1}^kexp(text{cosine}(f(mathbf{x}), g(mathbf{x}_j))} $$

Simple Embedding

In its basic form, the embedding function is a neural network that takes a single data sample as input. It’s possible to use the same function for both $f$ and $g$ ($f=g$).

Full Context Embeddings

Embedding vectors are crucial for constructing effective classifiers. Using only a single data point as input may not be sufficient to effectively capture the entire feature space. To address this, Matching Networks propose enhancing embedding functions by incorporating the entire support set $S$ as additional input. This allows the learned embedding to adapt based on relationships within the support set.

  • $g_theta(mathbf{x}_i, S)$ uses a bidirectional LSTM to encode $mathbf{x}_i$ within the context of the complete support set $S$.

  • $f_theta(mathbf{x}, S)$ encodes the test sample $mathbf{x}$ using an LSTM with a read attention mechanism over the support set $S$.

    1. The test sample is initially processed through a simple neural network, such as a CNN, to extract fundamental features, $f’(mathbf{x})$.
    2. An LSTM is then trained with a read attention vector over the support set, integrated into the hidden state:

    $$ begin{aligned} hat{mathbf{h}}_t, mathbf{c}t &= text{LSTM}(f'(mathbf{x}), [mathbf{h}{t-1}, mathbf{r}{t-1}], mathbf{c}{t-1}) mathbf{h}_t &= hat{mathbf{h}}t + f'(mathbf{x}) mathbf{r}{t-1} &= sum{i=1}^k a(mathbf{h}{t-1}, g(mathbf{x}_i)) g(mathbf{x}i) a(mathbf{h}{t-1}, g(mathbf{x}i)) &= text{softmax}(mathbf{h}{t-1}^top g(mathbf{x}i)) = frac{exp(mathbf{h}{t-1}^top g(mathbf{x}i))}{sum{j=1}^k exp(mathbf{h}_{t-1}^top g(mathbf{x}_j))} end{aligned} $$

    1. After $K$ “read” steps, $f(mathbf{x}, S)=mathbf{h}_K$ is obtained.

This embedding technique is known as “Full Contextual Embeddings (FCE).” Interestingly, while FCE improves performance on challenging tasks like few-shot classification on mini ImageNet, it shows little difference on simpler tasks like Omniglot.

The training process in Matching Networks is specifically designed to mirror the inference process during testing, aligning training and testing conditions. This refinement of matching training and testing conditions is a significant contribution of the Matching Networks paper.

$$ theta^* = argmaxtheta mathbb{E}{Lsubsetmathcal{L}}[ mathbb{E}{S^L subsetmathcal{D}, B^L subsetmathcal{D}} [sum{(mathbf{x}, y)in B^L} P_theta(yvertmathbf{x}, S^L)]] $$

Relation Network

The Relation Network (RN) (Sung et al., 2018) shares similarities with the Siamese network but incorporates key distinctions:

  1. Relationship Prediction via CNN: Instead of a simple L1 distance, the relationship between inputs is predicted by a CNN classifier $g_phi$. The relation score between $mathbf{x}_i$ and $mathbf{x}_j$ is given by $r_{ij} = g_phi([mathbf{x}_i, mathbf{x}_j])$, where $[.,.]$ denotes concatenation.
  2. MSE Loss Function: RN uses Mean Squared Error (MSE) loss instead of cross-entropy. This reflects RN’s focus on predicting relation scores, which is more akin to regression than binary classification. The loss function is $mathcal{L}(B) = sum_{(mathbf{x}_i, mathbf{x}_j, y_i, y_j)in B} (r_{ij} – mathbf{1}_{y_i=y_j})^2$.

Fig. 4. Relation Network architecture for a 5-way 1-shot problem with one query example. (Image source: original paper)

(Note: Be aware of another “Relation Network” for relational reasoning by DeepMind to avoid confusion.)

Prototypical Networks

Prototypical Networks (Snell, Swersky & Zemel, 2017) utilize an embedding function $f_theta$ to map each input to an $M$-dimensional feature vector. For each class $c in mathcal{C}$, a prototype feature vector $mathbf{v}_c$ is defined as the mean vector of the embedded support data samples belonging to that class.

$$ mathbf{v}_c = frac{1}{|Sc|} sum{(mathbf{x}_i, y_i) in Sc} ftheta(mathbf{x}_i) $$

Fig. 5. Prototypical networks in few-shot and zero-shot scenarios. (Image source: original paper)

For a given test input $mathbf{x}$, the probability distribution over classes is a softmax function applied to the negative distances between the test data embedding and the prototype vectors.

$$ P(y=cvertmathbf{x})=text{softmax}(-dvarphi(ftheta(mathbf{x}), mathbf{v}c)) = frac{exp(-dvarphi(f_theta(mathbf{x}), mathbf{v}c))}{sum{c’ in mathcal{C}}exp(-dvarphi(ftheta(mathbf{x}), mathbf{v}_{c’}))} $$

Here, $d_varphi$ represents a differentiable distance function. The paper uses squared Euclidean distance.

The loss function is the negative log-likelihood: $mathcal{L}(theta) = -log P_theta(y=cvertmathbf{x})$.

Model-Based Meta-Learning

Model-based meta-learning approaches depart from making assumptions about the form of $P_theta(yvertmathbf{x})$. Instead, they rely on models specifically designed for rapid learning – models that can quickly update their parameters with limited training steps. This rapid adaptation can be achieved through their intrinsic architecture or controlled by a separate meta-learner model.

Memory-Augmented Neural Networks

A class of model architectures, including Neural Turing Machines and Memory Networks, leverages external memory storage to enhance neural network learning. These are known as Memory-Augmented Neural Networks (MANN). Unlike recurrent neural networks that rely solely on internal memory (like vanilla RNNs or LSTMs), MANNs have an explicit storage buffer. This external memory facilitates rapid incorporation of new information and prevents catastrophic forgetting.

MANNs are well-suited for meta-learning as they are designed to quickly encode new information and adapt to new tasks with just a few examples. Taking the Neural Turing Machine (NTM) as a foundation, Santoro et al. (2016) proposed modifications to the training setup and memory retrieval mechanisms (or “addressing mechanisms”) to optimize for meta-learning. For a deeper understanding of NTMs, refer to the NTM section in my other post.

NTMs combine a controller neural network with external memory. The controller learns to read from and write to memory locations using soft attention, while the memory acts as a knowledge repository. Attention weights are determined by an addressing mechanism combining content-based and location-based addressing.

Fig. 6. Neural Turing Machine (NTM) architecture. The memory $mathbf{M}_t$ at time $t$ is an $N times M$ matrix, with $N$ vector rows, each of $M$ dimensions.

MANN for Meta-Learning

To effectively apply MANNs to meta-learning, training must encourage the memory to quickly capture and encode new task information while ensuring that stored representations are easily and stably accessible.

In the training approach described by Santoro et al. (2016), the memory is compelled to retain information for longer durations until the corresponding labels are presented. This is achieved by presenting the true label $y_t$ with a one-step offset, i.e., $(mathbf{x}_{t+1}, y_t)$. The true label for the input at time $t$ is presented as part of the input at time $t+1$.

Fig. 7. Task setup in MANN for meta-learning (Image source: original paper).

This setup motivates the MANN to memorize new dataset information, as it needs to retain the current input until its label is presented later. Then, it must retrieve the stored information to make accurate predictions.

Next, we examine how the memory update mechanism is designed for efficient information retrieval and storage in MANNs for meta-learning.

Addressing Mechanism for Meta-Learning

Beyond the training process, MANNs for meta-learning utilize a specialized, purely content-based addressing mechanism.

» Memory Read Operation: Read attention is solely based on content similarity.

First, a key feature vector $mathbf{k}_t$ is generated by the controller at time $t$ as a function of the input $mathbf{x}$. Similar to NTMs, a read weighting vector $mathbf{w}_t^r$ of $N$ elements is computed using the cosine similarity between the key vector and each memory vector row, normalized by softmax. The read vector $mathbf{r}_t$ is a weighted sum of memory records:

$$ mathbf{r}i = sum{i=1}^N w_t^r(i)mathbf{M}_t(i) text{, where } w_t^r(i) = text{softmax}(frac{mathbf{k}_t cdot mathbf{M}_t(i)}{|mathbf{k}_t| cdot |mathbf{M}_t(i)|}) $$

where $mathbf{M}_t$ is the memory matrix at time $t$, and $mathbf{M}_t(i)$ is its $i$-th row.

» Memory Write Operation: The write addressing mechanism, responsible for storing new information into memory, operates similarly to cache replacement policies. The Least Recently Used Access (LRUA) writer is designed for MANNs to optimize performance in meta-learning scenarios. LRUA prefers writing new content to either the least used or the most recently used memory location.

While LRUA is used in MANNs, other cache replacement algorithms could potentially be more effective depending on the specific use case. Furthermore, learning memory usage patterns and addressing strategies dynamically might be more advantageous than predefining them.

LRUA’s preference is implemented in a differentiable manner:

  1. Usage Weight Update: The usage weight $mathbf{w}^u_t$ at time $t$ is calculated as the sum of the current read and write vectors, plus a decayed version of the previous usage weight, $gamma mathbf{w}^u_{t-1}$, where $gamma$ is a decay factor.
  2. Write Weight Calculation: The write vector is an interpolation between the previous read weight (favoring “last used location”) and the previous least-used weight (favoring “rarely used location”). The interpolation is controlled by the sigmoid of a hyperparameter $alpha$.
  3. Least-Used Weight Determination: The least-used weight $mathbf{w}^{lu}$ is scaled based on usage weights $mathbf{w}_t^u$. Any dimension in $mathbf{w}^{lu}$ remains at 1 if its corresponding usage weight is less than or equal to the $n$-th smallest element in $mathbf{w}_t^u$, and 0 otherwise.

$$ begin{aligned} mathbf{w}t^u &= gamma mathbf{w}{t-1}^u + mathbf{w}_t^r + mathbf{w}_t^w mathbf{w}_t^r &= text{softmax}(text{cosine}(mathbf{k}_t, mathbf{M}_t(i))) mathbf{w}t^w &= sigma(alpha)mathbf{w}{t-1}^r + (1-sigma(alpha))mathbf{w}^{lu}_{t-1} mathbf{w}t^{lu} &= mathbf{1}{w_t^u(i) leq m(mathbf{w}_t^u, n)} text{, where }m(mathbf{w}_t^u, n)text{ is the }ntext{-th smallest element in vector }mathbf{w}_t^utext{.} end{aligned} $$

Finally, after setting the least used memory location (indicated by $mathbf{w}_t^{lu}$) to zero, each memory row is updated:

$$ mathbf{M}t(i) = mathbf{M}{t-1}(i) + w_t^w(i)mathbf{k}_t, forall i $$

Meta Networks

Meta Networks (MetaNet) (Munkhdalai & Yu, 2017) are meta-learning models designed with architecture and training processes optimized for rapid generalization across tasks.

Fast Weights

MetaNet’s rapid generalization capability is attributed to the concept of “fast weights.” While there’s extensive literature on fast weights, a universally accepted definition remains somewhat elusive. Generally, fast weights are parameters in neural networks that are learned through a faster mechanism than traditional gradient descent. Typically, neural network weights are updated using stochastic gradient descent (SGD), a relatively slow process. A faster approach involves using one neural network to predict the parameters of another. These predicted parameters are termed “fast weights,” while the standard SGD-based weights are referred to as “slow weights.”

In MetaNet, loss gradients serve as meta-information to guide models in learning fast weights. Slow and fast weights are then combined to generate predictions within the neural network.

Fig. 8. Combining slow and fast weights in a Multi-Layer Perceptron (MLP). $bigoplus$ denotes element-wise addition. (Image source: original paper).

Model Components

Disclaimer: My annotations below may differ slightly from those in the original paper. The paper can be challenging to interpret, but the underlying concept is compelling. I am presenting the idea in my own interpretation.

The core components of MetaNet include:

  • Embedding Function $f_theta$: Parameterized by $theta$, this function encodes raw inputs into feature vectors. Similar to Siamese Neural Networks, these embeddings are trained to be effective for determining if two inputs belong to the same class (verification task).
  • Base Learner Model $g_phi$: Parameterized by weights $phi$, this model performs the actual learning task.

These components resemble the Relation Network architecture. However, MetaNet goes further by explicitly modeling fast weights for both functions and integrating them back into the model (as shown in Fig. 8).

To accomplish this, MetaNet includes two additional functions to generate fast weights:

  • Fast Weight Network $F_w$: An LSTM parameterized by $w$ learns fast weights $theta^+$ for the embedding function $f$. It takes gradients of $f$’s embedding loss from the verification task as input.
  • Fast Weight Network $G_v$: A neural network parameterized by $v$ learns fast weights $phi^+$ for the base learner $g$ from its loss gradients. In MetaNet, the learner’s loss gradients are considered the task’s meta-information.

Now, let’s examine the training process of Meta Networks. The training data consists of pairs of datasets: a support set $S={mathbf{x}’_i, y’_i}_{i=1}^K$ and a test set $U={mathbf{x}_i, y_i}_{i=1}^L$. We have four networks and four sets of parameters to train: $(theta, phi, w, v)$.

Fig. 9. The MetaNet architecture.

Training Process

  1. Representation Learning Phase: For each timestep $t = 1, dots, K$, sample a random pair of inputs $(mathbf{x}’_i, y’_i)$ and $(mathbf{x}’_j, y_j)$ from the support set $S$. Let $mathbf{x}_{(t,1)}=mathbf{x}’_i$ and $mathbf{x}_{(t,2)}=mathbf{x}’_j$.

    • a. Compute the embedding loss, e.g., cross-entropy for the verification task:
      $mathcal{L}^text{emb}_t = mathbf{1}_{y’_i=y’_j} log P_t + (1 – mathbf{1}_{y’_i=y’_j})log(1 – P_t)text{, where }P_t = sigma(mathbf{W}vert f_theta(mathbf{x}_{(t,1)}) – f_theta(mathbf{x}_{(t,2)})vert)$
  2. Task-Level Fast Weight Generation: Compute task-level fast weights for the embedding function: $theta^+ = F_w(nabla_theta mathcal{L}^text{emb}_1, dots, mathcal{L}^text{emb}_T)$

  3. Example-Level Fast Weights and Memory Update: Iterate through examples in the support set $S$. For each example $i=1, dots, K$:

    • a. Base Learner Prediction: The base learner $g_phi$ outputs a probability distribution: $P(hat{y}_i vert mathbf{x}_i) = g_phi(mathbf{x}_i)$. Compute the task loss, e.g., cross-entropy or MSE: $mathcal{L}^text{task}_i = y’_i log g_phi(mathbf{x}’_i) + (1- y’_i) log (1 – g_phi(mathbf{x}’_i))$
    • b. Example-Level Fast Weight Computation: Extract meta-information (loss gradients) and compute example-level fast weights for the base learner: $phi_i^+ = G_v(nabla_phimathcal{L}^text{task}_i)$. Store $phi^+_i$ in the $i$-th location of the “value” memory $mathbf{M}$.
    • c. Task-Specific Input Representation: Encode the support sample using both slow and fast weights: $r’_i = f_{theta, theta^+}(mathbf{x}’_i)$. Store $r’_i$ in the $i$-th location of the “key” memory $mathbf{R}$.
  4. Training Loss Construction with Test Set: Use the test set $U={mathbf{x}_i, y_i}_{i=1}^L$ to construct the training loss. Initialize $mathcal{L}_text{train}=0$. For each example $j=1, dots, L$:

    • a. Task-Specific Test Sample Representation: Encode the test sample: $r_j = f_{theta, theta^+}(mathbf{x}_j)$
    • b. Fast Weight Computation via Attention: Compute fast weights by attending to representations of support set samples in memory $mathbf{R}$. Use cosine similarity as the attention function:
      $$ begin{aligned} a_j &= text{cosine}(mathbf{R}, r_j) = [frac{r’_1cdot r_j}{|r’_1|cdot|r_j|}, dots, frac{r’_Ncdot r_j}{|r’_N|cdot|r_j|}] phi^+_j &= text{softmax}(a_j)^top mathbf{M} end{aligned} $$
    • c. Update Training Loss: $mathcal{L}_text{train} leftarrow mathcal{L}_text{train} + mathcal{L}^text{task}(g_{phi, phi^+}(mathbf{x}_i), y_i) $
  5. Parameter Update: Update all parameters $(theta, phi, w, v)$ using $mathcal{L}_text{train}$ via backpropagation.

Optimization-Based Meta-Learning

Deep learning models are typically trained using backpropagation of gradients. However, gradient-based optimization is not inherently designed for scenarios with limited training data or for rapid convergence within a few optimization steps. Optimization-based meta-learning algorithms aim to address this by adjusting the optimization process itself to enable effective learning from few examples.

LSTM Meta-Learner

The optimization algorithm itself can be explicitly modeled. Ravi & Larochelle (2017) proposed this approach, terming the model that learns the optimization algorithm as the “meta-learner” and the task-performing model as the “learner.” The meta-learner’s objective is to efficiently update the learner’s parameters using a small support set, enabling the learner to quickly adapt to new tasks.

Let $M_theta$ represent the learner model with parameters $theta$, $R_Theta$ be the meta-learner with parameters $Theta$, and $mathcal{L}$ be the loss function.

Why LSTM?

LSTM networks are chosen as the meta-learner due to:

  1. Gradient-LSTM Similarity: The gradient-based parameter update in backpropagation shares a structural similarity with the cell state update mechanism in LSTMs.
  2. Gradient History Benefit: Utilizing the history of gradients can enhance the gradient update process, analogous to how momentum-based optimization improves convergence.

The standard gradient descent update for learner parameters at timestep $t$, with learning rate $alpha_t$, is:

$$ thetat = theta{t-1} – alphat nabla{theta_{t-1}}mathcal{L}_t $$

This update rule mirrors the cell state update in an LSTM if we set the forget gate $f_t=1$, input gate $i_t = alpha_t$, cell state $c_t = theta_t$, and new cell state $tilde{c}_t = -nabla_{theta_{t-1}}mathcal{L}_t$:

$$ begin{aligned} c_t &= ft odot c{t-1} + i_t odot tilde{c}t &= theta{t-1} – alphatnabla{theta_{t-1}}mathcal{L}_t end{aligned} $$

While fixing $f_t=1$ and $i_t=alpha_t$ might not be optimal, both can be made learnable and adaptable to different datasets.

$$ begin{aligned} f_t &= sigma(mathbf{W}f cdot [nabla{theta_{t-1}}mathcal{L}_t, mathcal{L}t, theta{t-1}, f_{t-1}] + mathbf{b}_f) & scriptstyle{text{; Forget gate: control parameter value retention.}} i_t &= sigma(mathbf{W}i cdot [nabla{theta_{t-1}}mathcal{L}_t, mathcal{L}t, theta{t-1}, i_{t-1}] + mathbf{b}_i) & scriptstyle{text{; Input gate: learnable learning rate at time t.}} tilde{theta}t &= -nabla{theta_{t-1}}mathcal{L}_t & theta_t &= ft odot theta{t-1} + i_t odot tilde{theta}_t & end{aligned} $$

Model Setup

Fig. 10. Training process of the learner $M_theta$ and meta-learner $R_Theta$. (Image source: original paper with added annotations)

The training process is designed to mimic the testing phase, mirroring the beneficial approach seen in Matching Networks. In each training epoch, a dataset $mathcal{D} = (mathcal{D}_text{train}, mathcal{D}_text{test}) in hat{mathcal{D}}_text{meta-train}$ is sampled. Mini-batches from $mathcal{D}_text{train}$ are then used to update $theta$ for $T$ rounds. The final learner parameter state $theta_T$ is then used to train the meta-learner on the test data $mathcal{D}_text{test}$.

Two implementation details are crucial:

  1. Parameter Space Compression: To manage the potentially vast parameter space of the LSTM meta-learner (as it models parameters of another neural network), parameter sharing across coordinates is employed, inspired by prior work.
  2. Loss and Gradient Independence Assumption: To simplify training, the meta-learner assumes independence between the loss $mathcal{L}_t$ and the gradient $nabla_{theta_{t-1}} mathcal{L}_t$.

MAML

Model-Agnostic Meta-Learning (MAML) (Finn et al., 2017) is a versatile optimization algorithm applicable to any model trained via gradient descent.

Consider a model $f_theta$ with parameters $theta$. For a given task $tau_i$ and its associated dataset $(mathcal{D}^{(i)}_text{train}, mathcal{D}^{(i)}_text{test})$, model parameters can be updated through one or more gradient descent steps (one step shown below):

$$ theta’i = theta – alpha nablathetamathcal{L}^{(0)}_{taui}(ftheta) $$

where $mathcal{L}^{(0)}$ is the loss calculated using the mini-batch with id (0).

Fig. 11. Diagram of MAML. (Image source: original paper)

To achieve effective generalization across tasks, MAML aims to find an optimal $theta^*$ that facilitates efficient task-specific fine-tuning. A new data batch with id (1) is sampled to update the meta-objective. The loss, $mathcal{L}^{(1)}$, depends on mini-batch (1). Superscripts in $mathcal{L}^{(0)}$ and $mathcal{L}^{(1)}$ denote different data batches but represent the same loss objective for the same task.

$$ begin{aligned} theta^* &= argmintheta sum{taui sim p(tau)} mathcal{L}{taui}^{(1)} (f{theta’i}) = argmintheta sum_{taui sim p(tau)} mathcal{L}{taui}^{(1)} (f{theta – alphanablatheta mathcal{L}{taui}^{(0)}(ftheta)}) & theta &leftarrow theta – beta nabla{theta} sum{taui sim p(tau)} mathcal{L}{taui}^{(1)} (f{theta – alphanablatheta mathcal{L}{taui}^{(0)}(ftheta)}) & scriptstyle{text{; updating rule}} end{aligned} $$

Fig. 12. General MAML algorithm. (Image source: original paper)

First-Order MAML

The meta-optimization step in MAML involves second derivatives. To reduce computational cost, a simplified version, First-Order MAML (FOMAML), omits second derivatives, leading to a computationally cheaper implementation.

For $k$ inner gradient steps ($kgeq1$), starting with initial meta-parameter $theta_text{meta}$:

$$ begin{aligned} theta0 &= thetatext{meta} theta_1 &= theta0 – alphanablathetamathcal{L}^{(0)}(theta_0) theta_2 &= theta1 – alphanablathetamathcal{L}^{(0)}(theta_1) &dots thetak &= theta{k-1} – alphanablathetamathcal{L}^{(0)}(theta{k-1}) end{aligned} $$

In the outer loop, a new data batch is sampled to update the meta-objective.

$$ begin{aligned} thetatext{meta} &leftarrow thetatext{meta} – beta gtext{MAML} & scriptstyle{text{; meta-objective update}} [2mm] text{where } gtext{MAML} &= nabla_{theta} mathcal{L}^{(1)}(thetak) &[2mm] &= nabla{theta_k} mathcal{L}^{(1)}(thetak) cdot (nabla{theta_{k-1}} thetak) dots (nabla{theta_0} theta1) cdot (nabla{theta} theta0) & scriptstyle{text{; chain rule}} &= nabla{theta_k} mathcal{L}^{(1)}(thetak) cdot Big( prod{i=1}^k nabla{theta{i-1}} thetai Big) cdot I & &= nabla{theta_k} mathcal{L}^{(1)}(thetak) cdot prod{i=1}^k nabla{theta{i-1}} (theta{i-1} – alphanablathetamathcal{L}^{(0)}(theta{i-1})) & &= nabla{theta_k} mathcal{L}^{(1)}(thetak) cdot prod{i=1}^k (I – alphanabla{theta{i-1}}(nablathetamathcal{L}^{(0)}(theta{i-1}))) & end{aligned} $$

The MAML gradient is:

$$ gtext{MAML} = nabla{theta_k} mathcal{L}^{(1)}(thetak) cdot prod{i=1}^k (I – alpha color{red}{nabla{theta{i-1}}(nablathetamathcal{L}^{(0)}(theta{i-1}))}) $$

FOMAML simplifies this by ignoring the second derivative term (red), approximating the gradient as the derivative of the last inner gradient update result:

$$ gtext{FOMAML} = nabla{theta_k} mathcal{L}^{(1)}(theta_k) $$

Reptile

Reptile (Nichol, Achiam & Schulman, 2018) is a remarkably straightforward meta-learning optimization algorithm. It shares fundamental similarities with MAML, both relying on meta-optimization via gradient descent and being model-agnostic.

Reptile iteratively performs:

  1. Task Sampling: Randomly selects a task.
  2. Task-Specific Training: Trains the model on the sampled task using multiple gradient descent steps.
  3. Parameter Update: Moves the model weights towards the newly trained parameters.

The algorithm is outlined below. $text{SGD}(mathcal{L}_{tau_i}, theta, k)$ performs stochastic gradient update for $k$ steps on the loss $mathcal{L}_{tau_i}$, starting with parameter $theta$, and returns the final parameter vector. The batched version processes multiple tasks per iteration. The Reptile gradient is $(theta – W)/alpha$, where $alpha$ is the step size used in the SGD operation.

Fig. 13. Batched Reptile algorithm. (Image source: original paper)

At first glance, Reptile resembles standard SGD. However, the key difference is that task-specific optimization involves multiple steps. This distinction leads to $text{SGD}(mathbb{E} _tau[mathcal{L}_{tau}], theta, k)$ diverging from $mathbb{E}_tau [text{SGD}(mathcal{L}_{tau}, theta, k)]$ when $k > 1$.

The Optimization Assumption

Assume each task $tau sim p(tau)$ has a manifold of optimal network configurations, $mathcal{W}_{tau}^*$. Model $f_theta$ achieves optimal performance for task $tau$ when $theta$ lies on the surface of $mathcal{W}_{tau}^*$. To find a solution that generalizes well across tasks, we aim for a parameter $theta$ close to the optimal manifolds of all tasks:

$$ theta^ = argmintheta mathbb{E}{tau sim p(tau)} [frac{1}{2} text{dist}(theta, mathcal{W}_tau^)^2] $$

Fig. 14. Reptile updates parameters to be closer to the optimal manifolds of different tasks. (Image source: original paper)

Using L2 distance, and defining the distance between a point $theta$ and a set $mathcal{W}_tau^*$ as the distance between $theta$ and the closest point $W_{tau}^*(theta)$ on the manifold:

$$ text{dist}(theta, mathcal{W}{tau}^*) = text{dist}(theta, W{tau}^(theta)) text{, where }W_{tau}^(theta) = argmin{Winmathcal{W}{tau}^*} text{dist}(theta, W) $$

The gradient of the squared Euclidean distance is:

$$ begin{aligned} nablatheta[frac{1}{2}text{dist}(theta, mathcal{W}{taui}^*)^2] &= nablatheta[frac{1}{2}text{dist}(theta, W_{taui}^*(theta))^2] & &= nablatheta[frac{1}{2}(theta – W_{taui}^*(theta))^2] & &= theta – W{tau_i}^*(theta) & scriptstyle{text{; See notes.}} end{aligned} $$

Notes: According to the Reptile paper, “the gradient of the squared Euclidean distance between a point $Theta$ and a set $S$ is the vector $2(Theta − p)$, where p is the closest point in $S$ to $Theta$”. While technically the closest point in $S$ is a function of $Theta$, the gradient calculation appears to disregard the derivative of $p$. (Feedback on this point is welcome.)

The update rule for one stochastic gradient step becomes:

$$ theta = theta – alpha nablatheta[frac{1}{2} text{dist}(theta, mathcal{W}{taui}^*)^2] = theta – alpha(theta – W{taui}^*(theta)) = (1-alpha)theta + alpha W{tau_i}^*(theta) $$

Since the closest point $W_{tau_i}^*(theta)$ on the optimal task manifold is not directly computable, Reptile approximates it using $text{SGD}(mathcal{L}_tau, theta, k)$.

Reptile vs FOMAML

To illustrate the connection between Reptile and MAML, consider the update formula with two gradient steps ($k=2$ in $text{SGD}(.)$). Using simplified notations $g^{(i)}_j = nabla_{theta} mathcal{L}^{(i)}(theta_j)$ and $H^{(i)}_j = nabla^2_{theta} mathcal{L}^{(i)}(theta_j)$, and losses $mathcal{L}^{(0)}$ and $mathcal{L}^{(1)}$ from different mini-batches:

$$ begin{aligned} theta0 &= thetatext{meta} theta_1 &= theta0 – alphanablathetamathcal{L}^{(0)}(theta_0)= theta_0 – alpha g^{(0)}_0 theta_2 &= theta1 – alphanablathetamathcal{L}^{(1)}(theta_1) = theta_0 – alpha g^{(0)}_0 – alpha g^{(1)}_1 end{aligned} $$

From earlier discussions, the FOMAML gradient is the last inner gradient update result. For $k=1$:

$$ begin{aligned} gtext{FOMAML} &= nabla{theta_1} mathcal{L}^{(1)}(theta_1) = g^{(1)}1 gtext{MAML} &= nabla_{theta_1} mathcal{L}^{(1)}(theta1) cdot (I – alphanabla^2{theta} mathcal{L}^{(0)}(theta_0)) = g^{(1)}_1 – alpha H^{(0)}_0 g^{(1)}_1 end{aligned} $$

The Reptile gradient is:

$$ g_text{Reptile} = (theta_0 – theta_2) / alpha = g^{(0)}_0 + g^{(1)}_1 $$

Summarizing the gradients:

Fig. 15. Reptile versus FOMAML in one meta-optimization loop. (Image source: slides on Reptile by Yoonho Lee.)

$$ begin{aligned} g_text{FOMAML} &= g^{(1)}1 gtext{MAML} &= g^{(1)}_1 – alpha H^{(0)}_0 g^{(1)}1 gtext{Reptile} &= g^{(0)}_0 + g^{(1)}_1 end{aligned} $$

Expanding $g^{(1)}_1$ using Taylor expansion around $theta_0$:

$$ begin{aligned} g1^{(1)} &= nabla{theta}mathcal{L}^{(1)}(theta1) &= nabla{theta}mathcal{L}^{(1)}(theta0) + nabla^2thetamathcal{L}^{(1)}(theta_0)(theta_1 – theta0) + frac{1}{2}nabla^3thetamathcal{L}^{(1)}(theta_0)(theta_1 – theta_0)^2 + dots & &= g_0^{(1)} – alpha H^{(1)}_0 g0^{(0)} + frac{alpha^2}{2}nabla^3thetamathcal{L}^{(1)}(theta_0) (g_0^{(0)})^2 + dots & scriptstyle{text{; because }theta_1-theta_0=-alpha g_0^{(0)}} &= g_0^{(1)} – alpha H^{(1)}_0 g_0^{(0)} + O(alpha^2) end{aligned} $$

Substituting the expanded $g_1^{(1)}$ into the MAML gradients:

$$ begin{aligned} g_text{FOMAML} &= g^{(1)}_1 = g_0^{(1)} – alpha H^{(1)}_0 g0^{(0)} + O(alpha^2) gtext{MAML} &= g^{(1)}_1 – alpha H^{(0)}_0 g^{(1)}_1 &= g_0^{(1)} – alpha H^{(1)}_0 g_0^{(0)} + O(alpha^2) – alpha H^{(0)}_0 (g_0^{(1)} – alpha H^{(1)}_0 g_0^{(0)} + O(alpha^2)) &= g_0^{(1)} – alpha H^{(1)}_0 g_0^{(0)} – alpha H^{(0)}_0 g_0^{(1)} + alpha^2 alpha H^{(0)}_0 H^{(1)}_0 g_0^{(0)} + O(alpha^2) &= g_0^{(1)} – alpha H^{(1)}_0 g_0^{(0)} – alpha H^{(0)}_0 g_0^{(1)} + O(alpha^2) end{aligned} $$

The Reptile gradient becomes:

$$ begin{aligned} g_text{Reptile} &= g^{(0)}_0 + g^{(1)}_1 &= g^{(0)}_0 + g_0^{(1)} – alpha H^{(1)}_0 g_0^{(0)} + O(alpha^2) end{aligned} $$

Thus, we have the gradients of FOMAML, MAML, and Reptile:

$$ begin{aligned} g_text{FOMAML} &= g_0^{(1)} – alpha H^{(1)}_0 g0^{(0)} + O(alpha^2) gtext{MAML} &= g_0^{(1)} – alpha H^{(1)}_0 g_0^{(0)} – alpha H^{(0)}_0 g0^{(1)} + O(alpha^2) gtext{Reptile} &= g^{(0)}_0 + g_0^{(1)} – alpha H^{(1)}_0 g_0^{(0)} + O(alpha^2) end{aligned} $$

During training, averaging over multiple data batches is common. In this example, mini-batches (0) and (1) are interchangeable. The expectation $mathbb{E}_{tau,0,1}$ is averaged over two data batches for task $tau$.

Let:

  • $A = mathbb{E}_{tau,0,1} [g_0^{(0)}] = mathbb{E}_{tau,0,1} [g_0^{(1)}]$: Average gradient of task loss, guiding model parameter improvement for better task performance.
  • $B = mathbb{E}_{tau,0,1} [H^{(1)}_0 g_0^{(0)}] = frac{1}{2}mathbb{E}_{tau,0,1} [H^{(1)}_0 g_0^{(0)} + H^{(0)}_0 g_0^{(1)}] = frac{1}{2}mathbb{E}_{tau,0,1} [nabla_theta(g^{(0)}_0 g_0^{(1)})]$: Direction (gradient) that increases inner product of gradients from different mini-batches for the same task, enhancing model generalization across data.

Both MAML and Reptile aim to optimize for improved task performance (guided by $A$) and better generalization (guided by $B$) when the gradient update is approximated by the first three leading terms.

$$ begin{aligned} mathbb{E}{tau,1,2}[gtext{FOMAML}] &= A – alpha B + O(alpha^2) mathbb{E}{tau,1,2}[gtext{MAML}] &= A – 2alpha B + O(alpha^2) mathbb{E}{tau,1,2}[gtext{Reptile}] &= 2A – alpha B + O(alpha^2) end{aligned} $$

The impact of the ignored term $O(alpha^2)$ on parameter learning is unclear. However, the comparable performance of FOMAML to full MAML suggests that higher-level derivatives may not be critical for gradient descent updates.

Cited as:

@article{weng2018metalearning, title = "Meta-Learning: Learning to Learn Fast", author = "Weng, Lilian", journal = "lilianweng.github.io", year = "2018", url = "https://lilianweng.github.io/posts/2018-11-30-meta-learning/" }

References

[1] Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. “Human-level concept learning through probabilistic program induction.” Science 350.6266 (2015): 1332-1338.

[2] Oriol Vinyals’ talk on “Model vs Optimization Meta Learning”

[3] Gregory Koch, Richard Zemel, and Ruslan Salakhutdinov. “Siamese neural networks for one-shot image recognition.” ICML Deep Learning Workshop. 2015.

[4] Oriol Vinyals, et al. “Matching networks for one shot learning.” NIPS. 2016.

[5] Flood Sung, et al. “Learning to compare: Relation network for few-shot learning.” CVPR. 2018.

[6] Jake Snell, Kevin Swersky, and Richard Zemel. “Prototypical Networks for Few-shot Learning.” CVPR. 2018.

[7] Adam Santoro, et al. “Meta-learning with memory-augmented neural networks.” ICML. 2016.

[8] Alex Graves, Greg Wayne, and Ivo Danihelka. “Neural turing machines.” arXiv preprint arXiv:1410.5401 (2014).

[9] Tsendsuren Munkhdalai and Hong Yu. “Meta Networks.” ICML. 2017.

[10] Sachin Ravi and Hugo Larochelle. “Optimization as a Model for Few-Shot Learning.” ICLR. 2017.

[11] Chelsea Finn’s BAIR blog on “Learning to Learn”.

[12] Chelsea Finn, Pieter Abbeel, and Sergey Levine. “Model-agnostic meta-learning for fast adaptation of deep networks.” ICML 2017.

[13] Alex Nichol, Joshua Achiam, John Schulman. “On First-Order Meta-Learning Algorithms.” arXiv preprint arXiv:1803.02999 (2018).

[14] Slides on Reptile by Yoonho Lee.

Comments

No comments yet. Why don’t you start the discussion?

Leave a Reply

Your email address will not be published. Required fields are marked *