Nx Tip of the Week #6 - Compiler or Backend?

Published on

I’ve recently seen some confusion with respect to compilers and backends. This post is intended to clear up some of that confusion. TLDR: If performance matters, benchmark and decide. If you need flexibility or want to prototype quickly and not sacrifice speed, backends are a good choice. If you need AOT compilation or your programs are very computationally intensive, compilers work better. Library writers should always write library functions in defn to leave this choice to the user.

The Nx library is a standalone Elixir project that defines the main API for working with tensors; however, Nx is very much a “batteries-excluded” library. The power of Nx is in it’s flexibility. While the library contains pure Elixir implementations of every function in the main API, Nx is designed to integrate with compilers and backends with highly-optimized, native tensor manipulation routines. It’s important to understand the distinction between an Nx compiler and an Nx backend. At the time of this writing, EXLA is the only available compiler, and Torchx is the only available backend. We will frame the difference with respect to these 2 implementations; however, these differences will likely hold true for other backends and compilers as well.

For those familiar with paradigms in other frameworks such as TensorFlow and PyTorch, the distinction between compilers and backends is similar to the distinction between graph mode and eager execution. Let’s try to understand this distinction with a simple example - softmax:

@defn_compiler EXLA
defn softmax_compiler(x) do
  max_val = Nx.reduce_max(x)
  Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
end

def softmax_backend(x) do
  Nx.default_backend(Torchx)
  max_val = Nx.reduce_max(x)
  Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
end

Softmax is a function that accepts a list of numbers, and returns an equally sized list of probabilities - scaled according to their magnitude relative to other numbers in the list. You can see the implementation for softmax is quite simple. One thing to note in the implementation above is the presence of max_val - which acts as guard against the common numerical stability issues of underflow and overflow. max_val has no impact on the result of the softmax function, but instead ensures intermediate values don’t grow too large or too small to fit in a finite-sized floating point or integer representation of the value.

You’ll notice the implementations above are subtly different in two ways:

  1. The first function uses defn - designating this method is a numerical definition and should be treated as such. The second function is a regular Elixir function.
  2. The first function adds the attribute @defn_compiler EXLA - telling Elixir to compile and run this definition with the EXLA compiler. The second function uses a call to Nx.default_backend(Torchx) - telling Elixir to dispatch Nx calls to their Torchx implementation.

If you compile and run this example, you should notice that, aside from minor differences in precision, each function arrives at approximately the same result. So how exactly are the subtle differences in implementation actually manifesting themselves at runtime? And why would you want to prefer a compiler over a backend or vice-versa? To understand, we’ll need to discuss what’s happening under the hood.

Implementations of the Nx API in the Nx module are actually meta-implementations that dispatch to third-party implementations of the same function at runtime. These meta-implementations act like contracts for functions in the Nx API - normalizing arguments and types, checking shape compatibility, and calculating output shapes before dispatching to an actual implementation. So, you can arbitrarily switch backends and somewhat guarantee your Nx implementations will remain unchanged (in reality there are some cases where implementation specific details may affect you, but they are documented).

In the Torchx softmax implementation above, Nx will dispatch to the Torchx NIF implementations of max, exp, sum, subtract, and divide - returning control to the VM between each call and realizing the result of each intermediate computation before completing the next. Because it’s a regular Elixir function, you can arbitrarily mix Nx code with Elixir code:

def softmax_backend(x) do
  Nx.default_backend(Torchx)
  max_val = Nx.max(x) |> IO.inspect

  {:ok, exp_x} =
    case :ok do
      :ok -> {:ok, Nx.exp(x - max_val)}
      _ -> {:error, "BAD"}
    end

  if true do
    exp_x / Nx.sum(exp_x)
  else
    {:error, "BAD"}
  end
end

This has the obvious benefit of flexibility - defn is a much more restrictive subset of the language (although it can be trivially extended). However, this flexibility comes at a performance cost. As I mentioned before, backends return control to the Elixir program between subsequent calls to Nx functions. With numerically intensive programs that rely on a lot of calls to Nx functions - this dispatching adds up and can often manifest itself in significant performance and memory bottlenecks. So how does this differ from the compiler implementation?

defn adds another layer of indirection to the program execution. While calls to Nx outside of defn dispatch to backends which evaluate the result of the computation, calls to Nx within defn dispatch to an Expr backend - which builds a call graph or evaluation trace of the definition. Internally, the expression looks something like:

#Nx.Tensor<
  f32
  
  Nx.Defn.Expr
  parameter a                                        s64
  b = reduce_max [ a, axes: nil, keep_axes: false ]  s64
  c = subtract [ a, b ]                              s64
  d = exp [ c ]                                      f32
  e = subtract [ a, b ]                              s64
  f = exp [ e ]                                      f32
  g = sum [ f, axes: nil, keep_axes: false ]         f32
  h = divide [ d, g ]                                f32
>

You can think of this expression as a kind of numerical assembly code which defines the low-level instructions for your program. When you invoke softmax_compiler Elixir sends this expression to a compiler, in this example the EXLA compiler, which just-in-time compiles a specialized program based on your input shape, type, and operations. This program is then cached, so subsequent calls to the same function with the same input shapes and types don’t need to recompile the original program. Notably, rather than dispatching and returning between calls to max, sum, exp, divide and subtract, JIT compiled programs are treated as single units of execution. So calls to softmax_compiler invoke a single specialized executable for your numerical definition. In other words, calls to max, sum, exp, divide, and subtract are fused into a single call.

Staging the computation in this way opens up the door to (potentially significant) performance and memory optimizations - you can learn more about these optimizations by researching XLA or other tensor compilers. However, these enhancements are not free. One thing you’ll notice when working with numerical definitions is that the first invocation of a defn is much slower than subsequent calls. That’s because the first invocation needs to do the work of compiling the program - depending on the size of the computation, this may require lots of time relative to program execution. Subsequent calls are cached; however, if you change the shape or type of your inputs, the numerical definition needs to compile a new version of your program specific to the input shape and type. If your input shapes or types are constantly changing, your bottleneck will almost certainly be compilation time.

An additional pitfall of the compiled approach is the strictness of the syntax within defn. All inputs must be tensors (unless using keywords), you must match on tuple inputs, and you can use a limited subset of Elixir. Although you can overcome most limitations with transforms, it requires some overhead, and may be confusing.

As opposed to the pitfalls and syntactic limitations of defn, you can only use grad from within defn. This is because Nx differentiation directly manipulates defn expressions. Does this mean you cannot use grad at all with backends like Torchx? Fortunately, no. The default “compiler” is essentially an expression interpreter - Nx.Defn.Evaluator. You can use grad and other future transforms and still use a backend like Torchx; however, the additional overhead will certainly lead to bottlenecks.

So when should you use a compiler like EXLA or a backend like Torchx? Like most things, it’s situation dependent. Backends like Torchx offer flexibility, and facilitate rapid prototyping, with pleasing performance gains over pure Elixir implementations. Compilers like EXLA unlock even more potential performance and memory optimizations, and open the door to things like ahead-of-time compilation. Generally, if you need the flexibility of intermixing Nx code with your regular Elixir code, using a backend is probably the more convenient option. Additionally, some Nx programs are so small or are such that they won’t benefit from performance gains from compilers.

When dealing with compute intensive, purely numerical programs, compilers are usually the better option. For example, you’ll likely realize better performance when training deep neural networks using a compiler over a backend. If performance is your priority, benchmarking will almost certainly lead you to the right decision. There are certainly instances where a compiler will actually hurt your performance, especially if your programs require excessive recompilations or are not well-suited for tensor-based implementations.

All of this is not to say compilers and backends are mutually exclusive. You may decide it’s easy to rapidly experiment with a backend like Torchx, and then slowly migrate to a compiler solution for the performance benefits.

On a final note, I will add a caveat for library developers who want to build packages on top of Nx. If you are a library developer building packages with Nx, you should almost exclusively work inside numerical definitions. This is to ensure you leave the choice of compiler/backend to the user. As a library developer, you can create packages with numerical definitions that are unconcerned with the backend or compiler implementation details. This allows users to experiment with what works best for their use-case. If you write your library largely with regular Elixir functions, users won’t be able to use your methods inside numerical definitions, and you take the compilation option completely out of their hands.

I hope this clears up some confusion about the difference between an Nx backend and compiler, and helps you make the right decision.