Nx Tip of the Week #11 - While Loops
Published on
Some numeric algorithms require sequential operations. In TOTW #9, we talked about one operation you can use to avoid while-loops in specific situations. Unfortunately, you won’t always be able to avoid a while-loop. Nx
has a while
construct which is implemented as an Elixir macro. The while
construct takes an initial state, a condition, and a body which returns a shape which is the same as the initial state. It’s essentially a reduce_while
, which aggregates state while some condition is satisfied:
defn count_to_ten() do
while i = 0, i < 10, do: i + 1
end
iex> count_to_ten()
#Nx.Tensor<
s64
10
>
The state can even be a container such as a tuple, so you can aggregate multiple things at once:
defn count_to_ten_twice() do
while {i = 0, j = 0}, i < 10, do: {i + 1, j + 1}
end
iex> count_to_ten_twice()
{#Nx.Tensor<
s64
10
>,
#Nx.Tensor<
s64
10
>}
It’s important to understand that the shape of the body of the while loop must match the shape of the initial condition. For example, if you want to use a while-loop to iteratively build a tensor by adding new-values to some axis, you need to know the final shape of the tensor ahead of time:
defn build_a_vector() do
# Create a "filler" tensor
initial_tensor = Nx.broadcast(0.0, {12})
{_, final_tensor} =
while {i = 0, initial_tensor}, i < 12 do
val = Nx.random_uniform({1})
# Update filler tensor "in-place"
{i + 1, Nx.put_slice(initial_tensor, [i], val)}
end
final_tensor
end
iex> build_a_vector()
#Nx.Tensor<
f32[12]
[0.9620421528816223, 0.37593021988868713, 0.5158101916313171, 0.39656928181648254, 0.6919131875038147, 0.1678706705570221, 0.9522126913070679, 0.5573846101760864, 0.37262946367263794, 0.40950807929039, 0.9263219237327576, 0.45467808842658997]
>
The native defn
loops you implement will be much more efficient than pure Elixir loops performing the same operations, but they’ll still be slow relative to other computations. Because they can be difficult for JIT compilers like XLA to optimize, you might actually see some benefit from applying optimizations yourself. For example, you might benefit from unrolling the body of the loop:
defn unrolled_count_to_ten() do
while i = 0, i < 10 do
# Once
i = i + 1
# Do the body again
i = i + 1
i
end
end
Just remember that you might need to perform an additional condition check inside the body of the loop if you are manually unrolling.
If you know your loop is guaranteed to run only a few times, you can also “inline” the loop into your computation using a transform
:
defn inlined_loop() do
i = 0
transform(i, fn i ->
for _ <- 0..9, reduce: i do
x ->
Nx.add(x, 1)
end
end)
end
This will repeat the expression in the loop body a fixed number of times inside the actual computation. Inlining will only work well if the loop runs a relatively small number of times, or if you know at runtime how many times the loop is supposed to run.
Until next time!