Nx Tip of the Week #2 – Tensor Operations for Elixir Programmers

In Elixir, it’s common to manipulate data using the Enum module. Enum provides a set of library functions for working with types that implement the Enumerable protocol. The Enum module is a productive interface for manipulating lists, maps, sets, etc. However, learning how to think about tensor manipulation using Nx can be a bit difficult when you’re used to thinking about data manipulation using Enum.

Element-wise unary functions

Nx contains a number of element-wise unary functions that are tensor aware. If you were asked to implement a function that returns the exponential of every element in a list, you would probably do something like:

iex> a = [1, 2, 3]
iex> Enum.map(a, fn x -> :math.exp(x) end)
[2.718281828459045, 7.38905609893065, 20.085536923187668]

If you we’re asked to implement the same function in Nx, it’s easy to try and do something like:

iex> a = Nx.tensor([1, 2, 3], type: {:s, 32}, names: [:data])
iex> Nx.map(a, [type: {:f, 32}], fn x -> Nx.exp(x) end)
#Nx.Tensor<
  f32[data: 3]
  [2.718281828459045, 7.38905609893065, 20.085536923187668]
>

Note: You have to be explicit about the output type of Nx.map/3 because it cannot infer the type of the output from the anonymous function.

While this implementation is correct, it’s verbose, and will be inefficient for most Nx compilers. Additionally, native Nx backends can’t implement functions like Nx.map/3 because there’s no way to pass functions on the VM to a native interface. Fortunately, the element-wise unary functions like Nx.exp/1 are tensor aware, which means they operate over the entire tensor:

iex> a = Nx.tensor([1, 2, 3], type: {:f, 32}, names: [:data])
iex> Nx.exp(a)
#Nx.Tensor<
  f32[data: 3]
  [2.718281828459045, 7.38905609893065, 20.085536923187668]
>

This comes in handy when working with higher-dimensional tensors:

iex> a = Nx.iota({2, 2, 1, 2, 1, 2}, type: {:f, 32})
iex> Nx.exp(a)
#Nx.Tensor<
  f32[2][2][1][2][1][2]
  [
    [
      [
        [
          [
            [1.0, 2.718281828459045]
          ],
          [
            [7.38905609893065, 20.085536923187668]
          ]
        ]
      ],
      [
        [
          [
            [54.598150033144236, 148.4131591025766]
          ],
          [
            [403.4287934927351, 1096.6331584284585]
          ]
        ]
      ]
    ],
    [
      [
        [
          [
            [2980.9579870417283, 8103.083927575384]
          ],
          [
            [22026.465794806718, 59874.14171519782]
          ]
        ]
      ],
      [
        [
          [
            [162754.79141900392, 442413.3920089205]
          ],
          [
            [1202604.2841647768, 3269017.3724721107]
          ]
        ]
      ]
    ]
  ]
>

An equivalent implementation using Enum and lists would be much more verbose. You should avoid usage of Nx.map/3 in favor of the element-wise unary functions whenever possible.

Element-wise binary functions

If you were asked to implement element-wise addition of two lists in Elixir, you would probably do something similar to:

iex> a = [1, 2, 3]
iex> b = [4, 5, 6]
iex> a |> Enum.zip(b) |> Enum.map(fn {x, y} -> x + y end)
[5, 7, 9]

In Nx, there is no concept of zip, element-wise binary functions always pair corresponding elements in the tensor:

iex> a = Nx.tensor([1, 2, 3], type: {:f, 32})
iex> b = Nx.tensor([4, 5, 6], type: {:f, 32})
iex> Nx.add(a, b)
#Nx.Tensor<
  f32[3]
  [5.0, 7.0, 9.0]
>

Just like the previous example, this becomes clearly useful in higher dimensions:

iex> a = Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], type: {:f, 32})
iex> b = Nx.tensor([[[2, 3, 4], [5, 6, 7], [8, 9, 10]]], type: {:f, 32})
iex> Nx.add(a, b)
#Nx.Tensor<
  f32[1][3][3]
  [
    [
      [3.0, 5.0, 7.0],
      [9.0, 11.0, 13.0],
      [15.0, 17.0, 19.0]
    ]
  ]
>

What about broadcasting? While broadcasting can get complex, it’s easy enough to reason about in the scalar case. Imagine you want to multiply each element in a list by a scalar. With Enum you would do:

iex> broadcast_scalar = fn x, list -> Enum.map(list, & &1*x) end
iex> broadcast_scalar.(5, [1, 2, 3]
[5, 10, 15]

The equivalent in Nx:

iex> broadcast_scalar = &Nx.multiply(&1, &2)
iex> broadcast_scalar(5, Nx.tensor([1, 2, 3], type: {:f, 32})
#Nx.Tensor<
  f32[3]
  [5.0, 10.0, 15.0]
>

You can see why broadcasting is so easy in the scalar case – it just applies the element-wise function between the scalar and every item in the input tensor. In higher dimensions, broadcasting can get tricky. I suggest reading the broadcasting section of the Nx documentation.

Aggregate Operators

Aggregation with Enum is typically done using Enum.reduce/3. For example, to sum all of the elements in a list:

iex> a = [1, 2, 3]
iex> Enum.reduce(a, 0, fn x, acc -> x + acc end)
6

This task gets considerably more verbose with nested lists:

iex> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
iex> Enum.reduce(a, 0,
...>   fn x, acc -> Enum.reduce(x, 0,
...>     fn y, inner_acc ->
...>       y + inner_acc
...>      end) + acc
...>   end)
45

What if you only want the sum of each inner list? You can do:

iex> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
iex> Enum.reduce(a, [],
...>   fn x, acc -> [Enum.reduce(x, 0,
...>     fn y, inner_acc ->
...>       y + inner_acc
...>     end) | acc]
...>   end) |> Enum.reverse()
[6, 15, 24]

But what if you don’t know the level of nesting of the input list? What if you want to sum across columns instead of rows? This task can get considerably complex. Fortunately, it’s easy in Nx. Nx also has an Nx.reduce function; however, similar to Nx.map, you should prefer aggregate operations like Nx.sum, Nx.product, and Nx.mean to custom implementations. I will show one example using Nx.reduce, and then favor pre-written aggregate operations for the rest of the examples (yes I’m aware of Enum.sum, but it doesn’t make the above problem that much easier):

iex> a = Nx.tensor([1, 2, 3], type: {:f, 32}
iex> Nx.reduce(a, 0, fn x, acc -> Nx.add(x, acc) end)
#Nx.Tensor<
  f32
  6.0
>

What about with higher dimensions? Nx will take care of that for you:

iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], type: {:f, 32})
iex> Nx.sum(a)
#Nx.Tensor<
  f32
  45.0
>

And what if I only want the sum along a specific axis?

iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], type: {:f, 32})
iex> Nx.sum(a, axes: [1])
#Nx.Tensor<
  f32[3]
  [6.0, 15.0, 24.0]
>
iex> Nx.sum(a, axes: [0])
#Nx.Tensor<
  f32[3]
  [12.0, 15.0, 18.0]
>

Aggregate operations like Nx.sum reduce along one or more axes. For example, you can aggregate multiple axes:

iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], type: {:f, 32})
iex> Nx.sum(a, axes: [0, 1])
#Nx.Tensor<
  f32
  45.0
>

If you don’t pass any axes, Nx will automatically aggregate over the entire tensor. How the aggregation will be performed can sometimes be difficult to reason about. Although, it’s a bit easier to see when using named tensors:

iex> a = Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]],
...>   names: [:x, :y, :z], type: {:f, 32})
#Nx.Tensor<
  f32[x: 1][y: 3][z: 3]
  [
    [
      [1.0, 2.0, 3.0],
      [4.0, 5.0, 6.0],
      [7.0, 8.0, 9.0]
    ]
  ]
>
iex> Nx.sum(a, axes: [:y])
#Nx.Tensor<
   f32[x: 1][z: 3]
   [
     [12.0, 15.0, 18.0]
   ]
>
iex> Nx.sum(a, axes: [:z])
#Nx.Tensor<
  f32[x: 1][y: 3]
  [
    [6.0, 15.0, 24.0]
  ]
>
iex> Nx.sum(a, axes: [:x])
#Nx.Tensor<
  f32[y: 3][z: 3]
  [
    [1.0, 2.0, 3.0],
    [4.0, 5.0, 6.0],
    [7.0, 8.0, 9.0]
  ]
>

Notice how in each case the axis that disappears is the one provided in axes. axes also supports negative indexing; however, you should generally prefer using named axes over integer axes wherever possible. We’ll cover named tensors in detail in a later post.

Hopefully this gets you primed for reasoning about manipulating tensors in Nx. Until next time, happy coding!

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s