Nx Tip of the Week #10 - Using Nx.select

Published on

Nx‘s API can seem a little more restrictive due to some of it’s static shape requirements. For example, boolean indexing is not currently supported because it would be impossible to know the shape at runtime. For those who don’t know boolean indexing selects values of an array based on some boolean mask. For example, let’s say I wanted to compute the sum of all non-negative values in an array - in NumPy I can do:

>>> a = np.array([[-1, 2, -3], [4, -5, 6]])
>>> non_negative_a = a[a > 0] # np.array([2, 4, 6]
>>> np.sum(non_negative_a)

This same operation is not supported because the output shape is dependent on the runtime result a > 0. There are some techniques we can potentially explore to make it work (e.g. dynamic shape support - compiling programs with bounded shapes), but for the time being, let’s consider an alternative.

More often than not, you can solve a problem that you think requires boolean indexing with Nx.select. Nx.select builds a tensor from 3 tensors:

  • A mask
  • A true tensor
  • A false tensor

Nx.select will choose values from the true tensor when corresponding values in the mask are true and values from the false tensor when they are not. By carefully choosing which tensors we use to construct our result tensor, we can usually replace boolean indexing. For the problem outlined above, you know ahead of time you want the final result to be the sum, so your false tensor can be all 0 values because this won’t have any effect on the sum:

iex> a = Nx.tensor([[-1, 2, -3], [4, -5, 6]])
iex> non_negative_a = Nx.select(Nx.greater(a, 0), a, 0)
iex> Nx.sum(non_negative_a)

By carefully choosing the on_false tensor, you’ve guaranteed it’s values will have no impact on your final result - effectively achieving the same thing as with boolean indexing.