Nx Tip of the Week #9 - Window Functions
Published on
With the release of Nx 0.1.0, I thought I should continue these posts. My time is a limited so these will be a little more brief than before. Each week I’ll highlight a small aspect of the Nx API with some code examples.
When you first get started with an array programming library like Nx (or NumPy, JAX, TensorFlow, etc.), it can be difficult to get out of the habit of writing loops or recursive functions to iterate over parts of a tensor. Some of the Nx code I’ve seen used in other projects tend to make excessive use of Nx’s while
construct. In some situations this is unavoidable (such as with Recurrent Neural Networks); however, if you can re-work an algorithm to not use a while
then you will be much better off. I will cover while
in a later post - you should just know that the body of a while
loop will always be executed sequentially and will roundtrip data back to the CPU even when running on GPU.
One type of operation I typically see people rely on while-loops for is performing a look-back or look-ahead. When I say look-back or look-ahead, I’m referring to an operation which relies on information from a tensor at position i
and i + 1
. or i - 1
. A common class of functions which require a look-back or a look-ahead are cumulative functions. For example, a cumulative sum computes a running sum at each position i
in a tensor. If you have a tensor which represents revenue over N
months, then the value at position i
would represent aggregate revenue of all of the months up to month i
:
iex> revenue_by_month = Nx.tensor([100, 200, 100, 50, 150])
iex> cumsum(revenue_by_month)
#Nx.Tensor<
s64[5]
[100, 300, 400, 450, 600]
>
You can probably see how it’s easy to want to implement this with a while
loop by iterating over a tensor and indexing at i
and i - 1
. But, an alternative way to compute this is with a window_sum
:
defn cumsum(tensor, opts \\ []) do
opts = keyword!(opts, axis: 0)
axis = opts[:axis]
{padding_config, strides, window_shape} =
transform({tensor, axis}, fn {tensor, axis} ->
n = elem(Nx.shape(tensor), axis)
padding_config =
for i <- 0..(Nx.rank(tensor) - 1) do
if i == axis, do: {n - 1, 0}, else: {0, 0}
end
strides = List.duplicate(1, Nx.rank(tensor))
window_shape =
List.duplicate(1, Nx.rank(tensor))
|> List.to_tuple()
|> put_elem(axis, n)
{padding_config, strides, window_shape}
end)
Nx.window_sum(tensor, window_shape, strides: strides, padding: padding_config)
end
Nx offers a collection of window functions which compute aggregations over sliding windows in a Tensor: Nx.window_sum/3
, Nx.window_mean/3
, Nx.window_max/3
, Nx.window_min/3
, and Nx.window_reduce/4
. Axon uses window functions to implement pooling, but they’re also very useful when you need to compute functions which rely on look-backs or look-aheads. The approach for computing a cumulative sum pads the input with N - 1
values (N
being the size of the aggregate axis), and then slides over the entire Tensor with a stride of 1 and a window of size N
. If you have an operation which requires a similar pattern, you should consider using a window function.
As the window you need to look at grows larger, using a window aggregate function can help you reason about the operation you’re performing a little better, and avoid hard-to-debug index and off-by-one errors.
See the documentation for window functions here.