Continual Learning with Elastic Weight Consolidation in TensorFlow 2

Based on Overcoming catastrophic forgetting in neural networks.

You can view the accompanying Jupyter Notebook here.

The mammalian brain allows for the learning of tasks in a sequential order. That is, we are capable of learning new tasks without forgetting how to perform old ones. Research suggests that retention of task-specific skills relies primarily on task-specific synaptic consolidation which protects previously acquired knowledge by strengthening a proportion of synapses relevant to a specific task. In effect, these synapses remain resistant to change, even when learning new tasks.

In the context neural networks, the task of continual learning presents a unique set of challenges. Consider a simple classification problem. Given a dataset \mathcal{D} generated from the true data-generating distribution P_{data} which consists of a number of examples \{x_1, x_2,...x_N\} \in \mathbf{R}^n with corresponding discrete label \{L_1, L_2,..L_N\} \in \{1, 2,...K\} , the goal is to learn a function f: \mathbf{R}^n \rightarrow \{1, 2,...K\} which accurately assigns a label L given an example x. Typically, neural networks parameterize f with \theta and optimize a loss function using gradient descent. The final parameterization \Theta^* to the classification problem lies on a manifold in the parameter space that yields good performance on the classification task.

Now, considering two independent classification tasks A and B with corresponding datasets, \mathcal{D}_A and \mathcal{D}_B. The goal is to first learn a parameterization \Theta_A^* that yields acceptable performance on task A followed by learning a parameterization \Theta_{AB}^* that yields acceptable performance on task B without losing a significant amount of performance on task A. Immediately following training on task A, the parameterization \Theta_A^* lies on a manifold that yields good performance on task A. With a gradient-based learning method, the parameterization \Theta will quickly move from the manifold yielding good performance on task A towards the manifold yielding good performance on task B. By the end of training, the parameterization \Theta lies exclusively on the manifold yielding good performance on task B. This means the neural network has lost all of it’s previous knowledge of task A – a phenomenon known as catastrophic forgetting.

Given the two tasks A and B, there exists two manifolds \mathcal{M}_A and \mathcal{M}_B in the parameter space which yield good performance on task A and task B respectively. Depending on the problem, there is also likely a certain amount of overlap between the manifolds such that there is a third manifold \mathcal{M}_{AB} which yields good performance on both task A and task B at the same time. In a multi-task learning context, you learn task A and task B simultaneously such that you are attempting to find a parameterization \Theta^* that lies on the manifold \mathcal{M}_{AB}. In a continuous learning problem, your goal is to first find a parameterization \Theta_A^* which lies on manifold \mathcal{M}_A and then navigate within \mathcal{M}_A towards the manifold \mathcal{M}_B such that you sequentially learn \mathcal{M}_{AB}.

In a more practical sense, consider a basic neural network that learns to classify digits in MNIST; however, rather than learning to classify the digits [0 - 9] simultaneously, you train the classifier to map the first 5 odd numbers to a label [0 - 4] and the first 5 even numbers to a label [0 - 4]. Mathematically, each digit x is assigned a label y such that y = floor(x / 2).

Given the significant (duh) overlap between these two tasks, solving this problem in the context of multi-task learning isn’t particularly difficult. To demonstrate this, create a new Jupyter Notebook and load up your prerequisites:

import tensorflow as tf
import tensorflow_datasets as tfds

# Load MNIST using TFDS
(mnist_train, mnist_test), mnist_info = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=True)

Next, you’ll need to do some simple preprocessing that normalizes the images and transforms the labels according to the new labeling scheme:

def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label

def transform_labels(image, label):
  return image, tf.math.floor(label / 2)

def prepare(ds, batch_size=128):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.map(transform_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.shuffle(ds_info.splits['train'].num_examples)
  ds = ds.cache()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  return ds

prepare accepts a TensorFlow Dataset object and maps normalize_img and transform_labels over every item in the dataset. The function then goes through a number of standard steps to prepare the dataset for training.

You’ll now want to split the MNIST dataset into 3: a multi-task training dataset which contains interleaved examples from both labeling tasks (both odd and even numbers), a dataset with only odd examples, and a dataset with only even examples.

def split_tasks(ds, predicate):
  return ds.filter(predicate), ds.filter(lambda img, label: not predicate(img, label))

multi_task_train, multi_task_test = prepare(mnist_train), prepare(mnist_test)
task_A_train, task_B_train = split_tasks(mnist_train, lambda img, label: label % 2 == 0)
task_A_train, task_B_train = prepare(task_A_train), prepare(task_B_train)
task_A_test, task_B_test = split_tasks(mnist_test, lambda img, label: label % 2 == 0)
task_A_test, task_B_test = prepare(task_A_test), prepare(task_B_test)

Now create a simple function to evaluate a model on a task:

def evaluate(model, test_set):
  acc = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
  for i, (imgs, labels) in enumerate(test_set):
    preds = model.predict_on_batch(imgs)
    acc.update_state(labels, preds)
  return acc.result().numpy()

You’ll train your first model on the multi-task dataset and test it independently on both task A and task B:

multi_task_model = tf.keras.Sequential([
   tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
   tf.keras.layers.Dense(128, activation='relu')
   tf.keras.layers.Dense(5)
])

multi_task_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossEntropy(from_logits=True), metrics='accuracy')

multi_task_model.fit(multi_task_train, epochs=6)

Notice how quickly the model converges on a successful parameterization. Evaluate your model on both task A and task B:

print("Task A accuracy after training on Multi-Task Problem: {}".format(evaluate(multi_task_model, task_A_test)))
print("Task B accuracy after training on Multi-Task Problem: {}".format(evaluate(multi_task_model, task_B_test)))

And you get the following output:

Task A accuracy after only training on Multi-Task Problem: 0.9723913669586182
Task B accuracy after only training on Multi-Task Problem: 0.97418212890625

The model performs well on both task A and task B after training on the multi-task dataset. Now, to demonstrate the difficulties of continual learning, create a new model and train it on task A:

basic_cl_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

basic_cl_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

basic_cl_model.fit(task_A_train, epochs=6)

Next, evaluate the model on task A:

print("Task A accuracy after training model on only Task A: {}".format(evaluate(basic_cl_model, task_A_test)))
Task A accuracy after training model on only Task A: 0.984977662563324

Now, train the same model on task B:

basic_cl_model.fit(task_B_train, epochs=6)

And evaluate it on both task A and task B:

print("Task B accuracy after training trained model on Task B: {}".format(evaluate(basic_cl_model, task_B_test)))
print("Task A accuracy after training trained model on Task B: {}".format(evaluate(basic_cl_model, task_A_test)))
Task B accuracy after training trained model on Task B: 0.9830508232116699
Task A accuracy after training trained model on Task B: 0.2663418650627136

Notice how the model easily solves task B; however, it loses nearly all of it’s knowledge about task A. The model still performs slightly better on task A than randomly guessing labels [0 - 4], but there’s certainly room for improvement. So, how can we learn task B without forgetting knowledge from task A?

Remember, all knowledge of task A is encoded in the parameterization \Theta_{A}^* which lies on the manifold M_A. One approach is to ensure the new parameterization \Theta_{AB} never drifts too far from the original parameterization \Theta_{A}^*. The L^2 norm of a vector measures the length of the vector in Euclidean space. It represents the distance of a vector from the origin. When used as a regularization term, the L^2 norm forces the parameters of a neural network to remain centered around the origin. Applied to continual learning, you can force \Theta_{AB} to remain close to \Theta_{A}^* by adding the L^2 norm of \Theta_{AB} centered around \Theta_{A}^* as a regularization term during training.

Let’s see how this works in practice. You’ll need to create a function that calculates the L^2 penalty of each parameter in the model as well as a custom training loop to apply this penalty at each update:

def l2_penalty(model, theta_A):
  penalty = 0
  for i, theta_i in enumerate(model.trainable_variables):
    _penalty = tf.norm(theta_i - theta_A[i])
    penalty += _penalty
  return 0.5*penalty

def train_with_l2(model, task_A_train, task_B_train, task_A_test, task_B_test, epochs=6):
  # First we're going to fit to task A and retain a copy of parameters trained on Task A
  model.fit(task_A_train, epochs=epochs)
  theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}

  print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))
  
  # Metrics for the custom training loop
  accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
  loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

  for epoch in range(epochs):
    accuracy.reset_states()
    loss.reset_states()
    for batch, (imgs, labels) in enumerate(task_B_train):
      with tf.GradientTape() as tape:
        preds = model(imgs)
        # Loss is crossentropy loss with regularization term for each parameter
        total_loss = model.loss(labels, preds) + l2_penalty(model, theta_A)
      grads = tape.gradient(total_loss, model.trainable_variables)
      model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
      
      accuracy.update_state(labels, preds)
      loss.update_state(labels, preds)
      print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
          epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
         )
    print("")
  
  print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
  print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))

And then run the training loop with a new model:

l2_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

l2_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

train_with_l2(l2_model, task_A_train, task_B_train, task_A_test, task_B_test)

You’ll notice the model plateaus very early when training on task B, finishing with the following results:

Task B accuracy after training trained model on Task B: 0.5977532267570496
Task A accuracy after training trained model on Task B: 0.9037758708000183

These results aren’t bad, but they show some of the limitations of using an L^2 penalty. While the L^2 penalty ensures the new model parameterization retains knowledge of task A, it’s too severe for successfully learning task B. The regularization term forces every parameter in the model to remain close to the parameterization \Theta_{A}^*, even those which are unimportant to task A. Rather than constrain every parameter to remain close to \Theta_{A}^*, we should focus only on those parameters which are important to task A. This process most closely relates to how the mammalian brain protects information.

Consider the following objective formulation of the continual learning problem on task A and task B:

\log  p(\Theta | \mathcal{D}) = \log p(\mathcal{D}_B | \Theta) + \log p(\Theta | \mathcal{D}_A) - log p(\mathcal{D}_B)

You can derive this formulation by starting with the idea that the learning problem given a dataset \mathcal{D} is to find the most likely parameterization \Theta given \mathcal{D} which is represented as p(\Theta | \mathcal{D}. It follows from Bayes’ Theorem that:

p(\Theta | \mathcal{D}) = \frac{p(\mathcal{D} | \Theta) \cdot p(\Theta)}{p(\mathcal{D})}

Taking the \log of both sides and rearranging using logarithmic identities yields:

\log p(\Theta | \mathcal{D}) = \log p(\mathcal{D} | \Theta) + \log p(\Theta) - \log p (\mathcal{D})

In the context of continual learning, the overall goal is still to learn p(\Theta | \mathcal{D}); however, this is done by assuming the dataset is split into independent parts \mathcal{D}_A and \mathcal{D}_B. The objective is first to learn p(\Theta | \mathcal{D}_A), then to learn p(\Theta | \mathcal{D} from the posterior p(\Theta | \mathcal{D}_A) trained on p(\Theta | \mathcal{D}_B). Algebraically this means your goal is:

p(\Theta | \mathcal{D}) = p(p(\Theta | \mathcal{D}_A) | \mathcal{D}_B)

then, from Bayes’ Theorem:

p(\Theta | \mathcal{D}) = \frac{p(\mathcal{D}_B | p(\Theta | \mathcal{D}_A)) \cdot p(\Theta | \mathcal{D}_A)}{p(\mathcal{D}_B}

Notice the term p(\mathcal{D}_B | p(\Theta | \mathcal{D}_A)). Assuming we start from a model trained on \mathcal{D}_A, then the parameterization \Theta estimates the posterior p(\Theta | \mathcal{D}_A), which means we can rewrite the formula as:

p(\Theta | \mathcal{D}) = \frac{p(\mathcal{D}_B | \Theta) \cdot p(\Theta | \mathcal{D}_A)}{p(\mathcal{D}_B)}

Taking the \log of both sides and rearranging yields:

\log p(\Theta | \mathcal{D}) = \log p(\mathcal{D}_B | \Theta) + \log p(\Theta | \mathcal{D}_A) - \log p(\mathcal{D}_B)

It follows from the problem formulation that all knowledge of task A is absorbed in the posterior p(\Theta | \mathcal{D}_A) which means p(\Theta | \mathcal{D}_A) contains information about which parameters are most important to task A. Given the true posterior, p(\Theta | \mathcal{D}_A), you could calculate the Fisher Information Matrix F_i of each parameter \Theta_i in \Theta with respect to \mathcal{D}_A which estimates the amount of information \Theta_i contains about the true posterior p(\Theta | \mathcal{D}_A). From F, we could determine the relative importance of each parameter and constrain the model to bound parameters relative to their importance to task A when training on task B. Unfortunately, recovering p(\Theta | \mathcal{D}_A) is intractable, so we have to approximate F.

F closely relates to the concept of the spring constant $k$ which is a measure of the stiffness of a given spring. Higher $k$ indicates a stiffer spring. The potential energy of a spring is represented as:

PE = \frac{1}{2}k(x - x_0)^2

Where x - x_0 measures the deviation of the spring from it’s initial position. The further the spring deviates from it’s initial position, the more potential energy is required. Additionally, the stiffer the spring, the more potential energy is required relative to the total displacement of the spring. If we consider the parameterization \Theta as a spring, we want to ensure that parameters \Theta_i are only able to deviate from \Theta_{A}^* relative to their importance on task A. Less important parameters have lower relative importance measured by F_i and are thus said to be more elastic. Using this formulation, we get the following regularization term:

\Omega(\Theta) = \sum\nolimits_i \frac{1}{2}F_i(\Theta_i - \Theta_{Ai}^*)^2

The regularization term is the elastic potential energy of each parameter in \Theta centered around \Theta_{A}^*. In practice, you want to add an additional term, \lambda, which represents the relative importance of the regularization term. Adding this term to the objective function yields:

\mathcal{L}(\Theta) = \mathcal{L}_B(\Theta) + \sum\nolimits_i \frac{\lambda}{2}F_i(\Theta_i - \Theta_{Ai}^*)^2

where \mathcal{L}_B(\Theta) is the total loss on task B. Using this objective during training is known as elastic weight consolidation or EWC.

F is approximated according to the following formula:

F = \frac{1}{N} \sum\limits_{i=1}^N \nabla_\Theta \log p(x_{A,i} | \Theta_{A}^*) \nabla_\Theta \log p(x_{A,i} | \Theta_{A}^*)^T

F is approximated as the mean of the gradients of the log-likelihood of N examples sampled from \mathcal{D}_A squared. So, how is this implemented in practice?

You’ll want to start with a function that approximates F given a model and \mathcal{D}_A. The function samples a number of batches from \mathcal{D}_A, and calculates the gradient of the log-likelihood with respect to the model’s parameters, and then calculates the mean of the gradients squared:

def compute_precision_matrices(model, task_set, num_batches=1, batch_size=32):
  task_set = task_set.repeat()
  precision_matrices = {n: tf.zeros_like(p.value()) for n, p in enumerate(model.trainable_variables)}

  for i, (imgs, labels) in enumerate(task_set.take(num_batches)):
    # We need gradients of model params
    with tf.GradientTape() as tape:
      # Get model predictions for each image
      preds = model(imgs)
      # Get the log likelihoods of the predictions
      ll = tf.nn.log_softmax(preds)
    # Attach gradients of ll to ll_grads
    ll_grads = tape.gradient(ll, model.trainable_variables)
    # Compute F_i as mean of gradients squared
    for i, g in enumerate(ll_grads):
      precision_matrices[i] += tf.math.reduce_mean(g ** 2, axis=0) / num_batches

  return precision_matrices

Next, you’ll need to compute the regularization term \Omega(\Theta):

def compute_elastic_penalty(F, theta, theta_A, alpha=25):
  penalty = 0
  for i, theta_i in enumerate(theta):
    _penalty = tf.math.reduce_sum(F[i] * (theta_i - theta_A[i]) ** 2)
    penalty += _penalty
  return 0.5*alpha*penalty

Implement a loss function which uses the regularization term:

def ewc_loss(labels, preds, model, F, theta_A):
  loss_b = model.loss(labels, preds)
  penalty = compute_elastic_penalty(F, model.trainable_variables, theta_A)
  return loss_b + penalty

And a custom training loop which ties everything together:

def train_with_ewc(model, task_A_set, task_B_set, task_A_test, task_B_test, epochs=3):
  # First we're going to fit to task A and retain a copy of parameters trained on Task A
  model.fit(task_A_set, epochs=epochs)
  theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}
  # We'll only compute Fisher once, you can do it whenever
  F = compute_precision_matrices(model, task_A_set, num_batches=1000)

  print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))

  # Now we set up the training loop for task B with EWC
  accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
  loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')

  for epoch in range(epochs*3):
    accuracy.reset_states()
    loss.reset_states()

    for batch, (imgs, labels) in enumerate(task_B_set):
      with tf.GradientTape() as tape:
        # Make the predictions
        preds = model(imgs)
        # Compute EWC loss
        total_loss = ewc_loss(labels, preds, model, F, theta_A)
      # Compute the gradients of model's trainable parameters wrt total loss
      grads = tape.gradient(total_loss, model.trainable_variables)
      # Update the model with gradients
      model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
      # Report updated loss and accuracy
      accuracy.update_state(labels, preds)
      loss.update_state(labels, preds)
      print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
          epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
         )
    print("")

  print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
  print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))

Finally, create a new model and train it using EWC:

ewc_model = tf.keras.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.5),
  tf.keras.layers.Dense(5)
])

ewc_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')

train_with_ewc(ewc_model, task_A_train, task_B_train, task_A_test, task_B_test)

With the following results after training:

Task B accuracy after training trained model on Task B: 0.7087110877037048
Task A accuracy after training trained model on Task B: 0.7677628993988037

The model performs acceptably on both task A and task B; however, it’s not all that impressive. I’ll expand on this below.

You should have noticed a few peculiarities about this model and the training process. First, I added dropout to the model as a form of regularization. I found during training that dropout was important when using EWC to ensure the model doesn’t significantly overfit to task A when training on task B. Without dropout, the model struggles to make any progress on task B after training on task A. Additionally, you’ll notice the model is trained 3 times longer on task B than on task A. I found progress on task B was much slower than on task A.

You may find that training with EWC is particularly sensitive to hyperparameters such as the importance of the elastic penalty. Most of my trials produced worse results using EWC than using an L^2 penalty. I’m assuming EWC is sensitive to the final parameterization of the model after training on task A. If the parameterization lies in a difficult region of the parameter space after training on task A, then the model will struggle to optimize on task B.

In theory, EWC is incredibly cool. In practice, it seems overly sensitive and sometimes difficult to train.

If you find any bugs in my implementation or problems with my explanation, feel free to comment and I will correct them!

6 thoughts on “Continual Learning with Elastic Weight Consolidation in TensorFlow 2

  1. Very nice work! How do you think the Fisher Information Matrix of the parameters should be updated in order to train a new task C? (Assuming that the training data of task A are no longer available)

    Like

    1. Good question! Given your parameters at this point are a composite of both tasks, I think it’d be difficult to retain information about task A if that training data was no longer available. You could use data from task B, but I believe you’d very quickly lose information about task A – however, given that the current parameterization is theoretically good for task A and B, it’s possible that only using task B is good enough because those parameters important to task B are also important for task A.

      Like

  2. Great work! This is the best article I have read on this subject – believe me, I have read a lot of them. I love how you explained the key concept, relating it to the potential energy of an extended spring. The implementation code was also straight to the point and easy to follow. I particularly found how you concluded the article intriguing – depending on the final parameterization region of the prior task, applying L2 penalty could lead to a better result than EWC.

    Like

    1. Thank you for your kind words! I’m glad you enjoyed it, I spent a lot of time attempting to get results that were worth showing and the trial and error lead me to some of those guesses. EWC itself though is definitely still very cool, but I am aware of some more recent SOTA work on continual learning that I have been meaning to write about for awhile now (specifically: https://arxiv.org/abs/2004.00070)

      Like

  3. Thanks for the swift response and sharing, I will read through the paper. I was wondering if we could connect. I’m currently experimenting with the idea of extending EWC or a similar technique to distributed optimization with heterogeneous data. It will be nice to share ideas. I look through LinkedIn to see if I could find you with no luck.

    Like

  4. Hey @seanmoriarity, I did a tf2 implementation of this technique. It worked quite impressively, I didn’t need to apply dropout when training task B. Little to no hyperparameter tunning was required – I used the Adam optimizer, set the learning rate at 0.001 and the penalty weight at 0.1. Task A and B were MNIST and its permuted version respectively. Here this the code and some results on Github: https://github.com/stijani/elastic-weight-consolidation-tf2

    Like

Leave a comment