Axon: Deep Learning in Elixir

Published on

ax·on/ˈakˌsän/noun

the long threadlike part of a nerve cell along which impulses are conducted from the cell body to other cells.

Today I am excited to publicly announce Axon, a library for creating neural networks in Elixir. Axon is still pre-release; however, I believe it’s reached a point where it’s ready for experimentation and input from the open-source community. In this post I will cover Axon’s API, discuss some of the design decisions, and lay out future plans for the library.

Overview

One of Axon’s main goals is to strike a balance between ease-of-use and flexibility. On one end, we hope programmers with zero experience in deep learning find Axon easy to use and approachable to beginners. On the other, we hope experienced practitioners and researchers find Axon to be a productive, flexible, and refreshing, functional take on deep learning frameworks.

Axon consists of the following components:

  1. Functional API - A low-level API of Elixir defn of which all other APIs build on.
  2. Model Creation API - A high-level model creation API which manages model initialization and application.
  3. Optimization API - An API for creating and using first-order optimization techniques.
  4. Training API - An API for quickly training models.

From the beginning, we’ve tried to create abstractions that enable easy integration while maintaining a level of separation between each component. You should be able to use any of the APIs without dependencies on others. By decoupling the APIs in this way, Axon gives you full control over each aspect of creating and training a neural network.

Functional API

At the lowest-level, Axon consists of a number of modules with functional implementations of common methods in deep learning:

  • Axon.Activations - Element-wise activation functions.
  • Axon.Initializers - Model parameter initialization functions.
  • Axon.Layers - Common deep learning layer implementations.
  • Axon.Losses - Common loss functions.
  • Axon.Metrics - Training metrics such as accuracy, absolute error, precision, etc.

This API largely resembles torch.nn.functional from PyTorch or tf.nn in TensorFlow. The functional implementations are very bare, and don’t come with the same conveniences provided in higher level APIs. The lack of conveniences comes with the benefit of maximum flexibility and control over your models and model training.

All of the methods in the functional API are implemented as Nx defn. That means you can use any Nx compiler or backend to accelerate Axon. Additionally, you can arbitrarily compose methods in the Axon functional API with your own numerical definitions. Axon works entirely on Nx tensors, so any library built on top of Nx is likely to integrate well with Axon.

Because Axon’s high-level APIs build on top of the functional API, the same benefits apply. You can use any Nx compiler or backend to accelerate model training or inference, with the possibility for things like AOT compilation.

Model Creation

The goal of the model creation API is to eliminate most of the boilerplate associated with creating, initializing, and applying models. Additionally, we wanted to build and represent models in a way that’s easy to export to other formats such as ONNX or TensorFlow Lite. Axon addresses these issues using an Axon struct:

defstruct [:id, :name, :output_shape, :parent, :op, :params, :opts]

The Axon struct represents a model’s computation graph. Each method in the Axon model creation API accepts an Axon struct and returns a new Axon struct, with input layers representing the “atomic” operations in model creation. An example model looks something like:

model =
  Axon.input({nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

With this approach, you can use regular Elixir functions to represent model building blocks, and compose them anyway you see fit:

defmodule MyModel do
  def residual(x, units) do
    x
    |> Axon.dense(units, activation: :relu)
    |> Axon.add(x)
  end
  
  def model() do
    Axon.input({nil, 784})
    |> Axon.dense(128, activation: :relu)
    |> residual(128)
    |> Axon.dense(10, activation: :softmax)
  end
end

Because the underlying model is just an Elixir struct, model serialization is no harder than traversing the struct and converting Axon nodes into the equivalent in whatever other format you’d like. At the moment, we’re still working out the details of model serialization, and would love feedback on potential needs and use cases.

An added benefit of using a regular Elixir struct is easy customization of model inspection using the Inspect protocol:

model =
  Axon.input({nil, 784})
  |> Axon.dense(128)
  |> Axon.dense(10, activation: :softmax)

IO.inspect model

Outputs:

-----------------------------------------------
                     Model
===============================================
 Layer                 Shape        Parameters
===============================================
 input_1 (input)       {nil, 784}   0
 dense_2 (dense)       {nil, 128}   100480
 dense_3 (dense)       {nil, 10}    1290
 softmax_4 (softmax)   {nil, 10}    0
-----------------------------------------------

Axon provides a few conveniences for working with models. First, we chose to take the philosophy that a model’s only concerns are initialization and application. That means the model shouldn’t be concerned at all with details like training. Axon provides the macros: Axon.init/2 and Axon.predict/4 for initializing and applying models:

model =
  Axon.input({nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

params = Axon.init(model, compiler: EXLA)

Axon.predict(model, params, input, compiler: EXLA)

Both Axon.init/2 and Axon.predict/4 can also be used from anywhere - both inside defn and nested with regular Elixir code:

defmodule MyModel do
  def model() do
    Axon.input({nil, 784})
    |> Axon.dense(128, activation: :relu)
    |> Axon.dropout(rate: 0.5)
    |> Axon.dense(10, activation: :softmax)
  end

  defn loss(params, inputs, targets) do
    preds = Axon.predict(model(), params, inputs) # treated as Expr
    Axon.Losses.categorical_cross_entropy(targets, preds)
  end
end

params = Axon.init(MyModel.model(), compiler: EXLA) # JIT compiled and executed

If you prefer, you can obtain the initialization and application methods yourself using Axon.compile:

model =
  Axon.input({nil, 784})
  |> Axon.dense(128, activation: :relu)
  |> Axon.dropout(rate: 0.5)
  |> Axon.dense(10, activation: :softmax)

{init_fn, predict_fn} = Axon.compile(model)

You should note that in order to accelerate these functions, you’ll need to use Nx.Defn.jit/3.

The model API also makes it easy to apply existing numerical definitions and Nx code at any point with the nx layer:

defmodule MyModel do
  defn mish(x) do
    x * Nx.tanh(Axon.Activations.softplus(x))
  end

  def model() do
    Axon.input({nil, 784})
    |> Axon.dense(128)
    |> Axon.nx(&mish/1)
    |> Axon.nx(fn x -> Nx.max(x, 0) end)
    |> Axon.dense(10, activation: :softmax)
  end
end

Axon currently has support for:

  • Linear layers (dense)
  • Dropout layers (dropout, feature_alpha_dropout, alpha_dropout, spatial_dropout)
  • Convolutional Layers (conv, depthwise_conv, separable_conv2d, separable_conv3d)
  • Normalization Layers (batch_norm, layer_norm, group_norm, instance_norm)
  • Pooling Layers (max_pool, avg_pool, lp_pool, adaptive_max_pool, adaptive_avg_pool)
  • Activation Layers (every function in Axon.Activations)
  • Utilities/combinators (flatten, add, multiply, subtract, concatenate)

with plans to support recurrent layers, attention layers, and many more. Our goal is to maintain an API that is productive, extensible, and on par with other modern deep learning frameworks. If there is functionality you need to see that’s not included on the roadmap, feel free to open an issue.

Model Optimization

Axon’s model optimization API takes the same approach as that taken in DeepMind’s Optax. The goal of the API is to provide low-level constructs for creating advanced optimizers, and then to provide high-level optimizers built on top of that API.

Axon considers optimizers as the tuple: {init_fn/1, update_fn/2}. init_fn/1 accepts a model’s parameters and initializes the optimizer’s state. update_fn/2 accepts “updates” (most commonly gradients in gradient-based optimization) and an optimizer state and returns transformed updates and a new optimizer state.

At the lowest level, Axon implements a number of update functions in Axon.Updates. Each update function acts as a combinator - accepting a tuple of init_fn/1 and update_fn/2 as it’s first argument and returning a tuple of modified init_fn and update_fn. This means you can arbitrarily compose updates to build complex optimizers. As an example, you can use the Axon.Updates API to implement the Adam optimizer like:

def adam(learning_rate, opts \\ []) do
  Axon.Updates.scale_by_adam(opts)
  |> Axon.Updates.scale(-learning_rate)
end

If you find the transformations in Axon.Updates are too high-level, you can implement custom transformations using Axon.Updates.stateful/3 and Axon.Updates.stateless/2. Axon.Updates.stateful/3 represents a stateful transformation:

def scale_by_stddev(combinator \\ identity(), opts) do
  {initial, opts} = Keyword.pop(opts, :initial_scale, 0.0)

  stateful(
    combinator,
    &init_scale_by_stddev(&1, initial),
    &apply_scale_by_stddev(&1, &2, &3, opts)
  )
end

defnp init_scale_by_stddev(params, value) do
  mu = zeros_like(params)
  nu = fulls_like(params, value)
  {mu, nu}
end

defnp apply_scale_by_stddev(x, {mu, nu}, _params, opts \\ []) do
  opts = keyword!(opts, decay: 0.9, eps: 1.0e-8)
  decay = opts[:decay]
  eps = opts[:eps]

  mu = update_moment(x, mu, decay, 1)
  nu = update_moment(x, nu, decay, 2)

  x =
    transform({x, mu, nu, eps}, fn {x, mu, nu, eps} ->
      [Tuple.to_list(x), Tuple.to_list(mu), Tuple.to_list(nu)]
      |> Enum.zip()
      |> Enum.map(fn {g, z, t} -> g * Nx.rsqrt(-Nx.power(z, 2) + t + eps) end)
      |> List.to_tuple()
    end)

  {x, {mu, nu}}
end

Axon.Updates.stateless/2 represents a stateless transformation:

def scale(combinator \\ identity(), step_size) do
  stateless(combinator, &apply_scale(&1, &2, step_size))
end

defnp apply_scale(x, _params, step) do
  transform(
    {x, step},
    fn {updates, step} ->
      updates
      |> Tuple.to_list()
      |> Enum.map(&Nx.multiply(&1, step))
      |> List.to_tuple()
    end
  )
end

You can then compose your custom updates arbitrarily with those in Axon.Updates:

def my_optimizer(learning_rate, opts \\ []) do
  Axon.Updates.scale_by_adam(opts)
  |> scale_by_stddev()
  |> scale(-learning_rate)
end

Axon uses the updates API to build a number of high-level optimizers:

  • Adabelief
  • Adagrad
  • Adam
  • Adamw
  • Fromage
  • Lamb
  • Noisy SGD
  • Radam
  • RMSProp
  • SGD
  • Yogi

It’s important to note that optimization API does not directly depend on Axon models. You can use the API to optimize any differentiable objective function.

In the future, we plan to support integration with learning rate schedules and explore more advanced optimization approaches including differentially-private SGD and second order methods.

Training

The purpose of the training API is to provide conveniences and common routines for implementing training loops. In Dougal MacLaurin’s (creator of AutoGrad alongside Matt Johnson and David Duvenaud) PhD Thesis, he writes:

The goal of Autograd is to make gradients effortless. If you can write a loss function, Autograd should be able to give you its gradient.

He also describes the usefulness of obtaining gradients in optimizing the parameters of an objective function. Essentially, Autograd simplifies the task of writing a machine learning algorithm down to the task of writing a differentiable objective function. This is a philosophy we maintain in Axon’s training API. If you can write a parameterized, differentiable objective function, and pair that with data, you can make use of Axon’s training API.

The API is partly inspired by the excellent PyTorch Lightning library. At the time of this writing, the Axon training API consists of 2 methods: Axon.Training.step and Axon.Training.train. In practice, you can use these methods to train an Axon model like this:

model =
  Axon.input({nil, 784})
  |> Axon.dense(128)
  |> Axon.dense(10, activation: :softmax)

final_params =
  model
  |> Axon.Training.step(:categorical_cross_entropy, Axon.Optimizers.adamw(0.005))
  |> Axon.Training.train(inputs, targets, epochs: 10, compiler: EXLA)

Axon.Training.step represents a single training step. It returns a tuple: {init_fn, step_fn} which represent the training initialization function and training step function respectively. The form of Axon.Training.step/3 in this example is actually just a convenience around Axon.Training.step/2 which accepts two tuples:

step({init_model_fn, objective_fn}, {init_update_fn, update_fn}) :: {init_fn, step_fn}

init_model_fn is combined with init_update_fn into a single initialization function which initializes training state.

objective_fn is a parameterized, differentiable objective function which accepts model parameters, inputs, and labels, and returns a loss. Recall that Axon optimizers are just tuples as well, so both init_update_fn and update_fn have the same form as an Axon optimizer. All of these methods are combined to produce a single step function which is applied on each batch during training.

Axon.Training.train has the following form:

train({init_fn, step_fn}, inputs, targets, opts \\ []) :: model_state

Both inputs and targets are an Enumerable containing batches of input and target tensors respectively. Note, the format of inputs and targets is likely to change in the future as we move towards a unified representation of datasets in the Nx ecosystem.

Axon.Training.train implements a common training loop which initializes the training state and iterates through the training set for some given number of epochs. It returns the final training state for serialization and potential use in inference workloads.

Currently, the Axon training API is rather limited; however, there are plans to extend it. In the immediate future, we plan to support:

Additionally, we would love to explore more advanced things like distributed training. We are also seeking ways to improve the performance of our training loops by running them entirely on native accelerators.

Wrapping Up

Axon is still very young with much work to do before it’s ready to release. You’ll likely encounter sharp edges, bugs, and confusing errors. We would love your help in making Axon better by experimenting with the API, reporting any issues you find, providing feedback on any of our open issues, or contributing to the project.

Finally, I would be remiss if I did not acknowledge some of the libraries that served as inspiration for Axon: