Nx Tip of the Week #4 - Using Keywords

Published on

Numerical definitions can only accept tensors or numbers as positional arguments; however, you can get around this inflexibility using keyword lists. You can pass and use optional keyword arguments in your numerical definitions with the keyword! method. Let’s take a look at some ways this might be useful.

Parameter Initializers

In many ML applications, you often start with an initial set of model parameters. Your initial parameters are a starting point for your model and can have a severe impact on whether or not your model converges. The importance of parameter initialization motivates the use of initialization functions outside of the standard random_uniform or random_normal distributions. For example, it’s common to initialize the biases of layers in a neural network to a uniform tensor of zeros:

iex> biases = Nx.broadcast(0.0, {2, 2})
#Nx.Tensor<
  f32[2][2]
  [
    [0.0, 0.0],
    [0.0, 0.0]
  ]
>

But what if you want to generalize this to an initialization function to reuse? Unfortunately, you can’t do:

defn zeros(shape) do
  Nx.broadcast(0.0, shape)
end

iex> zeros({2, 2})
** (ArgumentError) defn functions expects either numbers or tensors as arguments. If you want to pass Elixir values, they need to be sent as options and tagged as default arguments. Got: {2, 2}

You can get around this using optional keyword lists:

defn zeros(opts \\ []) do
  opts = keyword!(opts, [:shape])
  Nx.broadcast(0.0, opts[:shape])
end

iex> zeros(shape: {2, 2})
#Nx.Tensor<
  f32[2][2]
  [
    [0.0, 0.0],
    [0.0, 0.0]
  ]
>

Notice: you must tag keyword lists with a default argument [] and extract or assign values using the keyword! method. keyword! validates the arguments in the passed keyword list, and optionally assigns defaults:

defn zeros(opts \\ []) do
  opts = keyword!(opts, [:shape, type: {:f, 32})
  Nx.broadcast(Nx.tensor(0, type: opts[:type]), opts[:shape])
end

iex> zeros(shape: {2, 2})
#Nx.Tensor<
  f32[2][2]
  [
    [0.0, 0.0],
    [0.0, 0.0]
  ]
>

iex> zeros(shape: {2, 2}, type: {:bf, 16})
#Nx.Tensor<
  bf16[2][2]
  [
    [0.0, 0.0],
    [0.0, 0.0]
  ]
>

Passing Flags

Sometimes, you want to pass and use booleans or atoms to take certain code paths. You can achieve this using transforms and keywords:

defn my_function(x, opts \\ []) do
  opts = keyword!(opts, [mode: :sin])
  transform({x, opts[:mode]},
    fn
      {x, :sin} -> Nx.sin(x)
      {x, :cos} -> Nx.cos(x)
      {x, :tan} -> Nx.tan(x)
    end
  )
end

iex> Test.my_function(1, mode: :sin)
#Nx.Tensor<
  f32
  0.8414709568023682
>
iex> Test.my_function(1, mode: :cos)
#Nx.Tensor<
  f32
  0.5403022766113281
>
iex> Test.my_function(1, mode: :tan)
#Nx.Tensor<
  f32
  1.5574077367782593
>

With booleans, you can use a macro to use your options with if-statements:

defmacro to_predicate(term) do
  quote do
    Nx.Defn.Kernel.transform(
      unquote(term),
      fn term -> if term, do: 1, else: 0 end
    )
  end
end

defn my_function(x, opts \\ []) do
  opts = keyword!(opts, [add_one?: true])
  if to_predicate(opts[:add_one?]) do
    x + 1
  else
    x
  end
end

iex> my_function(1, add_one?: true)
#Nx.Tensor<
  s64
  2
>

iex> my_function(1, add_one?: false)
#Nx.Tensor<
  s64
  1
>

The macro converts true and false to 1 and 0 respectively. Note that defn treats scalar 1 as true and scalar 0 as false (notably different from Elixir). Based on what add_one? is set to, your function at compile-time would look something like:

defn my_function(x) do
  if 1 do
    x + 1
  else
    x
  end
end

Notice the constant 1 is inlined - most tensor compilers would see this and completely optimize the if-statement away. You get the readability without a performance hit. However, there are some things to consider when using keywords.

Performance Considerations

From the example above, you should notice that based on the value of add_one?, you’re actually compiling two different programs. With booleans, this isn’t a big deal - at most you’d get hit with one extra recompilation. Numerical definitions are JIT compiled and cached based on argument shapes (at least with the EXLA compiler) to avoid unnecessary recompilations. Compilation can be expensive, so you’d like to reuse compiled computations as much as possible.

When using keywords, if you have a value that’s constantly changing, you will force recompilation with the new value every time. For example:

defn my_function(opts \\ []) do
  opts = keyword!(opts, [:value])
  opts[:value]
end

iex> my_function(value: 0.1)
#Nx.Tensor<
  f32
  0.1
>

iex> my_function(value: 0.2)
#Nx.Tensor<
  f32
  0.2
>

iex> my_function(value: 0.3)
#Nx.Tensor<
  f32
  0.3
>

Actually compiles and uses 3 different functions! In the above example, this isn’t really that bad as compilation isn’t that expensive. However, you’ll often want to use keywords to pass hyper parameters such as learning rates or decay rates. If you want to change these rates at any point in the training process, you’ll force a recompilation of the computation - based on a somewhat minor part of the function. In these cases, you should consider explicitly passing values as parameters and not using keywords.

I hope this gives you some ways to work around argument limitations in numerical definitions.