Nx Tip of the Week #13 - Hooks

Published on

Part of the restrictiveness of defn is the inability to debug in the same way you would debug a normal Elixir function. I’m personally a big fan of plain old IO.inspect debugging. Because of how defn works, it’s not possible to inspect intermediate tensor values in the same way you would inspect intermediate values in a regular Elixir function. For example, this code:

defn my_function(x) do
  transform(x, fn x ->
    x |> Nx.exp() |> IO.inspect
  end)
end

Would actually just inspect the defn expression and print that to the console. If you don’t have a solid understanding of defn and defn expressions, that’s okay, just understand that because of how defn works, simply calling IO.inspect will not yield the value you’re looking for.

Instead, you need to make use of the built-in hook inspect_value:

defn my_function(x) do
  x |> Nx.exp() |> inspect_value()
end

Like IO.inspect, inspect_value will return the value it’s passed and inspect the input contents to the console. inspect_value is actually built on top of Nx.Defn.Kernel.hook, which allows you to perform side-effecting operations within defn. For example, you can use hooks to log values:

defn my_hooked_function(x) do
  x |> Nx.exp() |> hook(&Logger.info/1)
end

Hooks will yield the intermediate value passed at execution time and execute the given hook function at runtime. If you wanted to pass runtime values to external processes, you can accomplish that with hooks. You can also override hooks at JIT time:

defn my_hooked_function(x) do
  x |> Nx.exp() |> hook(:my_hook)
end

hooks = %{my_hook: &IO.inspect/1)
Nx.Defn.jit(&my_hooked_function/1, [value], hooks: hooks)

Hooked values must be used by a value returned by defn, otherwise the hook will never execute.

Hooks are implemented on top of the Nx.Stream API. We’ll cover Nx.Stream briefly in a later tip.

Until next time!