Nx Tip of the Week #9 – Window Functions

With the release of Nx 0.1.0, I thought I should continue these posts. My time is a limited so these will be a little more brief than before. Each week I’ll highlight a small aspect of the Nx API with some code examples.

When you first get started with an array programming library like Nx (or NumPy, JAX, TensorFlow, etc.), it can be difficult to get out of the habit of writing loops or recursive functions to iterate over parts of a tensor. Some of the Nx code I’ve seen used in other projects tend to make excessive use of Nx’s while construct. In some situations this is unavoidable (such as with Recurrent Neural Networks); however, if you can re-work an algorithm to not use a while then you will be much better off. I will cover while in a later post – you should just know that the body of a while loop will always be executed sequentially and will roundtrip data back to the CPU even when running on GPU.

One type of operation I typically see people rely on while-loops for is performing a look-back or look-ahead. When I say look-back or look-ahead, I’m referring to an operation which relies on information from a tensor at position i and i + 1. or i - 1. A common class of functions which require a look-back or a look-ahead are cumulative functions. For example, a cumulative sum computes a running sum at each position i in a tensor. If you have a tensor which represents revenue over N months, then the value at position i would represent aggregate revenue of all of the months up to month i:

iex> revenue_by_month = Nx.tensor([100, 200, 100, 50, 150])
iex> cumsum(revenue_by_month)
  [100, 300, 400, 450, 600]

You can probably see how it’s easy to want to implement this with a while loop by iterating over a tensor and indexing at i and i - 1. But, an alternative way to compute this is with a window_sum:

  defn cumsum(tensor, opts \\ []) do
    opts = keyword!(opts, axis: 0)
    axis = opts[:axis]
    {padding_config, strides, window_shape} =
      transform({tensor, axis}, fn {tensor, axis} ->
        n = elem(Nx.shape(tensor), axis)
        padding_config =
          for i <- 0..(Nx.rank(tensor) - 1) do
            if i == axis, do: {n - 1, 0}, else: {0, 0}

        strides = List.duplicate(1, Nx.rank(tensor))

        window_shape =
          List.duplicate(1, Nx.rank(tensor))
          |> List.to_tuple()
          |> put_elem(axis, n)

        {padding_config, strides, window_shape}

    Nx.window_sum(tensor, window_shape, strides: strides, padding: padding_config)

Nx offers a collection of window functions which compute aggregations over sliding windows in a Tensor: Nx.window_sum/3, Nx.window_mean/3, Nx.window_max/3, Nx.window_min/3, and Nx.window_reduce/4. Axon uses window functions to implement pooling, but they’re also very useful when you need to compute functions which rely on look-backs or look-aheads. The approach for computing a cumulative sum pads the input with N - 1 values (N being the size of the aggregate axis), and then slides over the entire Tensor with a stride of 1 and a window of size N. If you have an operation which requires a similar pattern, you should consider using a window function.

As the window you need to look at grows larger, using a window aggregate function can help you reason about the operation you’re performing a little better, and avoid hard-to-debug index and off-by-one errors.

See the documentation for window functions here.

One thought on “Nx Tip of the Week #9 – Window Functions

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s