Nx Tip of the Week #11 – While Loops

Some numeric algorithms require sequential operations. In TOTW #9, we talked about one operation you can use to avoid while-loops in specific situations. Unfortunately, you won’t always be able to avoid a while-loop. `Nx` has a `while` construct which is implemented as an Elixir macro. The `while` construct takes an initial state, a condition, and a body which returns a shape which is the same as the initial state. It’s essentially a `reduce_while`, which aggregates state while some condition is satisfied:

```defn count_to_ten() do
while i = 0, i < 10, do: i + 1
end
```
```iex> count_to_ten()
#Nx.Tensor<
s64
10
>
```

The state can even be a container such as a tuple, so you can aggregate multiple things at once:

```defn count_to_ten_twice() do
while {i = 0, j = 0}, i < 10, do: {i + 1, j + 1}
end
```
```iex> count_to_ten_twice()
{#Nx.Tensor<
s64
10
>,
#Nx.Tensor<
s64
10
>}
```

It’s important to understand that the shape of the body of the while loop must match the shape of the initial condition. For example, if you want to use a while-loop to iteratively build a tensor by adding new-values to some axis, you need to know the final shape of the tensor ahead of time:

```defn build_a_vector() do
# Create a "filler" tensor
{_, final_tensor} =
while {i = 0, initial_tensor}, i < 12 do
val = Nx.random_uniform({1})
# Update filler tensor "in-place"
{i + 1, Nx.put_slice(initial_tensor, [i], val)}
end

final_tensor
end
```
```iex> build_a_vector()
#Nx.Tensor<
f32[12]
[0.9620421528816223, 0.37593021988868713, 0.5158101916313171, 0.39656928181648254, 0.6919131875038147, 0.1678706705570221, 0.9522126913070679, 0.5573846101760864, 0.37262946367263794, 0.40950807929039, 0.9263219237327576, 0.45467808842658997]
>
```

The native `defn` loops you implement will be much more efficient than pure Elixir loops performing the same operations, but they’ll still be slow relative to other computations. Because they can be difficult for JIT compilers like XLA to optimize, you might actually see some benefit from applying optimizations yourself. For example, you might benefit from unrolling the body of the loop:

```defn unrolled_count_to_ten() do
while i = 0, i < 10 do
# Once
i = i + 1
# Do the body again
i = i + 1
i
end
end
```

Just remember that you might need to perform an additional condition check inside the body of the loop if you are manually unrolling.

If you know your loop is guaranteed to run only a few times, you can also “inline” the loop into your computation using a `transform`:

```defn inlined_loop() do
i = 0
transform(i, fn i ->
for _ <- 0..9, reduce: i do
x ->