Nx Tip of the Week #7 - Using Nx.Defn.jit

Published on

There are actually 2 ways in Nx to accelerate your numerical definitions: invoking calls to defn with a @defn_compiler attribute set, or calling Nx.Defn.jit/3. Let’s take a look at these 2 methods in practice:

defmodule JIT do
  import Nx.Defn

  @default_defn_compiler EXLA

  defn softmax(x) do
    max_val = Nx.reduce_max(x)
    Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
  end
end

IO.inspect Nx.Defn.jit(&JIT.softmax/1, [Nx.random_uniform({5})], compiler: EXLA)
IO.inspect JIT.softmax(Nx.random_uniform({5}))

Calling both Nx.Defn.jit/3 and invoking the softmax/1 numerical definition will result in your program being compiled and run with EXLA, so why would you use one over the other?

While it’s generally up to you, I’ve found using Nx.Defn.jit/3 is very useful when integrating numerical definitions with regular Elixir code. As an example, Axon consists of a “low-level” functional API made of entirely numerical definitions. These low-level functions are then glued together by high-level APIs that are more user friendly. Axon.Training makes use of Nx.Defn.jit/3 to glue things like optimizer initialization and model initialization together and to “dynamically” define training step functions from loss functions and models.

Here’s an example snippet that generates a step function from an objective function and an update function:

step_fn = fn model_state, input, target ->
      {params, update_state} = model_state

      {batch_loss, gradients} =
        Nx.Defn.Kernel.value_and_grad(params, &objective_fn.(&1, input, target))

      {updates, new_update_state} = update_fn.(gradients, update_state, params)
      {{Axon.Updates.apply_updates(params, updates), new_update_state}, batch_loss}
end

You can then use Nx.Defn.jit/3 to apply step_fn:

{model_state, batch_loss} = Nx.Defn.jit(step_fn, [model_state, inp, tar], jit_opts)

Using Nx.Defn.jit/3 can be very useful when writing higher-level APIs that serve as “glue” or integration of base functionality.