Nx Tip of the Week #8 - Using Nx.Defn.aot/3
Published on
Last week, we discussed the usage of Nx.Defn.jit/3
to JIT compile and run numerical definitions. Nx
also supports ahead-of-time compilation using Nx.Defn.aot/3
. In this post, we’ll briefly look at how to use ahead-of-time compilation, and why you’d want to do it in the first place.
Ahead-of-time compilation allows you to compile your numerical definitions into compact NIFs and execute them without needing the entire EXLA compiler and runtime. Nx.Defn.aot/3
defines a module to interact with the NIF for you, so most of the work is out of your hands. Consider the following Softmax example:
defmodule MyDefn do
import Nx.Defn
defn softmax(x) do
max_val = Nx.reduce_max(x)
Nx.exp(x - max_val) / Nx.sum(Nx.exp(x - max_val))
end
end
You can ahead-of-time compile and export this function into a separate module:
Nx.Defn.aot(MyModule, [{:softmax, &MyDefn.softmax/1, [Nx.template({100}, {:f, 32})]}], compiler: EXLA)
Here you need to specify the name of the export module, the desired name of your function in an export module, the actual function to AOT compile, and the expected input types and shapes. You can AOT compile as many functions as you’d like into a single module. This will generate both a NIF and a module MyModule
to interact with the NIF, and then load the module into the runtime. You can then invoke your ahead-of-time compiled function:
MyModule.softmax(Nx.random_uniform({100}))
#Nx.Tensor<
f32[100]
[...]
>
In practice, you’d likely want to export the module and NIF for use in a separate project. Fortunately, you can do that with Nx.Defn.export_aot/4
:
Nx.Defn.export_aot("priv", MyModule, [{:softmax, &MyDefn.softmax/1, [Nx.template({100}, {:f, 32})]}], compiler: EXLA)
If you compile and run this, you’ll notice a shared object and a compiled module in the priv
directory of your project. You can then import and use these in a separate project:
if File.exists?("priv/MyModule.nx.aot") do
defmodule MyModule do
Nx.Defn.import_aot("priv", __MODULE__)
end
else
IO.warn "Skipping MyModule because aot definition was not found"
end
It’s even possible to AOT compile entire neural networks with trained parameters:
IO.puts("AOT-compiling a trained neural network that predicts a batch")
Nx.Defn.aot(
MNIST.Trained,
[{:predict, &MNIST.predict(final_params, &1), [Nx.template({30, 784}, {:f, 32})]}],
compiler: EXLA
)
IO.puts("The result of the first batch against the AOT-compiled one")
IO.inspect MNIST.Trained.predict(hd(train_images))
It’s also possible to cross-compile AOT modules; however, most of the cross-compilation needs to be done manually.
But how does AOT compilation compare to JIT compilation? From a performance perspective, AOT compilation is competitive:
Name ips average deviation median 99th %
xla jit-cpu f32 265.86 3.76 ms ±10.38% 3.71 ms 4.89 ms
xla aot-cpu f32 158.22 6.32 ms ±25.74% 6.27 ms 9.35 ms
elixir f32 2.95 338.55 ms ±2.89% 336.93 ms 362.43 ms
Comparison:
xla jit-cpu f32 265.86
xla aot-cpu f32 158.22 - 1.68x slower +2.56 ms
elixir f32 2.95 - 90.01x slower +334.79 ms
You’ll notice AOT compilation is only slightly slower than JIT compilation, and still much faster than pure Elixir.
The footprint of an AOT compiled NIF is incredibly small compared to the footprint of the EXLA NIF. On my machine, the EXLA shared object compiled with CUDA is 612MB. The AOT compiled NIF in the example above is 392KB - that’s an incredible reduction. Most of the AOT compiled NIF size can be attributed to statically linking a compiled function runtime, so additional AOT compiled functions should have a much smaller footprint.
There are, of course, tradeoffs to using AOT compiled functions. First, you need to know the type and shape of your inputs ahead-of-time. Depending on your needs, this can be a pretty serious limitation. Second, AOT compilation is only supported for CPUs, so you can’t take advantage of GPU or TPU acceleration.
Overall, AOT compilation can be an excellent choice for deploying a model - especially at the edge. You can experiment on a more powerful machine before exporting the smaller compiled module to an edge device. I hope this gives you the tools you need to experiment with AOT compilation.