# Nx Tip of the Week #1 – Using transforms

Note: This is an idea I had after learning a significant amount from the abseil C++ Tips of the Week during my work on EXLA. I’ll keep writing them as long as there’s interest. If there’s anything in particular you’d like to read about, feel free to let me know!

`Nx` is an exciting new project with the hopes of making mathematical computing practical within the Elixir ecosystem. Fundamentally, `Nx` is a tensor manipulation or array programming library similar to NumPy, TensorFlow, or PyTorch. `Nx` introduces a new type of function definition, `defn`, that is a subset of the Elixir programming language tailored specifically to numerical computations. When numerical definitions are invoked, they’re transformed into expressions (internally `Nx.Defn.Expr`) which represent the AST or computation graph of the numerical definition. These expressions are manipulated by compilers (like `EXLA`) to produce executables that run natively on accelerators.

The subset of supported Elixir code within `defn` can, at times, feel restrictive; however, there are a number of ways to overcome these perceived limitations. One such example is `transform/2`. From the docs:

```transform(arg, fun)

Defines a transform that executes the given `fun` with `arg`
when building `defn` expressions.
```

You can invoke `transform/2` from within `defn` to call out to any Elixir function. As an example, you can use `transform/2` to inspect the underlying expression in a numerical definition:

```defn tanh_power(a, b) do
res = Nx.tanh(a) + Nx.power(b, 2)
transform(res, &IO.inspect/1)
res
end
```

Invoking `tanh_power/2` will print:

```#Nx.Defn.Expr<
parameter a
parameter c
b = tanh [ a ] ()
d = power [ c, 2 ] ()
e = add [ b, d ] ()
>
```

Note: the parens would normally contain missing type and shape information determined at the time of invocation. There’s also a macro, `inspect_expr`, that implements this transform available within `defn`.

Transforms prove particularly useful when doing type or shape checks. For example, you can use transforms to assert that input shapes are equal:

```defn cross_entropy_loss(y_true, y_pred) do
transform({Nx.shape(y_true), Nx.shape(y_pred)},
fn
{s1, s2} when s1 == s2 -> :ok
{s1, s2} -> raise ArgumentError, "shapes do not equal"
end
)
Nx.mean(Nx.log(y_true) * y_pred)
end
```

You can take that one step further and package everything in a macro:

```defmacro assert_equal_shapes(expr1, expr2) do
quote do
Nx.Defn.Kernel.transform(
{Nx.shape(unquote(expr1)), Nx.shape(unquote(expr2))},
&assert_equal_shapes_impl/1
)
end
end

defp assert_equal_shapes_impl(s1, s2) when s1 == s2, do: :ok
defp assert_equal_shapes_impl(s1, s2) do
raise ArgumentError, "expected shapes to be equal," <>
" got #{inspect(s1)} != #{inspect(s2)}"
end

defn cross_entropy_loss(y_true, y_pred) do
assert_equal_shapes(y_true, y_pred)
Nx.mean(Nx.log(y_true) * y_pred)
end
```

And you’ll get some nice shape validation on your functions:

```iex> y_true = Nx.tensor([[0, 1], [1, 0]], type: {:f, 32})
iex> y_pred = Nx.tensor([[0.6811, 0.5565], [0.6551, 0.4551], [0.5422, 0.2648]], type: {:f, 32})
iex> cross_entropy_loss(y_true,y_pred)
** (ArgumentError) expected shapes to be equal, got {2, 2} != {3, 2}
```

You can even use values returned from transforms. For example, say you need to calculate the shape of a tensor created with `Nx.random_uniform` based on the shape of an input tensor:

```defn dense_layer(input) do
weight_shape = transform(Nx.shape(input),
fn {_batch_size, in_units} ->
{in_units, 32}
end)
weight = Nx.random_uniform(weight_shape, type: Nx.type(input))
Nx.dot(input, weight)
end
```

When invoked:

```iex> t1 = Nx.tensor([[1.0, 2.0, 3.0]], type: {:f, 32})
iex> t2 = Nx.tensor([[1.0, 2.0, 3.0, 4.0]], type: {:f, 32})
iex> dense_layer(t1)
#Nx.Tensor<
f32
[...]
>
iex> dense_layer(t2)
#Nx.Tensor<
f32
[...]
>
```

Notice how the dot product works for both inputs, even though their shapes are different!

As a final example illustrating the power of transforms, we’ll look at `grad`. `grad` is actually implemented as a transform:

```defmacro grad(var_or_vars, expr) do
var_or_vars =
case var_or_vars do
{:{}, meta, vars} -> {:{}, meta, Enum.map(vars, &grad_var!/1)}
end

quote do
Nx.Defn.Kernel.transform(
{unquote(var_or_vars), unquote(expr)},
Essentially, the `grad` transform takes an expression and some variables to differentiate. It then transforms the underlying expression with respect to those variables using `Nx.Defn.Grad.transform/1`. The gradient transform traverses the expression and recursively transforms expressions using some defined gradient transformation rules. The implementation is surprisingly simple, but the result is amazingly powerful.
I hope these few relatively simple examples illustrate the power of `transform/2`. If you have any issues with my explanation or find any problems with the code, feel free to let me know. Happy coding!