# Nx Tip of the Week #1 - Using transforms

Published on

*Note: This is an idea I had after learning a significant amount from the abseil C++ Tips of the Week* *during my work on EXLA*. *I’ll keep writing them as long as there’s interest*. *If there’s anything in particular you’d like to read about, feel free to let me know!*

`Nx`

is an exciting new project with the hopes of making mathematical computing practical within the Elixir ecosystem. Fundamentally, `Nx`

is a *tensor manipulation* or *array programming* library similar to NumPy, TensorFlow, or PyTorch. `Nx`

introduces a new type of function definition, `defn`

, that is a subset of the Elixir programming language tailored specifically to numerical computations. When numerical definitions are invoked, they’re transformed into *expressions* (internally `Nx.Defn.Expr`

) which represent the AST or computation graph of the numerical definition. These expressions are manipulated by compilers (like `EXLA`

) to produce executables that run natively on accelerators.

The subset of supported Elixir code within `defn`

can, at times, feel restrictive; however, there are a number of ways to overcome these perceived limitations. One such example is `transform/2`

. From the docs:

```
transform(arg, fun)
Defines a transform that executes the given `fun` with `arg`
when building `defn` expressions.
```

You can invoke `transform/2`

from within `defn`

to call out to any Elixir function. As an example, you can use `transform/2`

to inspect the underlying expression in a numerical definition:

```
defn tanh_power(a, b) do
res = Nx.tanh(a) + Nx.power(b, 2)
transform(res, &IO.inspect/1)
res
end
```

Invoking `tanh_power/2`

will print:

```
#Nx.Defn.Expr<
parameter a
parameter c
b = tanh [ a ] ()
d = power [ c, 2 ] ()
e = add [ b, d ] ()
>
```

*Note: the parens would normally contain missing type and shape information determined at the time of invocation.* *There’s also a macro, inspect_expr, that implements this transform available within defn.*

Transforms prove particularly useful when doing type or shape checks. For example, you can use transforms to assert that input shapes are equal:

```
defn cross_entropy_loss(y_true, y_pred) do
transform({Nx.shape(y_true), Nx.shape(y_pred)},
fn
{s1, s2} when s1 == s2 -> :ok
{s1, s2} -> raise ArgumentError, "shapes do not equal"
end
)
Nx.mean(Nx.log(y_true) * y_pred)
end
```

You can take that one step further and package everything in a macro:

```
defmacro assert_equal_shapes(expr1, expr2) do
quote do
Nx.Defn.Kernel.transform(
{Nx.shape(unquote(expr1)), Nx.shape(unquote(expr2))},
&assert_equal_shapes_impl/1
)
end
end
defp assert_equal_shapes_impl(s1, s2) when s1 == s2, do: :ok
defp assert_equal_shapes_impl(s1, s2) do
raise ArgumentError, "expected shapes to be equal," <>
" got #{inspect(s1)} != #{inspect(s2)}"
end
defn cross_entropy_loss(y_true, y_pred) do
assert_equal_shapes(y_true, y_pred)
Nx.mean(Nx.log(y_true) * y_pred)
end
```

And you’ll get some nice shape validation on your functions:

```
iex> y_true = Nx.tensor([[0, 1], [1, 0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[0.6811, 0.5565], [0.6551, 0.4551], [0.5422, 0.2648]], type: {:f, 32})
iex> cross_entropy_loss(y_true,y_pred)
** (ArgumentError) expected shapes to be equal, got {2, 2} != {3, 2}
```

You can even use values returned from transforms. For example, say you need to calculate the shape of a tensor created with `Nx.random_uniform`

based on the shape of an input tensor:

```
defn dense_layer(input) do
weight_shape = transform(Nx.shape(input),
fn {_batch_size, in_units} ->
{in_units, 32}
end)
weight = Nx.random_uniform(weight_shape, type: Nx.type(input))
Nx.dot(input, weight)
end
```

When invoked:

```
iex> t1 = Nx.tensor([[1.0, 2.0, 3.0]], type: {:f, 32})
iex> t2 = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], type: {:f, 32})
iex> dense_layer(t1)
#Nx.Tensor<
f32[1][32]
[...]
>
iex> dense_layer(t2)
#Nx.Tensor<
f32[1][32]
[...]
>
```

Notice how the dot product works for both inputs, even though their shapes are different!

As a final example illustrating the power of transforms, we’ll look at `grad`

. `grad`

is actually implemented as a transform:

```
defmacro grad(var_or_vars, expr) do
var_or_vars =
case var_or_vars do
{:{}, meta, vars} -> {:{}, meta, Enum.map(vars, &grad_var!/1)}
{left, right} -> {grad_var!(left), grad_var!(right)}
var -> grad_var!(var)
end
quote do
Nx.Defn.Kernel.transform(
{unquote(var_or_vars), unquote(expr)},
&Nx.Defn.Grad.transform/1
)
end
end
```

Essentially, the `grad`

transform takes an expression and some variables to differentiate. It then transforms the underlying expression with respect to those variables using `Nx.Defn.Grad.transform/1`

. The gradient transform traverses the expression and recursively transforms expressions using some defined gradient transformation rules. The implementation is surprisingly simple, but the result is amazingly powerful.

I hope these few relatively simple examples illustrate the power of `transform/2`

. If you have any issues with my explanation or find any problems with the code, feel free to let me know. Happy coding!