Nx Tip of the Week #7 - Using Nx.Defn.jit
Published on
There are actually 2 ways in Nx
to accelerate your numerical definitions: invoking calls to defn
with a @defn_compiler
attribute set, or calling Nx.Defn.jit/3
. Let’s take a look at these 2 methods in practice:
defmodule JIT do
import Nx.Defn
@default_defn_compiler EXLA
defn softmax(x) do
max_val = Nx.reduce_max(x)
Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
end
end
IO.inspect Nx.Defn.jit(&JIT.softmax/1, [Nx.random_uniform({5})], compiler: EXLA)
IO.inspect JIT.softmax(Nx.random_uniform({5}))
Calling both Nx.Defn.jit/3
and invoking the softmax/1
numerical definition will result in your program being compiled and run with EXLA
, so why would you use one over the other?
While it’s generally up to you, I’ve found using Nx.Defn.jit/3
is very useful when integrating numerical definitions with regular Elixir code. As an example, Axon
consists of a “low-level” functional API made of entirely numerical definitions. These low-level functions are then glued together by high-level APIs that are more user friendly. Axon.Training
makes use of Nx.Defn.jit/3
to glue things like optimizer initialization and model initialization together and to “dynamically” define training step functions from loss functions and models.
Here’s an example snippet that generates a step function from an objective function and an update function:
step_fn = fn model_state, input, target ->
{params, update_state} = model_state
{batch_loss, gradients} =
Nx.Defn.Kernel.value_and_grad(params, &objective_fn.(&1, input, target))
{updates, new_update_state} = update_fn.(gradients, update_state, params)
{{Axon.Updates.apply_updates(params, updates), new_update_state}, batch_loss}
end
You can then use Nx.Defn.jit/3
to apply step_fn
:
{model_state, batch_loss} = Nx.Defn.jit(step_fn, [model_state, inp, tar], jit_opts)
Using Nx.Defn.jit/3
can be very useful when writing higher-level APIs that serve as “glue” or integration of base functionality.