JIT/GPU accelerated deep learning for Elixir with Axon v0.1

Published on

I am excited to announce the official v0.1.0 release of Axon and AxonOnnx. A lot has changed (and improved) since the initial public announcement of Axon. In this post I will explore Axon and its internals, and give reasoning for some of the design decisions made along the way.

You can view the official documentation here:

What is Axon?

At a high-level, Axon is a library for creating and training neural networks. Axon is implemented in pure Elixir and relies on Nx to compile Neural Networks to the CPU/GPU just-in-time. It consists of a few components which are loosely tied together:

Functional API

The functional API are “low-level” implementations of common neural network operations. It’s similar to torch.functional or tf.nn in the Python ecosystem. The functional API offers no conveniences—just implementations. These implementations are all written in defn, so they can be JIT compiled, or composed with Nx transformations like grad in other defn functions.

The Functional API consists of:

Model Creation API

The model creation API is a high-level API for creating and executing neural networks. The API will be covered in-depth in this post, so I’ll omit the details here.

Optimization API

The optimization API is built to mirror the beautiful Optax library. Optax is literally my favorite library that’s not written in Elixir. The idea is to implement optimizers using composable higher-order functions. I highly recommend checking out the Axon.Updates documentation as well as the Optax library.

Loop API

The Loop API is an API for writing loops (like training and evaluation loops) in a functional style. Elixir is a functional language, which means we cannot take advantage of mutable state in the same way you would be able to in Python frameworks. To get around this, Axon.Loop constructs loops as reducers over data with some state. The API itself is inspired by the PyTorch Ignite library.

The Loop API is still a work-in-progress, so you should expect significant improvements in subsequent Axon releases.

What really is a Neural Network?

There are really two interpretations of this question I’d like to explore:

  1. What is a neural network mathematically?
  2. What is a neural network in the eyes of Axon?

Mathematically a neural network is just a composition of linear and non-linear transformations with some learnable parameters. In Nx, you can implement a neural network relatively easily with defn:

defn feed_forward_network(x, w1, b1, w2, b2, w3, b3) do
  x
  |> Nx.dot(w1)
  |> Nx.add(b1)
  |> Nx.sigmoid()
  |> Nx.dot(w2)
  |> Nx.add(b2)
  |> Nx.sigmoid()
  |> Nx.dot(w3)
  |> Nx.add(b3)
  |> Nx.sigmoid()
end

There’s really nothing more to it! Of course, implementing neural networks in Nx now introduces a lot of painful boilerplate. The goal of Axon is to abstract away the boilerplate, and make creating and training neural networks a breeze in Elixir.

So what is a neural network in the eyes of Axon? Axon sees a neural network as an Elixir struct:

  defstruct [
    :id,
    :name,
    :output_shape,
    :parent,
    :parameters,
    :args,
    :op,
    :policy,
    :hooks,
    :opts,
    :op_name
  ]

Of particular importance in this struct are: parent, parameters, and op. parent is a list of parent networks which are also Axon structs. It’s a recursive data structure which represents a computation graph with some additional metadata relevant to specific neural network tasks. parameters represent a list of trainable parameters attached to this layer. They’re automatically initialized when the model is initialized, and will be part of the training process within Axon’s internal APIs. op is a function that’s applied on parent and parameters. In laymen’s terms, Axon views a neural network as just a function of other “neural networks” (Axon structs) and trainable parameters. In fact, you can “wrap” any function you want into a neural network with Axon.layer:

defn dense(input, weight, bias, _opts \\ []) do
  input
  |> Nx.dot(weight)
  |> Nx.add(bias)
end

input = Axon.input({nil, 32}, "features")
weight = Axon.param({32, 64}, "weight")
bias = Axon.param({64}, "bias")

Axon.layer(&dense/4, [input, weight, bias])

Notice I only had to define an input layer and two trainable parameters using Axon’s built-in function. Using Axon.layer should feel a lot like using Elixir’s apply — you’re just applying a function to some specialized inputs. All but a few of Axon’s built-in layers are implemented in essentially this same manner:

  1. Define an implementation function in Axon.Layers
  2. Wrap the implementation in a layer with a public interface in Axon

It’s just a Graph

The “magic” of Axon is its compiler, which knows how to convert Axon structs into meaningful initialization and prediction functions. Model execution comes in the form of two functions: Axon.init/3 and Axon.predict/4. Axon.init/3 returns a model’s initial parameters:

model = Axon.input({nil, 32}) |> Axon.dense(64)

model_state = Axon.init(model)

For prediction, you need both a model and a compatible model state:

model = Axon.input({nil, 32}) |> Axon.dense(64)

model_state = Axon.init(model)
input = Nx.random_uniform({1, 32})

Axon.predict(model, model_state, input)

Both Axon.init/3 and Axon.predict/4 take additional compilation options; however, it’s recommended you use global configuration rather than compilation options. For example, rather than:

Axon.predict(model, model_state, input, compiler: EXLA)

You should use:

EXLA.set_as_nx_default([:tpu, :cuda, :rocm, :host])

Axon.predict(model, model_state, input)

Axon.init/3 also optionally accepts initial parameters to initialize portions of a model from an initial state (e.g. if trying to fine-tune a model). This is where Axon.namespace/2 comes in handy. You can “tag” a part of a model as belonging to a particular namespace, and initialize without needing to know anything about the underlying architecture:

{bert, bert_params} = get_bert_model()
bert = bert |> Axon.namespace("bert")

model = bert |> Axon.dense(1)

model_state = Axon.init(model, %{"bert" => bert_params})

Axon.namespace/2 is one of the few layers with special meaning in Axon. There’s also Axon.input/2, Axon.constant/3, and Axon.container/3. Input layers are symbolic representations of model inputs. Each input is associated with a unique name used to reference it when passing names to a model. For example, if you have multiple inputs, you can give them semantic meanings:

text_features = Axon.input({nil, 32}, "text_features")
cat_features = Axon.input({nil, 32}, "cat_features")

With named inputs, you don’t have to worry about passing things out of order, since you’ll always reference an input by it’s name:

model = Axon.add(text_features, cat_features)

Axon.predict(model, model_state, %{"text_features" => text_inp, "cat_features" => cat_inp})

Axon.constant/3 allows you to introduce constant-values into the graph. Be warned that introducing large constants will have negative impacts on the performance of the model.

Axon.container/3 can accept any valid Nx container. This is particularly useful for giving semantic meaning to outputs:

model = Axon.container(%{
  last_hidden_state: last_hidden_state,
  pooler_output: pooler_output
})

Every other Axon built-in layer is treated in the same way as custom layers by the compiler. This means that (besides for the few “special layers”) there’s no difference between what you can do with a custom layer and what you can do with a built-in layer. They’re both handled by the same clause in the Axon compiler.

In the Axon interpretation of a neural network, every execution of a graph is seen as a specialized compilation of the graph. In other words, initialization and prediction are just two types of compilation. There’s nothing stopping you from implementing your own specialized compilation of an Axon graph in the same way. For example, an older version of Axon implemented a macro Axon.penalty which compiled a graph into a regularization function. Axon also implements the Inspect protocol—which itself can be seen as a symbolic compilation of the graph.

Maybe you don’t like my API…

The Axon interpretation of a “model” is intentionally as flexible as possible. All you need to do is build a data structure. This means that if you’re not satisfied with Axon’s model creation API, you can create your own! As long as you finish with an Axon struct, your model will work with the rest of Axon’s components. The Axon struct is really the unifying data structure for every component of the Axon library. I would love to see some cool Neural Network DSLs pop-up which build off of the lower-level components Axon defines.

Converting to Other Formats

Another benefit of the Axon data structure is portability. If you can traverse the Axon graph, you can lower or compile it into a meaningful function or representation, such as ONNX. This is exactly the functionality AxonOnnx provides—you can take a pre-trained model from popular frameworks like PyTorch and TensorFlow, convert them to ONNX, and then import them into Elixir with AxonOnnx.import. For example, you can take any of the ONNX supported models in HuggingFace Transformers and import them in Axon with ease!

Just export the model you want:

$ python -m transformers.onnx --model=bert-base-cased models/

And load it with AxonOnnx:

{bert, bert_params} = AxonOnnx.import("path/to/bert.onnx")

The ability to import and use external models is an absolute must for any Neural network library (especially given the pace of progress in deep learning). AxonOnnx enables Elixir programmers to utilize pre-trained models from the Python ecosystem without needing to implement or train them from scratch.

This also means you can integrate some pretty cool pre-trained models with established projects like Phoenix and LiveView. For example live_onnx, implements a sample ML application using AxonOnnx and LiveView.

You should note that we are still actively working to enable support for all of ONNX’s operations. If you have a model you’d like to see supported, please feel free to open an issue or a PR :)

Future Work

If you look at the issues tracker you’ll notice there’s still much work to be done; however, the core components of Axon are at a stable point. This means you can use Axon with a reasonable expectation of stability. Moving forward, you can expect the following from Axon:

  • First-class transformer model support
  • More integration with Livebook
  • Mixed precision training
  • Multi-device training

Additionally, I’d like to build out a large collection of Axon examples. If you are looking for a place to get started in the Nx ecosystem, please feel free to open a pull request which demonstrates Axon applied to a unique problem set. If you’re looking for inspiration, check out Keras Examples.

Acknowledgements

I am very grateful to DockYard and their support of the Elixir Machine Learning Ecosystem from the beginning. Additionally, Axon would not be where it is today without the hard work of all of the Nx contributors and the individuals Erlang Ecosystem Foundation ML WG. The Elixir community is nothing short of amazing, and I hope Axon can play a small part in seeing the community grow.