Nx Tip of the Week #11 - While Loops

Published on

Some numeric algorithms require sequential operations. In TOTW #9, we talked about one operation you can use to avoid while-loops in specific situations. Unfortunately, you won’t always be able to avoid a while-loop. Nx has a while construct which is implemented as an Elixir macro. The while construct takes an initial state, a condition, and a body which returns a shape which is the same as the initial state. It’s essentially a reduce_while, which aggregates state while some condition is satisfied:

defn count_to_ten() do
  while i = 0, i < 10, do: i + 1
end
iex> count_to_ten()
#Nx.Tensor<
  s64
  10
>

The state can even be a container such as a tuple, so you can aggregate multiple things at once:

defn count_to_ten_twice() do
  while {i = 0, j = 0}, i < 10, do: {i + 1, j + 1}
end
iex> count_to_ten_twice()
{#Nx.Tensor<
   s64
   10
 >,
 #Nx.Tensor<
   s64
   10
 >}

It’s important to understand that the shape of the body of the while loop must match the shape of the initial condition. For example, if you want to use a while-loop to iteratively build a tensor by adding new-values to some axis, you need to know the final shape of the tensor ahead of time:

defn build_a_vector() do
  # Create a "filler" tensor
  initial_tensor = Nx.broadcast(0.0, {12})
  {_, final_tensor} =
    while {i = 0, initial_tensor}, i < 12 do
      val = Nx.random_uniform({1})
      # Update filler tensor "in-place"
      {i + 1, Nx.put_slice(initial_tensor, [i], val)}
    end

  final_tensor
end
iex> build_a_vector()
 #Nx.Tensor<
   f32[12]
   [0.9620421528816223, 0.37593021988868713, 0.5158101916313171, 0.39656928181648254, 0.6919131875038147, 0.1678706705570221, 0.9522126913070679, 0.5573846101760864, 0.37262946367263794, 0.40950807929039, 0.9263219237327576, 0.45467808842658997]
 >

The native defn loops you implement will be much more efficient than pure Elixir loops performing the same operations, but they’ll still be slow relative to other computations. Because they can be difficult for JIT compilers like XLA to optimize, you might actually see some benefit from applying optimizations yourself. For example, you might benefit from unrolling the body of the loop:

defn unrolled_count_to_ten() do
  while i = 0, i < 10 do
    # Once
    i = i + 1
    # Do the body again
    i = i + 1
    i
  end
end

Just remember that you might need to perform an additional condition check inside the body of the loop if you are manually unrolling.

If you know your loop is guaranteed to run only a few times, you can also “inline” the loop into your computation using a transform:

defn inlined_loop() do
  i = 0
  transform(i, fn i ->
    for _ <- 0..9, reduce: i do
      x ->
        Nx.add(x, 1)
    end
  end)
end

This will repeat the expression in the loop body a fixed number of times inside the actual computation. Inlining will only work well if the loop runs a relatively small number of times, or if you know at runtime how many times the loop is supposed to run.

Until next time!