Nx Tip of the Week #6 - Compiler or Backend?
Published on
I’ve recently seen some confusion with respect to compilers and backends. This post is intended to clear up some of that confusion. TLDR: If performance matters, benchmark and decide. If you need flexibility or want to prototype quickly and not sacrifice speed, backends are a good choice. If you need AOT compilation or your programs are very computationally intensive, compilers work better. Library writers should always write library functions in defn
to leave this choice to the user.
The Nx
library is a standalone Elixir project that defines the main API for working with tensors; however, Nx
is very much a “batteries-excluded” library. The power of Nx
is in it’s flexibility. While the library contains pure Elixir implementations of every function in the main API, Nx
is designed to integrate with compilers and backends with highly-optimized, native tensor manipulation routines. It’s important to understand the distinction between an Nx
compiler and an Nx
backend. At the time of this writing, EXLA is the only available compiler, and Torchx is the only available backend. We will frame the difference with respect to these 2 implementations; however, these differences will likely hold true for other backends and compilers as well.
For those familiar with paradigms in other frameworks such as TensorFlow and PyTorch, the distinction between compilers and backends is similar to the distinction between graph mode and eager execution. Let’s try to understand this distinction with a simple example - softmax:
@defn_compiler EXLA
defn softmax_compiler(x) do
max_val = Nx.reduce_max(x)
Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
end
def softmax_backend(x) do
Nx.default_backend(Torchx)
max_val = Nx.reduce_max(x)
Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
end
Softmax is a function that accepts a list of numbers, and returns an equally sized list of probabilities - scaled according to their magnitude relative to other numbers in the list. You can see the implementation for softmax is quite simple. One thing to note in the implementation above is the presence of max_val
- which acts as guard against the common numerical stability issues of underflow and overflow. max_val
has no impact on the result of the softmax function, but instead ensures intermediate values don’t grow too large or too small to fit in a finite-sized floating point or integer representation of the value.
You’ll notice the implementations above are subtly different in two ways:
-
The first function uses
defn
- designating this method is a numerical definition and should be treated as such. The second function is a regular Elixir function. -
The first function adds the attribute
@defn_compiler EXLA
- telling Elixir to compile and run this definition with the EXLA compiler. The second function uses a call toNx.default_backend(Torchx)
- telling Elixir to dispatchNx
calls to theirTorchx
implementation.
If you compile and run this example, you should notice that, aside from minor differences in precision, each function arrives at approximately the same result. So how exactly are the subtle differences in implementation actually manifesting themselves at runtime? And why would you want to prefer a compiler over a backend or vice-versa? To understand, we’ll need to discuss what’s happening under the hood.
Implementations of the Nx
API in the Nx
module are actually meta-implementations that dispatch to third-party implementations of the same function at runtime. These meta-implementations act like contracts for functions in the Nx
API - normalizing arguments and types, checking shape compatibility, and calculating output shapes before dispatching to an actual implementation. So, you can arbitrarily switch backends and somewhat guarantee your Nx
implementations will remain unchanged (in reality there are some cases where implementation specific details may affect you, but they are documented).
In the Torchx
softmax implementation above, Nx
will dispatch to the Torchx
NIF implementations of max
, exp
, sum
, subtract
, and divide
- returning control to the VM between each call and realizing the result of each intermediate computation before completing the next. Because it’s a regular Elixir function, you can arbitrarily mix Nx
code with Elixir code:
def softmax_backend(x) do
Nx.default_backend(Torchx)
max_val = Nx.max(x) |> IO.inspect
{:ok, exp_x} =
case :ok do
:ok -> {:ok, Nx.exp(x - max_val)}
_ -> {:error, "BAD"}
end
if true do
exp_x / Nx.sum(exp_x)
else
{:error, "BAD"}
end
end
This has the obvious benefit of flexibility - defn
is a much more restrictive subset of the language (although it can be trivially extended). However, this flexibility comes at a performance cost. As I mentioned before, backends return control to the Elixir program between subsequent calls to Nx
functions. With numerically intensive programs that rely on a lot of calls to Nx
functions - this dispatching adds up and can often manifest itself in significant performance and memory bottlenecks. So how does this differ from the compiler implementation?
defn
adds another layer of indirection to the program execution. While calls to Nx
outside of defn
dispatch to backends which evaluate the result of the computation, calls to Nx
within defn
dispatch to an Expr
backend - which builds a call graph or evaluation trace of the definition. Internally, the expression looks something like:
#Nx.Tensor<
f32
Nx.Defn.Expr
parameter a s64
b = reduce_max [ a, axes: nil, keep_axes: false ] s64
c = subtract [ a, b ] s64
d = exp [ c ] f32
e = subtract [ a, b ] s64
f = exp [ e ] f32
g = sum [ f, axes: nil, keep_axes: false ] f32
h = divide [ d, g ] f32
>
You can think of this expression as a kind of numerical assembly code which defines the low-level instructions for your program. When you invoke softmax_compiler
Elixir sends this expression to a compiler, in this example the EXLA compiler, which just-in-time compiles a specialized program based on your input shape, type, and operations. This program is then cached, so subsequent calls to the same function with the same input shapes and types don’t need to recompile the original program. Notably, rather than dispatching and returning between calls to max
, sum
, exp
, divide
and subtract
, JIT compiled programs are treated as single units of execution. So calls to softmax_compiler
invoke a single specialized executable for your numerical definition. In other words, calls to max
, sum
, exp
, divide
, and subtract
are fused into a single call.
Staging the computation in this way opens up the door to (potentially significant) performance and memory optimizations - you can learn more about these optimizations by researching XLA or other tensor compilers. However, these enhancements are not free. One thing you’ll notice when working with numerical definitions is that the first invocation of a defn
is much slower than subsequent calls. That’s because the first invocation needs to do the work of compiling the program - depending on the size of the computation, this may require lots of time relative to program execution. Subsequent calls are cached; however, if you change the shape or type of your inputs, the numerical definition needs to compile a new version of your program specific to the input shape and type. If your input shapes or types are constantly changing, your bottleneck will almost certainly be compilation time.
An additional pitfall of the compiled approach is the strictness of the syntax within defn
. All inputs must be tensors (unless using keywords), you must match on tuple inputs, and you can use a limited subset of Elixir. Although you can overcome most limitations with transforms, it requires some overhead, and may be confusing.
As opposed to the pitfalls and syntactic limitations of defn
, you can only use grad
from within defn
. This is because Nx
differentiation directly manipulates defn
expressions. Does this mean you cannot use grad
at all with backends like Torchx
? Fortunately, no. The default “compiler” is essentially an expression interpreter - Nx.Defn.Evaluator
. You can use grad
and other future transforms and still use a backend like Torchx
; however, the additional overhead will certainly lead to bottlenecks.
So when should you use a compiler like EXLA
or a backend like Torchx
? Like most things, it’s situation dependent. Backends like Torchx
offer flexibility, and facilitate rapid prototyping, with pleasing performance gains over pure Elixir implementations. Compilers like EXLA
unlock even more potential performance and memory optimizations, and open the door to things like ahead-of-time compilation. Generally, if you need the flexibility of intermixing Nx
code with your regular Elixir code, using a backend is probably the more convenient option. Additionally, some Nx
programs are so small or are such that they won’t benefit from performance gains from compilers.
When dealing with compute intensive, purely numerical programs, compilers are usually the better option. For example, you’ll likely realize better performance when training deep neural networks using a compiler over a backend. If performance is your priority, benchmarking will almost certainly lead you to the right decision. There are certainly instances where a compiler will actually hurt your performance, especially if your programs require excessive recompilations or are not well-suited for tensor-based implementations.
All of this is not to say compilers and backends are mutually exclusive. You may decide it’s easy to rapidly experiment with a backend like Torchx
, and then slowly migrate to a compiler solution for the performance benefits.
On a final note, I will add a caveat for library developers who want to build packages on top of Nx
. If you are a library developer building packages with Nx
, you should almost exclusively work inside numerical definitions. This is to ensure you leave the choice of compiler/backend to the user. As a library developer, you can create packages with numerical definitions that are unconcerned with the backend or compiler implementation details. This allows users to experiment with what works best for their use-case. If you write your library largely with regular Elixir functions, users won’t be able to use your methods inside numerical definitions, and you take the compilation option completely out of their hands.
I hope this clears up some confusion about the difference between an Nx
backend and compiler, and helps you make the right decision.