Nx Tip of the Week #12 - Nx.to_heatmap

Published on

Sometimes you want to quickly visualize the contents of a tensor. For example, when working with the MNIST dataset, you might want to make sure you’ve sliced it up correctly. A quick way to visualize images across a single color channel is with Nx.to_heatmap:

Nx.to_heatmap(img)

When inspecting the result of Nx.to_heatmap, you’ll get a nice console representation of a heatmap printed out. This is especially useful when you’re quickly debugging and don’t want to bring in any additional dependencies such as VegaLite. As a neat little trick, you can use Nx.to_heatmap to visualize intermediate activations in a convolutional neural network in Axon:

def visualize_conv_activations(input_to_visualize, conv_layer) do
  fn state ->
    # Get the activation for some input and convert to grayscale
    activations = Axon.predict(conv_layer, state.model_state, input_to_visualize) |> Nx.reduce_max(axes: [1])
    # Convert to heatmap and inspect
    IO.inspect Nx.to_heatmap(activations)
    {:continue, state}
  end
end

conv_base =
  Axon.input({nil, 1, 28, 28})
  |> Axon.conv(32, kernel_size: {3, 3})

model =
  conv_base
  |> Axon.flatten()
  |> Axon.dense(10, activation: :softmax)

model
|> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
|> Axon.Loop.handle(:epoch_completed, visualize_conv_activations(input_stream, conv_layer))
|> Axon.Loop.run(data, epochs: 10)

If you look at some of the Axon examples, you’ll see Nx.to_heatmap is used all over the place. It’s very useful for quick visualizations of images, activations, and more.