Nx Tip of the Week #11 – While Loops

Some numeric algorithms require sequential operations. In TOTW #9, we talked about one operation you can use to avoid while-loops in specific situations. Unfortunately, you won’t always be able to avoid a while-loop. Nx has a while construct which is implemented as an Elixir macro. The while construct takes an initial state, a condition, and a body which returns a shape which is the same as the initial state. It’s essentially a reduce_while, which aggregates state while some condition is satisfied:

defn count_to_ten() do
  while i = 0, i < 10, do: i + 1
iex> count_to_ten()

The state can even be a container such as a tuple, so you can aggregate multiple things at once:

defn count_to_ten_twice() do
  while {i = 0, j = 0}, i < 10, do: {i + 1, j + 1}
iex> count_to_ten_twice()

It’s important to understand that the shape of the body of the while loop must match the shape of the initial condition. For example, if you want to use a while-loop to iteratively build a tensor by adding new-values to some axis, you need to know the final shape of the tensor ahead of time:

defn build_a_vector() do
  # Create a "filler" tensor
  initial_tensor = Nx.broadcast(0.0, {12})
  {_, final_tensor} =
    while {i = 0, initial_tensor}, i < 12 do
      val = Nx.random_uniform({1})
      # Update filler tensor "in-place"
      {i + 1, Nx.put_slice(initial_tensor, [i], val)}

iex> build_a_vector()
   [0.9620421528816223, 0.37593021988868713, 0.5158101916313171, 0.39656928181648254, 0.6919131875038147, 0.1678706705570221, 0.9522126913070679, 0.5573846101760864, 0.37262946367263794, 0.40950807929039, 0.9263219237327576, 0.45467808842658997]

The native defn loops you implement will be much more efficient than pure Elixir loops performing the same operations, but they’ll still be slow relative to other computations. Because they can be difficult for JIT compilers like XLA to optimize, you might actually see some benefit from applying optimizations yourself. For example, you might benefit from unrolling the body of the loop:

defn unrolled_count_to_ten() do
  while i = 0, i < 10 do
    # Once
    i = i + 1
    # Do the body again
    i = i + 1

Just remember that you might need to perform an additional condition check inside the body of the loop if you are manually unrolling.

If you know your loop is guaranteed to run only a few times, you can also “inline” the loop into your computation using a transform:

defn inlined_loop() do
  i = 0
  transform(i, fn i ->
    for _ <- 0..9, reduce: i do
      x ->
        Nx.add(x, 1)

This will repeat the expression in the loop body a fixed number of times inside the actual computation. Inlining will only work well if the loop runs a relatively small number of times, or if you know at runtime how many times the loop is supposed to run.

Until next time!

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s