Nx Tip of the Week #5 - Named Tensors

Published on

Note: The original named tensors article, Tensor Considered Harmful, goes through these details in much more detail and explains much better than I can. I recommend reading that as well.

One of my biggest frustrations when working with NumPy and TensorFlow comes when working with axes. Take for example, this TensorFlow implementation of the Mean Squared Error formula:

mse = lambda x, y: tf.math.reduce_mean(tf.math.square(x - y), axis=-1)

While this is a relatively simple function, I often find myself struggling to reason immediately about what axis=-1 is really doing. The intent of the code isn’t immediately clear. Enter named tensors.

Nx gives you the ability to name the dimensions of your tensors upon creation:

iex> Nx.tensor([[1, 2, 3]], names: [:batch, :data])
#Nx.Tensor<
  s64[batch: 1][data: 3]
  [
    [1, 2, 3]
  ]
>

Attaching names to dimensions makes your code more immediately understandable. For example, the same MSE implementation in Nx can look like:

iex> mse = fn x, y -> Nx.mean(Nx.power(x - y, 2), axes: [:data])

It’s immediately clear that your goal is to take the mean of the squared difference along the :data dimension. Additionally, it allows you to enforce some restrictions on the format of data coming in to your functions. The above function will fail if there is no :data dimension in the passed tensors.

Named tensors make other operations easier to reason about as well. Imagine you want to flip an image horizontally. There are two common image data formats: channels first and channels last. Given this, normally you would have to write a method that either explicitly enforces one format, or accepts the format and handles it accordingly. With named tensors, you can simply access the correct dimension using it’s name:

defn flip_left_right(x) do
  x |> Nx.reverse(axes: [:width])
end

Whether the image comes in with names: [:batch, :channels, :height, :width] or [:batch, :height, :width, :channels], your method will always work because Nx normalizes axes with respect to the names of the given tensor.

Imagine now you want to enforce the data format of your image is always channels first. You can do this by transposing using names:

defn channels_first(x) do
  x |> Nx.transpose(axes: [:batch, :channels, :height, :width])
end

Once again, no matter the format of the image, your method will always succeed. Additionally, it provides a sort of automatic check on the dimensions of the input. You wouldn’t expect the above method to work on anything but images.

Broadcasting

Named tensors also provide some safety with respect to broadcasting. Imagine you have two tensors:

iex> t1 = Nx.tensor([[1, 2, 3]], names: [:batch, :data]
#Nx.Tensor<
  s64[batch: 1][data: 3]
  [
    [1, 2, 3]
  ]
>

iex> t2 = Nx.tensor([[1, 2, 3]], names: [:x, :y]
#Nx.Tensor<
  s64[x: 1][y: 3]
  [
    [1, 2, 3]
  ]
>

You’ll notice the shapes of these two tensors are compatible - that is if you tried to add, subtract, multiply, etc. them together, it would succeed. But, does that really make sense? How does the :batch dimension really line up with :x? Semantically, it doesn’t make sense. If you were to try to add these tensors together, it would fail:

iex> Nx.add(t1, t2)
** (ArgumentError) cannot merge names :data, :y
    (nx 0.1.0-dev) lib/nx/shape.ex:1001: Nx.Shape.merge_names!/2
    (nx 0.1.0-dev) lib/nx/shape.ex:220: Nx.Shape.binary_broadcast/6
    (nx 0.1.0-dev) lib/nx/shape.ex:200: Nx.Shape.binary_broadcast/4
    (nx 0.1.0-dev) lib/nx.ex:2424: Nx.element_wise_bin_op/4

In order to perform binary operations on named tensors, the names must align. This means they must either match, or one of the names must be nil (considered a wild card name). When broadcasting between named and unnamed dimensions, the resulting tensor merges names:

iex> t1 = Nx.tensor([[1, 2, 3]], names: [:batch, nil])
#Nx.Tensor<
  s64[batch: 1][3]
  [
    [1, 2, 3],
  ]
>

iex> t2 = Nx.tensor([[1, 2, 3]], names: [nil, :data])
#Nx.Tensor<
  s64[1][data: 3]
  [
    [1, 2, 3]
  ]
>

iex> Nx.add(t1, t2)
#Nx.Tensor<
  s64[batch: 1][data: 3]
  [
    [1, 2, 3]
  ]
>

Most operations have specific name rules that validate the operation can be performed correctly on the given tensors.

Named tensors are relatively new, but are already gaining traction in other frameworks like PyTorch, xarray, and einops. Named tensors within Nx have definitely not been fully explored. I highly recommend reading the original Named tensors proposal, and exploring named tensors in other libraries for inspiration on what is possible within Nx.