Nx Tip of the Week #7 – Using Nx.Defn.jit

There are actually 2 ways in Nx to accelerate your numerical definitions: invoking calls to defn with a @defn_compiler attribute set, or calling Nx.Defn.jit/3. Let’s take a look at these 2 methods in practice:

defmodule JIT do
  import Nx.Defn

  @default_defn_compiler EXLA

  defn softmax(x) do
    max_val = Nx.reduce_max(x)
    Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))

IO.inspect Nx.Defn.jit(&JIT.softmax/1, [Nx.random_uniform({5})], compiler: EXLA)
IO.inspect JIT.softmax(Nx.random_uniform({5}))

Calling both Nx.Defn.jit/3 and invoking the softmax/1 numerical definition will result in your program being compiled and run with EXLA, so why would you use one over the other?

While it’s generally up to you, I’ve found using Nx.Defn.jit/3 is very useful when integrating numerical definitions with regular Elixir code. As an example, Axon consists of a “low-level” functional API made of entirely numerical definitions. These low-level functions are then glued together by high-level APIs that are more user friendly. Axon.Training makes use of Nx.Defn.jit/3 to glue things like optimizer initialization and model initialization together and to “dynamically” define training step functions from loss functions and models.

Here’s an example snippet that generates a step function from an objective function and an update function:

step_fn = fn model_state, input, target ->
      {params, update_state} = model_state

      {batch_loss, gradients} =
        Nx.Defn.Kernel.value_and_grad(params, &objective_fn.(&1, input, target))

      {updates, new_update_state} = update_fn.(gradients, update_state, params)
      {{Axon.Updates.apply_updates(params, updates), new_update_state}, batch_loss}

You can then use Nx.Defn.jit/3 to apply step_fn:

{model_state, batch_loss} = Nx.Defn.jit(step_fn, [model_state, inp, tar], jit_opts)

Using Nx.Defn.jit/3 can be very useful when writing higher-level APIs that serve as “glue” or integration of base functionality.

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s