Nx Tip of the Week #1 - Using transforms

Published on

Note: This is an idea I had after learning a significant amount from the abseil C++ Tips of the Week during my work on EXLA. I’ll keep writing them as long as there’s interest. If there’s anything in particular you’d like to read about, feel free to let me know!

Nx is an exciting new project with the hopes of making mathematical computing practical within the Elixir ecosystem. Fundamentally, Nx is a tensor manipulation or array programming library similar to NumPy, TensorFlow, or PyTorch. Nx introduces a new type of function definition, defn, that is a subset of the Elixir programming language tailored specifically to numerical computations. When numerical definitions are invoked, they’re transformed into expressions (internally Nx.Defn.Expr) which represent the AST or computation graph of the numerical definition. These expressions are manipulated by compilers (like EXLA) to produce executables that run natively on accelerators.

The subset of supported Elixir code within defn can, at times, feel restrictive; however, there are a number of ways to overcome these perceived limitations. One such example is transform/2. From the docs:

transform(arg, fun)

Defines a transform that executes the given `fun` with `arg`
when building `defn` expressions.

You can invoke transform/2 from within defn to call out to any Elixir function. As an example, you can use transform/2 to inspect the underlying expression in a numerical definition:

defn tanh_power(a, b) do
  res = Nx.tanh(a) + Nx.power(b, 2)
  transform(res, &IO.inspect/1)
  res
end

Invoking tanh_power/2 will print:

#Nx.Defn.Expr<
  parameter a
  parameter c
  b = tanh [ a ] ()
  d = power [ c, 2 ] ()
  e = add [ b, d ] ()
 >

Note: the parens would normally contain missing type and shape information determined at the time of invocation. There’s also a macro, inspect_expr, that implements this transform available within defn.

Transforms prove particularly useful when doing type or shape checks. For example, you can use transforms to assert that input shapes are equal:

defn cross_entropy_loss(y_true, y_pred) do
  transform({Nx.shape(y_true), Nx.shape(y_pred)},
    fn
      {s1, s2} when s1 == s2 -> :ok
      {s1, s2} -> raise ArgumentError, "shapes do not equal"
    end
  )
  Nx.mean(Nx.log(y_true) * y_pred)
end

You can take that one step further and package everything in a macro:

defmacro assert_equal_shapes(expr1, expr2) do
  quote do
    Nx.Defn.Kernel.transform(
      {Nx.shape(unquote(expr1)), Nx.shape(unquote(expr2))},
      &assert_equal_shapes_impl/1
    )
  end
end

defp assert_equal_shapes_impl(s1, s2) when s1 == s2, do: :ok
defp assert_equal_shapes_impl(s1, s2) do
  raise ArgumentError, "expected shapes to be equal," <>
                       " got #{inspect(s1)} != #{inspect(s2)}"
end

defn cross_entropy_loss(y_true, y_pred) do
  assert_equal_shapes(y_true, y_pred)
  Nx.mean(Nx.log(y_true) * y_pred)
end

And you’ll get some nice shape validation on your functions:

iex> y_true = Nx.tensor([[0, 1], [1, 0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[0.6811, 0.5565], [0.6551, 0.4551], [0.5422, 0.2648]], type: {:f, 32})
iex> cross_entropy_loss(y_true,y_pred)
** (ArgumentError) expected shapes to be equal, got {2, 2} != {3, 2}

You can even use values returned from transforms. For example, say you need to calculate the shape of a tensor created with Nx.random_uniform based on the shape of an input tensor:

defn dense_layer(input) do
  weight_shape = transform(Nx.shape(input),
    fn {_batch_size, in_units} ->
      {in_units, 32}
    end)
  weight = Nx.random_uniform(weight_shape, type: Nx.type(input))
  Nx.dot(input, weight)
end

When invoked:

iex> t1 = Nx.tensor([[1.0, 2.0, 3.0]], type: {:f, 32})
iex> t2 = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], type: {:f, 32})
iex> dense_layer(t1)
#Nx.Tensor<
  f32[1][32]
  [...]
>
iex> dense_layer(t2)
#Nx.Tensor<
  f32[1][32]
  [...]
>

Notice how the dot product works for both inputs, even though their shapes are different!

As a final example illustrating the power of transforms, we’ll look at grad. grad is actually implemented as a transform:

defmacro grad(var_or_vars, expr) do
  var_or_vars =
    case var_or_vars do
      {:{}, meta, vars} -> {:{}, meta, Enum.map(vars, &grad_var!/1)}
      {left, right} -> {grad_var!(left), grad_var!(right)}
      var -> grad_var!(var)
    end

  quote do
    Nx.Defn.Kernel.transform(
      {unquote(var_or_vars), unquote(expr)},
      &Nx.Defn.Grad.transform/1
    )
  end
end

Essentially, the grad transform takes an expression and some variables to differentiate. It then transforms the underlying expression with respect to those variables using Nx.Defn.Grad.transform/1. The gradient transform traverses the expression and recursively transforms expressions using some defined gradient transformation rules. The implementation is surprisingly simple, but the result is amazingly powerful.

I hope these few relatively simple examples illustrate the power of transform/2. If you have any issues with my explanation or find any problems with the code, feel free to let me know. Happy coding!