Nx Tip of the Week #12 - Nx.to_heatmap
Published on
Sometimes you want to quickly visualize the contents of a tensor. For example, when working with the MNIST dataset, you might want to make sure you’ve sliced it up correctly. A quick way to visualize images across a single color channel is with Nx.to_heatmap
:
Nx.to_heatmap(img)
When inspecting the result of Nx.to_heatmap
, you’ll get a nice console representation of a heatmap printed out. This is especially useful when you’re quickly debugging and don’t want to bring in any additional dependencies such as VegaLite. As a neat little trick, you can use Nx.to_heatmap
to visualize intermediate activations in a convolutional neural network in Axon:
def visualize_conv_activations(input_to_visualize, conv_layer) do
fn state ->
# Get the activation for some input and convert to grayscale
activations = Axon.predict(conv_layer, state.model_state, input_to_visualize) |> Nx.reduce_max(axes: [1])
# Convert to heatmap and inspect
IO.inspect Nx.to_heatmap(activations)
{:continue, state}
end
end
conv_base =
Axon.input({nil, 1, 28, 28})
|> Axon.conv(32, kernel_size: {3, 3})
model =
conv_base
|> Axon.flatten()
|> Axon.dense(10, activation: :softmax)
model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.handle(:epoch_completed, visualize_conv_activations(input_stream, conv_layer))
|> Axon.Loop.run(data, epochs: 10)
If you look at some of the Axon examples, you’ll see Nx.to_heatmap
is used all over the place. It’s very useful for quick visualizations of images, activations, and more.