# Nx Tip of the Week #2 - Tensor Operations for Elixir Programmers

Published on

In Elixir, it’s common to manipulate data using the `Enum`

module. `Enum`

provides a set of library functions for working with types that implement the `Enumerable`

protocol. The `Enum`

module is a productive interface for manipulating lists, maps, sets, etc. However, learning how to think about tensor manipulation using `Nx`

can be a bit difficult when you’re used to thinking about data manipulation using `Enum`

.

## Element-wise unary functions

`Nx`

contains a number of element-wise unary functions that are *tensor* *aware*. If you were asked to implement a function that returns the exponential of every element in a list, you would probably do something like:

```
iex> a = [1, 2, 3]
iex> Enum.map(a, fn x -> :math.exp(x) end)
[2.718281828459045, 7.38905609893065, 20.085536923187668]
```

If you we’re asked to implement the same function in `Nx`

, it’s easy to try and do something like:

```
iex> a = Nx.tensor([1, 2, 3], type: {:s, 32}, names: [:data])
iex> Nx.map(a, [type: {:f, 32}], fn x -> Nx.exp(x) end)
#Nx.Tensor<
f32[data: 3]
[2.718281828459045, 7.38905609893065, 20.085536923187668]
>
```

*Note: You have to be explicit about the output type of Nx.map/3 because it cannot infer the type of the output from the anonymous function.*

While this implementation is correct, it’s verbose, and will be inefficient for most `Nx`

compilers. Additionally, native `Nx`

backends can’t implement functions like `Nx.map/3`

because there’s no way to pass functions on the VM to a native interface. Fortunately, the element-wise unary functions like `Nx.exp/1`

are *tensor aware*, which means they operate over the entire tensor:

```
iex> a = Nx.tensor([1, 2, 3], type: {:f, 32}, names: [:data])
iex> Nx.exp(a)
#Nx.Tensor<
f32[data: 3]
[2.718281828459045, 7.38905609893065, 20.085536923187668]
>
```

This comes in handy when working with higher-dimensional tensors:

```
iex> a = Nx.iota({2, 2, 1, 2, 1, 2}, type: {:f, 32})
iex> Nx.exp(a)
#Nx.Tensor<
f32[2][2][1][2][1][2]
[
[
[
[
[
[1.0, 2.718281828459045]
],
[
[7.38905609893065, 20.085536923187668]
]
]
],
[
[
[
[54.598150033144236, 148.4131591025766]
],
[
[403.4287934927351, 1096.6331584284585]
]
]
]
],
[
[
[
[
[2980.9579870417283, 8103.083927575384]
],
[
[22026.465794806718, 59874.14171519782]
]
]
],
[
[
[
[162754.79141900392, 442413.3920089205]
],
[
[1202604.2841647768, 3269017.3724721107]
]
]
]
]
]
>
```

An equivalent implementation using `Enum`

and lists would be much more verbose. You should avoid usage of `Nx.map/3`

in favor of the element-wise unary functions whenever possible.

## Element-wise binary functions

If you were asked to implement element-wise addition of two lists in Elixir, you would probably do something similar to:

```
iex> a = [1, 2, 3]
iex> b = [4, 5, 6]
iex> a |> Enum.zip(b) |> Enum.map(fn {x, y} -> x + y end)
[5, 7, 9]
```

In `Nx`

, there is no concept of `zip`

, element-wise binary functions always pair corresponding elements in the tensor:

```
iex> a = Nx.tensor([1, 2, 3], type: {:f, 32})
iex> b = Nx.tensor([4, 5, 6], type: {:f, 32})
iex> Nx.add(a, b)
#Nx.Tensor<
f32[3]
[5.0, 7.0, 9.0]
>
```

Just like the previous example, this becomes clearly useful in higher dimensions:

```
iex> a = Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]], type: {:f, 32})
iex> b = Nx.tensor([[[2, 3, 4], [5, 6, 7], [8, 9, 10]]], type: {:f, 32})
iex> Nx.add(a, b)
#Nx.Tensor<
f32[1][3][3]
[
[
[3.0, 5.0, 7.0],
[9.0, 11.0, 13.0],
[15.0, 17.0, 19.0]
]
]
>
```

What about broadcasting? While broadcasting can get complex, it’s easy enough to reason about in the scalar case. Imagine you want to multiply each element in a list by a scalar. With `Enum`

you would do:

```
iex> broadcast_scalar = fn x, list -> Enum.map(list, & &1*x) end
iex> broadcast_scalar.(5, [1, 2, 3]
[5, 10, 15]
```

The equivalent in `Nx`

:

```
iex> broadcast_scalar = &Nx.multiply(&1, &2)
iex> broadcast_scalar(5, Nx.tensor([1, 2, 3], type: {:f, 32})
#Nx.Tensor<
f32[3]
[5.0, 10.0, 15.0]
>
```

You can see why broadcasting is so easy in the scalar case - it just applies the element-wise function between the scalar and every item in the input tensor. In higher dimensions, broadcasting can get tricky. I suggest reading the broadcasting section of the `Nx`

documentation.

## Aggregate Operators

Aggregation with `Enum`

is typically done using `Enum.reduce/3`

. For example, to sum all of the elements in a list:

```
iex> a = [1, 2, 3]
iex> Enum.reduce(a, 0, fn x, acc -> x + acc end)
6
```

This task gets considerably more verbose with nested lists:

```
iex> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
iex> Enum.reduce(a, 0,
...> fn x, acc -> Enum.reduce(x, 0,
...> fn y, inner_acc ->
...> y + inner_acc
...> end) + acc
...> end)
45
```

What if you only want the sum of each inner list? You can do:

```
iex> a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
iex> Enum.reduce(a, [],
...> fn x, acc -> [Enum.reduce(x, 0,
...> fn y, inner_acc ->
...> y + inner_acc
...> end) | acc]
...> end) |> Enum.reverse()
[6, 15, 24]
```

But what if you don’t know the level of nesting of the input list? What if you want to sum across columns instead of rows? This task can get considerably complex. Fortunately, it’s easy in `Nx`

. `Nx`

also has an `Nx.reduce`

function; however, similar to `Nx.map`

, you should prefer aggregate operations like Nx.`sum`

, `Nx.product`

, and `Nx.mean`

to custom implementations. I will show one example using `Nx.reduce`

, and then favor pre-written aggregate operations for the rest of the examples (yes I’m aware of `Enum.sum`

, but it doesn’t make the above problem that much easier):

```
iex> a = Nx.tensor([1, 2, 3], type: {:f, 32}
iex> Nx.reduce(a, 0, fn x, acc -> Nx.add(x, acc) end)
#Nx.Tensor<
f32
6.0
>
```

What about with higher dimensions? `Nx`

will take care of that for you:

```
iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], type: {:f, 32})
iex> Nx.sum(a)
#Nx.Tensor<
f32
45.0
>
```

And what if I only want the sum along a specific axis?

```
iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], type: {:f, 32})
iex> Nx.sum(a, axes: [1])
#Nx.Tensor<
f32[3]
[6.0, 15.0, 24.0]
>
iex> Nx.sum(a, axes: [0])
#Nx.Tensor<
f32[3]
[12.0, 15.0, 18.0]
>
```

Aggregate operations like `Nx.sum`

reduce along one or more `axes`

. For example, you can aggregate multiple axes:

```
iex> a = Nx.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], type: {:f, 32})
iex> Nx.sum(a, axes: [0, 1])
#Nx.Tensor<
f32
45.0
>
```

If you don’t pass any `axes`

, `Nx`

will automatically aggregate over the entire tensor. How the aggregation will be performed can sometimes be difficult to reason about. Although, it’s a bit easier to see when using *named tensors*:

```
iex> a = Nx.tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]],
...> names: [:x, :y, :z], type: {:f, 32})
#Nx.Tensor<
f32[x: 1][y: 3][z: 3]
[
[
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]
]
]
>
iex> Nx.sum(a, axes: [:y])
#Nx.Tensor<
f32[x: 1][z: 3]
[
[12.0, 15.0, 18.0]
]
>
iex> Nx.sum(a, axes: [:z])
#Nx.Tensor<
f32[x: 1][y: 3]
[
[6.0, 15.0, 24.0]
]
>
iex> Nx.sum(a, axes: [:x])
#Nx.Tensor<
f32[y: 3][z: 3]
[
[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]
]
>
```

Notice how in each case the axis that disappears is the one provided in `axes`

. `axes`

also supports negative indexing; however, you should generally prefer using named axes over integer axes wherever possible. We’ll cover named tensors in detail in a later post.

Hopefully this gets you primed for reasoning about manipulating tensors in `Nx`

. Until next time, happy coding!