Continual Learning with Elastic Weight Consolidation in TensorFlow 2
Published on
Based on Overcoming catastrophic forgetting in neural networks.
You can view the accompanying Jupyter Notebook here.
The mammalian brain allows for the learning of tasks in a sequential order. That is, we are capable of learning new tasks without forgetting how to perform old ones. Research suggests that retention of task-specific skills relies primarily on task-specific synaptic consolidation which protects previously acquired knowledge by strengthening a proportion of synapses relevant to a specific task. In effect, these synapses remain resistant to change, even when learning new tasks.
In the context neural networks, the task of continual learning presents a unique set of challenges. Consider a simple classification problem. Given a dataset $latex \mathcal{D}$ generated from the true data-generating distribution $latex P_{data}$ which consists of a number of examples $latex \{x_1, x_2,…x_N\} \in \mathbf{R}^n$ with corresponding discrete label $latex \{L_1, L_2,..L_N\} \in \{1, 2,…K\}$ , the goal is to learn a function $latex f: \mathbf{R}^n \rightarrow \{1, 2,…K\}$ which accurately assigns a label $latex L$ given an example $latex x$. Typically, neural networks parameterize $latex f$ with $latex \theta$ and optimize a loss function using gradient descent. The final parameterization $latex \Theta^*$ to the classification problem lies on a manifold in the parameter space that yields good performance on the classification task.
Now, considering two independent classification tasks A and B with corresponding datasets, $latex \mathcal{D}_A$ and $latex \mathcal{D}_B$. The goal is to first learn a parameterization $latex \Theta_A^*$ that yields acceptable performance on task A followed by learning a parameterization $latex \Theta_{AB}^*$ that yields acceptable performance on task B without losing a significant amount of performance on task A. Immediately following training on task A, the parameterization $latex \Theta_A^*$ lies on a manifold that yields good performance on task A. With a gradient-based learning method, the parameterization $latex \Theta$ will quickly move from the manifold yielding good performance on task A towards the manifold yielding good performance on task B. By the end of training, the parameterization $latex \Theta$ lies exclusively on the manifold yielding good performance on task B. This means the neural network has lost all of it’s previous knowledge of task A - a phenomenon known as catastrophic forgetting.
Given the two tasks A and B, there exists two manifolds $latex \mathcal{M}_A$ and $latex \mathcal{M}_B$ in the parameter space which yield good performance on task A and task B respectively. Depending on the problem, there is also likely a certain amount of overlap between the manifolds such that there is a third manifold $latex \mathcal{M}_{AB}$ which yields good performance on both task A and task B at the same time. In a multi-task learning context, you learn task A and task B simultaneously such that you are attempting to find a parameterization $latex \Theta^*$ that lies on the manifold $latex \mathcal{M}_{AB}$. In a continuous learning problem, your goal is to first find a parameterization $latex \Theta_A^*$ which lies on manifold $latex \mathcal{M}_A$ and then navigate within $latex \mathcal{M}_A$ towards the manifold $latex \mathcal{M}_B$ such that you sequentially learn $latex \mathcal{M}_{AB}$.
In a more practical sense, consider a basic neural network that learns to classify digits in MNIST; however, rather than learning to classify the digits $latex [0 - 9]$ simultaneously, you train the classifier to map the first 5 odd numbers to a label $latex [0 - 4]$ and the first 5 even numbers to a label $latex [0 - 4]$. Mathematically, each digit $latex x$ is assigned a label $latex y$ such that $latex y = floor(x / 2)$.
Given the significant (duh) overlap between these two tasks, solving this problem in the context of multi-task learning isn’t particularly difficult. To demonstrate this, create a new Jupyter Notebook and load up your prerequisites:
import tensorflow as tf
import tensorflow_datasets as tfds
# Load MNIST using TFDS
(mnist_train, mnist_test), mnist_info = tfds.load('mnist', split=['train', 'test'], as_supervised=True, with_info=True)
Next, you’ll need to do some simple preprocessing that normalizes the images and transforms the labels according to the new labeling scheme:
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def transform_labels(image, label):
return image, tf.math.floor(label / 2)
def prepare(ds, batch_size=128):
ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.map(transform_labels, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds = ds.shuffle(ds_info.splits['train'].num_examples)
ds = ds.cache()
ds = ds.batch(batch_size)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
return ds
prepare
accepts a TensorFlow Dataset object and maps normalize_img
and transform_labels
over every item in the dataset. The function then goes through a number of standard steps to prepare the dataset for training.
You’ll now want to split the MNIST dataset into 3: a multi-task training dataset which contains interleaved examples from both labeling tasks (both odd and even numbers), a dataset with only odd examples, and a dataset with only even examples.
def split_tasks(ds, predicate):
return ds.filter(predicate), ds.filter(lambda img, label: not predicate(img, label))
multi_task_train, multi_task_test = prepare(mnist_train), prepare(mnist_test)
task_A_train, task_B_train = split_tasks(mnist_train, lambda img, label: label % 2 == 0)
task_A_train, task_B_train = prepare(task_A_train), prepare(task_B_train)
task_A_test, task_B_test = split_tasks(mnist_test, lambda img, label: label % 2 == 0)
task_A_test, task_B_test = prepare(task_A_test), prepare(task_B_test)
Now create a simple function to evaluate a model on a task:
def evaluate(model, test_set):
acc = tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy')
for i, (imgs, labels) in enumerate(test_set):
preds = model.predict_on_batch(imgs)
acc.update_state(labels, preds)
return acc.result().numpy()
You’ll train your first model on the multi-task dataset and test it independently on both task A and task B:
multi_task_model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation='relu')
tf.keras.layers.Dense(5)
])
multi_task_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossEntropy(from_logits=True), metrics='accuracy')
multi_task_model.fit(multi_task_train, epochs=6)
Notice how quickly the model converges on a successful parameterization. Evaluate your model on both task A and task B:
print("Task A accuracy after training on Multi-Task Problem: {}".format(evaluate(multi_task_model, task_A_test)))
print("Task B accuracy after training on Multi-Task Problem: {}".format(evaluate(multi_task_model, task_B_test)))
And you get the following output:
Task A accuracy after only training on Multi-Task Problem: 0.9723913669586182
Task B accuracy after only training on Multi-Task Problem: 0.97418212890625
The model performs well on both task A and task B after training on the multi-task dataset. Now, to demonstrate the difficulties of continual learning, create a new model and train it on task A:
basic_cl_model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(5)
])
basic_cl_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')
basic_cl_model.fit(task_A_train, epochs=6)
Next, evaluate the model on task A:
print("Task A accuracy after training model on only Task A: {}".format(evaluate(basic_cl_model, task_A_test)))
Task A accuracy after training model on only Task A: 0.984977662563324
Now, train the same model on task B:
basic_cl_model.fit(task_B_train, epochs=6)
And evaluate it on both task A and task B:
print("Task B accuracy after training trained model on Task B: {}".format(evaluate(basic_cl_model, task_B_test)))
print("Task A accuracy after training trained model on Task B: {}".format(evaluate(basic_cl_model, task_A_test)))
Task B accuracy after training trained model on Task B: 0.9830508232116699
Task A accuracy after training trained model on Task B: 0.2663418650627136
Notice how the model easily solves task B; however, it loses nearly all of it’s knowledge about task A. The model still performs slightly better on task A than randomly guessing labels $latex [0 - 4]$, but there’s certainly room for improvement. So, how can we learn task B without forgetting knowledge from task A?
Remember, all knowledge of task A is encoded in the parameterization $latex \Theta_{A}^*$ which lies on the manifold $latex M_A$. One approach is to ensure the new parameterization $latex \Theta_{AB}$ never drifts too far from the original parameterization $latex \Theta_{A}^*$. The $latex L^2$ norm of a vector measures the length of the vector in Euclidean space. It represents the distance of a vector from the origin. When used as a regularization term, the $latex L^2$ norm forces the parameters of a neural network to remain centered around the origin. Applied to continual learning, you can force $latex \Theta_{AB}$ to remain close to $latex \Theta_{A}^*$ by adding the $latex L^2$ norm of $latex \Theta_{AB}$ centered around $latex \Theta_{A}^*$ as a regularization term during training.
Let’s see how this works in practice. You’ll need to create a function that calculates the $latex L^2$ penalty of each parameter in the model as well as a custom training loop to apply this penalty at each update:
def l2_penalty(model, theta_A):
penalty = 0
for i, theta_i in enumerate(model.trainable_variables):
_penalty = tf.norm(theta_i - theta_A[i])
penalty += _penalty
return 0.5*penalty
def train_with_l2(model, task_A_train, task_B_train, task_A_test, task_B_test, epochs=6):
# First we're going to fit to task A and retain a copy of parameters trained on Task A
model.fit(task_A_train, epochs=epochs)
theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}
print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))
# Metrics for the custom training loop
accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')
for epoch in range(epochs):
accuracy.reset_states()
loss.reset_states()
for batch, (imgs, labels) in enumerate(task_B_train):
with tf.GradientTape() as tape:
preds = model(imgs)
# Loss is crossentropy loss with regularization term for each parameter
total_loss = model.loss(labels, preds) + l2_penalty(model, theta_A)
grads = tape.gradient(total_loss, model.trainable_variables)
model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
accuracy.update_state(labels, preds)
loss.update_state(labels, preds)
print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
)
print("")
print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))
And then run the training loop with a new model:
l2_model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(5)
])
l2_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')
train_with_l2(l2_model, task_A_train, task_B_train, task_A_test, task_B_test)
You’ll notice the model plateaus very early when training on task B, finishing with the following results:
Task B accuracy after training trained model on Task B: 0.5977532267570496
Task A accuracy after training trained model on Task B: 0.9037758708000183
These results aren’t bad, but they show some of the limitations of using an $latex L^2$ penalty. While the $latex L^2$ penalty ensures the new model parameterization retains knowledge of task A, it’s too severe for successfully learning task B. The regularization term forces every parameter in the model to remain close to the parameterization $latex \Theta_{A}^*$, even those which are unimportant to task A. Rather than constrain every parameter to remain close to $latex \Theta_{A}^*$, we should focus only on those parameters which are important to task A. This process most closely relates to how the mammalian brain protects information.
Consider the following objective formulation of the continual learning problem on task A and task B:
$latex \log p(\Theta | \mathcal{D}) = \log p(\mathcal{D}_B | \Theta) + \log p(\Theta | \mathcal{D}_A) - log p(\mathcal{D}_B)$
You can derive this formulation by starting with the idea that the learning problem given a dataset $latex \mathcal{D}$ is to find the most likely parameterization $latex \Theta$ given $latex \mathcal{D}$ which is represented as $latex p(\Theta | \mathcal{D}$. It follows from Bayes’ Theorem that:
$latex p(\Theta | \mathcal{D}) = \frac{p(\mathcal{D} | \Theta) \cdot p(\Theta)}{p(\mathcal{D})}$
Taking the $latex \log$ of both sides and rearranging using logarithmic identities yields:
$latex \log p(\Theta | \mathcal{D}) = \log p(\mathcal{D} | \Theta) + \log p(\Theta) - \log p (\mathcal{D})$
In the context of continual learning, the overall goal is still to learn $latex p(\Theta | \mathcal{D})$; however, this is done by assuming the dataset is split into independent parts $latex \mathcal{D}_A$ and $latex \mathcal{D}_B$. The objective is first to learn $latex p(\Theta | \mathcal{D}_A)$, then to learn $latex p(\Theta | \mathcal{D}$ from the posterior $latex p(\Theta | \mathcal{D}_A)$ trained on $latex p(\Theta | \mathcal{D}_B)$. Algebraically this means your goal is:
$latex p(\Theta | \mathcal{D}) = p(p(\Theta | \mathcal{D}_A) | \mathcal{D}_B)$
then, from Bayes’ Theorem:
$latex p(\Theta | \mathcal{D}) = \frac{p(\mathcal{D}_B | p(\Theta | \mathcal{D}_A)) \cdot p(\Theta | \mathcal{D}_A)}{p(\mathcal{D}_B}$
Notice the term $latex p(\mathcal{D}_B | p(\Theta | \mathcal{D}_A))$. Assuming we start from a model trained on $latex \mathcal{D}_A$, then the parameterization $latex \Theta$ estimates the posterior $latex p(\Theta | \mathcal{D}_A)$, which means we can rewrite the formula as:
$latex p(\Theta | \mathcal{D}) = \frac{p(\mathcal{D}_B | \Theta) \cdot p(\Theta | \mathcal{D}_A)}{p(\mathcal{D}_B)}$
Taking the $latex \log$ of both sides and rearranging yields:
$latex \log p(\Theta | \mathcal{D}) = \log p(\mathcal{D}_B | \Theta) + \log p(\Theta | \mathcal{D}_A) - \log p(\mathcal{D}_B)$
It follows from the problem formulation that all knowledge of task A is absorbed in the posterior $latex p(\Theta | \mathcal{D}_A)$ which means $latex p(\Theta | \mathcal{D}_A)$ contains information about which parameters are most important to task A. Given the true posterior, $latex p(\Theta | \mathcal{D}_A)$, you could calculate the Fisher Information Matrix $latex F_i$ of each parameter $latex \Theta_i$ in $latex \Theta$ with respect to $latex \mathcal{D}_A$ which estimates the amount of information $latex \Theta_i$ contains about the true posterior $latex p(\Theta | \mathcal{D}_A)$. From $latex F$, we could determine the relative importance of each parameter and constrain the model to bound parameters relative to their importance to task A when training on task B. Unfortunately, recovering $latex p(\Theta | \mathcal{D}_A)$ is intractable, so we have to approximate $latex F$.
$latex F$ closely relates to the concept of the spring constant $k$ which is a measure of the stiffness of a given spring. Higher $k$ indicates a stiffer spring. The potential energy of a spring is represented as:
$latex PE = \frac{1}{2}k(x - x_0)^2$
Where $latex x - x_0$ measures the deviation of the spring from it’s initial position. The further the spring deviates from it’s initial position, the more potential energy is required. Additionally, the stiffer the spring, the more potential energy is required relative to the total displacement of the spring. If we consider the parameterization $latex \Theta$ as a spring, we want to ensure that parameters $latex \Theta_i$ are only able to deviate from $latex \Theta_{A}^*$ relative to their importance on task A. Less important parameters have lower relative importance measured by $latex F_i$ and are thus said to be more elastic. Using this formulation, we get the following regularization term:
$latex \Omega(\Theta) = \sum\nolimits_i \frac{1}{2}F_i(\Theta_i - \Theta_{Ai}^*)^2$
The regularization term is the elastic potential energy of each parameter in $latex \Theta$ centered around $latex \Theta_{A}^*$. In practice, you want to add an additional term, $latex \lambda$, which represents the relative importance of the regularization term. Adding this term to the objective function yields:
$latex \mathcal{L}(\Theta) = \mathcal{L}_B(\Theta) + \sum\nolimits_i \frac{\lambda}{2}F_i(\Theta_i - \Theta_{Ai}^*)^2$
where $latex \mathcal{L}_B(\Theta)$ is the total loss on task B. Using this objective during training is known as elastic weight consolidation or EWC.
$latex F$ is approximated according to the following formula:
$latex F = \frac{1}{N} \sum\limits_{i=1}^N \nabla_\Theta \log p(x_{A,i} | \Theta_{A}^*) \nabla_\Theta \log p(x_{A,i} | \Theta_{A}^*)^T$
$latex F$ is approximated as the mean of the gradients of the log-likelihood of $latex N$ examples sampled from $latex \mathcal{D}_A$ squared. So, how is this implemented in practice?
You’ll want to start with a function that approximates $latex F$ given a model and $latex \mathcal{D}_A$. The function samples a number of batches from $latex \mathcal{D}_A$, and calculates the gradient of the log-likelihood with respect to the model’s parameters, and then calculates the mean of the gradients squared:
def compute_precision_matrices(model, task_set, num_batches=1, batch_size=32):
task_set = task_set.repeat()
precision_matrices = {n: tf.zeros_like(p.value()) for n, p in enumerate(model.trainable_variables)}
for i, (imgs, labels) in enumerate(task_set.take(num_batches)):
# We need gradients of model params
with tf.GradientTape() as tape:
# Get model predictions for each image
preds = model(imgs)
# Get the log likelihoods of the predictions
ll = tf.nn.log_softmax(preds)
# Attach gradients of ll to ll_grads
ll_grads = tape.gradient(ll, model.trainable_variables)
# Compute F_i as mean of gradients squared
for i, g in enumerate(ll_grads):
precision_matrices[i] += tf.math.reduce_mean(g ** 2, axis=0) / num_batches
return precision_matrices
Next, you’ll need to compute the regularization term $latex \Omega(\Theta)$:
def compute_elastic_penalty(F, theta, theta_A, alpha=25):
penalty = 0
for i, theta_i in enumerate(theta):
_penalty = tf.math.reduce_sum(F[i] * (theta_i - theta_A[i]) ** 2)
penalty += _penalty
return 0.5*alpha*penalty
Implement a loss function which uses the regularization term:
def ewc_loss(labels, preds, model, F, theta_A):
loss_b = model.loss(labels, preds)
penalty = compute_elastic_penalty(F, model.trainable_variables, theta_A)
return loss_b + penalty
And a custom training loop which ties everything together:
def train_with_ewc(model, task_A_set, task_B_set, task_A_test, task_B_test, epochs=3):
# First we're going to fit to task A and retain a copy of parameters trained on Task A
model.fit(task_A_set, epochs=epochs)
theta_A = {n: p.value() for n, p in enumerate(model.trainable_variables.copy())}
# We'll only compute Fisher once, you can do it whenever
F = compute_precision_matrices(model, task_A_set, num_batches=1000)
print("Task A accuracy after training on Task A: {}".format(evaluate(model, task_A_test)))
# Now we set up the training loop for task B with EWC
accuracy = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
loss = tf.keras.metrics.SparseCategoricalCrossentropy('loss')
for epoch in range(epochs*3):
accuracy.reset_states()
loss.reset_states()
for batch, (imgs, labels) in enumerate(task_B_set):
with tf.GradientTape() as tape:
# Make the predictions
preds = model(imgs)
# Compute EWC loss
total_loss = ewc_loss(labels, preds, model, F, theta_A)
# Compute the gradients of model's trainable parameters wrt total loss
grads = tape.gradient(total_loss, model.trainable_variables)
# Update the model with gradients
model.optimizer.apply_gradients(zip(grads, model.trainable_variables))
# Report updated loss and accuracy
accuracy.update_state(labels, preds)
loss.update_state(labels, preds)
print("\rEpoch: {}, Batch: {}, Loss: {:.3f}, Accuracy: {:.3f}".format(
epoch+1, batch+1, loss.result().numpy(), accuracy.result().numpy()), flush=True, end=''
)
print("")
print("Task B accuracy after training trained model on Task B: {}".format(evaluate(model, task_B_test)))
print("Task A accuracy after training trained model on Task B: {}".format(evaluate(model, task_A_test)))
Finally, create a new model and train it using EWC:
ewc_model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(5)
])
ewc_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics='accuracy')
train_with_ewc(ewc_model, task_A_train, task_B_train, task_A_test, task_B_test)
With the following results after training:
Task B accuracy after training trained model on Task B: 0.7087110877037048
Task A accuracy after training trained model on Task B: 0.7677628993988037
The model performs acceptably on both task A and task B; however, it’s not all that impressive. I’ll expand on this below.
You should have noticed a few peculiarities about this model and the training process. First, I added dropout to the model as a form of regularization. I found during training that dropout was important when using EWC to ensure the model doesn’t significantly overfit to task A when training on task B. Without dropout, the model struggles to make any progress on task B after training on task A. Additionally, you’ll notice the model is trained 3 times longer on task B than on task A. I found progress on task B was much slower than on task A.
You may find that training with EWC is particularly sensitive to hyperparameters such as the importance of the elastic penalty. Most of my trials produced worse results using EWC than using an $latex L^2$ penalty. I’m assuming EWC is sensitive to the final parameterization of the model after training on task A. If the parameterization lies in a difficult region of the parameter space after training on task A, then the model will struggle to optimize on task B.
In theory, EWC is incredibly cool. In practice, it seems overly sensitive and sometimes difficult to train.
If you find any bugs in my implementation or problems with my explanation, feel free to comment and I will correct them!