Nx Tip of the Week #14 - Slicing and Indexing

Published on

Often times you want to slice and index into specific parts of a tensor. Nx offers a few different slicing and indexing routines which allow you to accomplish most of what you would want to do. Slicing can be a bit tricky given static shape requirements, but you usually can work around limitations.

First, you can perform generic slices using Nx.slice/4:

a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
Nx.slice(a, [0, 1], [1, 2])

Returns:

#Nx.Tensor<
  s64[1][2]
  [
    [2, 3]
  ]
>

The first list in Nx.slice/4 is a list of start indices. You may specify dynamic or tensor values for the start index - as long as they have a scalar shape:

Nx.slice(a, [Nx.tensor(0), Nx.tensor(1)], [1, 2])

The second list is the length of each slice. Each value must be known or static at compile-time. This is because the length dictates the final shape of the sliced tensor:

Nx.slice(a, [1, 2], [2, 3])

Returns:

#Nx.Tensor<
  s64[2][3]
  [
    [1, 2, 3],
    [4, 5, 6]
  ]
>

This might surprise you! You might have expected an out-of-bounds error or something similar, considering that you are requesting slices from both axes past the original bounds of the tensor. In reality, Nx will force the slice in bounds to match the length requested from the slice in each dimension.

You can also make use of the Access syntax which builds on top of normal slice operations. For example:

a = Nx.tensor([[1, 2, 3], [4, 5, 6]])
a[[1, 0..1]]

Returns:

#Nx.Tensor<
  s64[2]
  [4, 5]
>

Because you’re accessing the first index from the first dimension and the zero to first indices from the second dimension. In newer versions of Elixir, you can slice an entire axis with ..:

a[[.., 2]]
#Nx.Tensor<
  s64[2]
  [3, 6]
>

If Nx.slice/4 and the Access syntax is not flexible enough for you, you can try one of Nx.take/3, Nx.gather/3, and Nx.take_along_axis/3. For example, a common operation in deep learning is computing a vector representation from sparse values. Given a sequence of integer tokens between 0 and 127, you can convert each token to a vector representation using Nx.take/3:

tokens = Nx.tensor([127, 32, 0, 1, 5, 6])
weights = Nx.random_uniform({128, 32})
Nx.take(weights, tokens)

Returns:

#Nx.Tensor<
  f32[6][32]
  [
    [0.39863643050193787, 0.33112287521362305, 0.531199038028717, 0.3594178259372711, 0.8754940629005432, 0.30342867970466614, 0.9188190698623657, 0.29185304045677185, 0.543312668800354, 0.5064964294433594, 0.7225326299667358, 0.06837604194879532, 0.5449554920196533, 0.2207975834608078, 0.0635833740234375, 0.3370073139667511, 0.6428131461143494, 0.8821378946304321, 0.9932462573051453, 0.8975431323051453, 0.7079696655273438, 0.023084526881575584, 0.4048435091972351, 0.12792034447193146, 0.4222281277179718, 0.21171192824840546, 0.7248737812042236, 0.5454342365264893, 0.2521210312843323, 0.2614332437515259, 0.3105127811431885, 0.03566299006342888],
    [0.3035010099411011, 0.07670660316944122, 0.07924123853445053, 0.161861851811409, 0.14367112517356873, 0.06336789578199387, 0.9437791109085083, 0.5998468399047852, 0.4222017228603363, 0.14000535011291504, 0.12471750378608704, 0.31671205163002014, 0.6216381192207336, 0.4062456488609314, 0.1768452525138855, 0.2160402536392212, 0.9336262345314026, 0.289279043674469, ...],
    ...
  ]
>

Nx.take/3 takes and concatenates slices from indices in a given tensor. In this example, each value in tokens represents an index in the first dimension of the weights tensor. What you end up with is a sequence of dense vectors from an original sparse representation. Another option is Nx.take_along_axis/3 which can be used to take indices along an axis of a given tensor. For example, you can combine Nx.take_along_axis/3 with Nx.argsort/2 to sort values along an axis:

a = Nx.tensor([[2, 3, 0, 1, 4, 8], [5, 1, 2, 3, 6, 9]])
indices = Nx.argsort(a, axis: 1)
Nx.take_along_axis(a, indices, axis: 1)

Returns:

#Nx.Tensor<
  s64[2][6]
  [
    [0, 1, 2, 3, 4, 8],
    [1, 2, 3, 5, 6, 9]
  ]
>

The last indexing operation Nx offers is Nx.gather/3. Nx.gather/3 contains a tensor of indices, where the last axes in the indices tensor represents a single value in the source tensor:

t = Nx.tensor([[1, 2], [3, 4]])
Nx.gather(t, Nx.tensor([[1, 1], [0, 1], [1, 0]]))

Returns:

#Nx.Tensor<
  s64[3]
  [4, 2, 3]
>

Notice how the last dimension in the indices tensor has a size of 2 which matches the rank of the input tensor. The last dimension represents indexes in the source tensor: [1, 1], [0, 1], and [1, 0]. The output shape of a gather operation will be equal to the leading dimensions in the indices tensor minus the last dimension.